Source code for bayesflow.networks.subnets.unet.unet

from typing import Sequence, Literal

import keras

from bayesflow.types import Tensor
from bayesflow.utils import layer_kwargs, concatenate_valid
from bayesflow.utils.serialization import deserialize, serializable, serialize
from bayesflow.utils import check_lengths_same

from .blocks.residual import ResidualBlock2D
from .blocks.upsample import UpSample2D
from .blocks.downsample import DownSample2D
from .blocks.attention import SelfAttention2D
from .embeddings.dense_fourier import DenseFourier

from ...helpers import SimpleNorm


[docs] @serializable("bayesflow.networks") class UNet(keras.Layer): """ Time-conditioned U-Net backbone for diffusion models [1]. Expects inputs `(x, t, cond)`, where `cond` is concatenated channel-wise to `x` and a learned time embedding conditions residual blocks (optionally via FiLM). The network follows a DDPM-style encoder–decoder with skip connections, optional self-attention per stage, and pad/crop logic to support odd spatial sizes (see [1]). [1] Nain (2022) Keras example: Denoising Diffusion Probabilistic Model (https://keras.io/examples/generative/ddpm/) """ def __init__( self, widths: Sequence[int] = (64, 128, 256, 512), res_blocks: Sequence[int] | int = 2, attn_stage: Sequence[bool] | None = (False, False, True, True), time_emb_dim: int = 32, time_emb: keras.Layer | None = None, time_emb_include_identity: bool = True, time_emb_use_residual_mlp: bool = True, use_film: bool = False, activation: str = "swish", kernel_initializer: str | keras.initializers.Initializer = "he_normal", dropout: Sequence[float] | float = 0.0, groups: int = 8, num_heads: int = 1, down_mode: Literal["average", "conv"] = "average", up_kernel_size: Literal[1, 3] = 3, up_conv_first: bool = False, norm: Literal["layer", "group"] = "group", **kwargs, ): """ Time-conditioned U-Net backbone for diffusion models. Parameters ---------- widths : Sequence[int], optional Channel widths per resolution stage. res_blocks : Sequence[int] or int, optional Number of residual blocks per stage (decoder uses `+1` per stage). attn_stage : Sequence[bool] or None, optional Whether to use self-attention within each stage. time_emb_dim : int, optional Dimensionality of the time embedding. If 1, time is used directly. time_emb : keras.layers.Layer or None, optional Custom global time embedding layer. If None, uses `DenseFourier` when `time_emb_dim > 1`. time_emb_include_identity : bool, optional Whether the time embedding includes the original time scalar concatenated to the Fourier features. Default is True. time_emb_use_residual_mlp : bool, optional Whether the time embedding uses a residual MLP instead of a simple MLP. Default is True. use_film : bool, optional Whether residual blocks use FiLM or additive conditioning with the local time embedding. activation : str, optional Activation used throughout the network. kernel_initializer : str or keras.initializers.Initializer, optional Kernel initializer for convolution layers. dropout : Sequence[float] or float, optional Dropout rate used inside residual blocks. Default is 0.0. groups : int, optional Number of groups for group normalization where applicable. num_heads : int, optional Number of attention heads for self-attention layers. Default is 1. down_mode : {"average", "conv"}, optional "conv" uses a strided convolution, while "average" uses average pooling followed by a convolution. Default is "conv". up_kernel_size : {1, 3}, optional Kernel size for upsampling convolutions. Default is 3. up_conv_first : bool, optional If True, applies convolution before upsampling, after upsampling otherwise. Default is False. norm: Literal["layer", "group"], optional The type of normalization layer applied, defaults to "group" **kwargs Additional keyword arguments. """ super().__init__(**layer_kwargs(kwargs)) self.widths = widths self.res_blocks = (res_blocks,) * len(self.widths) if isinstance(res_blocks, int) else res_blocks self.attn_stage = (False,) * len(self.widths) if attn_stage is None else attn_stage self.time_emb_dim = time_emb_dim self.time_emb_include_identity = time_emb_include_identity self.time_emb_use_residual_mlp = time_emb_use_residual_mlp self.use_film = use_film self.activation = activation self.kernel_initializer = kernel_initializer self.dropout = (dropout,) * len(self.widths) if isinstance(dropout, float) else dropout self.groups = groups self.num_heads = num_heads self.down_mode = down_mode self.up_kernel_size = up_kernel_size self.up_conv_first = up_conv_first self.norm = norm check_lengths_same(self.widths, self.res_blocks, self.attn_stage, self.dropout) if time_emb is None: if self.time_emb_dim == 1: self.time_emb = keras.layers.Identity() else: self.time_emb = DenseFourier( emb_dim=self.time_emb_dim, include_identity=self.time_emb_include_identity, use_residual_mlp=self.time_emb_use_residual_mlp, kernel_initializer=self.kernel_initializer, ) else: self.time_emb = time_emb self.input_projector = keras.layers.Conv2D( filters=self.widths[0], kernel_size=3, padding="same", kernel_initializer=self.kernel_initializer, ) self.down_stage_names = [] self.downsamples = [] self.paddings = [] for si, ch in enumerate(self.widths): blocks = [] for bi in range(self.res_blocks[si]): blocks.append( ResidualBlock2D( width=ch, activation=self.activation, norm=self.norm, groups=self.groups, dropout=self.dropout[si], kernel_initializer=self.kernel_initializer, use_film=self.use_film, ) ) if self.attn_stage[si]: blocks.append( SelfAttention2D( num_heads=self.num_heads, groups=self.groups, residual="norm", kernel_initializer=self.kernel_initializer, ) ) stage_name = f"down_stage_{si}" setattr(self, stage_name, blocks) self.down_stage_names.append(stage_name) if si < len(self.widths) - 1: self.downsamples.append(DownSample2D(width=self.widths[si + 1], mode=self.down_mode)) self.mid1 = ResidualBlock2D( width=self.widths[-1], activation=self.activation, norm=self.norm, groups=self.groups, dropout=self.dropout[-1], kernel_initializer=self.kernel_initializer, use_film=self.use_film, ) self.mid_attn = SelfAttention2D( num_heads=self.num_heads, groups=self.groups, residual="norm", kernel_initializer=self.kernel_initializer, ) self.mid2 = ResidualBlock2D( width=self.widths[-1], activation=self.activation, norm=self.norm, groups=self.groups, dropout=self.dropout[-1], kernel_initializer=self.kernel_initializer, use_film=self.use_film, ) self.upsamples = [] self.up_stage_names = [] self.crops = [] for ri, ch in enumerate(reversed(self.widths)): si = (len(self.widths) - 1) - ri blocks = [] for bi in range(self.res_blocks[si] + 1): blocks.append( ResidualBlock2D( width=ch, activation=self.activation, norm=self.norm, groups=self.groups, dropout=self.dropout[si], kernel_initializer=self.kernel_initializer, use_film=self.use_film, ) ) if self.attn_stage[si]: blocks.append( SelfAttention2D( num_heads=self.num_heads, groups=self.groups, residual="norm", kernel_initializer=self.kernel_initializer, ) ) stage_name = f"up_stage_{ri}" setattr(self, stage_name, blocks) self.up_stage_names.append(stage_name) if ri != len(self.widths) - 1: self.upsamples.append( UpSample2D( width=self.widths[si - 1], kernel_size=self.up_kernel_size, conv_first=self.up_conv_first, ) ) self.out_norm = SimpleNorm(method=self.norm, groups=self.groups, center=True, scale=True) self.out_conv = None
[docs] @classmethod def from_config(cls, config, custom_objects=None): return cls(**deserialize(config, custom_objects=custom_objects))
[docs] def get_config(self): base_config = super().get_config() base_config = layer_kwargs(base_config) config = { "widths": self.widths, "res_blocks": self.res_blocks, "attn_stage": self.attn_stage, "time_emb_dim": self.time_emb_dim, "time_emb": self.time_emb, "time_emb_include_identity": self.time_emb_include_identity, "time_emb_use_residual_mlp": self.time_emb_use_residual_mlp, "use_film": self.use_film, "activation": self.activation, "kernel_initializer": self.kernel_initializer, "dropout": self.dropout, "groups": self.groups, "num_heads": self.num_heads, "down_mode": self.down_mode, "up_kernel_size": self.up_kernel_size, "up_conv_first": self.up_conv_first, "norm": self.norm, } return base_config | serialize(config)
[docs] def build(self, input_shape): if self.built: return assert len(input_shape) == 3, "UNet expects input shape to be a tuple of (x_shape, t_shape, cond_shape)" x_shape, t_shape, cond_shape = input_shape assert x_shape[-1] is not None, "UNet requires a known channel dimension for x." assert x_shape[1] is not None and x_shape[2] is not None, "UNet requires known spatial dimensions for x." t_shape = (t_shape[0], 1) self.time_emb.build(t_shape) t_emb_shape = self.time_emb.compute_output_shape(t_shape) # concatenate condition at beginning h_shape = list(x_shape) h_shape[-1] = x_shape[-1] + cond_shape[-1] h_shape = tuple(h_shape) self.input_projector.build(h_shape) h_shape = self.input_projector.compute_output_shape(h_shape) # down skip_shapes = [h_shape] padding = [] for si, stage_name in enumerate(self.down_stage_names): blocks = getattr(self, stage_name) for layer in blocks: if isinstance(layer, ResidualBlock2D): layer.build((h_shape, t_emb_shape)) h_shape = layer.compute_output_shape((h_shape, t_emb_shape)) if self.attn_stage[si]: continue else: layer.build(h_shape) skip_shapes.append(h_shape) if si < len(self.widths) - 1: pad_h = h_shape[1] % 2 != 0 pad_w = h_shape[2] % 2 != 0 padding.append((pad_h, pad_w)) layer = keras.layers.ZeroPadding2D(padding=((0, int(pad_h)), (0, int(pad_w)))) layer.build(h_shape) h_shape = layer.compute_output_shape(h_shape) self.paddings.append(layer) self.downsamples[si].build(h_shape) h_shape = self.downsamples[si].compute_output_shape(h_shape) skip_shapes.append(h_shape) # mid self.mid1.build((h_shape, t_emb_shape)) h_shape = self.mid1.compute_output_shape((h_shape, t_emb_shape)) self.mid_attn.build(h_shape) self.mid2.build((h_shape, t_emb_shape)) h_shape = self.mid2.compute_output_shape((h_shape, t_emb_shape)) # up for ri, stage_name in enumerate(self.up_stage_names): blocks = getattr(self, stage_name) si = (len(self.widths) - 1) - ri for layer in blocks: if isinstance(layer, ResidualBlock2D): skip_shape = skip_shapes.pop() h_shape = list(h_shape) h_shape[-1] = h_shape[-1] + skip_shape[-1] h_shape = tuple(h_shape) layer.build((h_shape, t_emb_shape)) h_shape = layer.compute_output_shape((h_shape, t_emb_shape)) else: layer.build(h_shape) if ri != len(self.widths) - 1: # Upsampling and Crop self.upsamples[ri].build(h_shape) h_shape = self.upsamples[ri].compute_output_shape(h_shape) pad_h, pad_w = padding[si - 1] layer = keras.layers.Cropping2D(((0, int(pad_h)), (0, int(pad_w)))) layer.build(h_shape) h_shape = layer.compute_output_shape(h_shape) self.crops.append(layer) self.out_norm.build(h_shape) self.out_conv = keras.layers.Conv2D( filters=x_shape[-1], kernel_size=3, padding="same", kernel_initializer="zeros", ) self.out_conv.build(h_shape)
[docs] def compute_output_shape(self, input_shape): return tuple(input_shape[0])
[docs] def call(self, inputs: tuple[Tensor, Tensor, Tensor], training: bool = False) -> Tensor: x, t, condition = inputs x = self._prepare_inputs(x, condition) t_emb = self._compute_time_embedding(t, training=training) x, skips = self.encode(x, t_emb, training=training) x = self.bottleneck(x, t_emb, training=training) x = self.decode(x, t_emb, skips, training=training) x = self._project_output(x, training=training) return x
[docs] def encode(self, x: Tensor, t_emb: Tensor, training: bool) -> tuple[Tensor, list[Tensor]]: skips = [x] for idx in range(len(self.down_stage_names)): x, skips = self._run_down_stage(idx, x, t_emb, skips, training=training) if idx < len(self.downsamples): x = self.paddings[idx](x, training=training) x = self.downsamples[idx](x, training=training) skips.append(x) return x, skips
[docs] def bottleneck(self, x: Tensor, t_emb: Tensor, training: bool) -> Tensor: x = self.mid1((x, t_emb), training=training) x = self.mid_attn(x, training=training) x = self.mid2((x, t_emb), training=training) return x
[docs] def decode(self, x: Tensor, t_emb: Tensor, skips: list[Tensor], training: bool) -> Tensor: for idx in range(len(self.up_stage_names)): x = self._run_up_stage(idx, x, t_emb, skips, training=training) if idx != len(self.widths) - 1: x = self.upsamples[idx](x, training=training) x = self.crops[idx](x, training=training) return x
def _prepare_inputs(self, x: Tensor, cond: Tensor) -> Tensor: x = concatenate_valid([x, cond], axis=-1) return self.input_projector(x) def _compute_time_embedding(self, t: Tensor, training: bool) -> Tensor: # Ensure shape [B, 1] even if t comes in with extra dims. t = keras.ops.reshape(t, (keras.ops.shape(t)[0], -1))[:, :1] return self.time_emb(t, training=training) def _run_down_stage( self, idx: int, x: Tensor, t_emb: Tensor, skips: list[Tensor], training: bool ) -> tuple[Tensor, list[Tensor]]: for layer in getattr(self, self.down_stage_names[idx]): is_residual = isinstance(layer, ResidualBlock2D) x = layer((x, t_emb), training=training) if is_residual else layer(x, training=training) # Don't store the residual output because the next layer is attention. if not (is_residual and self.attn_stage[idx]): skips.append(x) return x, skips def _run_up_stage(self, idx: int, x: Tensor, t_emb: Tensor, skips: list[Tensor], training: bool) -> Tensor: for layer in getattr(self, self.up_stage_names[idx]): if isinstance(layer, ResidualBlock2D): skip = skips.pop() x = concatenate_valid([x, skip], axis=-1) x = layer((x, t_emb), training=training) else: x = layer(x, training=training) return x def _project_output(self, x: Tensor, training: bool) -> Tensor: x = self.out_norm(x, training=training) return self.out_conv(x)