Source code for bayesflow.utils.jacobian.vjp
from collections.abc import Callable
import keras
from bayesflow.types import Tensor
[docs]
def vjp(f: Callable[[Tensor], Tensor], x: Tensor, return_output: bool = False):
"""Compute the vector-Jacobian product of f at x."""
match keras.backend.backend():
case "jax":
import jax
fx, _vjp_fn = jax.vjp(f, x)
def vjp_fn(projector):
return _vjp_fn(projector)[0]
case "tensorflow":
import tensorflow as tf
with tf.GradientTape(persistent=True) as tape:
tape.watch(x)
fx = f(x)
def vjp_fn(projector):
return tape.gradient(fx, x, projector)
case "torch":
import torch
x = keras.ops.copy(x)
x.requires_grad_(True)
with torch.enable_grad():
fx = f(x)
def vjp_fn(projector):
return torch.autograd.grad(fx, x, projector, retain_graph=True)[0]
case other:
raise NotImplementedError(f"Cannot build a vjp function for backend '{other}'.")
if return_output:
return fx, vjp_fn
return vjp_fn