Source code for bayesflow.utils.integrate

from collections.abc import Callable, Sequence
from typing import Dict, Tuple, Optional
from functools import partial

import keras

import numpy as np
from typing import Literal, Union

from bayesflow.adapters import Adapter
from bayesflow.types import Tensor
from bayesflow.utils import filter_kwargs
from bayesflow.utils.logging import warning, debug


ArrayLike = int | float | Tensor
StateDict = Dict[str, ArrayLike]


DETERMINISTIC_METHODS = ["euler", "rk45", "tsit5"]
STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "two_step_adaptive", "langevin"]


def _check_all_nans(state: StateDict):
    all_nans_flags = []
    for v in state.values():
        all_nans_flags.append(keras.ops.all(keras.ops.isnan(v)))
    return keras.ops.all(keras.ops.stack(all_nans_flags))


def euler_step(
    fn: Callable,
    state: StateDict,
    time: ArrayLike,
    step_size: ArrayLike,
) -> Tuple[StateDict, ArrayLike, None, ArrayLike]:
    k1 = fn(time, **filter_kwargs(state, fn))

    new_state = state.copy()
    for key in k1.keys():
        new_state[key] = state[key] + step_size * k1[key]
    new_time = time + step_size

    return new_state, new_time, None, 0.0


def add_scaled(state, ks, coeffs, h):
    out = {}
    for key, y in state.items():
        acc = keras.ops.zeros_like(y)
        for c, k in zip(coeffs, ks):
            acc = acc + c * k[key]
        out[key] = y + h * acc
    return out


def rk45_step(
    fn: Callable,
    state: StateDict,
    time: ArrayLike,
    step_size: ArrayLike,
    k1: StateDict = None,
    use_adaptive_step_size: bool = True,
) -> Tuple[StateDict, ArrayLike, StateDict | None, ArrayLike]:
    """
    Dormand-Prince 5(4) method with embedded error estimation [1].

    Dormand (1996), Numerical Methods for Differential Equations: A Computational Approach
    """
    h = step_size

    if k1 is None:  # reuse k1 if available
        k1 = fn(time, **filter_kwargs(state, fn))
    k2 = fn(time + h * (1 / 5), **filter_kwargs(add_scaled(state, [k1], [1 / 5], h), fn))
    k3 = fn(time + h * (3 / 10), **filter_kwargs(add_scaled(state, [k1, k2], [3 / 40, 9 / 40], h), fn))
    k4 = fn(time + h * (4 / 5), **filter_kwargs(add_scaled(state, [k1, k2, k3], [44 / 45, -56 / 15, 32 / 9], h), fn))
    k5 = fn(
        time + h * (8 / 9),
        **filter_kwargs(
            add_scaled(state, [k1, k2, k3, k4], [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], h), fn
        ),
    )
    k6 = fn(
        time + h,
        **filter_kwargs(
            add_scaled(state, [k1, k2, k3, k4, k5], [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], h),
            fn,
        ),
    )

    # 5th order solution
    new_state = {}
    for key in k1.keys():
        new_state[key] = state[key] + h * (
            35 / 384 * k1[key] + 500 / 1113 * k3[key] + 125 / 192 * k4[key] - 2187 / 6784 * k5[key] + 11 / 84 * k6[key]
        )

    new_time = time + h
    if not use_adaptive_step_size:
        return new_state, new_time, None, 0.0

    k7 = fn(time + h, **filter_kwargs(new_state, fn))

    # 4th order embedded solution
    err_state = {}
    for key in k1.keys():
        y4 = state[key] + h * (
            5179 / 57600 * k1[key]
            + 7571 / 16695 * k3[key]
            + 393 / 640 * k4[key]
            - 92097 / 339200 * k5[key]
            + 187 / 2100 * k6[key]
            + 1 / 40 * k7[key]
        )
        err_state[key] = new_state[key] - y4

    err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()])
    err = keras.ops.max(err_norm)

    return new_state, new_time, k7, err


def tsit5_step(
    fn: Callable,
    state: StateDict,
    time: ArrayLike,
    step_size: ArrayLike,
    k1: StateDict = None,
    use_adaptive_step_size: bool = True,
) -> Tuple[StateDict, ArrayLike, StateDict | None, ArrayLike]:
    """
    Implements a single step of the Tsitouras 5/4 Runge-Kutta method [1].

    [1] Tsitouras (2011), Runge--Kutta pairs of order 5(4) satisfying only the first column simplifying assumption
    """
    h = step_size

    # Butcher tableau coefficients
    c2 = 0.161
    c3 = 0.327
    c4 = 0.9
    c5 = 0.9800255409045097

    if k1 is None:  # reuse k1 if available
        k1 = fn(time, **filter_kwargs(state, fn))
    k2 = fn(time + h * c2, **filter_kwargs(add_scaled(state, [k1], [0.161], h), fn))
    k3 = fn(
        time + h * c3, **filter_kwargs(add_scaled(state, [k1, k2], [-0.0084806554923570, 0.3354806554923570], h), fn)
    )
    k4 = fn(
        time + h * c4,
        **filter_kwargs(
            add_scaled(state, [k1, k2, k3], [2.897153057105494, -6.359448489975075, 4.362295432869581], h), fn
        ),
    )
    k5 = fn(
        time + h * c5,
        **filter_kwargs(
            add_scaled(
                state,
                [k1, k2, k3, k4],
                [5.325864828439257, -11.74888356406283, 7.495539342889836, -0.09249506636175525],
                h,
            ),
            fn,
        ),
    )
    k6 = fn(
        time + h,
        **filter_kwargs(
            add_scaled(
                state,
                [k1, k2, k3, k4, k5],
                [5.86145544294270, -12.92096931784711, 8.159367898576159, -0.07158497328140100, -0.02826905039406838],
                h,
            ),
            fn,
        ),
    )

    # 5th order solution: b coefficients
    new_state = {}
    for key in state.keys():
        new_state[key] = state[key] + h * (
            0.09646076681806523 * k1[key]
            + 0.01 * k2[key]
            + 0.4798896504144996 * k3[key]
            + 1.379008574103742 * k4[key]
            - 3.290069515436081 * k5[key]
            + 2.324710524099774 * k6[key]
        )

    new_time = time + h
    if not use_adaptive_step_size:
        return new_state, new_time, None, 0.0

    k7 = fn(time + h, **filter_kwargs(new_state, fn))

    err_state = {}
    for key in state.keys():
        err_state[key] = h * (
            0.007880878010261995 * k3[key]
            + 0.5823571654525552 * k5[key]
            + 0.015151515151515152 * k7[key]
            - 0.00178001105222577714 * k1[key]
            - 0.0008164344596567469 * k2[key]
            - 0.1447110071732629 * k4[key]
            - 0.45808210592918697 * k6[key]
        )

    err_norm = keras.ops.stack([keras.ops.norm(v, ord=2, axis=-1) for v in err_state.values()])
    err = keras.ops.max(err_norm)

    return new_state, new_time, k7, err


