Skip to content

Compile

compile

op_system.compile.

Compile normalized RHS specifications into efficient, runnable callables.

This module is domain-agnostic and intentionally does not import flepimop2.

Contract

  • Accepts NormalizedRhs (from op_system.specs) and produces a CompiledRhs object containing an eval_fn suitable for numerical backends.
  • Raises built-in exceptions with standardized messages.

Security note

This module uses eval() on code objects compiled from user-provided strings. To reduce risk, expressions are parsed and validated with a conservative AST whitelist, and evaluation runs with empty builtins.

BodyEvalFn

Bases: Protocol

Callable that evaluates history signal bodies at (t, y, **params).

Returns a mapping from signal_id to the evaluated body array/value.

CompiledRhs(state_names, param_names, eval_fn, meta=(lambda: MappingProxyType({}))(), operators=tuple(), factorize_axes=tuple(), block_axes=tuple(), pytree_eval_fn=None, template_shapes=None, block_pytree_eval_fn=None, block_template_shapes=None, history_requirements=tuple(), history_eval_fn=None, body_eval_fn=None, block_history_eval_fn=None, block_body_eval_fn=None, _rhs=None) dataclass

Container for a compiled RHS evaluation function.

Instances produced by :func:compile_rhs retain a private reference to their source :class:NormalizedRhs so the container can be pickled and re-hydrated by re-running the compile pipeline on load. eval_fn itself is a closure (and on the vectorized path captures compiled code objects), so it is dropped from the pickle and rebuilt by :func:compile_rhs in :meth:__setstate__. Round-tripping a CompiledRhs therefore costs one compile on load and yields a functionally equivalent instance whose eval_fn produces identical outputs for identical inputs.

__getstate__()

Return picklable state.

The compiled eval_fn is a closure (and on the vectorized path captures compiled :class:types.CodeType objects), which is not portably picklable. Instead we serialize just the source :class:NormalizedRhs and let :meth:__setstate__ recompile.

Raises:

Type Description
TypeError

If the source NormalizedRhs was not retained (i.e. the instance was constructed directly rather than via :func:compile_rhs).

Source code in src/op_system/compile.py
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
def __getstate__(self) -> dict[str, Any]:
    """Return picklable state.

    The compiled ``eval_fn`` is a closure (and on the vectorized path
    captures compiled :class:`types.CodeType` objects), which is not
    portably picklable. Instead we serialize just the source
    :class:`NormalizedRhs` and let :meth:`__setstate__` recompile.

    Raises:
        TypeError: If the source ``NormalizedRhs`` was not retained
            (i.e. the instance was constructed directly rather than
            via :func:`compile_rhs`).
    """
    if self._rhs is None:
        msg = (
            "CompiledRhs is not picklable: the source NormalizedRhs was "
            "not retained. Construct via compile_rhs() to produce a "
            "picklable CompiledRhs."
        )
        raise TypeError(msg)
    return {"_rhs": self._rhs}

__setstate__(state)

Restore by recompiling from the pickled :class:NormalizedRhs.

Source code in src/op_system/compile.py
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
def __setstate__(self, state: Mapping[str, Any]) -> None:
    """Restore by recompiling from the pickled :class:`NormalizedRhs`."""
    rhs = state["_rhs"]
    rebuilt = compile_rhs(rhs)
    # frozen+slots dataclass: bypass __setattr__ via object.__setattr__
    object.__setattr__(self, "state_names", rebuilt.state_names)
    object.__setattr__(self, "param_names", rebuilt.param_names)
    object.__setattr__(self, "eval_fn", rebuilt.eval_fn)
    object.__setattr__(self, "meta", rebuilt.meta)
    object.__setattr__(self, "operators", rebuilt.operators)
    object.__setattr__(self, "factorize_axes", rebuilt.factorize_axes)
    object.__setattr__(self, "block_axes", rebuilt.block_axes)
    object.__setattr__(self, "pytree_eval_fn", rebuilt.pytree_eval_fn)
    object.__setattr__(self, "template_shapes", rebuilt.template_shapes)
    object.__setattr__(self, "block_pytree_eval_fn", rebuilt.block_pytree_eval_fn)
    object.__setattr__(self, "block_template_shapes", rebuilt.block_template_shapes)
    object.__setattr__(self, "history_requirements", rebuilt.history_requirements)
    object.__setattr__(self, "history_eval_fn", rebuilt.history_eval_fn)
    object.__setattr__(self, "body_eval_fn", rebuilt.body_eval_fn)
    object.__setattr__(self, "block_history_eval_fn", rebuilt.block_history_eval_fn)
    object.__setattr__(self, "block_body_eval_fn", rebuilt.block_body_eval_fn)
    object.__setattr__(self, "_rhs", rhs)

