Source code for bayesflow.experimental.free_form_flow.free_form_flow
import keras
from keras import ops
from keras.saving import register_keras_serializable as serializable
from bayesflow.types import Tensor
from bayesflow.utils import (
find_network,
keras_kwargs,
concatenate_valid,
jacobian,
jvp,
vjp,
serialize_value_or_type,
deserialize_value_or_type,
)
from bayesflow.networks import InferenceNetwork
[docs]
@serializable(package="networks.free_form_flow")
class FreeFormFlow(InferenceNetwork):
"""Implements a dimensionality-preserving Free-form Flow.
Incorporates ideas from [1-2].
[1] Draxler, F., Sorrenson, P., Zimmermann, L., Rousselot, A., & Köthe, U. (2024).F
ree-form flows: Make Any Architecture a Normalizing Flow.
In International Conference on Artificial Intelligence and Statistics.
[2] Sorrenson, P., Draxler, F., Rousselot, A., Hummerich, S., Zimmermann, L., &
Köthe, U. (2024). Lifting Architectural Constraints of Injective Flows.
In International Conference on Learning Representations.
"""
ENCODER_MLP_DEFAULT_CONFIG = {
"widths": (256, 256, 256, 256),
"activation": "mish",
"kernel_initializer": "he_normal",
"residual": True,
"dropout": 0.05,
"spectral_normalization": False,
}
DECODER_MLP_DEFAULT_CONFIG = {
"widths": (256, 256, 256, 256),
"activation": "mish",
"kernel_initializer": "he_normal",
"residual": True,
"dropout": 0.05,
"spectral_normalization": False,
}
def __init__(
self,
beta: float = 50.0,
encoder_subnet: str | type = "mlp",
decoder_subnet: str | type = "mlp",
base_distribution: str = "normal",
hutchinson_sampling: str = "qr",
**kwargs,
):
"""Creates an instance of a Free-form Flow.
Parameters
----------
beta : float, optional, default: 50.0
encoder_subnet : str or type, optional, default: "mlp"
A neural network type for the flow, will be instantiated using
encoder_subnet_kwargs. Will be equipped with a projector to ensure
the correct output dimension and a global skip connection.
decoder_subnet : str or type, optional, default: "mlp"
A neural network type for the flow, will be instantiated using
decoder_subnet_kwargs. Will be equipped with a projector to ensure
the correct output dimension and a global skip connection.
base_distribution : str, optional, default: "normal"
The latent distribution
hutchinson_sampling : str, optional, default: "qr
One of `["sphere", "qr"]`. Select the sampling scheme for the
vectors of the Hutchinson trace estimator.
**kwargs : dict, optional, default: {}
Additional keyword arguments
"""
super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs))
if encoder_subnet == "mlp":
encoder_subnet_kwargs = FreeFormFlow.ENCODER_MLP_DEFAULT_CONFIG.copy()
encoder_subnet_kwargs.update(kwargs.get("encoder_subnet_kwargs", {}))
else:
encoder_subnet_kwargs = kwargs.get("encoder_subnet_kwargs", {})
self.encoder_subnet = find_network(encoder_subnet, **encoder_subnet_kwargs)
self.encoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros")
if decoder_subnet == "mlp":
decoder_subnet_kwargs = FreeFormFlow.DECODER_MLP_DEFAULT_CONFIG.copy()
decoder_subnet_kwargs.update(kwargs.get("decoder_subnet_kwargs", {}))
else:
decoder_subnet_kwargs = kwargs.get("decoder_subnet_kwargs", {})
self.decoder_subnet = find_network(decoder_subnet, **decoder_subnet_kwargs)
self.decoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros")
self.hutchinson_sampling = hutchinson_sampling
self.beta = beta
self.seed_generator = keras.random.SeedGenerator()
# serialization: store all parameters necessary to call __init__
self.config = {
"beta": beta,
"base_distribution": base_distribution,
"hutchinson_sampling": hutchinson_sampling,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "encoder_subnet", encoder_subnet)
self.config = serialize_value_or_type(self.config, "decoder_subnet", decoder_subnet)
[docs]
@classmethod
def from_config(cls, config):
config = deserialize_value_or_type(config, "encoder_subnet")
config = deserialize_value_or_type(config, "decoder_subnet")
return cls(**config)
# noinspection PyMethodOverriding
[docs]
def build(self, xz_shape, conditions_shape=None):
super().build(xz_shape)
self.encoder_projector.units = xz_shape[-1]
self.decoder_projector.units = xz_shape[-1]
# construct input shape for subnet and subnet projector
input_shape = list(xz_shape)
if conditions_shape is not None:
input_shape[-1] += conditions_shape[-1]
input_shape = tuple(input_shape)
self.encoder_subnet.build(input_shape)
self.decoder_subnet.build(input_shape)
input_shape = self.encoder_subnet.compute_output_shape(input_shape)
self.encoder_projector.build(input_shape)
input_shape = self.decoder_subnet.compute_output_shape(input_shape)
self.decoder_projector.build(input_shape)
def _forward(
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
if density:
z, jac = jacobian(
lambda inp: self.encode(inp, conditions=conditions, training=training, **kwargs), x, return_output=True
)
log_det = keras.ops.logdet(jac)
log_density = self.base_distribution.log_prob(z) + log_det
return z, log_density
z = self.encode(x, conditions, training=training, **kwargs)
return z
def _inverse(
self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
if density:
x, jac = jacobian(
lambda inp: self.decode(inp, conditions=conditions, training=training, **kwargs), z, return_output=True
)
log_det = keras.ops.logdet(jac)
log_density = self.base_distribution.log_prob(z) - log_det
return x, log_density
x = self.decode(z, conditions, training=training, **kwargs)
return x
[docs]
def encode(self, x: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor:
if conditions is None:
inp = x
else:
inp = concatenate_valid([x, conditions], axis=-1)
network_out = self.encoder_projector(
self.encoder_subnet(inp, training=training, **kwargs), training=training, **kwargs
)
return network_out + x
[docs]
def decode(self, z: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor:
if conditions is None:
inp = z
else:
inp = concatenate_valid([z, conditions], axis=-1)
network_out = self.decoder_projector(
self.decoder_subnet(inp, training=training, **kwargs), training=training, **kwargs
)
return network_out + z
def _sample_v(self, x):
batch_size = ops.shape(x)[0]
total_dim = ops.shape(x)[-1]
match self.hutchinson_sampling:
case "qr":
# Use QR decomposition as described in [2]
v_raw = keras.random.normal((batch_size, total_dim, 1), dtype=ops.dtype(x), seed=self.seed_generator)
q = ops.reshape(ops.qr(v_raw)[0], ops.shape(x))
v = q * ops.sqrt(total_dim)
case "sphere":
# Sample from sphere with radius sqrt(total_dim), as implemented in [1]
v_raw = keras.random.normal((batch_size, total_dim), dtype=ops.dtype(x), seed=self.seed_generator)
v = v_raw * ops.sqrt(total_dim) / ops.sqrt(ops.sum(v_raw**2, axis=-1, keepdims=True))
case _:
raise ValueError(f"{self.hutchinson_sampling} is not a valid value for hutchinson_sampling.")
return v
[docs]
def compute_metrics(
self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training"
) -> dict[str, Tensor]:
base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage)
# sample random vector
v = self._sample_v(x)
def encode(x):
return self.encode(x, conditions, training=stage == "training")
def decode(z):
return self.decode(z, conditions, training=stage == "training")
# VJP computation
z, vjp_fn = vjp(encode, x)
v1 = vjp_fn(v)[0]
# JVP computation
x_pred, v2 = jvp(decode, (z,), (v,))
# equivalent: surrogate = ops.matmul(ops.stop_gradient(v2[:, None]), v1[:, :, None])[:, 0, 0]
surrogate = ops.sum((ops.stop_gradient(v2) * v1), axis=-1)
nll = -self.base_distribution.log_prob(z)
maximum_likelihood_loss = nll - surrogate
reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1)
losses = maximum_likelihood_loss + self.beta * reconstruction_loss
loss = self.aggregate(losses, sample_weight)
return base_metrics | {"loss": loss}