def integrate_fixed(
    fn: Callable,
    state: StateDict,
    start_time: ArrayLike,
    stop_time: ArrayLike,
    steps: int,
    method: str,
    **kwargs,
) -> StateDict:
    if steps <= 0:
        raise ValueError("Number of steps must be positive.")

    match method:
        case "euler":
            step_fn = partial(euler_step, fn, **filter_kwargs(kwargs, rk45_step))
        case "rk45":
            step_fn = partial(rk45_step, fn, **filter_kwargs(kwargs, rk45_step), use_adaptive_step_size=False)
        case "tsit5":
            step_fn = partial(tsit5_step, fn, **filter_kwargs(kwargs, rk45_step), use_adaptive_step_size=False)
        case str() as name:
            raise ValueError(f"Unknown integration method name: {name!r}")
        case other:
            raise TypeError(f"Invalid integration method: {other!r}")

    step_size = (stop_time - start_time) / steps

    def body(_loop_var, _loop_state):
        _state, _time = _loop_state
        _state, _time, _, _ = step_fn(_state, _time, step_size)
        return _state, _time

    state, _ = keras.ops.fori_loop(
        0,
        steps,
        body,
        (state, start_time),
    )
    return state


def integrate_scheduled(
    fn: Callable,
    state: StateDict,
    steps: Tensor | np.ndarray,
    method: str,
    **kwargs,
) -> StateDict:
    match method:
        case "euler":
            step_fn = partial(euler_step, fn, **filter_kwargs(kwargs, rk45_step))
        case "rk45":
            step_fn = partial(rk45_step, fn, **filter_kwargs(kwargs, rk45_step), use_adaptive_step_size=False)
        case "tsit5":
            step_fn = partial(tsit5_step, fn, **filter_kwargs(kwargs, rk45_step), use_adaptive_step_size=False)
        case str() as name:
            raise ValueError(f"Unknown integration method name: {name!r}")
        case other:
            raise TypeError(f"Invalid integration method: {other!r}")

    def body(_loop_var, _loop_state):
        _time = steps[_loop_var]
        step_size = steps[_loop_var + 1] - steps[_loop_var]
        _loop_state, _, _, _ = step_fn(_loop_state, _time, step_size)
        return _loop_state

    state = keras.ops.fori_loop(
        0,
        keras.ops.shape(steps)[0] - 1,
        body,
        state,
    )
    return state


def integrate_adaptive(
    fn: Callable,
    state: dict[str, ArrayLike],
    start_time: ArrayLike,
    stop_time: ArrayLike,
    min_steps: int,
    max_steps: int,
    method: str,
    **kwargs,
) -> dict[str, ArrayLike]:
    if max_steps <= min_steps:
        raise ValueError("Maximum number of steps must be greater than minimum number of steps.")

    match method:
        case "rk45":
            step_fn = partial(rk45_step, fn, **filter_kwargs(kwargs, rk45_step), use_adaptive_step_size=True)
        case "tsit5":
            step_fn = partial(tsit5_step, fn, **filter_kwargs(kwargs, rk45_step), use_adaptive_step_size=True)
        case "euler":
            raise ValueError("Adaptive step sizing is not supported for the 'euler' method.")
        case str() as name:
            raise ValueError(f"Unknown integration method name: {name!r}")
        case other:
            raise TypeError(f"Invalid integration method: {other!r}")

    atol = keras.ops.convert_to_tensor(kwargs.get("atol", 1e-6), dtype="float32")
    rtol = keras.ops.convert_to_tensor(kwargs.get("rtol", 1e-4), dtype="float32")
    initial_step = keras.ops.convert_to_tensor((stop_time - start_time) / float(min_steps), dtype="float32")
    step0 = keras.ops.convert_to_tensor(0.0, dtype="float32")
    count_not_accepted = 0

    # "First Same As Last" (FSAL) property
    k1_0 = fn(start_time, **filter_kwargs(state, fn))

    def cond(_state, _time, _step_size, _step, _k1, _count_not_accepted):
        time_remaining = keras.ops.sign(stop_time - start_time) * (stop_time - (_time + _step_size))
        step_lt_min = keras.ops.less(_step, float(min_steps))
        step_lt_max = keras.ops.less(_step, float(max_steps))

        all_nans = _check_all_nans(_state)

        end_now = keras.ops.logical_or(
            step_lt_min, keras.ops.logical_and(keras.ops.all(time_remaining > 0), step_lt_max)
        )
        return keras.ops.logical_and(~all_nans, end_now)

    def body(_state, _time, _step_size, _step, _k1, _count_not_accepted):
        # Time remaining from current point
        time_remaining = keras.ops.abs(stop_time - _time)
        min_step_size = time_remaining / (max_steps - _step)
        max_step_size = time_remaining / keras.ops.maximum(min_steps - _step, 1.0)
        h = keras.ops.sign(_step_size) * keras.ops.clip(keras.ops.abs(_step_size), min_step_size, max_step_size)

        # Take one trial step
        new_state, new_time, new_k1, err = step_fn(
            state=_state,
            time=_time,
            step_size=h,
            k1=_k1,
        )

        # New step size suggestion
        max_abs = None
        for k, v in _state.items():
            m = keras.ops.max(keras.ops.abs(v))
            max_abs = m if max_abs is None else keras.ops.maximum(max_abs, m)
        scale = atol + rtol * max_abs
        error_ratio = err / scale
        new_step_size = h * keras.ops.clip(0.9 * (1.0 / (error_ratio + 1e-12)) ** 0.2, 0.2, 5.0)
        new_step_size = keras.ops.sign(new_step_size) * keras.ops.clip(
            keras.ops.abs(new_step_size), min_step_size, max_step_size
        )

        # Error control
        too_big = keras.ops.greater(error_ratio, 1.0)
        at_min = keras.ops.less_equal(
            keras.ops.abs(h),
            keras.ops.abs(min_step_size),
        )
        accepted = keras.ops.logical_or(keras.ops.logical_not(too_big), at_min)

        updated_state = keras.ops.cond(accepted, lambda: new_state, lambda: _state)
        updated_time = keras.ops.cond(accepted, lambda: new_time, lambda: _time)
        updated_k1 = keras.ops.cond(accepted, lambda: new_k1, lambda: _k1)

        # Step counter: increment only on accepted steps
        updated_step = _step + keras.ops.where(accepted, 1.0, 0.0)
        _count_not_accepted = _count_not_accepted + keras.ops.where(accepted, 0.0, 1.0)

        # For the next iteration, always use the new suggested step size
        return updated_state, updated_time, new_step_size, updated_step, updated_k1, _count_not_accepted

    # Run the adaptive loop
    state, time, step_size, step, k1, count_not_accepted = keras.ops.while_loop(
        cond,
        body,
        [state, start_time, initial_step, step0, k1_0, count_not_accepted],
    )

    if _check_all_nans(state):
        raise RuntimeError(f"All values are NaNs in state during integration at {time}.")

    # Final step to hit stop_time exactly
    time_diff = stop_time - time
    time_remaining = keras.ops.sign(stop_time - start_time) * time_diff
    if keras.ops.all(time_remaining > 0):
        state, time, _, _ = step_fn(
            state=state,
            time=time,
            step_size=time_diff,
            k1=k1,
        )
        step = step + 1.0

    debug(f"Finished integration after {step} steps with {count_not_accepted} rejected steps.")
    return state


