Source code for bayesflow.experimental.free_form_flow.free_form_flow

import keras
from keras import ops
from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
from bayesflow.utils import (
    find_network,
    keras_kwargs,
    concatenate_valid,
    jacobian,
    jvp,
    vjp,
    serialize_value_or_type,
    deserialize_value_or_type,
)

from bayesflow.networks import InferenceNetwork


[docs] @serializable(package="networks.free_form_flow") class FreeFormFlow(InferenceNetwork): """Implements a dimensionality-preserving Free-form Flow. Incorporates ideas from [1-2]. [1] Draxler, F., Sorrenson, P., Zimmermann, L., Rousselot, A., & Köthe, U. (2024).F ree-form flows: Make Any Architecture a Normalizing Flow. In International Conference on Artificial Intelligence and Statistics. [2] Sorrenson, P., Draxler, F., Rousselot, A., Hummerich, S., Zimmermann, L., & Köthe, U. (2024). Lifting Architectural Constraints of Injective Flows. In International Conference on Learning Representations. """ ENCODER_MLP_DEFAULT_CONFIG = { "widths": (256, 256, 256, 256), "activation": "mish", "kernel_initializer": "he_normal", "residual": True, "dropout": 0.05, "spectral_normalization": False, } DECODER_MLP_DEFAULT_CONFIG = { "widths": (256, 256, 256, 256), "activation": "mish", "kernel_initializer": "he_normal", "residual": True, "dropout": 0.05, "spectral_normalization": False, } def __init__( self, beta: float = 50.0, encoder_subnet: str | type = "mlp", decoder_subnet: str | type = "mlp", base_distribution: str = "normal", hutchinson_sampling: str = "qr", **kwargs, ): """Creates an instance of a Free-form Flow. Parameters ---------- beta : float, optional, default: 50.0 encoder_subnet : str or type, optional, default: "mlp" A neural network type for the flow, will be instantiated using encoder_subnet_kwargs. Will be equipped with a projector to ensure the correct output dimension and a global skip connection. decoder_subnet : str or type, optional, default: "mlp" A neural network type for the flow, will be instantiated using decoder_subnet_kwargs. Will be equipped with a projector to ensure the correct output dimension and a global skip connection. base_distribution : str, optional, default: "normal" The latent distribution hutchinson_sampling : str, optional, default: "qr One of `["sphere", "qr"]`. Select the sampling scheme for the vectors of the Hutchinson trace estimator. **kwargs : dict, optional, default: {} Additional keyword arguments """ super().__init__(base_distribution=base_distribution, **keras_kwargs(kwargs)) if encoder_subnet == "mlp": encoder_subnet_kwargs = FreeFormFlow.ENCODER_MLP_DEFAULT_CONFIG.copy() encoder_subnet_kwargs.update(kwargs.get("encoder_subnet_kwargs", {})) else: encoder_subnet_kwargs = kwargs.get("encoder_subnet_kwargs", {}) self.encoder_subnet = find_network(encoder_subnet, **encoder_subnet_kwargs) self.encoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros") if decoder_subnet == "mlp": decoder_subnet_kwargs = FreeFormFlow.DECODER_MLP_DEFAULT_CONFIG.copy() decoder_subnet_kwargs.update(kwargs.get("decoder_subnet_kwargs", {})) else: decoder_subnet_kwargs = kwargs.get("decoder_subnet_kwargs", {}) self.decoder_subnet = find_network(decoder_subnet, **decoder_subnet_kwargs) self.decoder_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros") self.hutchinson_sampling = hutchinson_sampling self.beta = beta self.seed_generator = keras.random.SeedGenerator() # serialization: store all parameters necessary to call __init__ self.config = { "beta": beta, "base_distribution": base_distribution, "hutchinson_sampling": hutchinson_sampling, **kwargs, } self.config = serialize_value_or_type(self.config, "encoder_subnet", encoder_subnet) self.config = serialize_value_or_type(self.config, "decoder_subnet", decoder_subnet)
[docs] def get_config(self): base_config = super().get_config() return base_config | self.config
[docs] @classmethod def from_config(cls, config): config = deserialize_value_or_type(config, "encoder_subnet") config = deserialize_value_or_type(config, "decoder_subnet") return cls(**config)
# noinspection PyMethodOverriding
[docs] def build(self, xz_shape, conditions_shape=None): super().build(xz_shape) self.encoder_projector.units = xz_shape[-1] self.decoder_projector.units = xz_shape[-1] # construct input shape for subnet and subnet projector input_shape = list(xz_shape) if conditions_shape is not None: input_shape[-1] += conditions_shape[-1] input_shape = tuple(input_shape) self.encoder_subnet.build(input_shape) self.decoder_subnet.build(input_shape) input_shape = self.encoder_subnet.compute_output_shape(input_shape) self.encoder_projector.build(input_shape) input_shape = self.decoder_subnet.compute_output_shape(input_shape) self.decoder_projector.build(input_shape)
def _forward( self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: if density: z, jac = jacobian( lambda inp: self.encode(inp, conditions=conditions, training=training, **kwargs), x, return_output=True ) log_det = keras.ops.logdet(jac) log_density = self.base_distribution.log_prob(z) + log_det return z, log_density z = self.encode(x, conditions, training=training, **kwargs) return z def _inverse( self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs ) -> Tensor | tuple[Tensor, Tensor]: if density: x, jac = jacobian( lambda inp: self.decode(inp, conditions=conditions, training=training, **kwargs), z, return_output=True ) log_det = keras.ops.logdet(jac) log_density = self.base_distribution.log_prob(z) - log_det return x, log_density x = self.decode(z, conditions, training=training, **kwargs) return x
[docs] def encode(self, x: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor: if conditions is None: inp = x else: inp = concatenate_valid([x, conditions], axis=-1) network_out = self.encoder_projector( self.encoder_subnet(inp, training=training, **kwargs), training=training, **kwargs ) return network_out + x
[docs] def decode(self, z: Tensor, conditions: Tensor = None, training: bool = False, **kwargs) -> Tensor: if conditions is None: inp = z else: inp = concatenate_valid([z, conditions], axis=-1) network_out = self.decoder_projector( self.decoder_subnet(inp, training=training, **kwargs), training=training, **kwargs ) return network_out + z
def _sample_v(self, x): batch_size = ops.shape(x)[0] total_dim = ops.shape(x)[-1] match self.hutchinson_sampling: case "qr": # Use QR decomposition as described in [2] v_raw = keras.random.normal((batch_size, total_dim, 1), dtype=ops.dtype(x), seed=self.seed_generator) q = ops.reshape(ops.qr(v_raw)[0], ops.shape(x)) v = q * ops.sqrt(total_dim) case "sphere": # Sample from sphere with radius sqrt(total_dim), as implemented in [1] v_raw = keras.random.normal((batch_size, total_dim), dtype=ops.dtype(x), seed=self.seed_generator) v = v_raw * ops.sqrt(total_dim) / ops.sqrt(ops.sum(v_raw**2, axis=-1, keepdims=True)) case _: raise ValueError(f"{self.hutchinson_sampling} is not a valid value for hutchinson_sampling.") return v
[docs] def compute_metrics( self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training" ) -> dict[str, Tensor]: base_metrics = super().compute_metrics(x, conditions=conditions, sample_weight=sample_weight, stage=stage) # sample random vector v = self._sample_v(x) def encode(x): return self.encode(x, conditions, training=stage == "training") def decode(z): return self.decode(z, conditions, training=stage == "training") # VJP computation z, vjp_fn = vjp(encode, x) v1 = vjp_fn(v)[0] # JVP computation x_pred, v2 = jvp(decode, (z,), (v,)) # equivalent: surrogate = ops.matmul(ops.stop_gradient(v2[:, None]), v1[:, :, None])[:, 0, 0] surrogate = ops.sum((ops.stop_gradient(v2) * v1), axis=-1) nll = -self.base_distribution.log_prob(z) maximum_likelihood_loss = nll - surrogate reconstruction_loss = ops.sum((x - x_pred) ** 2, axis=-1) losses = maximum_likelihood_loss + self.beta * reconstruction_loss loss = self.aggregate(losses, sample_weight) return base_metrics | {"loss": loss}