Source code for bayesflow.experimental.rectifiers

# Copyright (c) 2022 The BayesFlow Developers

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from functools import partial

import tensorflow as tf
import tensorflow_probability as tfp

import bayesflow.default_settings as defaults
from bayesflow.computational_utilities import compute_jacobian_trace
from bayesflow.exceptions import SummaryStatsError
from bayesflow.helper_networks import MCDropout
from bayesflow.losses import mmd_summary_space


[docs] class DriftNetwork(tf.keras.Model): """Implements a learnable velocity field for a neural ODE. Will typically be used in conjunction with a ``RectifyingFlow`` instance, as proposed by [1] in the context of unconditional image generation. [1] Liu, X., Gong, C., & Liu, Q. (2022). Flow straight and fast: Learning to generate and transfer data with rectified flow. arXiv preprint arXiv:2209.03003. """
[docs] def __init__( self, target_dim, num_dense=3, dense_args=None, dropout=True, mc_dropout=False, dropout_prob=0.05, **kwargs ): """Creates a learnable velocity field instance to be used in the context of rectifying flows or neural ODEs. [1] Liu, X., Gong, C., & Liu, Q. (2022). Flow straight and fast: Learning to generate and transfer data with rectified flow. arXiv preprint arXiv:2209.03003. Parameters ---------- target_dim : int The problem dimensionality (e.g., in parameter estimation, the number of parameters) num_dense : int, optional, default: 3 The number of hidden layers for the inner fully-connected network dense_args : dict or None, optional, default: None The arguments to be passed to ``tf.keras.layers.Dense`` constructor. If None, default settings will be fetched from ``bayesflow.default_settings``. dropout : bool, optional, default: True Whether to use dropout in-between the hidden layers. mc_dropout : bool, optional, default: False Whether to use dropout Monte Carlo dropout (i.e., Bayesian approximation) during inference dropout_prob : float in (0, 1), optional, default: 0.05 The dropout probability. Only has effecft if ``dropout=True`` or ``mc_dropout=True`` **kwargs : dict, optional, default: {} Optional keyword arguments passed to the ``tf.keras.Model.__init__`` method. """ super().__init__(**kwargs) self.latent_dim = target_dim if dense_args is None: dense_args = defaults.DEFAULT_SETTING_DENSE_RECT self.net = tf.keras.Sequential() for _ in range(num_dense): self.net.add(tf.keras.layers.Dense(**dense_args)) if mc_dropout: self.net.add(MCDropout(dropout_prob)) elif dropout: self.net.add(tf.keras.layers.Dropout(dropout_prob)) else: pass self.net.add(tf.keras.layers.Dense(self.latent_dim)) self.net.build(input_shape=())
[docs] def call(self, target_vars, latent_vars, time, condition, **kwargs): """Performs a linear interpolation between target and latent variables over time (i.e., a single ODE step during training). Parameters ---------- target_vars : tf.Tensor of shape (batch_size, ..., num_targets) The variables of interest (e.g., parameters) over which we perform inference. latent_vars : tf.Tensor of shape (batch_size, ..., num_targets) The sampled random variates from the base distribution. time : tf.Tensor of shape (batch_size, ..., 1) A vector of time indices in (0, 1) condition : tf.Tensor of shape (batch_size, ..., condition_dim) The optional conditioning variables (e.g., as returned by a summary network) **kwargs : dict, optional, default: {} Optional keyword arguments passed to the ``tf.keras.Model`` call() method """ diff = target_vars - latent_vars wdiff = time * target_vars + (1 - time) * latent_vars drift = self.drift(wdiff, time, condition, **kwargs) return diff, drift
[docs] def drift(self, target_t, time, condition, **kwargs): """Returns the drift at target_t time given optional condition(s). Parameters ---------- target_t : tf.Tensor of shape (batch_size, ..., num_targets) The variables of interest (e.g., parameters) over which we perform inference. time : tf.Tensor of shape (batch_size, ..., 1) A vector of time indices in (0, 1) condition : tf.Tensor of shape (batch_size, ..., condition_dim) The optional conditioning variables (e.g., as returned by a summary network) **kwargs : dict, optional, default: {} Optional keyword arguments passed to the drift network. """ if condition is not None: inp = tf.concat([target_t, condition, time], axis=-1) else: inp = tf.concat([target_t, time], axis=-1) return self.net(inp, **kwargs)
[docs] class RectifiedDistribution(tf.keras.Model): """Implements a rectifying flows according to [1]. To be used as an alternative to a normalizing flow in a BayesFlow pipeline. [1] Liu, X., Gong, C., & Liu, Q. (2022). Flow straight and fast: Learning to generate and transfer data with rectified flow. arXiv preprint arXiv:2209.03003. """
[docs] def __init__(self, drift_net, summary_net=None, latent_dist=None, loss_fun=None, summary_loss_fun=None, **kwargs): """Initializes a composite neural network to represent an amortized approximate posterior through for a rectifying flow. Parameters ---------- drift_net : tf.keras.Model A neural network for the velocity field (drift) of the learnable ODE summary_net : tf.keras.Model or None, optional, default: None An optional summary network to compress non-vector data structures. latent_dist : callable or None, optional, default: None The latent distribution towards which to optimize the networks. Defaults to a multivariate unit Gaussian. loss_fun : callable or None, optional, default: None The loss function for "rectifying" the velocity field. If ``None``, defaults to tf.keras.losses.logcosh. Sensible alternatives are MSE (as in []) summary_loss_fun : callable, str, or None, optional, default: None The loss function which accepts the outputs of the summary network. If ``None``, no loss is provided and the summary space will not be shaped according to a known distribution (see [2]). If ``summary_loss_fun='MMD'``, the default loss from [2] will be used. **kwargs : dict, optional, default: {} Additional keyword arguments passed to the ``__init__`` method of a ``tf.keras.Model`` instance. Important ---------- - If no ``summary_net`` is provided, then the output dictionary of your generative model should not contain any ``summary_conditions``, i.e., ``summary_conditions`` should be set to ``None``, otherwise these will be ignored. """ super().__init__(**kwargs) self.drift_net = drift_net self.summary_net = summary_net self.latent_dim = drift_net.latent_dim self.latent_dist = self._determine_latent_dist(latent_dist) self.loss_fun = self._determine_loss(loss_fun) self.summary_loss = self._determine_summary_loss(summary_loss_fun)
[docs] def call(self, input_dict, return_summary=False, num_eval_points=1, **kwargs): """Performs a forward pass through the summary and drift network given an input dictionary. Parameters ---------- input_dict : dict Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged: ``targets`` - the latent model parameters over which a condition density is learned ``summary_conditions`` - the conditioning variables (including data) that are first passed through a summary network ``direct_conditions`` - the conditioning variables that the directly passed to the inference network return_summary : bool, optional, default: False A flag which determines whether the learnable data summaries (representations) are returned or not. num_eval_points : int, optional, default: 1 The number of time points for evaluating the noisy estimator. Values larger than the default 1 may reduce the variance of the estimator, but may lead to increased memory demands, since an additional dimension is added at axis 1 of all tensors. **kwargs : dict, optional, default: {} Additional keyword arguments passed to the networks For instance, ``kwargs={'training': True}`` is passed automatically during training. Returns ------- net_out or (net_out, summary_out) """ # Concatenate conditions, if given summary_out, full_cond = self._compute_summary_condition( input_dict.get(defaults.DEFAULT_KEYS["summary_conditions"]), input_dict.get(defaults.DEFAULT_KEYS["direct_conditions"]), **kwargs, ) # Extract target variables target_vars = input_dict[defaults.DEFAULT_KEYS["parameters"]] # Extract batch size (autograph friendly) batch_size = tf.shape(target_vars)[0] # Sample latent variables latent_vars = self.latent_dist.sample(batch_size) # Do a little trick for less noisy estimator, if evals > 1 if num_eval_points > 1: target_vars = tf.stack([target_vars] * num_eval_points, axis=1) latent_vars = tf.stack([latent_vars] * num_eval_points, axis=1) full_cond = tf.stack([full_cond] * num_eval_points, axis=1) # Sample time time = tf.random.uniform((batch_size, num_eval_points, 1)) else: time = tf.random.uniform((batch_size, 1)) # Compute drift net_out = self.drift_net(target_vars, latent_vars, time, full_cond, **kwargs) # Return summary outputs or not, depending on parameter if return_summary: return net_out, summary_out return net_out
[docs] def compute_loss(self, input_dict, **kwargs): """Computes the loss of the posterior amortizer given an input dictionary, which will typically be the output of a Bayesian ``GenerativeModel`` instance. Parameters ---------- input_dict : dict Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged: ``targets`` - the latent variables over which a condition density is learned ``summary_conditions`` - the conditioning variables that are first passed through a summary network ``direct_conditions`` - the conditioning variables that the directly passed to the inference network **kwargs : dict, optional, default: {} Additional keyword arguments passed to the networks For instance, ``kwargs={'training': True}`` is passed automatically during training. Returns ------- total_loss : tf.Tensor of shape (1,) - the total computed loss given input variables """ net_out, sum_out = self(input_dict, return_summary=True, **kwargs) diff, drift = net_out loss = self.loss_fun(diff, drift) # Case summary loss should be computed if self.summary_loss is not None: sum_loss = self.summary_loss(sum_out) # Case no summary loss, simply add 0 for convenience else: sum_loss = 0.0 # Compute and return total loss total_loss = tf.reduce_mean(loss) + sum_loss return total_loss
[docs] def sample(self, input_dict, n_samples, to_numpy=True, step_size=1e-3, **kwargs): """Generates random draws from the approximate posterior given a dictionary with conditonal variables. Parameters ---------- input_dict : dict Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged: ``summary_conditions`` : the conditioning variables (including data) that are first passed through a summary network ``direct_conditions`` : the conditioning variables that the directly passed to the inference network n_samples : int The number of posterior draws (samples) to obtain from the approximate posterior to_numpy : bool, optional, default: True Flag indicating whether to return the samples as a ``np.ndarray`` or a ``tf.Tensor`` step_size : float, optional, default: 0.01 The step size for the stochastic Euler solver. **kwargs : dict, optional, default: {} Additional keyword arguments passed to the networks Returns ------- post_samples : tf.Tensor or np.ndarray of shape (n_data_sets, n_samples, n_params) The sampled parameters from the approximate posterior of each data set """ # Compute condition (direct, summary, or both) _, conditions = self._compute_summary_condition( input_dict.get(defaults.DEFAULT_KEYS["summary_conditions"]), input_dict.get(defaults.DEFAULT_KEYS["direct_conditions"]), training=False, **kwargs, ) n_data_sets = tf.shape(conditions)[0] # Sample initial latent variables -> shape (n_data_sets, n_samples, latent_dim) latent_vars = self.latent_dist.sample((n_data_sets, n_samples)) # Replicate conditions and solve ODEs simulatenously conditions = tf.stack([conditions] * n_samples, axis=1) post_samples = self._solve_euler(latent_vars, conditions, step_size, **kwargs) # Remove trailing first dimension in the single data case if n_data_sets == 1: post_samples = tf.squeeze(post_samples, axis=0) # Return numpy version of tensor or tensor itself if to_numpy: return post_samples.numpy() return post_samples
[docs] def log_density(self, input_dict, to_numpy=True, step_size=1e-3, **kwargs): """Computes the log density...""" # Compute condition (direct, summary, or both) _, conditions = self._compute_summary_condition( input_dict.get(defaults.DEFAULT_KEYS["summary_conditions"]), input_dict.get(defaults.DEFAULT_KEYS["direct_conditions"]), training=False, **kwargs, ) # Extract targets target_vars = input_dict[defaults.DEFAULT_KEYS["parameters"]] # Reverse ODE and log pdf computation with the trace method latents, trace = self._solve_euler_inv(target_vars, conditions, step_size, **kwargs) lpdf = self.latent_dist.log_prob(latents) + trace # Return numpy version of tensor or tensor itself if to_numpy: return lpdf.numpy() return lpdf
def _solve_euler(self, latent_vars, condition, dt=1e-3, **kwargs): """Simple stochastic parallel Euler solver.""" num_steps = int(1 / dt) time_vec = tf.zeros((tf.shape(latent_vars)[0], tf.shape(latent_vars)[1], 1)) target = tf.identity(latent_vars) for _ in range(num_steps + 1): target += self.drift_net.drift(target, time_vec, condition, **kwargs) * dt time_vec += dt return target def _solve_euler_inv(self, targets, condition, dt=1e-3, **kwargs): """Solves the reverse ODE (negative direction of drift) and returns the trace.""" def velocity(latents, drift, time_vec, condition, **kwargs): v = drift(latents, time_vec, condition, **kwargs) return v batch_size = tf.shape(targets)[0] num_samples = tf.shape(targets)[1] num_steps = int(1 / dt) time_vec = tf.ones((batch_size, num_samples, 1)) trace = tf.zeros((batch_size, num_samples)) latents = tf.identity(targets) for _ in range(num_steps + 1): f = partial(velocity, drift=self.drift_net.drift, time_vec=time_vec, condition=condition) drift_t, trace_t = compute_jacobian_trace(f, latents, **kwargs) latents -= drift_t * dt trace -= trace_t * dt time_vec -= dt return latents, trace def _compute_summary_condition(self, summary_conditions, direct_conditions, **kwargs): """Determines how to concatenate the provided conditions.""" # Compute learnable summaries, if given if self.summary_net is not None: sum_condition = self.summary_net(summary_conditions, **kwargs) else: sum_condition = None # Concatenate learnable summaries with fixed summaries if sum_condition is not None and direct_conditions is not None: full_cond = tf.concat([sum_condition, direct_conditions], axis=-1) elif sum_condition is not None: full_cond = sum_condition elif direct_conditions is not None: full_cond = direct_conditions else: raise SummaryStatsError("Could not concatenarte or determine conditioning inputs...") return sum_condition, full_cond def _determine_latent_dist(self, latent_dist): """Determines which latent distribution to use and defaults to unit normal if ``None`` provided.""" if latent_dist is None: return tfp.distributions.MultivariateNormalDiag(loc=[0.0] * self.latent_dim) else: return latent_dist def _determine_summary_loss(self, loss_fun): """Determines which summary loss to use if default `None` argument provided, otherwise return identity.""" # If callable, return provided loss if loss_fun is None or callable(loss_fun): return loss_fun # If string, check for MMD or mmd elif type(loss_fun) is str: if loss_fun.lower() == "mmd": return mmd_summary_space else: raise NotImplementedError("For now, only 'mmd' is supported as a string argument for summary_loss_fun!") # Throw if loss type unexpected else: raise NotImplementedError( "Could not infer summary_loss_fun, argument should be of type (None, callable, or str)!" ) def _determine_loss(self, loss_fun): """Determines which summary loss to use if default ``None`` argument provided, otherwise return identity.""" if loss_fun is None: return tf.keras.losses.log_cosh return loss_fun