def integrate_scipy(
    fn: Callable,
    state: StateDict,
    start_time: ArrayLike,
    stop_time: ArrayLike,
    scipy_kwargs: dict | None = None,
) -> StateDict:
    import scipy.integrate

    scipy_kwargs = scipy_kwargs or {}
    keys = list(state.keys())
    # convert to tensor before determining the shape in case a number was passed
    shapes = keras.tree.map_structure(lambda x: keras.ops.shape(keras.ops.convert_to_tensor(x)), state)
    adapter = Adapter().concatenate(keys, into="x", axis=-1).convert_dtype(np.float32, np.float64)

    def state_to_vector(state):
        state = keras.tree.map_structure(keras.ops.convert_to_numpy, state)
        # flatten state
        state = keras.tree.map_structure(lambda x: keras.ops.reshape(x, (-1,)), state)
        # apply concatenation
        x = adapter.forward(state)["x"]
        return x

    def vector_to_state(x):
        state = adapter.inverse({"x": x})
        state = {key: keras.ops.reshape(value, shapes[key]) for key, value in state.items()}
        state = keras.tree.map_structure(keras.ops.convert_to_tensor, state)
        return state

    def scipy_wrapper_fn(time, x):
        state = vector_to_state(x)
        time = keras.ops.convert_to_tensor(time, dtype="float32")
        deltas = fn(time, **filter_kwargs(state, fn))
        return state_to_vector(deltas)

    result = scipy.integrate.solve_ivp(
        scipy_wrapper_fn,
        (start_time, stop_time),
        state_to_vector(state),
        **scipy_kwargs,
    )

    result = vector_to_state(result.y[:, -1])
    return result


