DiffusionModel#

class bayesflow.experimental.DiffusionModel(*args, **kwargs)[source]#

Bases: InferenceNetwork

Diffusion Model as described in this overview paper [1].

[1] Variational Diffusion Models 2.0: Understanding Diffusion Model Objectives as the ELBO with Simple Data Augmentation: Kingma et al. (2023)

[2] Score-Based Generative Modeling through Stochastic Differential Equations: Song et al. (2021)

Initializes a diffusion model with configurable subnet architecture, noise schedule, and prediction/loss types for amortized Bayesian inference.

Note, that score-based diffusion is the most sluggish of all available samplers, so expect slower inference times than flow matching and much slower than normalizing flows.

Parameters:
subnetstr, type or keras.Layer, optional

Architecture for the transformation network. Can be “mlp”, a custom network class, or a Layer object, e.g., bayesflow.networks.MLP(widths=[32, 32]). Default is “mlp”.

noise_schedule{‘edm’, ‘cosine’} or NoiseSchedule or type, optional

Noise schedule controlling the diffusion dynamics. Can be a string identifier, a schedule class, or a pre-initialized schedule instance. Default is “edm”.

prediction_type{‘velocity’, ‘noise’, ‘F’, ‘x’}, optional

Output format of the model’s prediction. Default is “F”.

loss_type{‘velocity’, ‘noise’, ‘F’}, optional

Loss function used to train the model. Default is “noise”.

subnet_kwargsdict[str, any], optional

Additional keyword arguments passed to the subnet constructor. Default is None.

schedule_kwargsdict[str, any], optional

Additional keyword arguments passed to the noise schedule constructor. Default is None.

integrate_kwargsdict[str, any], optional

Configuration dictionary for integration during training or inference. Default is None.

**kwargs

Additional keyword arguments passed to the base class and internal components.

MLP_DEFAULT_CONFIG = {'activation': 'mish', 'dropout': 0.0, 'kernel_initializer': 'he_normal', 'residual': True, 'spectral_normalization': False, 'widths': (256, 256, 256, 256, 256)}#
INTEGRATE_DEFAULT_CONFIG = {'method': 'euler', 'steps': 100}#
build(xz_shape: tuple[int, ...], conditions_shape: tuple[int, ...] = None) None[source]#
get_config()[source]#

Returns the config of the object.

An object config is a Python dictionary (serializable) containing the information needed to re-instantiate it.

classmethod from_config(config, custom_objects=None)[source]#

Creates an operation from its config.

This method is the reverse of get_config, capable of instantiating the same operation from the config dictionary.

Note: If you override this method, you might receive a serialized dtype config, which is a dict. You can deserialize it as follows:

if "dtype" in config and isinstance(config["dtype"], dict):
    policy = dtype_policies.deserialize(config["dtype"])
Args:

config: A Python dictionary, typically the output of get_config.

Returns:

An operation instance.

convert_prediction_to_x(pred: Tensor, z: Tensor, alpha_t: Tensor, sigma_t: Tensor, log_snr_t: Tensor) Tensor[source]#

Converts the neural network prediction into the denoised data x, depending on the prediction type configured for the model.

Parameters:
predTensor

The output prediction from the neural network, typically representing noise, velocity, or a transformation of the clean signal.

zTensor

The noisy latent variable z to be denoised.

alpha_tTensor

The noise schedule’s scaling factor for the clean signal at time t.

sigma_tTensor

The standard deviation of the noise at time t.

log_snr_tTensor

The log signal-to-noise ratio at time t.

Returns:
Tensor

The reconstructed clean signal x from the model prediction.

velocity(xz: Tensor, time: float | Tensor, stochastic_solver: bool, conditions: Tensor = None, training: bool = False) Tensor[source]#

Computes the velocity (i.e., time derivative) of the target or latent variable xz for either a stochastic differential equation (SDE) or ordinary differential equation (ODE).

Parameters:
xzTensor

The current state of the latent variable z, typically of shape (…, D), where D is the dimensionality of the latent space.

timefloat or Tensor

Scalar or tensor representing the time (or noise level) at which the velocity should be computed. Will be broadcasted to xz.

stochastic_solverbool

If True, computes the velocity for the stochastic formulation (SDE). If False, uses the deterministic formulation (ODE).

conditionsTensor, optional

