from collections.abc import Callable, Sequence
from functools import partial
import keras
import numpy as np
from typing import Literal, Union
from bayesflow.types import Tensor
from bayesflow.utils import filter_kwargs
from . import logging
ArrayLike = int | float | Tensor
def euler_step(
fn: Callable,
state: dict[str, ArrayLike],
time: ArrayLike,
step_size: ArrayLike,
tolerance: ArrayLike = 1e-6,
min_step_size: ArrayLike = -float("inf"),
max_step_size: ArrayLike = float("inf"),
use_adaptive_step_size: bool = False,
) -> (dict[str, ArrayLike], ArrayLike, ArrayLike):
k1 = fn(time, **filter_kwargs(state, fn))
if use_adaptive_step_size:
intermediate_state = state.copy()
for key, delta in k1.items():
intermediate_state[key] = state[key] + step_size * delta
k2 = fn(time + step_size, **filter_kwargs(intermediate_state, fn))
# check all keys are equal
if set(k1.keys()) != set(k2.keys()):
raise ValueError("Keys of the deltas do not match. Please return zero for unchanged variables.")
# compute next step size
intermediate_error = keras.ops.stack([keras.ops.norm(k2[key] - k1[key], ord=2, axis=-1) for key in k1])
new_step_size = step_size * tolerance / (intermediate_error + 1e-9)
new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size)
# consolidate step size
new_step_size = keras.ops.take(new_step_size, keras.ops.argmin(keras.ops.abs(new_step_size)))
else:
new_step_size = step_size
# apply updates
new_state = state.copy()
for key in k1.keys():
new_state[key] = state[key] + step_size * k1[key]
new_time = time + step_size
return new_state, new_time, new_step_size
def rk45_step(
fn: Callable,
state: dict[str, ArrayLike],
time: ArrayLike,
last_step_size: ArrayLike,
tolerance: ArrayLike = 1e-6,
min_step_size: ArrayLike = -float("inf"),
max_step_size: ArrayLike = float("inf"),
use_adaptive_step_size: bool = False,
) -> (dict[str, ArrayLike], ArrayLike, ArrayLike):
step_size = last_step_size
k1 = fn(time, **filter_kwargs(state, fn))
intermediate_state = state.copy()
for key, delta in k1.items():
intermediate_state[key] = state[key] + 0.5 * step_size * delta
k2 = fn(time + 0.5 * step_size, **filter_kwargs(intermediate_state, fn))
intermediate_state = state.copy()
for key, delta in k2.items():
intermediate_state[key] = state[key] + 0.5 * step_size * delta
k3 = fn(time + 0.5 * step_size, **filter_kwargs(intermediate_state, fn))
intermediate_state = state.copy()
for key, delta in k3.items():
intermediate_state[key] = state[key] + step_size * delta
k4 = fn(time + step_size, **filter_kwargs(intermediate_state, fn))
if use_adaptive_step_size:
intermediate_state = state.copy()
for key, delta in k4.items():
intermediate_state[key] = state[key] + 0.5 * step_size * delta
k5 = fn(time + 0.5 * step_size, **filter_kwargs(intermediate_state, fn))
# check all keys are equal
if not all(set(k.keys()) == set(k1.keys()) for k in [k2, k3, k4, k5]):
raise ValueError("Keys of the deltas do not match. Please return zero for unchanged variables.")
# compute next step size
intermediate_error = keras.ops.stack([keras.ops.norm(k5[key] - k4[key], ord=2, axis=-1) for key in k5.keys()])
new_step_size = step_size * tolerance / (intermediate_error + 1e-9)
new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size)
# consolidate step size
new_step_size = keras.ops.take(new_step_size, keras.ops.argmin(keras.ops.abs(new_step_size)))
else:
new_step_size = step_size
# apply updates
new_state = state.copy()
for key in k1.keys():
new_state[key] = state[key] + (step_size / 6.0) * (k1[key] + 2.0 * k2[key] + 2.0 * k3[key] + k4[key])
new_time = time + step_size
return new_state, new_time, new_step_size
def integrate_fixed(
fn: Callable,
state: dict[str, ArrayLike],
start_time: ArrayLike,
stop_time: ArrayLike,
steps: int,
method: str = "rk45",
**kwargs,
) -> dict[str, ArrayLike]:
if steps <= 0:
raise ValueError("Number of steps must be positive.")
match method:
case "euler":
step_fn = euler_step
case "rk45":
step_fn = rk45_step
case str() as name:
raise ValueError(f"Unknown integration method name: {name!r}")
case other:
raise TypeError(f"Invalid integration method: {other!r}")
step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False)
step_size = (stop_time - start_time) / steps
time = start_time
def body(_loop_var, _loop_state):
_state, _time = _loop_state
_state, _time, _ = step_fn(_state, _time, step_size)
return _state, _time
state, time = keras.ops.fori_loop(0, steps, body, (state, time))
return state
def integrate_adaptive(
fn: Callable,
state: dict[str, ArrayLike],
start_time: ArrayLike,
stop_time: ArrayLike,
min_steps: int = 10,
max_steps: int = 1000,
method: str = "rk45",
**kwargs,
) -> dict[str, ArrayLike]:
if max_steps <= min_steps:
raise ValueError("Maximum number of steps must be greater than minimum number of steps.")
match method:
case "euler":
step_fn = euler_step
case "rk45":
step_fn = rk45_step
case str() as name:
raise ValueError(f"Unknown integration method name: {name!r}")
case other:
raise TypeError(f"Invalid integration method: {other!r}")
step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=True)
def cond(_state, _time, _step_size, _step):
# while step < min_steps or time_remaining > 0 and step < max_steps
# time remaining after the next step
time_remaining = keras.ops.abs(stop_time - (_time + _step_size))
return keras.ops.logical_or(
keras.ops.all(_step < min_steps),
keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.all(_step < max_steps)),
)
def body(_state, _time, _step_size, _step):
_step = _step + 1
# time remaining after the next step
time_remaining = stop_time - (_time + _step_size)
min_step_size = time_remaining / (max_steps - _step)
max_step_size = time_remaining / keras.ops.maximum(min_steps - _step, 1.0)
# reorder
min_step_size, max_step_size = (
keras.ops.minimum(min_step_size, max_step_size),
keras.ops.maximum(min_step_size, max_step_size),
)
_state, _time, _step_size = step_fn(
_state, _time, _step_size, min_step_size=min_step_size, max_step_size=max_step_size
)
return _state, _time, _step_size, _step
# select initial step size conservatively
step_size = (stop_time - start_time) / max_steps
step = 0
time = start_time
state, time, step_size, step = keras.ops.while_loop(cond, body, [state, time, step_size, step])
# do the last step
step_size = stop_time - time
state, _, _ = step_fn(state, time, step_size)
step = step + 1
logging.debug("Finished integration after {} steps.", step)
return state
def integrate_scheduled(
fn: Callable,
state: dict[str, ArrayLike],
steps: Tensor | np.ndarray,
method: str = "rk45",
**kwargs,
) -> dict[str, ArrayLike]:
match method:
case "euler":
step_fn = euler_step
case "rk45":
step_fn = rk45_step
case str() as name:
raise ValueError(f"Unknown integration method name: {name!r}")
case other:
raise TypeError(f"Invalid integration method: {other!r}")
step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False)
def body(_loop_var, _loop_state):
_time = steps[_loop_var]
step_size = steps[_loop_var + 1] - steps[_loop_var]
_loop_state, _, _ = step_fn(_loop_state, _time, step_size)
return _loop_state
state = keras.ops.fori_loop(0, len(steps) - 1, body, state)
return state
[docs]
def integrate(
fn: Callable,
state: dict[str, ArrayLike],
start_time: ArrayLike | None = None,
stop_time: ArrayLike | None = None,
min_steps: int = 10,
max_steps: int = 10_000,
steps: int | Literal["adaptive"] | Tensor | np.ndarray = 100,
method: str = "euler",
**kwargs,
) -> dict[str, ArrayLike]:
if isinstance(steps, str) and steps in ["adaptive", "dynamic"]:
if start_time is None or stop_time is None:
raise ValueError(
"Please provide start_time and stop_time for the integration, was "
f"'start_time={start_time}', 'stop_time={stop_time}'."
)
return integrate_adaptive(fn, state, start_time, stop_time, min_steps, max_steps, method, **kwargs)
elif isinstance(steps, int):
if start_time is None or stop_time is None:
raise ValueError(
"Please provide start_time and stop_time for the integration, was "
f"'start_time={start_time}', 'stop_time={stop_time}'."
)
return integrate_fixed(fn, state, start_time, stop_time, steps, method, **kwargs)
elif isinstance(steps, Sequence) or isinstance(steps, np.ndarray) or keras.ops.is_tensor(steps):
return integrate_scheduled(fn, state, steps, method, **kwargs)
else:
raise RuntimeError(f"Type or value of `steps` not understood (steps={steps})")
def euler_maruyama_step(
drift_fn: Callable,
diffusion_fn: Callable,
state: dict[str, ArrayLike],
time: ArrayLike,
step_size: ArrayLike,
noise: dict[str, ArrayLike],
) -> (dict[str, ArrayLike], ArrayLike, ArrayLike):
"""
Performs a single Euler-Maruyama step for stochastic differential equations.
Args:
drift_fn: Function computing the drift term f(t, **state).
diffusion_fn: Function computing the diffusion term g(t, **state).
state: Current state, mapping variable names to tensors.
time: Current time scalar tensor.
step_size: Time increment dt.
noise: Mapping of variable names to dW noise tensors.
Returns:
new_state: Updated state after one Euler-Maruyama step.
new_time: time + dt.
"""
# Compute drift and diffusion
drift = drift_fn(time, **filter_kwargs(state, drift_fn))
diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn))
# Check noise keys
if set(diffusion.keys()) != set(noise.keys()):
raise ValueError("Keys of diffusion terms and noise do not match.")
new_state = {}
for key, d in drift.items():
base = state[key] + step_size * d
if key in diffusion: # stochastic update
base = base + diffusion[key] * noise[key]
new_state[key] = base
return new_state, time + step_size
[docs]
def integrate_stochastic(
drift_fn: Callable,
diffusion_fn: Callable,
state: dict[str, ArrayLike],
start_time: ArrayLike,
stop_time: ArrayLike,
steps: int,
seed: keras.random.SeedGenerator,
method: str = "euler_maruyama",
**kwargs,
) -> Union[dict[str, ArrayLike], tuple[dict[str, ArrayLike], dict[str, Sequence[ArrayLike]]]]:
"""
Integrates a stochastic differential equation from start_time to stop_time.
Args:
drift_fn: Function that computes the drift term.
diffusion_fn: Function that computes the diffusion term.
state: Dictionary containing the initial state.
start_time: Starting time for integration.
stop_time: Ending time for integration.
steps: Number of integration steps.
seed: Random seed for noise generation.
method: Integration method to use, e.g., 'euler_maruyama'.
**kwargs: Additional arguments to pass to the step function.
Returns:
If return_noise is False, returns the final state dictionary.
If return_noise is True, returns a tuple of (final_state, noise_history).
"""
if steps <= 0:
raise ValueError("Number of steps must be positive.")
# Select step function based on method
match method:
case "euler_maruyama":
step_fn = euler_maruyama_step
case other:
raise TypeError(f"Invalid integration method: {other!r}")
# Prepare step function with partial application
step_fn = partial(step_fn, drift_fn=drift_fn, diffusion_fn=diffusion_fn, **kwargs)
# Time step
step_size = (stop_time - start_time) / steps
sqrt_dt = keras.ops.sqrt(keras.ops.abs(step_size))
# Pre-generate noise history: shape = (steps, *state_shape)
noise_history = {}
for key, val in state.items():
noise_history[key] = (
keras.random.normal((steps, *keras.ops.shape(val)), dtype=keras.ops.dtype(val), seed=seed) * sqrt_dt
)
def body(_loop_var, _loop_state):
_current_state, _current_time = _loop_state
_noise_i = {k: noise_history[k][_loop_var] for k in _current_state.keys()}
new_state, new_time = step_fn(state=_current_state, time=_current_time, step_size=step_size, noise=_noise_i)
return new_state, new_time
final_state, final_time = keras.ops.fori_loop(0, steps, body, (state, start_time))
return final_state