[docs] def integrate( fn: Callable, state: StateDict, start_time: ArrayLike | None = None, stop_time: ArrayLike | None = None, min_steps: int = 50, max_steps: int = 10_000, steps: int | Literal["adaptive"] | Tensor | np.ndarray = 100, method: str = "rk45", **kwargs, ) -> StateDict: if isinstance(steps, str) and steps in ["adaptive", "dynamic"]: if start_time is None or stop_time is None: raise ValueError( "Please provide start_time and stop_time for the integration, was " f"'start_time={start_time}', 'stop_time={stop_time}'." ) if method == "scipy": if min_steps != 10: warning("Setting min_steps has no effect for method 'scipy'") if max_steps != 10_000: warning("Setting max_steps has no effect for method 'scipy'") return integrate_scipy(fn, state, start_time, stop_time) return integrate_adaptive(fn, state, start_time, stop_time, min_steps, max_steps, method, **kwargs) elif isinstance(steps, int): if start_time is None or stop_time is None: raise ValueError( "Please provide start_time and stop_time for the integration, was " f"'start_time={start_time}', 'stop_time={stop_time}'." ) return integrate_fixed(fn, state, start_time, stop_time, steps, method, **kwargs) elif isinstance(steps, Sequence) or isinstance(steps, np.ndarray) or keras.ops.is_tensor(steps): return integrate_scheduled(fn, state, steps, method, **kwargs) else: raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})")
############ SDE Solvers ############# def generate_noise(z: StateDict, seed: keras.random.SeedGenerator) -> StateDict: noise = { k: keras.random.normal(keras.ops.shape(val), dtype=keras.ops.dtype(val), seed=seed) for k, val in z.items() } return noise def stochastic_adaptive_step_size_controller( state, drift, adaptive_factor: ArrayLike, min_step_size: float = -float("inf"), max_step_size: float = float("inf"), ) -> ArrayLike: """ Adaptive step size controller based on [1]. Similar to a tamed explicit Euler method when used in Euler-Maruyama. Adaptive step sizing uses: h = max(1, ||x||**2) / max(1, ||f(x)||**2) * adaptive_factor [1] Fang & Giles, Adaptive Euler-Maruyama Method for SDEs with Non-Globally Lipschitz Drift Coefficients (2020) Returns ------- New step size. """ state_norms = [] drift_norms = [] for key in state.keys(): state_norms.append(keras.ops.norm(state[key], ord=2, axis=-1)) drift_norms.append(keras.ops.norm(drift[key], ord=2, axis=-1)) state_norm = keras.ops.stack(state_norms) drift_norm = keras.ops.stack(drift_norms) max_state_norm = keras.ops.maximum( keras.ops.cast(1.0, dtype=keras.ops.dtype(state_norm)), keras.ops.max(state_norm) ** 2 ) max_drift_norm = keras.ops.maximum( keras.ops.cast(1.0, dtype=keras.ops.dtype(drift_norm)), keras.ops.max(drift_norm) ** 2 ) new_step_size = max_state_norm / max_drift_norm * adaptive_factor new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size) return new_step_size def euler_maruyama_step( drift_fn: Callable, diffusion_fn: Callable, state: StateDict, time: ArrayLike, step_size: ArrayLike, noise: StateDict, use_adaptive_step_size: bool = False, min_step_size: float = -float("inf"), max_step_size: float = float("inf"), ) -> Union[Tuple[StateDict, ArrayLike, ArrayLike], Tuple[StateDict, ArrayLike, ArrayLike, StateDict]]: """ Performs a single Euler-Maruyama step for stochastic differential equations. Args: drift_fn: Function computing the drift term f(t, **state). diffusion_fn: Function computing the diffusion term g(t, **state). state: Current state, mapping variable names to tensors. time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. use_adaptive_step_size: Whether to use adaptive step sizing. min_step_size: Minimum allowed step size. max_step_size: Maximum allowed step size. Returns: new_state: Updated state after one Euler-Maruyama step. new_time: time + dt. """ # Compute drift and diffusion drift = drift_fn(time, **filter_kwargs(state, drift_fn)) diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) new_step_size = step_size if use_adaptive_step_size: sign_step = keras.ops.sign(step_size) new_step_size = stochastic_adaptive_step_size_controller( state=state, drift=drift, adaptive_factor=max_step_size, min_step_size=min_step_size, max_step_size=max_step_size, ) new_step_size = sign_step * keras.ops.abs(new_step_size) sqrt_step_size = keras.ops.sqrt(keras.ops.abs(new_step_size)) new_state = {} for key, d in drift.items(): base = state[key] + new_step_size * d if key in diffusion: base = base + diffusion[key] * sqrt_step_size * noise[key] new_state[key] = base if use_adaptive_step_size: return new_state, time + new_step_size, new_step_size, state return new_state, time + new_step_size, new_step_size def two_step_adaptive_step( drift_fn: Callable, diffusion_fn: Callable, state: StateDict, time: ArrayLike, step_size: ArrayLike, noise: StateDict, last_state: StateDict = None, use_adaptive_step_size: bool = True, min_step_size: float = -float("inf"), max_step_size: float = float("inf"), e_rel: float = 0.1, e_abs: float = None, r: float = 0.9, adapt_safety: float = 0.9, ) -> Union[ Tuple[StateDict, ArrayLike, ArrayLike], Tuple[StateDict, ArrayLike, ArrayLike, StateDict], ]: """ Performs a single adaptive step for stochastic differential equations based on [1]. Based on This method uses a predictor-corrector approach with error estimation: 1. Take an Euler-Maruyama step (predictor) 2. Take another Euler-Maruyama step from the predicted state 3. Average the two predictions (corrector) 4. Estimate error and adapt step size When step_size reaches min_step_size, steps are always accepted regardless of error to ensure progress and termination within max_steps. [1] Jolicoeur-Martineau et al. (2021) "Gotta Go Fast When Generating Data with Score-Based Models" Args: drift_fn: Function computing the drift term f(t, **state). diffusion_fn: Function computing the diffusion term g(t, **state). state: Current state, mapping variable names to tensors. time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors (pre-scaled by sqrt(dt)). last_state: Previous state for error estimation. use_adaptive_step_size: Whether to adapt step size. min_step_size: Minimum allowed step size. max_step_size: Maximum allowed step size. e_rel: Relative error tolerance. e_abs: Absolute error tolerance. Default assumes standardized targets. r: Order of the method for step size adaptation. adapt_safety: Safety factor for step size adaptation. Returns: new_state: Updated state after one adaptive step. new_time: time + dt (or time if step rejected). new_step_size: Adapted step size for next iteration. """ state_euler, time_mid, _ = euler_maruyama_step( drift_fn=drift_fn, diffusion_fn=diffusion_fn, state=state, time=time, step_size=step_size, min_step_size=min_step_size, max_step_size=max_step_size, noise=noise, use_adaptive_step_size=False, ) # Compute drift and diffusion at new state, but update from old state drift_mid = drift_fn(time_mid, **filter_kwargs(state_euler, drift_fn)) diffusion_mid = diffusion_fn(time_mid, **filter_kwargs(state_euler, diffusion_fn)) sqrt_step_size = keras.ops.sqrt(keras.ops.abs(step_size)) state_euler_mid = {} for key, d in drift_mid.items(): base = state[key] + step_size * d if key in diffusion_mid: base = base + diffusion_mid[key] * sqrt_step_size * noise[key] state_euler_mid[key] = base # average the two predictions state_heun = {} for key in state.keys(): state_heun[key] = 0.5 * (state_euler[key] + state_euler_mid[key]) # Error estimation if use_adaptive_step_size: if e_abs is None: e_abs = 0.02576 # 1% of 99% CI of standardized unit variance # Check if we're at minimum step size - if so, force acceptance at_min_step = keras.ops.less_equal(keras.ops.abs(step_size), min_step_size) # Compute error tolerance for each component e_abs_tensor = keras.ops.cast(e_abs, dtype=keras.ops.dtype(list(state.values())[0])) e_rel_tensor = keras.ops.cast(e_rel, dtype=keras.ops.dtype(list(state.values())[0])) max_error = keras.ops.cast(0.0, dtype=keras.ops.dtype(list(state.values())[0])) for key in state.keys(): # Local error estimate: difference between Heun and first Euler step error_estimate = keras.ops.abs(state_heun[key] - state_euler[key]) # Tolerance threshold delta = keras.ops.maximum( e_abs_tensor, e_rel_tensor * keras.ops.maximum(keras.ops.abs(state_euler[key]), keras.ops.abs(last_state[key])), ) # Normalized error normalized_error = error_estimate / (delta + 1e-10) # Maximum error across all components and batch dimensions component_max_error = keras.ops.max(normalized_error) max_error = keras.ops.maximum(max_error, component_max_error) error_scale = 1 # 1/sqrt(n_params) E2 = error_scale * max_error # Accept step if error is acceptable OR if at minimum step size error_acceptable = keras.ops.less_equal(E2, keras.ops.cast(1.0, dtype=keras.ops.dtype(E2))) accepted = keras.ops.logical_or(error_acceptable, at_min_step) # Adapt step size for next iteration (only if not at minimum) # Ensure E2 is not zero to avoid division issues E2_safe = keras.ops.maximum(E2, 1e-10) # New step size based on error estimate adapt_factor = adapt_safety * keras.ops.power(E2_safe, -r) new_step_candidate = step_size * adapt_factor # Clamp to valid range new_step_size = keras.ops.clip(keras.ops.abs(new_step_candidate), min_step_size, max_step_size) new_step_size = keras.ops.sign(step_size) * new_step_size # Return appropriate state based on acceptance new_state = keras.ops.cond(accepted, lambda: state_heun, lambda: state) new_time = keras.ops.cond(accepted, lambda: time_mid, lambda: time) prev_state = keras.ops.cond(accepted, lambda: state_euler, lambda: state) return new_state, new_time, new_step_size, prev_state else: return state_heun, time_mid, step_size def compute_levy_area( state: StateDict, diffusion: StateDict, noise: StateDict, noise_aux: StateDict, step_size: ArrayLike ) -> StateDict: step_size_abs = keras.ops.abs(step_size) sqrt_step_size = keras.ops.sqrt(step_size_abs) inv_sqrt3 = keras.ops.cast(1.0 / np.sqrt(3.0), dtype=keras.ops.dtype(step_size_abs)) # Build Lévy area H_k from w_k and Z_k H = {} for k in state.keys(): if k in diffusion: term1 = 0.5 * step_size_abs * noise[k] term2 = 0.5 * step_size_abs * sqrt_step_size * inv_sqrt3 * noise_aux[k] H[k] = term1 + term2 else: H[k] = keras.ops.zeros_like(state[k]) return H def sea_step( drift_fn: Callable, diffusion_fn: Callable, state: StateDict, time: ArrayLike, step_size: ArrayLike, noise: StateDict, # standard normals noise_aux: StateDict, # standard normals ) -> Tuple[StateDict, ArrayLike, ArrayLike]: """ Performs a single shifted Euler step for SDEs with additive noise [1]. Compared to Euler-Maruyama, this evaluates the drift at a shifted state, which improves the local error and the global error constant for additive noise. The scheme is X_{n+1} = X_n + f(t_n, X_n + g(t_n) * (0.5 * ΔW_n + ΔH_n) * h + g(t_n) * ΔW_n [1] Foster et al., "High order splitting methods for SDEs satisfying a commutativity condition" (2023) Args: drift_fn: Function computing the drift term f(t, **state). diffusion_fn: Function computing the diffusion term g(t, **state). state: Current state, mapping variable names to tensors. time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. noise_aux: Mapping of variable names to auxiliary noise. Returns: new_state: Updated state after one SEA step. new_time: time + dt. """ # Compute diffusion diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn)) sqrt_step_size = keras.ops.sqrt(keras.ops.abs(step_size)) la = compute_levy_area(state=state, diffusion=diffusion, noise=noise, noise_aux=noise_aux, step_size=step_size) # Build shifted state: X_shift = X + g * (0.5 * ΔW + ΔH) shifted_state = {} for key, x in state.items(): if key in diffusion: shifted_state[key] = x + diffusion[key] * (0.5 * sqrt_step_size * noise[key] + la[key]) else: shifted_state[key] = x # Drift evaluated at shifted state drift_shifted = drift_fn(time, **filter_kwargs(shifted_state, drift_fn)) # Final update new_state = {} for key, d in drift_shifted.items(): base = state[key] + step_size * d if key in diffusion: base = base + diffusion[key] * sqrt_step_size * noise[key] new_state[key] = base return new_state, time + step_size, step_size def shark_step( drift_fn: Callable, diffusion_fn: Callable, state: StateDict, time: ArrayLike, step_size: ArrayLike, noise: StateDict, noise_aux: StateDict, ) -> Tuple[StateDict, ArrayLike, ArrayLike]: """ Shifted Additive noise Runge Kutta (SHARK) for additive SDEs [1]. Makes two evaluations of the drift and diffusion per step and has a strong order 1.5. SHARK method as specified: 1) ỹ_k = y_k + g(y_k) H_k 2) ỹ_{k+5/6} = ỹ_k + (5/6)[ f(ỹ_k) h + g(ỹ_k) W_k ] 3) y_{k+1} = y_k + (2/5) f(ỹ_k) h + (3/5) f(ỹ_{k+5/6}) h + g(ỹ_k) ( 2/5 W_k + 6/5 H_k ) + g(ỹ_{k+5/6}) ( 3/5 W_k - 6/5 H_k ) with H_k = 0.5 * |h| * W_k + (|h| ** 1.5) / (2 * sqrt(3)) * Z_k [1] Foster et al., "High order splitting methods for SDEs satisfying a commutativity condition" (2023) Args: drift_fn: Function computing the drift term f(t, **state). diffusion_fn: Function computing the diffusion term g(t, **state). state: Current state, mapping variable names to tensors. time: Current time scalar tensor. step_size: Time increment dt. noise: Mapping of variable names to dW noise tensors. noise_aux: Mapping of variable names to auxiliary noise. Returns: new_state: Updated state after one SHARK step. new_time: time + dt. """ h = step_size t = time h_mag = keras.ops.abs(h) sqrt_h_mag = keras.ops.sqrt(h_mag) diffusion = diffusion_fn(t, **filter_kwargs(state, diffusion_fn)) la = compute_levy_area(state=state, diffusion=diffusion, noise=noise, noise_aux=noise_aux, step_size=step_size) # === 1) shifted initial state === y_tilde_k = {} for k in state.keys(): if k in diffusion: y_tilde_k[k] = state[k] + diffusion[k] * la[k] else: y_tilde_k[k] = state[k] # === evaluate drift and diffusion at ỹ_k === f_tilde_k = drift_fn(t, **filter_kwargs(y_tilde_k, drift_fn)) g_tilde_k = diffusion_fn(t, **filter_kwargs(y_tilde_k, diffusion_fn)) # === 2) internal stage at 5/6 === y_tilde_mid = {} for k in state.keys(): drift_part = (5.0 / 6.0) * f_tilde_k[k] * h if k in g_tilde_k: sto_part = (5.0 / 6.0) * g_tilde_k[k] * sqrt_h_mag * noise[k] else: sto_part = keras.ops.zeros_like(state[k]) y_tilde_mid[k] = y_tilde_k[k] + drift_part + sto_part # === evaluate drift and diffusion at ỹ_(k+5/6) === f_tilde_mid = drift_fn(t + 5.0 / 6.0 * h, **filter_kwargs(y_tilde_mid, drift_fn)) g_tilde_mid = diffusion_fn(t + 5.0 / 6.0 * h, **filter_kwargs(y_tilde_mid, diffusion_fn)) # === 3) final update === new_state = {} for k in state.keys(): # deterministic weights det = state[k] + (2.0 / 5.0) * f_tilde_k[k] * h + (3.0 / 5.0) * f_tilde_mid[k] * h # stochastic parts sto1 = ( g_tilde_k[k] * ((2.0 / 5.0) * sqrt_h_mag * noise[k] + (6.0 / 5.0) * la[k]) if k in g_tilde_k else keras.ops.zeros_like(det) ) sto2 = ( g_tilde_mid[k] * ((3.0 / 5.0) * sqrt_h_mag * noise[k] - (6.0 / 5.0) * la[k]) if k in g_tilde_mid else keras.ops.zeros_like(det) ) new_state[k] = det + sto1 + sto2 return new_state, t + h, h def _apply_corrector( new_state: StateDict, new_time: ArrayLike, i: ArrayLike, corrector_steps: int, score_fn: Optional[Callable], corrector_noise_history: StateDict | None, seed: keras.random.SeedGenerator, step_size_factor: ArrayLike = 0.01, noise_schedule=None, ) -> StateDict: """Helper function to apply corrector steps [1]. [1] Song et al., "Score-Based Generative Modeling through Stochastic Differential Equations" (2020) """ for j in range(corrector_steps): score = score_fn(new_time, **filter_kwargs(new_state, score_fn)) if corrector_noise_history is None: _z_corr = generate_noise(new_state, seed=seed) else: _z_corr = {k: val[i, j] for k, val in corrector_noise_history.items()} log_snr_t = noise_schedule.get_log_snr(t=new_time, training=False) alpha_t, _ = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) for k in new_state.keys(): if k in score: # Calculate required norms for Langevin step z_norm = keras.ops.norm(_z_corr[k], axis=-1, keepdims=True) score_norm = keras.ops.norm(score[k], axis=-1, keepdims=True) score_norm = keras.ops.maximum(score_norm, 1e-8) # Compute step size for the Langevin update e = 2.0 * alpha_t * (step_size_factor * z_norm / score_norm) ** 2 # Annealed Langevin Dynamics update new_state[k] = new_state[k] + e * score[k] + keras.ops.sqrt(2.0 * e) * _z_corr[k] return new_state def integrate_stochastic_fixed( step_fn: Callable, state: StateDict, start_time: ArrayLike, stop_time: ArrayLike, steps: int, min_step_size: ArrayLike, max_step_size: ArrayLike, z_history: StateDict | None, z_extra_history: StateDict | None, score_fn: Optional[Callable], step_size_factor: ArrayLike, corrector_noise_history: StateDict | None, seed: keras.random.SeedGenerator, corrector_steps: int = 0, noise_schedule=None, ) -> StateDict: """ Performs fixed-step SDE integration. """ initial_step = (stop_time - start_time) / float(steps) def cond(_loop_var, _loop_state, _loop_time, _loop_step): all_nans = _check_all_nans(_loop_state) end_now = keras.ops.less(_loop_var, steps) return keras.ops.logical_and(~all_nans, end_now) def body(_i, _current_state, _current_time, _current_step): # Determine step size: either the constant size or the remainder to reach stop_time remaining = keras.ops.abs(stop_time - _current_time) sign = keras.ops.sign(_current_step) dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), remaining) dt = sign * dt_mag # Generate noise increment if z_history is None: _noise_i = generate_noise(_current_state, seed=seed) else: _noise_i = {k: val[_i] for k, val in z_history.items()} _noise_extra_i = None if z_extra_history is not None: if len(z_extra_history) == 0: _noise_extra_i = generate_noise(_current_state, seed=seed) else: _noise_extra_i = {k: val[_i] for k, val in z_history.items()} step_fn_additional_args = dict( noise_aux=_noise_extra_i, min_step_size=min_step_size, max_step_size=keras.ops.minimum(max_step_size, remaining), use_adaptive_step_size=False, ) new_state, new_time, new_step = step_fn( state=_current_state, time=_current_time, step_size=dt, noise=_noise_i, **filter_kwargs(step_fn_additional_args, step_fn), ) if corrector_steps > 0: new_state = _apply_corrector( new_state=new_state, new_time=new_time, i=_i, corrector_steps=corrector_steps, score_fn=score_fn, noise_schedule=noise_schedule, step_size_factor=step_size_factor, corrector_noise_history=corrector_noise_history, seed=seed, ) return _i + 1, new_state, new_time, initial_step _, final_state, final_time, _ = keras.ops.while_loop( cond, body, [0, state, start_time, initial_step], ) if _check_all_nans(final_state): raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.") return final_state def integrate_stochastic_adaptive( step_fn: Callable, state: StateDict, start_time: ArrayLike, stop_time: ArrayLike, max_steps: int, min_step_size: ArrayLike, max_step_size: ArrayLike, initial_step: ArrayLike, z_history: StateDict | None, z_extra_history: StateDict | None, score_fn: Optional[Callable], step_size_factor: ArrayLike, seed: keras.random.SeedGenerator, corrector_noise_history: StateDict | None, corrector_steps: int = 0, noise_schedule=None, ) -> StateDict: """ Performs adaptive-step SDE integration. """ initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, state) if keras.backend.backend() == "jax": seed = None # not needed, noise is generated upfront else: seed_body = seed def cond(i, current_state, current_time, current_step, last_state): time_remaining = keras.ops.sign(stop_time - start_time) * (stop_time - (current_time + current_step)) all_nans = _check_all_nans(current_state) end_now = keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.less(i, max_steps)) return keras.ops.logical_and(~all_nans, end_now) def body_adaptive(_i, _current_state, _current_time, _current_step, _last_state): # Step Size Control remaining = keras.ops.abs(stop_time - _current_time) sign = keras.ops.sign(_current_step) # Ensure the next step does not overshoot the stop_time dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), remaining) dt = sign * dt_mag if z_history is None: _noise_i = generate_noise(_current_state, seed=seed_body) else: _noise_i = {k: val[_i] for k, val in z_history.items()} _noise_extra_i = None if z_extra_history is not None: if len(z_extra_history) == 0: _noise_extra_i = generate_noise(_current_state, seed=seed_body) else: _noise_extra_i = {k: val[_i] for k, val in z_history.items()} step_fn_additional_args = dict( last_state=_last_state, noise_aux=_noise_extra_i, ) new_state, new_time, new_step, _new_current_state = step_fn( state=_current_state, time=_current_time, step_size=dt, min_step_size=min_step_size, max_step_size=keras.ops.minimum(max_step_size, remaining), noise=_noise_i, use_adaptive_step_size=True, **filter_kwargs(step_fn_additional_args, step_fn), ) if corrector_steps > 0: new_state = _apply_corrector( new_state=new_state, new_time=new_time, i=_i, corrector_steps=corrector_steps, score_fn=score_fn, noise_schedule=noise_schedule, step_size_factor=step_size_factor, corrector_noise_history=corrector_noise_history, seed=seed_body, ) return _i + 1, new_state, new_time, new_step, _new_current_state # Execute the adaptive loop final_counter, final_state, final_time, _, final_k1 = keras.ops.while_loop(cond, body_adaptive, initial_loop_state) if _check_all_nans(final_state): raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.") # Final step to hit stop_time exactly time_diff = stop_time - final_time time_remaining = keras.ops.sign(stop_time - start_time) * time_diff if keras.ops.all(time_remaining > 0): noise_final = generate_noise(final_state, seed=seed) noise_extra_final = None if z_extra_history is not None and len(z_extra_history) > 0: noise_extra_final = generate_noise(final_state, seed=seed) step_fn_additional_args_final = dict( noise_aux=noise_extra_final, last_state=final_k1, ) final_state, _, _ = step_fn( state=final_state, time=final_time, step_size=time_diff, min_step_size=min_step_size, max_step_size=time_remaining, noise=noise_final, use_adaptive_step_size=False, **filter_kwargs(step_fn_additional_args_final, step_fn), ) final_counter = final_counter + 1 debug(f"Finished integration after {final_counter}.") return final_state def integrate_langevin( state: StateDict, start_time: ArrayLike, stop_time: ArrayLike, steps: int, z_history: StateDict | None, score_fn: Callable, noise_schedule, seed: keras.random.SeedGenerator, corrector_noise_history: StateDict | None, step_size_factor: ArrayLike = 0.01, corrector_steps: int = 0, ) -> StateDict: """ Annealed Langevin dynamics using the given score_fn and noise_schedule [1]. At each step i with time t_i, performs for every state component k: state_k <- state_k + e * score_k + sqrt(2 * e) * z Times are stepped linearly from start_time to stop_time. [1] Song et al., "Generative Modeling by Estimating Gradients of the Data Distribution" (2020) """ if steps <= 0: raise ValueError("Number of Langevin steps must be positive.") if score_fn is None or noise_schedule is None: raise ValueError("score_fn and noise_schedule must be provided.") # Linear time grid dt = (stop_time - start_time) / float(steps) effective_factor = step_size_factor * 100 / np.sqrt(steps) def cond(_loop_var, _loop_state, _loop_time): all_nans = _check_all_nans(_loop_state) end_now = keras.ops.less(_loop_var, steps) return keras.ops.logical_and(~all_nans, end_now) def body(_i, _loop_state, _loop_time): # score at current time score = score_fn(_loop_time, **filter_kwargs(_loop_state, score_fn)) # noise schedule log_snr_t = noise_schedule.get_log_snr(t=_loop_time, training=False) _, sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) new_state: StateDict = {} if z_history is None: z_history_i = generate_noise(_loop_state, seed=seed) else: z_history_i = {k: val[_i] for k, val in z_history.items()} for k in _loop_state.keys(): s_k = score.get(k, None) if s_k is None: new_state[k] = _loop_state[k] continue e = effective_factor * sigma_t**2 new_state[k] = _loop_state[k] + e * s_k + keras.ops.sqrt(2.0 * e) * z_history_i[k] new_time = _loop_time + dt if corrector_steps > 0: new_state = _apply_corrector( new_state=new_state, new_time=new_time, i=_i, corrector_steps=corrector_steps, score_fn=score_fn, noise_schedule=noise_schedule, step_size_factor=step_size_factor, corrector_noise_history=corrector_noise_history, seed=seed, ) return _i + 1, new_state, new_time _, final_state, final_time = keras.ops.while_loop( cond, body, (0, state, start_time), ) if _check_all_nans(final_state): raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.") return final_state
[docs] def integrate_stochastic( drift_fn: Callable, diffusion_fn: Callable, state: StateDict, start_time: ArrayLike, stop_time: ArrayLike, seed: keras.random.SeedGenerator, steps: int | Literal["adaptive"] = 100, method: str = "euler_maruyama", min_steps: int = 50, max_steps: int = 1_000, score_fn: Callable = None, corrector_steps: int = 0, noise_schedule=None, step_size_factor: ArrayLike = 0.01, **kwargs, ) -> StateDict: """ Integrate a stochastic differential equation from ``start_time`` to ``stop_time``. This function dispatches to either fixed-step or adaptive-step integration logic, depending on the selected integration ``method`` and the value of ``steps``. Parameters ---------- drift_fn : callable Function computing the drift term of the SDE. It should accept the current state and time as inputs. diffusion_fn : callable Function computing the diffusion term of the SDE. It should accept the current state and time as inputs. state : StateDict Dictionary containing the initial state of the system. start_time : array-like Starting time for integration. stop_time : array-like Ending time for integration. seed : keras.random.SeedGenerator Random seed generator used for noise generation. steps : int or {'adaptive'}, optional Number of integration steps for fixed-step integration, or ``'adaptive'`` to enable adaptive step sizing. Adaptive integration is only supported by the ``'shark'`` method. Default is ``100``. method : str, optional Integration method to use (e.g., ``'euler_maruyama'`` or ``'shark'``). Default is ``'euler_maruyama'``. min_steps : int, optional Minimum number of steps for adaptive integration. Default is ``50``. max_steps : int, optional Maximum number of steps for adaptive integration. Noise is pre-generated up to this number of steps, which may increase memory usage. Default is ``1000``. score_fn : callable, optional Score function used for predictor–corrector sampling. If ``None``, no corrector step is applied. corrector_steps : int, optional Number of corrector steps applied after each predictor step. Default is ``0``. noise_schedule : object, optional Noise schedule object used to compute ``alpha_t`` during the corrector step. Required if ``corrector_steps > 0``. step_size_factor : array-like, optional Scaling factor applied to the corrector step size. Default is ``0.01``. **kwargs Additional keyword arguments passed to the underlying step function. Returns ------- StateDict Final state dictionary after integration. """ is_adaptive = isinstance(steps, str) and steps in ["adaptive", "dynamic"] if is_adaptive: if start_time is None or stop_time is None: raise ValueError("Please provide start_time and stop_time for adaptive integration.") if min_steps <= 0 or max_steps <= 0 or max_steps < min_steps: raise ValueError("min_steps and max_steps must be positive, and max_steps >= min_steps.") loop_steps = max_steps initial_step = (stop_time - start_time) / float(min_steps) span_mag = keras.ops.abs(stop_time - start_time) min_step_size = span_mag / keras.ops.cast(max_steps, span_mag.dtype) max_step_size = span_mag / keras.ops.cast(min_steps, span_mag.dtype) else: if steps <= 0: raise ValueError("Number of steps must be positive.") loop_steps = int(steps) initial_step = (stop_time - start_time) / float(loop_steps) # For fixed step, min/max step size are just the fixed step size min_step_size, max_step_size = initial_step, initial_step # Pre-generate corrector noise if requested corrector_noise_history = None if corrector_steps > 0: if score_fn is None or noise_schedule is None: raise ValueError("Please provide both score_fn and noise_schedule when using corrector_steps > 0.") if keras.backend.backend() == "jax": corrector_noise_history = {} for key, val in state.items(): shape = keras.ops.shape(val) corrector_noise_history[key] = keras.random.normal( (loop_steps, corrector_steps, *shape), dtype=keras.ops.dtype(val), seed=seed ) match method: case "euler_maruyama": step_fn_raw = euler_maruyama_step case "sea": step_fn_raw = sea_step if is_adaptive: raise ValueError("SEA SDE solver does not support adaptive steps.") case "shark": step_fn_raw = shark_step if is_adaptive: raise ValueError("SHARK SDE solver does not support adaptive steps.") case "two_step_adaptive": step_fn_raw = two_step_adaptive_step case "langevin": if is_adaptive: raise ValueError("Langevin sampling does not support adaptive steps.") z_history = None if keras.backend.backend() == "jax": warning(f"JAX backend needs to preallocate random samples for 'max_steps={max_steps}'.") z_history = {} for key, val in state.items(): shape = keras.ops.shape(val) z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) return integrate_langevin( state=state, start_time=start_time, stop_time=stop_time, steps=loop_steps, z_history=z_history, score_fn=score_fn, noise_schedule=noise_schedule, step_size_factor=step_size_factor, corrector_steps=corrector_steps, corrector_noise_history=corrector_noise_history, seed=seed, ) case other: raise TypeError(f"Invalid integration method: {other!r}") # Partial the step function with common arguments step_fn = partial(step_fn_raw, drift_fn=drift_fn, diffusion_fn=diffusion_fn, **filter_kwargs(kwargs, step_fn_raw)) # Pre-generate standard normals for the predictor step (up to max_steps) z_history = None z_extra_history = None if method not in ["sea", "shark"] else {} if keras.backend.backend() == "jax": warning(f"JAX backend needs to preallocate random samples for 'max_steps={max_steps}'.") z_history = {} for key, val in state.items(): shape = keras.ops.shape(val) z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) if method in ["sea", "shark"]: z_extra_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed) if is_adaptive: return integrate_stochastic_adaptive( step_fn=step_fn, state=state, start_time=start_time, stop_time=stop_time, max_steps=max_steps, min_step_size=min_step_size, max_step_size=max_step_size, initial_step=initial_step, z_history=z_history, z_extra_history=z_extra_history, corrector_steps=corrector_steps, score_fn=score_fn, noise_schedule=noise_schedule, step_size_factor=step_size_factor, corrector_noise_history=corrector_noise_history, seed=seed, ) else: return integrate_stochastic_fixed( step_fn=step_fn, state=state, start_time=start_time, stop_time=stop_time, min_step_size=min_step_size, max_step_size=max_step_size, steps=loop_steps, z_history=z_history, z_extra_history=z_extra_history, corrector_steps=corrector_steps, score_fn=score_fn, noise_schedule=noise_schedule, step_size_factor=step_size_factor, corrector_noise_history=corrector_noise_history, seed=seed, )