Optional conditional inputs to the network, such as conditioning variables or encoder outputs. Shape must be broadcastable with xz. Default is None.

trainingbool, optional

Whether the model is in training mode. Affects behavior of dropout, batch norm, or other stochastic layers. Default is False.

Returns:
Tensor

The velocity tensor of the same shape as xz, representing the right-hand side of the SDE or ODE at the given time.

diffusion_term(xz: Tensor, time: float | Tensor, training: bool = False) Tensor[source]#

Compute the diffusion term (standard deviation of the noise) at a given time.

Parameters:
xzTensor

Input tensor of shape (…, D), typically representing the target or latent variables at given time.

timefloat or Tensor

The diffusion time step(s). Can be a scalar or a tensor broadcastable to the shape of xz.

trainingbool, optional

Whether to use the training noise schedule (default is False).

Returns:
Tensor

The diffusion term tensor with shape matching xz except for the last dimension, which is set to 1.

compute_metrics(x: Tensor | Sequence[Tensor, ...], conditions: Tensor = None, sample_weight: Tensor = None, stage: str = 'training') dict[str, Tensor][source]#
__call__(*args, **kwargs)#

Call self as a function.

add_loss(loss)#

Can be called inside of the call() method to add a scalar loss.

Example:

class MyLayer(Layer):
    ...
    def call(self, x):
        self.add_loss(ops.sum(x))
        return x
add_metric(*args, **kwargs)#
add_variable(shape, initializer, dtype=None, trainable=True, autocast=True, regularizer=None, constraint=None, name=None)#

Add a weight variable to the layer.

Alias of add_weight().

add_weight(shape=None, initializer=None, dtype=None, trainable=True, autocast=True, regularizer=None, constraint=None, aggregation='none', overwrite_with_gradient=False, name=None)#

Add a weight variable to the layer.

Args:
shape: Shape tuple for the variable. Must be fully-defined

(no None entries). Defaults to () (scalar) if unspecified.

initializer: Initializer object to use to populate the initial

variable value, or string name of a built-in initializer (e.g. “random_normal”). If unspecified, defaults to “glorot_uniform” for floating-point variables and to “zeros” for all other types (e.g. int, bool).

dtype: Dtype of the variable to create, e.g. “float32”. If

unspecified, defaults to the layer’s variable dtype (which itself defaults to “float32” if unspecified).

trainable: Boolean, whether the variable should be trainable via

backprop or whether its updates are managed manually. Defaults to True.

autocast: Boolean, whether to autocast layers variables when

accessing them. Defaults to True.

regularizer: Regularizer object to call to apply penalty on the

weight. These penalties are summed into the loss function during optimization. Defaults to None.

constraint: Contrainst object to call on the variable after any

optimizer update, or string name of a built-in constraint. Defaults to None.

aggregation: Optional string, one of None, “none”, “mean”,

“sum” or “only_first_replica”. Annotates the variable with the type of multi-replica aggregation to be used for this variable when writing custom data parallel training loops. Defaults to “none”.

overwrite_with_gradient: Boolean, whether to overwrite the variable

with the computed gradient. This is useful for float8 training. Defaults to False.

name: String name of the variable. Useful for debugging purposes.

build_from_config(config)#

Builds the layer’s states with the supplied config dict.

By default, this method calls the build(config[“input_shape”]) method, which creates weights based on the layer’s input shape in the supplied config. If your config contains other information needed to load the layer’s state, you should override this method.

Args:

config: Dict containing the input shape associated with this layer.

call(xz: Tensor, conditions: Tensor = None, inverse: bool = False, density: bool = False, training: bool = False, **kwargs) Tensor | tuple[Tensor, Tensor]#
property compute_dtype#

The dtype of the computations performed by the layer.

compute_mask(inputs, previous_mask)#
compute_output_shape(*args, **kwargs)#
compute_output_spec(*args, **kwargs)#
count_params()#

Count the total number of scalars composing the weights.

Returns:

An integer count.

property dtype#

Alias of layer.variable_dtype.

property dtype_policy#
get_build_config()#

Returns a dictionary with the layer’s input shape.

This method returns a config dict that can be used by build_from_config(config) to create all states (e.g. Variables and Lookup tables) needed by the layer.

