import keras
from keras import ops
from keras.saving import (
register_keras_serializable,
)
import numpy as np
from bayesflow.types import Tensor
from bayesflow.utils import (
jvp,
concatenate_valid,
find_network,
keras_kwargs,
expand_right_as,
expand_right_to,
serialize_value_or_type,
deserialize_value_or_type,
)
from bayesflow.networks import InferenceNetwork
from bayesflow.networks.embeddings import FourierEmbedding
[docs]
@register_keras_serializable(package="bayesflow.networks")
class ContinuousTimeConsistencyModel(InferenceNetwork):
"""Implements an sCM (simple, stable, and scalable Consistency Model)
with continous-time Consistency Training (CT) as described in [1].
The sampling procedure is taken from [2].
[1] Lu, C., & Song, Y. (2024).
Simplifying, Stabilizing and Scaling Continuous-Time Consistency Models
arXiv preprint arXiv:2410.11081
[2] Song, Y., Dhariwal, P., Chen, M. & Sutskever, I. (2023).
Consistency Models.
arXiv preprint arXiv:2303.01469
"""
def __init__(
self,
subnet: str | type = "mlp",
sigma_data: float = 1.0,
**kwargs,
):
"""Creates an instance of an sCM to be used for consistency training (CT).
Parameters
----------
subnet : str or type, optional, default: "mlp"
A neural network type for the consistency model, will be
instantiated using subnet_kwargs.
sigma_data : float, optional, default: 1.0
Standard deviation of the target distribution
**kwargs : dict, optional, default: {}
Additional keyword arguments, such as
"""
super().__init__(base_distribution="normal", **keras_kwargs(kwargs))
self.subnet = find_network(subnet, **kwargs.get("subnet_kwargs", {}))
self.subnet_projector = keras.layers.Dense(units=None, bias_initializer="zeros", kernel_initializer="zeros")
self.weight_fn = find_network("mlp", widths=(256,), dropout=0.0)
self.weight_fn_projector = keras.layers.Dense(units=1, bias_initializer="zeros", kernel_initializer="zeros")
self.time_emb = FourierEmbedding(**kwargs.get("embedding_kwargs", {}))
self.time_emb_dim = self.time_emb.embed_dim
self.sigma_data = sigma_data
self.seed_generator = keras.random.SeedGenerator()
# serialization: store all parameters necessary to call __init__
self.config = {
"sigma_data": sigma_data,
**kwargs,
}
self.config = serialize_value_or_type(self.config, "subnet", subnet)
[docs]
def get_config(self):
base_config = super().get_config()
return base_config | self.config
[docs]
@classmethod
def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)
def _discretize_time(self, num_steps: int, rho: float = 3.5, **kwargs):
t = np.linspace(0.0, np.pi / 2, num_steps)
times = np.exp((t - np.pi / 2) * rho) * np.pi / 2
times[0] = 0.0
# if rho is set too low, bad schedules can occur
EPS_WARN = 0.1
if times[1] > EPS_WARN:
print("Warning: The last time step is large.")
print(f"Increasing rho (was {rho}) or n_steps (was {num_steps}) might improve results.")
return ops.convert_to_tensor(times)
[docs]
def build(self, xz_shape, conditions_shape=None):
super().build(xz_shape)
self.subnet_projector.units = xz_shape[-1]
# construct input shape for subnet and subnet projector
input_shape = list(xz_shape)
# time vector
input_shape[-1] += self.time_emb_dim + 1
if conditions_shape is not None:
input_shape[-1] += conditions_shape[-1]
input_shape = tuple(input_shape)
self.subnet.build(input_shape)
input_shape = self.subnet.compute_output_shape(input_shape)
self.subnet_projector.build(input_shape)
# input shape for time embedding
self.time_emb.build((xz_shape[0], 1))
# input shape for weight function and projector
input_shape = (xz_shape[0], 1)
self.weight_fn.build(input_shape)
input_shape = self.weight_fn.compute_output_shape(input_shape)
self.weight_fn_projector.build(input_shape)
[docs]
def call(
self,
xz: Tensor,
conditions: Tensor = None,
inverse: bool = False,
**kwargs,
):
if inverse:
return self._inverse(xz, conditions=conditions, **kwargs)
return self._forward(xz, conditions=conditions, **kwargs)
def _forward(self, x: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
# Consistency Models only learn the direction from noise distribution
# to target distribution, so we cannot implement this function.
raise NotImplementedError("Consistency Models are not invertible")
def _inverse(self, z: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
"""Generate random draws from the approximate target distribution
using the multistep sampling algorithm from [2], Algorithm 1.
Parameters
----------
z : Tensor
Samples from a standard normal distribution
conditions : Tensor, optional, default: None
Conditions for a approximate conditional distribution
**kwargs : dict, optional, default: {}
Additional keyword arguments. Include `steps` (default: 30) to
adjust the number of sampling steps.
Returns
-------
x : Tensor
The approximate samples
"""
steps = kwargs.get("steps", 15)
rho = kwargs.get("rho", 3.5)
# noise distribution has variance sigma_data
x = keras.ops.copy(z) * self.sigma_data
discretized_time = keras.ops.flip(self._discretize_time(steps, rho=rho), axis=-1)
t = keras.ops.full((*keras.ops.shape(x)[:-1], 1), discretized_time[0], dtype=x.dtype)
x = self.consistency_function(x, t, conditions=conditions)
for n in range(1, steps):
noise = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator)
x_n = ops.cos(t) * x + ops.sin(t) * noise
t = keras.ops.full_like(t, discretized_time[n])
x = self.consistency_function(x_n, t, conditions=conditions)
return x
[docs]
def consistency_function(
self,
x: Tensor,
t: Tensor,
conditions: Tensor = None,
training: bool = False,
**kwargs,
) -> Tensor:
"""Compute consistency function at time t.
Parameters
----------
x : Tensor
Input vector
t : Tensor
Vector of time samples in [0, pi/2]
conditions : Tensor
The conditioning vector
training : bool
Flag to control whether the inner network operates in training or test mode
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the inner network.
"""
xtc = concatenate_valid([x / self.sigma_data, self.time_emb(t), conditions], axis=-1)
f = self.subnet_projector(self.subnet(xtc, training=training, **kwargs))
out = ops.cos(t) * x - ops.sin(t) * self.sigma_data * f
return out
[docs]
def compute_metrics(self, x: Tensor, conditions: Tensor = None, stage: str = "training") -> dict[str, Tensor]:
base_metrics = super().compute_metrics(x, conditions=conditions, stage=stage)
# $# Implements Algorithm 1 from [1]
# training parameters
p_mean = -1.0
p_std = 1.6
c = 0.1
# generate noise vector
z = (
keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator)
* self.sigma_data
)
# sample time
tau = (
keras.random.normal(keras.ops.shape(x)[:1], dtype=keras.ops.dtype(x), seed=self.seed_generator) * p_std
+ p_mean
)
t_ = ops.arctan(ops.exp(tau) / self.sigma_data)
t = expand_right_as(t_, x)
# generate noisy sample
xt = ops.cos(t) * x + ops.sin(t) * z
# calculate estimator for dx_t/dt
dxtdt = ops.cos(t) * z - ops.sin(t) * x
r = 1.0 # TODO: if consistency distillation training (not supported yet) is unstable, add schedule here
def f_teacher(x, t):
o = self.subnet(concatenate_valid([x, self.time_emb(t), conditions], axis=-1), training=stage == "training")
return self.subnet_projector(o)
primals = (xt / self.sigma_data, t)
tangents = (
ops.cos(t) * ops.sin(t) * dxtdt,
ops.cos(t) * ops.sin(t) * self.sigma_data,
)
teacher_output, cos_sin_dFdt = jvp(f_teacher, primals, tangents, return_output=True)
teacher_output = ops.stop_gradient(teacher_output)
cos_sin_dFdt = ops.stop_gradient(cos_sin_dFdt)
# calculate output of the network
xtc = concatenate_valid([xt / self.sigma_data, self.time_emb(t), conditions], axis=-1)
student_out = self.subnet_projector(self.subnet(xtc, training=stage == "training"))
# calculate the tangent
g = -(ops.cos(t) ** 2) * (self.sigma_data * teacher_output - dxtdt) - r * ops.cos(t) * ops.sin(t) * (
xt + self.sigma_data * cos_sin_dFdt
)
# apply normalization to stabilize training
g = g / (ops.norm(g, axis=-1, keepdims=True) + c)
# compute adaptive weights and calculate loss
w = self.weight_fn_projector(self.weight_fn(expand_right_to(t_, 2)))
D = ops.shape(x)[-1]
loss = ops.mean(
(ops.exp(w) / D)
* ops.mean(
ops.reshape(((student_out - teacher_output - g) ** 2), (ops.shape(teacher_output)[0], -1)), axis=-1
)
- w
)
return base_metrics | {"loss": loss}