Source code for bayesflow.utils.jacobian.jvp
from collections.abc import Callable
import keras
from bayesflow.types import Tensor
[docs]
def jvp(
f: Callable, x: Tensor | tuple[Tensor, ...], tangents: Tensor | tuple[Tensor, ...], return_output: bool = False
):
"""Compute the Jacobian-vector product of f at x with tangents."""
if keras.ops.is_tensor(x):
x = (x,)
if keras.ops.is_tensor(tangents):
tangents = (tangents,)
match keras.backend.backend():
case "torch":
import torch
fx, _jvp = torch.autograd.functional.jvp(f, x, tangents)
case "tensorflow":
import tensorflow as tf
with tf.autodiff.ForwardAccumulator(primals=x, tangents=tangents) as acc:
fx = f(*x)
_jvp = acc.jvp(fx)
case "jax":
import jax
fx, _jvp = jax.jvp(
f,
x,
tangents,
)
case _:
raise NotImplementedError(f"JVP not implemented for backend {keras.backend.backend()!r}")
if return_output:
return fx, _jvp
return _jvp