Source code for bayesflow.networks.consistency_models.consistency_model

import keras
from keras import ops
from keras.saving import (
    register_keras_serializable,
)

import numpy as np

from bayesflow.types import Tensor
from bayesflow.utils import find_network, keras_kwargs, serialize_value_or_type, deserialize_value_or_type, weighted_sum


from ..inference_network import InferenceNetwork


[docs] @register_keras_serializable(package="bayesflow.networks") class ConsistencyModel(InferenceNetwork): """Implements a Consistency Model with Consistency Training (CT) a described in [1-2]. The adaptations to CT described in [2] were taken into account in our implementation for ABI [3]. [1] Song, Y., Dhariwal, P., Chen, M. & Sutskever, I. (2023). Consistency Models. arXiv preprint arXiv:2303.01469 [2] Song, Y., & Dhariwal, P. (2023). Improved Techniques for Training Consistency Models. arXiv preprint arXiv:2310.14189. Discussion: https://openreview.net/forum?id=WNzy9bRDvG [3] Schmitt, M., Pratz, V., Köthe, U., Bürkner, P. C., & Radev, S. T. (2023). Consistency models for scalable and fast simulation-based inference. arXiv preprint arXiv:2312.05440. """ MLP_DEFAULT_CONFIG = { "widths": (256, 256, 256, 256, 256), "activation": "mish", "kernel_initializer": "he_normal", "residual": True, "dropout": 0.05, "spectral_normalization": False, } def __init__( self, total_steps: int | float, subnet: str | type = "mlp", max_time: int | float = 200, sigma2: float = 1.0, eps: float = 0.001, s0: int | float = 10, s1: int | float = 50, subnet_kwargs: dict[str, any] = None, **kwargs, ): """Creates an instance of a consistency model (CM) to be used for standalone consistency training (CT). Parameters ---------- total_steps : int The total number of training steps, must be calculated as number of epochs * number of batches and cannot be inferred during construction time. subnet : str or type, optional, default: "mlp" A neural network type for the consistency model, will be instantiated using subnet_kwargs. max_time : int or float, optional, default: 200.0 The maximum time of the diffusion sigma2 : float or Tensor of dimension (input_dim, 1), optional, default: 1.0 Controls the shape of the skip-function eps : float, optional, default: 0.001 The minimum time s0 : int or float, optional, default: 10 Initial number of discretization steps s1 : int or float, optional, default: 50 Final number of discretization steps subnet_kwargs: dict[str, any], optional Keyword arguments passed to the subnet constructor or used to update the default MLP settings. **kwargs : dict, optional, default: {} Additional keyword arguments """ super().__init__(base_distribution="normal", **keras_kwargs(kwargs)) self.total_steps = float(total_steps) subnet_kwargs = subnet_kwargs or {} if subnet == "mlp": subnet_kwargs = ConsistencyModel.MLP_DEFAULT_CONFIG | subnet_kwargs self.student = find_network(subnet, **subnet_kwargs) self.student_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros") self.sigma2 = ops.convert_to_tensor(sigma2) self.sigma = ops.sqrt(sigma2) self.eps = eps self.max_time = max_time self.c_huber = None self.c_huber2 = None self.s0 = float(s0) self.s1 = float(s1) # create variable that works with JIT compilation self.current_step = self.add_weight(name="current_step", initializer="zeros", trainable=False, dtype="int") self.current_step.assign(0) self.seed_generator = keras.random.SeedGenerator() # serialization: store all parameters necessary to call __init__ self.config = { "total_steps": total_steps, "max_time": max_time, "sigma2": sigma2, "eps": eps, "s0": s0, "s1": s1, "subnet_kwargs": subnet_kwargs, **kwargs, } self.config = serialize_value_or_type(self.config, "subnet", subnet)
[docs] def get_config(self): base_config = super().get_config() return base_config | self.config
[docs] @classmethod def from_config(cls, config): config = deserialize_value_or_type(config, "subnet") return cls(**config)
def _schedule_discretization(self, step) -> float: """Schedule function for adjusting the discretization level `N` during the course of training. Implements the function N(k) from [2], Section 3.4. """ k_ = ops.floor(self.total_steps / (ops.log(self.s1 / self.s0) / ops.log(2.0) + 1.0)) out = ops.minimum(self.s0 * ops.power(2.0, ops.floor(step / k_)), self.s1) + 1.0 return out def _discretize_time(self, num_steps, rho=7.0): """Function for obtaining the discretized time according to [2], Section 2, bottom of page 2. """ N = num_steps + 1 indices = ops.arange(1, N + 1, dtype="float32") one_over_rho = 1.0 / rho discretized_time = ( self.eps**one_over_rho + (indices - 1.0) / (ops.cast(N, "float32") - 1.0) * (self.max_time**one_over_rho - self.eps**one_over_rho) ) ** rho return discretized_time
[docs] def build(self, xz_shape, conditions_shape=None): super().build(xz_shape) self.student_projector.units = xz_shape[-1] input_shape = list(xz_shape) # time vector input_shape[-1] += 1 if conditions_shape is not None: input_shape[-1] += conditions_shape[-1] input_shape = tuple(input_shape) self.student.build(input_shape) input_shape = self.student.compute_output_shape(input_shape) self.student_projector.build(input_shape) # Choose coefficient according to [2] Section 3.3 self.c_huber = 0.00054 * ops.sqrt(xz_shape[-1]) self.c_huber2 = self.c_huber**2 ## Calculate discretization schedule in advance # The Jax compiler requires fixed-size arrays, so we have # to store all the discretized_times in one matrix in advance # and later only access the relevant entries. # First, we calculate all unique numbers of discretization steps n # in a loop, as self.total_steps might be large self.max_n = int(self._schedule_discretization(self.total_steps)) if self.max_n != self.s1 + 1: raise ValueError("The maximum number of discretization steps must be equal to s1 + 1.") unique_n = set() for step in range(int(self.total_steps)): unique_n.add(int(self._schedule_discretization(step))) unique_n = sorted(list(unique_n)) # Next, we calculate the discretized times for each n # and establish a mapping between n and the position i of the # discretizated times in the vector discretized_times = np.zeros((len(unique_n), self.max_n + 1)) discretization_map = np.zeros((self.max_n + 1,), dtype=np.int32) for i, n in enumerate(unique_n): disc = ops.convert_to_numpy(self._discretize_time(n)) discretized_times[i, : len(disc)] = disc discretization_map[n] = i # Finally, we convert the vectors to tensors self.discretized_times = ops.convert_to_tensor(discretized_times, dtype="float32") self.discretization_map = ops.convert_to_tensor(discretization_map)
[docs] def call( self, xz: Tensor, conditions: Tensor = None, inverse: bool = False, **kwargs, ): if inverse: return self._inverse(xz, conditions=conditions, **kwargs) return self._forward(xz, conditions=conditions, **kwargs)
def _forward_train(self, x: Tensor, noise: Tensor, t: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: """Forward function for training. Calls consistency function with noisy input """ inp = x + t * noise return self.consistency_function(inp, t, conditions=conditions, **kwargs) def _forward(self, x: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: # Consistency Models only learn the direction from noise distribution # to target distribution, so we cannot implement this function. raise NotImplementedError("Consistency Models are not invertible") def _inverse(self, z: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: """Generate random draws from the approximate target distribution using the multistep sampling algorithm from [1], Algorithm 1. Parameters ---------- z : Tensor Samples from a standard normal distribution conditions : Tensor, optional, default: None Conditions for a approximate conditional distribution **kwargs : dict, optional, default: {} Additional keyword arguments. Include `steps` (default: 10) to adjust the number of sampling steps. Returns ------- x : Tensor The approximate samples """ steps = kwargs.get("steps", 10) x = keras.ops.copy(z) * self.max_time discretized_time = keras.ops.flip(self._discretize_time(steps), axis=-1) t = keras.ops.full((*keras.ops.shape(x)[:-1], 1), discretized_time[0], dtype=x.dtype) x = self.consistency_function(x, t, conditions=conditions) for n in range(1, steps): noise = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator) x_n = x + keras.ops.sqrt(keras.ops.square(discretized_time[n]) - self.eps**2) * noise t = keras.ops.full_like(t, discretized_time[n]) x = self.consistency_function(x_n, t, conditions=conditions) return x
[docs] def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None, **kwargs) -> Tensor: """Compute consistency function. Parameters ---------- x : Tensor Input vector t : Tensor Vector of time samples in [eps, T] conditions : Tensor The conditioning vector **kwargs : dict, optional, default: {} Additional keyword arguments passed to the network. """ if conditions is not None: xtc = ops.concatenate([x, t, conditions], axis=-1) else: xtc = ops.concatenate([x, t], axis=-1) f = self.student_projector(self.student(xtc, **kwargs)) # Compute skip and out parts (vectorized, since self.sigma2 is of shape (1, input_dim) # Thus, we can do a cross product with the time vector which is (batch_size, 1) for # a resulting shape of cskip and cout of (batch_size, input_dim) skip = self.sigma2 / ((t - self.eps) ** 2 + self.sigma2) out = self.sigma * (t - self.eps) / (ops.sqrt(self.sigma2 + t**2)) out = skip * x + out * f return out
[docs] def compute_metrics( self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training" ) -> dict[str, Tensor]: base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage) # The discretization schedule requires the number of passed training steps. # To be independent of external information, we track it here. if stage == "training": self.current_step.assign_add(1) self.current_step.assign(ops.minimum(self.current_step, self.total_steps - 1)) discretization_index = ops.take( self.discretization_map, ops.cast(self._schedule_discretization(self.current_step), "int") ) discretized_time = ops.take(self.discretized_times, discretization_index, axis=0) # Randomly sample t_n and t_[n+1] and reshape to (batch_size, 1) # adapted noise schedule from [2], Section 3.5 p_mean = -1.1 p_std = 2.0 p = ops.where( discretized_time[1:] > 0.0, ops.erf((ops.log(discretized_time[1:]) - p_mean) / (ops.sqrt(2.0) * p_std)) - ops.erf((ops.log(discretized_time[:-1]) - p_mean) / (ops.sqrt(2.0) * p_std)), 0.0, ) log_p = ops.log(p) times = keras.random.categorical(ops.expand_dims(log_p, 0), ops.shape(x)[0], seed=self.seed_generator)[0] t1 = ops.take(discretized_time, times)[..., None] t2 = ops.take(discretized_time, times + 1)[..., None] # generate noise vector noise = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator) teacher_out = self._forward_train(x, noise, t1, conditions=conditions, training=stage == "training") # difference between teacher and student: different time, # and no gradient for the teacher teacher_out = ops.stop_gradient(teacher_out) student_out = self._forward_train(x, noise, t2, conditions=conditions, training=stage == "training") # weighting function, see [2], Section 3.1 lam = 1 / (t2 - t1) # Pseudo-huber loss, see [2], Section 3.3 loss = lam * (ops.sqrt(ops.square(teacher_out - student_out) + self.c_huber2) - self.c_huber) loss = weighted_sum(loss, sample_weight) return base_metrics | {"loss": loss}