bind(params)

Bind parameter values and return a 2-arg RHS: rhs(t, y) -> dydt.

Parameters:

Name Type Description Default
params Mapping[str, object]

Mapping of parameter names to values.

required

Returns:

Type Description
Callable[[object, object], Float64Array]

A callable rhs(t, y) that evaluates the RHS with params fixed.

Source code in src/op_system/compile.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
def bind(
    self, params: Mapping[str, object]
) -> Callable[[object, object], Float64Array]:
    """Bind parameter values and return a 2-arg RHS: rhs(t, y) -> dydt.

    Args:
        params: Mapping of parameter names to values.

    Returns:
        A callable `rhs(t, y)` that evaluates the RHS with `params` fixed.
    """
    params_dict = dict(params)

    def rhs(t: object, y: object) -> Float64Array:
        return self.eval_fn(t, y, **params_dict)

    return rhs

EvalFn

Bases: Protocol

Callable RHS evaluator supporting runtime parameter kwargs.

Accepts a flat (n_state,) state array and returns a flat (n_state,) derivative array in the same array namespace.

HistoryEvalFn

Bases: Protocol

Callable RHS evaluator with history provider support.

Accepts y as a StateDict and a history_provider object implementing a query(signal_id, body, **options) method. Returns a StateDict of derivatives.

PytreeEvalFn

Bases: Protocol

Callable RHS evaluator operating on shaped PyTree state dicts.