By default, the config only contains the input shape that the layer was built with. If you’re writing a custom layer that creates state in an unusual way, you should override this method to make sure this state is already created when Keras attempts to load its value upon model loading.

Returns:

A dict containing the input shape associated with the layer.

get_weights()#

Return the values of layer.weights as a list of NumPy arrays.

property input#

Retrieves the input tensor(s) of a symbolic operation.

Only returns the tensor(s) corresponding to the first time the operation was called.

Returns:

Input tensor or list of input tensors.

property input_dtype#

The dtype layer inputs should be converted to.

property input_spec#
load_own_variables(store)#

Loads the state of the layer.

You can override this method to take full control of how the state of the layer is loaded upon calling keras.models.load_model().

Args:

store: Dict from which the state of the model will be loaded.

log_prob(samples: Tensor, conditions: Tensor = None, **kwargs) Tensor#
property losses#

List of scalar losses from add_loss, regularizers and sublayers.

property metrics#

List of all metrics.

property metrics_variables#

List of all metric variables.

property non_trainable_variables#

List of all non-trainable layer state.

This extends layer.non_trainable_weights to include all state used by the layer including state for metrics and `SeedGenerator`s.

property non_trainable_weights#

List of all non-trainable weight variables of the layer.

These are the weights that should not be updated by the optimizer during training. Unlike, layer.non_trainable_variables this excludes metric state and random seeds.

property output#

Retrieves the output tensor(s) of a layer.

Only returns the tensor(s) corresponding to the first time the operation was called.

Returns:

Output tensor or list of output tensors.

property path#

The path of the layer.

If the layer has not been built yet, it will be None.

property quantization_mode#

The quantization mode of this layer, None if not quantized.

quantize(mode, type_check=True)#
quantized_build(input_shape, mode)#
quantized_call(*args, **kwargs)#
rematerialized_call(layer_call, *args, **kwargs)#

Enable rematerialization dynamically for layer’s call method.

Args:

layer_call: The original call method of a layer.

Returns:

Rematerialized layer’s call method.

sample(batch_shape: tuple[int, ...], conditions: Tensor = None, **kwargs) Tensor#
save_own_variables(store)#

Saves the state of the layer.

You can override this method to take full control of how the state of the layer is saved upon calling model.save().

Args:

store: Dict where the state of the model will be saved.

set_weights(weights)#

Sets the values of layer.weights from a list of NumPy arrays.

stateless_call(trainable_variables, non_trainable_variables, *args, return_losses=False, **kwargs)#

Call the layer without any side effects.

Args:

trainable_variables: List of trainable variables of the model. non_trainable_variables: List of non-trainable variables of the

model.

*args: Positional arguments to be passed to call(). return_losses: If True, stateless_call() will return the list of

losses created during call() as part of its return values.

**kwargs: Keyword arguments to be passed to call().

Returns:
A tuple. By default, returns (outputs, non_trainable_variables).

If return_losses = True, then returns (outputs, non_trainable_variables, losses).

Note: non_trainable_variables include not only non-trainable weights such as BatchNormalization statistics, but also RNG seed state (if there are any random operations part of the layer, such as dropout), and Metric state (if there are any metrics attached to the layer). These are all elements of state of the layer.

Example:

model = ...
data = ...
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
# Call the model with zero side effects
outputs, non_trainable_variables = model.stateless_call(
    trainable_variables,
    non_trainable_variables,
    data,
)
# Attach the updated state to the model
# (until you do this, the model is still in its pre-call state).
for ref_var, value in zip(
    model.non_trainable_variables, non_trainable_variables
):
    ref_var.assign(value)
property supports_masking#

Whether this layer supports computing a mask using compute_mask.

symbolic_call(*args, **kwargs)#
property trainable#

Settable boolean, whether this layer should be trainable or not.

property trainable_variables#

List of all trainable layer state.

This is equivalent to layer.trainable_weights.

property trainable_weights#

List of all trainable weight variables of the layer.

These are the weights that get updated by the optimizer during training.

property variable_dtype#

The dtype of the state (weights) of the layer.

property variables#

List of all layer state, including random seeds.

This extends layer.weights to include all state used by the layer including `SeedGenerator`s.

Note that metrics variables are not included here, use metrics_variables to visit all the metric variables.

property weights#

List of all weight variables of the layer.

Unlike, layer.variables this excludes metric state and random seeds.