Accepts y as a StateDict (mapping from state-template base name to a shaped array with the template's natural N-D shape) and returns a StateDict of the same structure containing the derivative. Enables the engine to skip the flatten/unflatten step entirely and expose the full tensor structure to JAX/XLA.

compile_rhs(rhs, *, xp=None)

Compile a normalized RHS into a runnable evaluation function.

Always uses the vectorized eval path that operates on shaped buffers (one tensor expression per state template) for specs that declare axes. Specs without axes (genuinely scalar models) fall back to the scalar path. Raising :class:UnsupportedFeatureError if an axis-indexed spec cannot be vectorized, rather than silently falling back to the catastrophically slow scalar path.

The returned eval_fn is namespace-polymorphic: it infers its array namespace from the input y at call time (y.__array_namespace__()), and returns arrays in that same namespace. Calling it with JAX arrays (or tracers) yields a JAX-native computation suitable for jax.jit / jax.vmap without any correctness wrapping.

Parameters:

Name Type Description Default
rhs NormalizedRhs

Normalized RHS produced by op_system.specs.normalize_rhs.

required
xp object | None

Deprecated. Formerly the compile-time array backend namespace. Now ignored — the namespace is resolved per call from the input y. Will be removed in a future release.

None

Returns:

Type Description
CompiledRhs

A CompiledRhs containing an eval_fn(t, y, **params) -> dydt.

CompiledRhs

For axis-indexed specs the returned object also carries

CompiledRhs

pytree_eval_fn and template_shapes. If the spec declares

CompiledRhs

axes but the vectorizer cannot build a plan an

CompiledRhs

UnsupportedFeatureError is raised (see bail reason in the detail

CompiledRhs

message) rather than silently degrading to the scalar path.

Source code in src/op_system/compile.py
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
def compile_rhs(rhs: NormalizedRhs, *, xp: object | None = None) -> CompiledRhs:
    """Compile a normalized RHS into a runnable evaluation function.

    Always uses the vectorized eval path that operates on shaped buffers
    (one tensor expression per state template) for specs that declare axes.
    Specs without axes (genuinely scalar models) fall back to the scalar path.
    Raising :class:`UnsupportedFeatureError` if an axis-indexed spec cannot be
    vectorized, rather than silently falling back to the catastrophically slow
    scalar path.

    The returned ``eval_fn`` is **namespace-polymorphic**: it infers its
    array namespace from the input ``y`` at call time
    (``y.__array_namespace__()``), and returns arrays in that same
    namespace. Calling it with JAX arrays (or tracers) yields a JAX-native
    computation suitable for ``jax.jit`` / ``jax.vmap`` without any
    correctness wrapping.

    Args:
        rhs: Normalized RHS produced by `op_system.specs.normalize_rhs`.
        xp: **Deprecated.** Formerly the compile-time array backend
            namespace. Now ignored — the namespace is resolved per call
            from the input ``y``. Will be removed in a future release.

    Returns:
        A `CompiledRhs` containing an `eval_fn(t, y, **params) -> dydt`.
        For axis-indexed specs the returned object also carries
        ``pytree_eval_fn`` and ``template_shapes``.  If the spec declares
        axes but the vectorizer cannot build a plan an
        ``UnsupportedFeatureError`` is raised (see bail reason in the detail
        message) rather than silently degrading to the scalar path.
    """
    _warn_on_deprecated_xp(xp)
    _validate_rhs_type(rhs)

    raw_history_requirements = _history_requirements_from_ir(
        aliases_ir=rhs.aliases_ir,
        equations_ir=rhs.equations_ir,
    )
    _validate_history_kinds(raw_history_requirements)

    vec, plan, eval_fn, pytree_eval_fn, template_shapes = _build_primary_eval_artifacts(
        rhs
    )

    # Apply both time-varying and synth-const wrappers to ``pytree_eval_fn``
    # BEFORE ``_build_history_artifacts`` consumes it.  history_eval_fn /
    # body_eval_fn capture the pytree_eval_fn reference at construction
    # time, so any later re-wrapping would not propagate into them; that
    # would leave runtime history-body calls missing time-varying param
    # slicing and synthesized constants (e.g. __op_system_mask__* one-hot
    # arrays for pinned transition selectors).
    synth_consts: Mapping[str, object] | None = None
    if eval_fn is not None:
        eval_fn, pytree_eval_fn = _wrap_time_varying_artifacts(
            rhs=rhs,
            eval_fn=eval_fn,
            pytree_eval_fn=pytree_eval_fn,
        )
        eval_fn, pytree_eval_fn, synth_consts = _apply_synth_const_wrappers(
            rhs=rhs,
            eval_fn=eval_fn,
            pytree_eval_fn=pytree_eval_fn,
        )

    history_requirements, history_eval_fn, body_eval_fn = _build_history_artifacts(
        rhs=rhs,
        plan=plan,
        pytree_eval_fn=pytree_eval_fn,
    )

    eval_fn = _resolve_eval_fn(
        rhs=rhs,
        eval_fn=eval_fn,
        history_requirements=history_requirements,
    )

    block_axes = analyze_block_axes(rhs)

    # ------------------------------------------------------------------
    # Block-stripped compile: produce a per-block-coord pytree_eval_fn by
    # stripping the first factorize axis from the RHS and re-running the
    # vectorizer.  Engines can jax.vmap this over the block axis instead
    # of baking literal axis indices that break under vmap.
    # ------------------------------------------------------------------
    block_pytree_eval_fn, block_template_shapes = _build_block_pytree_artifacts(
        rhs=rhs,
        vec=vec,
        pytree_eval_fn=pytree_eval_fn,
        block_axes=block_axes,
        synth_consts=synth_consts,
    )

    # Build per-block history / body eval fns when both the block compile and
    # history path succeeded.  These wrap ``block_pytree_eval_fn`` exactly as
    # ``history_eval_fn`` / ``body_eval_fn`` wrap ``pytree_eval_fn``.
    block_history_eval_fn, block_body_eval_fn = _build_block_history_artifacts(
        block_pytree_eval_fn=block_pytree_eval_fn,
        history_requirements=history_requirements,
    )

    return CompiledRhs(
        state_names=rhs.state_names,
        param_names=tuple(rhs.param_names),
        eval_fn=eval_fn,
        meta=rhs.meta,
        operators=_parse_operator_descriptors(rhs.meta),
        factorize_axes=_parse_factorize_axes(rhs.meta),
        block_axes=block_axes,
        pytree_eval_fn=pytree_eval_fn,
        template_shapes=template_shapes,
        block_pytree_eval_fn=block_pytree_eval_fn,
        block_template_shapes=block_template_shapes,
        history_requirements=history_requirements,
        history_eval_fn=history_eval_fn,
        body_eval_fn=body_eval_fn,
        block_history_eval_fn=block_history_eval_fn,
        block_body_eval_fn=block_body_eval_fn,
        _rhs=rhs,
    )