# Copyright (c) 2022 The BayesFlow Developers
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import logging
from abc import ABC, abstractmethod
from functools import partial
logging.basicConfig()
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from bayesflow.default_settings import DEFAULT_KEYS
from bayesflow.exceptions import ConfigurationError, SummaryStatsError
from bayesflow.helper_functions import check_tensor_sanity
from bayesflow.losses import log_loss, mmd_summary_space, norm_diff
from bayesflow.networks import EvidentialNetwork
[docs]
class AmortizedTarget(ABC):
"""An abstract interface for an amortized learned distribution. Children should
implement the following public methods:
1. ``compute_loss(self, input_dict, **kwargs)``
2. ``sample(input_dict, **kwargs)``
3. ``log_prob(input_dict, **kwargs)``
"""
[docs]
@abstractmethod
def __init__(self, *args, **kwargs):
pass
[docs]
@abstractmethod
def compute_loss(self, input_dict, **kwargs):
pass
[docs]
@abstractmethod
def sample(self, input_dict, **kwargs):
pass
[docs]
@abstractmethod
def log_prob(self, input_dict, **kwargs):
pass
def _check_output_sanity(self, tensor):
logger = logging.getLogger()
check_tensor_sanity(tensor, logger)
[docs]
class AmortizedPosterior(tf.keras.Model, AmortizedTarget):
"""A wrapper to connect an inference network for parameter estimation with an optional summary network
as in the original BayesFlow set-up described in the paper:
[1] Radev, S. T., Mertens, U. K., Voss, A., Ardizzone, L., & Köthe, U. (2020).
BayesFlow: Learning complex stochastic models with invertible neural networks.
IEEE Transactions on Neural Networks and Learning Systems.
But also allowing for augmented functionality, such as model misspecification detection in summary space:
[2] Schmitt, M., Bürkner, P. C., Köthe, U., & Radev, S. T. (2022).
Detecting Model Misspecification in Amortized Bayesian Inference with Neural Networks
arXiv preprint arXiv:2112.08866.
And learning of fat-tailed posteriors with a Student-t latent pushforward density:
[3] Jaini, P., Kobyzev, I., Yu, Y., & Brubaker, M. (2020, November).
Tails of lipschitz triangular flows.
In International Conference on Machine Learning (pp. 4673-4681). PMLR.
[4] Alexanderson, S., & Henter, G. E. (2020).
Robust model training and generalisation with Studentising flows.
arXiv preprint arXiv:2006.06599.
Serves as in interface for learning ``p(parameters | data, context).``
"""
[docs]
def __init__(
self,
inference_net,
summary_net=None,
latent_dist=None,
latent_is_dynamic=False,
summary_loss_fun=None,
**kwargs,
):
"""Initializes a composite neural network to represent an amortized approximate posterior
for a Bayesian generative model.
Parameters
----------
inference_net : tf.keras.Model
An (invertible) inference network which processes the outputs of a generative model
summary_net : tf.keras.Model or None, optional, default: None
An optional summary network to compress non-vector data structures.
latent_dist : callable or None, optional, default: None
The latent distribution towards which to optimize the networks. Defaults to
a multivariate unit Gaussian.
latent_is_dynamic : bool, optional, default: False
If set to `True`, assumes that ``latent_dist`` is a function of the condtion and takes
a different shape depending on the condition. Useful for more expressive transforms
of complex distributions, such as fat-tailed or highly-multimodal distributions.
Important: In the case of dynamic latents, the user is responsible that the
latent is appropriately parameterized! If not using ``tensorflow_probability``,
the ``latent_dist`` object needs to implement the following methods:
- ``latent_dist(x).log_prob(z)`` and
- ``latent_dist(x).sample(n_samples)``
summary_loss_fun : callable, str, or None, optional, default: None
The loss function which accepts the outputs of the summary network. If ``None``, no loss is provided
and the summary space will not be shaped according to a known distribution (see [2]).
If ``summary_loss_fun='MMD'``, the default loss from [2] will be used.
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the ``__init__`` method of a ``tf.keras.Model`` instance.
Important
----------
- If no ``summary_net`` is provided, then the output dictionary of your generative model should not contain
any ``summary_conditions``, i.e., ``summary_conditions`` should be set to ``None``, otherwise these will be ignored.
"""
tf.keras.Model.__init__(self, **kwargs)
self.inference_net = inference_net
self.summary_net = summary_net
self.latent_dim = self.inference_net.latent_dim
self.latent_is_dynamic = latent_is_dynamic
self.summary_loss = self._determine_summary_loss(summary_loss_fun)
self.latent_dist = self._determine_latent_dist(latent_dist)
[docs]
def call(self, input_dict, return_summary=False, **kwargs):
"""Performs a forward pass through the summary and inference network given an input dictionary.
Parameters
----------
input_dict : dict
Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
``parameters`` - the latent model parameters over which a condition density is learned
``summary_conditions`` - the conditioning variables (including data) that are first passed through a summary network
``direct_conditions`` - the conditioning variables that the directly passed to the inference network
return_summary : bool, optional, default: False
A flag which determines whether the learnable data summaries (representations) are returned or not.
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the networks
For instance, ``kwargs={'training': True}`` is passed automatically during training.
Returns
-------
net_out or (net_out, summary_out) : tuple of tf.Tensor
the outputs of ``inference_net(theta, summary_net(x, c_s), c_d)``, usually a latent variable and
log(det(Jacobian)), that is a tuple ``(z, log_det_J) or (sum_outputs, (z, log_det_J))`` if
``return_summary`` is set to True and a summary network is defined.``
"""
# Concatenate conditions, if given
summary_out, full_cond = self._compute_summary_condition(
input_dict.get(DEFAULT_KEYS["summary_conditions"]),
input_dict.get(DEFAULT_KEYS["direct_conditions"]),
**kwargs,
)
# Compute output of inference net
net_out = self.inference_net(input_dict[DEFAULT_KEYS["parameters"]], full_cond, **kwargs)
# Return summary outputs or not, depending on parameter
if return_summary:
return net_out, summary_out
return net_out
[docs]
def compute_loss(self, input_dict, **kwargs):
"""Computes the loss of the posterior amortizer given an input dictionary, which will
typically be the output of a Bayesian ``GenerativeModel`` instance.
Parameters
----------
input_dict : dict
Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
``parameters`` - the latent model parameters over which a condition density is learned
``summary_conditions`` - the conditioning variables that are first passed through a summary network
``direct_conditions`` - the conditioning variables that the directly passed to the inference network
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the networks
For instance, ``kwargs={'training': True}`` is passed automatically during training.
Returns
-------
total_loss : tf.Tensor of shape (1,) - the total computed loss given input variables
"""
# Get amortizer outputs
net_out, sum_out = self(input_dict, return_summary=True, **kwargs)
z, log_det_J = net_out
# Case summary loss should be computed
if self.summary_loss is not None:
sum_loss = self.summary_loss(sum_out)
# Case no summary loss, simply add 0 for convenience
else:
sum_loss = 0.0
# Case dynamic latent space - function of summary conditions
if self.latent_is_dynamic:
logpdf = self.latent_dist(sum_out).log_prob(z)
# Case _static latent space
else:
logpdf = self.latent_dist.log_prob(z)
# Compute and return total loss
total_loss = tf.reduce_mean(-logpdf - log_det_J) + sum_loss
return total_loss
[docs]
def call_loop(self, input_list, return_summary=False, **kwargs):
"""Performs a forward pass through the summary and inference network given a list of dicts
with the appropriate entries (i.e., as used for the standard call method).
This method is useful when GPU memory is limited or data sets have a different (non-Tensor) structure.
Parameters
----------
input_list : list of dicts, where each dict contains the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
``parameters`` - the latent model parameters over which a condition density is learned
``summary_conditions`` - the conditioning variables (including data) that are first passed through a summary network
``direct_conditions`` - the conditioning variables that the directly passed to the inference network
return_summary : bool, optional, default: False
A flag which determines whether the learnable data summaries (representations) are returned or not.
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the networks
Returns
-------
net_out or (net_out, summary_out) : tuple of tf.Tensor
the outputs of ``inference_net(theta, summary_net(x, c_s), c_d)``, usually a latent variable and
log(det(Jacobian)), that is a tuple ``(z, log_det_J) or (sum_outputs, (z, log_det_J)) if
return_summary is set to True and a summary network is defined.``
"""
outputs = []
for forward_dict in input_list:
outputs.append(self(forward_dict, return_summary, **kwargs))
net_out = [tf.concat([o[i] for o in outputs], axis=0) for i in range(len(outputs[0]))]
return tuple(net_out)
[docs]
def sample(self, input_dict, n_samples, to_numpy=True, **kwargs):
"""Generates random draws from the approximate posterior given a dictionary with conditonal variables.
Parameters
----------
input_dict : dict
Input dictionary containing at least one of the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
``summary_conditions`` : the conditioning variables (including data) that are first passed through a summary network
``direct_conditions`` : the conditioning variables that the directly passed to the inference network
n_samples : int
The number of posterior draws (samples) to obtain from the approximate posterior
to_numpy : bool, optional, default: True
Flag indicating whether to return the samples as a ``np.ndarray`` or a ``tf.Tensor``.
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the networks
Returns
-------
post_samples : tf.Tensor or np.ndarray of shape (n_data_sets, n_samples, n_params)
The sampled parameters from the approximate posterior of each data set
"""
# Compute learnable summaries, if appropriate
_, conditions = self._compute_summary_condition(
input_dict.get(DEFAULT_KEYS["summary_conditions"]),
input_dict.get(DEFAULT_KEYS["direct_conditions"]),
training=False,
**kwargs,
)
# Obtain number of data sets
n_data_sets = conditions.shape[0]
# Obtain random draws from the approximate posterior given conditioning variables
# Case dynamic, assume tensorflow_probability instance, so need to reshape output from
# (n_samples, n_data_sets, latent_dim) to (n_data_sets, n_samples, latent_dim)
if self.latent_is_dynamic:
z_samples = self.latent_dist(conditions).sample(n_samples)
z_samples = tf.transpose(z_samples, (1, 0, 2))
# Case _static latent - marginal samples from the specified dist
else:
z_samples = self.latent_dist.sample((n_data_sets, n_samples))
# Obtain random draws from the approximate posterior given conditioning variables
post_samples = self.inference_net.inverse(z_samples, conditions, training=False, **kwargs)
# Only return 2D array, if first dimensions is 1
if post_samples.shape[0] == 1:
post_samples = post_samples[0]
self._check_output_sanity(post_samples)
# Return numpy version of tensor or tensor itself
if to_numpy:
return post_samples.numpy()
return post_samples
[docs]
def sample_loop(self, input_list, n_samples, to_numpy=True, **kwargs):
"""Generates random draws from the approximate posterior given a list of dicts with conditonal variables.
Useful when GPU memory is limited or data sets have a different (non-Tensor) structure.
Parameters
----------
input_list : list of dictionaries, each dictionary having the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
``summary_conditions`` : the conditioning variables (including data) that are first passed through a summary network
``direct_conditions`` : the conditioning variables that the directly passed to the inference network
n_samples : int
The number of posterior draws (samples) to obtain from the approximate posterior
to_numpy : bool, optional, default: True
Flag indicating whether to return the samples as a ``np.ndarray`` or a ``tf.Tensor``
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the networks
Returns
-------
post_samples : tf.Tensor or np.ndarray of shape (n_datasets, n_samples, n_params)
The sampled parameters from the approximate posterior of each data set
"""
post_samples = []
for input_dict in input_list:
post_samples.append(self.sample(input_dict, n_samples, to_numpy, **kwargs))
if to_numpy:
return np.concatenate(post_samples, axis=0)
return tf.concat(post_samples, axis=0)
[docs]
def log_posterior(self, input_dict, to_numpy=True, **kwargs):
"""Calculates the approximate log-posterior of targets given conditional variables via
the change-of-variable formula for a conditional normalizing flow.
Parameters
----------
input_dict : dict
Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
``parameters`` : the latent model parameters over which a conditional density (i.e., a posterior) is learned
``summary_conditions`` : the conditioning variables (including data) that are first passed through a summary network
``direct_conditions`` : the conditioning variables that are directly passed to the inference network
to_numpy : bool, optional, default: True
Flag indicating whether to return the lpdf values as a ``np.ndarray`` or a ``tf.Tensor``
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the networks
Returns
-------
log_post : tf.Tensor or np.ndarray of shape (batch_size, n_obs)
the approximate log-posterior density of each each parameter
"""
# Compute learnable summaries, if appropriate
_, conditions = self._compute_summary_condition(
input_dict.get(DEFAULT_KEYS["summary_conditions"]),
input_dict.get(DEFAULT_KEYS["direct_conditions"]),
training=False,
**kwargs,
)
# Forward pass through the network
z, log_det_J = self.inference_net.forward(
input_dict[DEFAULT_KEYS["parameters"]], conditions, training=False, **kwargs
)
# Compute approximate log posterior
# Case dynamic latent - function of conditions
if self.latent_is_dynamic:
log_post = self.latent_dist(conditions).log_prob(z) + log_det_J
# Case _static latent - marginal samples from z
else:
log_post = self.latent_dist.log_prob(z) + log_det_J
self._check_output_sanity(log_post)
if to_numpy:
return log_post.numpy()
return log_post
[docs]
def log_prob(self, input_dict, to_numpy=True, **kwargs):
"""Identical to `log_posterior(input_dict, to_numpy, **kwargs)`."""
return self.log_posterior(input_dict, to_numpy=to_numpy, **kwargs)
def _compute_summary_condition(self, summary_conditions, direct_conditions, **kwargs):
"""Determines how to concatenate the provided conditions."""
# Compute learnable summaries, if given
if self.summary_net is not None:
sum_condition = self.summary_net(summary_conditions, **kwargs)
else:
sum_condition = None
# Concatenate learnable summaries with fixed summaries
if sum_condition is not None and direct_conditions is not None:
full_cond = tf.concat([sum_condition, direct_conditions], axis=-1)
elif sum_condition is not None:
full_cond = sum_condition
elif direct_conditions is not None:
full_cond = direct_conditions
else:
raise SummaryStatsError("Could not concatenate or determine conditioning inputs...")
return sum_condition, full_cond
def _determine_latent_dist(self, latent_dist):
"""Determines which latent distribution to use and defaults to unit normal if None provided."""
if latent_dist is None:
return tfp.distributions.MultivariateNormalDiag(loc=[0.0] * self.latent_dim)
else:
return latent_dist
def _determine_summary_loss(self, loss_fun):
"""Determines which summary loss to use if default `None` argument provided, otherwise return identity."""
# Throw, if summary loss without a summary network provided
if loss_fun is not None and self.summary_net is None:
raise ConfigurationError("You need to provide a summary_net if you want to use a summary_loss_fun.")
# If callable, return provided loss
if loss_fun is None or callable(loss_fun):
return loss_fun
# If string, check for MMD or mmd
elif type(loss_fun) is str:
if loss_fun.lower() == "mmd":
return mmd_summary_space
else:
raise NotImplementedError("For now, only 'mmd' is supported as a string argument for summary_loss_fun!")
# Throw if loss type unexpected
else:
raise NotImplementedError(
"Could not infer summary_loss_fun, argument should be of type (None, callable, or str)!"
)
[docs]
class AmortizedLikelihood(tf.keras.Model, AmortizedTarget):
"""An interface for a surrogate model of a simulator, or an implicit likelihood
``p(data | parameters, context)``.
"""
[docs]
def __init__(self, surrogate_net, latent_dist=None, **kwargs):
"""Initializes a composite neural architecture representing an amortized emulator
for the simulator (i.e., the implicit likelihood model).
Parameters
----------
surrogate_net : tf.keras.Model
An (invertible) inference network which processes the outputs of the generative model.
latent_dist : callable or None, optional, default: None
The latent distribution towards which to optimize the surrogate network outputs. Defaults to
a multivariate unit Gaussian.
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the ``__init__`` method of a ``tf.keras.Model`` instance.
"""
tf.keras.Model.__init__(self, **kwargs)
self.surrogate_net = surrogate_net
self.latent_dim = self.surrogate_net.latent_dim
self.latent_dist = self._determine_latent_dist(latent_dist)
[docs]
def call(self, input_dict, **kwargs):
"""Performs a forward pass through the summary and inference network.
Parameters
----------
input_dict : dict
Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
``observables`` - the observables over which a condition density is learned (i.e., the data)
``conditions`` - the conditioning variables that the directly passed to the inference network
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the network
For instance, ``kwargs={'training': True}`` is passed automatically during training.
Returns
-------
net_out
the outputs of ``surrogate_net(theta, summary_net(x, c_s), c_d)``, usually a latent variable and
log(det(Jacobian)), that is a tuple ``(z, log_det_J)``.
"""
net_out = self.surrogate_net(
input_dict[DEFAULT_KEYS["observables"]], input_dict[DEFAULT_KEYS["conditions"]], **kwargs
)
return net_out
[docs]
def call_loop(self, input_list, **kwargs):
"""Performs a forward pass through the surrogate network given a list of dicts
with the appropriate entries (i.e., as used for the standard call method).
This method is useful when GPU memory is limited or data sets have a different (non-Tensor) structure.
Parameters
----------
input_list : list of dicts, where each dict contains the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
``observables`` - the observables over which a condition density is learned (i.e., the data)
``conditions`` - the conditioning variables that the directly passed to the inference network
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the network
Returns
-------
net_out or (net_out, summary_out) : tuple of tf.Tensor
the outputs of ``inference_net(theta, summary_net(x, c_s), c_d)``, usually a latent variable and
log(det(Jacobian)), that is a tuple ``(z, log_det_J)``.
"""
outputs = []
for forward_dict in input_list:
outputs.append(self(forward_dict, **kwargs))
net_out = [tf.concat([o[i] for o in outputs], axis=0) for i in range(len(outputs[0]))]
return tuple(net_out)
[docs]
def sample(self, input_dict, n_samples, to_numpy=True, **kwargs):
"""Generates `n_samples` random draws from the surrogate likelihood given input conditions.
Parameters
----------
input_dict : dict
Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
``conditions`` - the conditioning variables that are directly passed to the surrogate network
n_samples : int
The number of posterior samples to obtain from the approximate posterior
to_numpy : bool, optional, default: True
Flag indicating whether to return the samples as a ``np.ndarray`` or a ``tf.Tensor``
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the network
Returns
-------
lik_samples : tf.Tensor or np.ndarray of shape (n_datasets, n_samples, None)
A simulated batch of observables from the surrogate likelihood.
"""
# Extract condition
conditions = input_dict[DEFAULT_KEYS["conditions"]]
# Obtain number of data sets
n_data_sets = conditions.shape[0]
# Obtain random draws from the surrogate likelihood given conditioning variables
z_samples = self.latent_dist.sample((n_data_sets, n_samples))
# Obtain random draws from the surrogate likelihood given conditioning variables
lik_samples = self.surrogate_net.inverse(z_samples, conditions, training=False, **kwargs)
# Only return 2D array, if first dimensions is 1
if lik_samples.shape[0] == 1:
lik_samples = lik_samples[0]
self._check_output_sanity(lik_samples)
if to_numpy:
return lik_samples.numpy()
return lik_samples
[docs]
def sample_loop(self, input_list, n_samples, to_numpy=True, **kwargs):
"""Generates random draws from the surrogate network given a list of dicts with conditional variables.
Useful when GPU memory is limited or data sets have a different (non-Tensor) structure.
Parameters
----------
input_list : list of dictionaries, each dictionary having the following mandatory keys (default):
``conditions`` - the conditioning variables that the directly passed to the surrogate network
n_samples : int
The number of posterior draws (samples) to obtain from the approximate posterior
to_numpy : bool, optional, default: True
Flag indicating whether to return the samples as a `np.array` or a `tf.Tensor`
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the network
Returns
-------
post_samples : tf.Tensor or np.ndarray of shape (n_data_sets, n_samples, data_dim)
the sampled parameters per data set
"""
post_samples = []
for input_dict in input_list:
post_samples.append(self.sample(input_dict, n_samples, to_numpy, **kwargs))
if to_numpy:
return np.concatenate(post_samples, axis=0)
return tf.concat(post_samples, axis=0)
[docs]
def log_likelihood(self, input_dict, to_numpy=True, **kwargs):
"""Calculates the approximate log-likelihood of targets given conditional variables via
the change-of-variable formula for a conditional normalizing flow.
Parameters
----------
input_dict : dict
Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
``observables`` - the variables over which a condition density is learned (i.e., the observables)
``conditions`` - the conditioning variables that the directly passed to the inference network
to_numpy : bool, optional, default: True
Boolean flag indicating whether to return the log-lik values as a ``np.ndarray`` or a ``tf.Tensor``
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the network
Returns
-------
log_lik : tf.Tensor or np.ndarray of shape (batch_size, n_obs)
the approximate log-likelihood of each data point in each data set
"""
# Forward pass through the network
z, log_det_J = self.surrogate_net.forward(
input_dict[DEFAULT_KEYS["observables"]], input_dict[DEFAULT_KEYS["conditions"]], training=False, **kwargs
)
# Compute approximate log likelihood
log_lik = self.latent_dist.log_prob(z) + log_det_J
self._check_output_sanity(log_lik)
# Convert tensor to numpy array, if specified
if to_numpy:
return log_lik.numpy()
return log_lik
[docs]
def log_prob(self, input_dict, to_numpy=True, **kwargs):
"""Identical to `log_likelihood(input_dict, to_numpy, **kwargs)`."""
return self.log_likelihood(input_dict, to_numpy=to_numpy, **kwargs)
[docs]
def compute_loss(self, input_dict, **kwargs):
"""Computes the loss of the amortized given input data provided in input_dict.
Parameters
----------
input_dict : dict
Input dictionary containing the following mandatory keys:
``data`` - the observables over which a condition density is learned (i.e., the observables)
``conditions`` - the conditioning variables that the directly passed to the surrogate network
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the network
For instance, ``kwargs={'training': True}`` is passed automatically during simulation-based training.
Returns
-------
loss : tf.Tensor of shape (1,) - the total computed loss given input variables
"""
z, log_det_J = self(input_dict, **kwargs)
loss = tf.reduce_mean(-self.latent_dist.log_prob(z) - log_det_J)
return loss
def _determine_latent_dist(self, latent_dist):
"""Determines which latent distribution to use and defaults to unit normal if ``None`` provided."""
if latent_dist is None:
return tfp.distributions.MultivariateNormalDiag(loc=[0.0] * self.latent_dim)
else:
return latent_dist
[docs]
class AmortizedPosteriorLikelihood(tf.keras.Model, AmortizedTarget):
"""An interface for jointly learning a surrogate model of the simulator and an approximate
posterior given a generative model, as proposed by:
[1] Radev, S. T., Schmitt, M., Pratz, V., Picchini, U., Köthe, U., & Bürkner, P. C. (2023).
JANA: Jointly Amortized Neural Approximation of Complex Bayesian Models.
arXiv preprint arXiv:2302.09125.
"""
[docs]
def __init__(self, amortized_posterior, amortized_likelihood, **kwargs):
"""Initializes a joint learner comprising an amortized posterior and an amortized emulator.
Parameters
----------
amortized_posterior : an instance of AmortizedPosterior or a custom tf.keras.Model
The generative neural posterior approximator
amortized_likelihood : an instance of AmortizedLikelihood or a custom tf.keras.Model
The generative neural likelihood approximator
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the ``__init__`` method of a ``tf.keras.Model`` instance
"""
tf.keras.Model.__init__(self, **kwargs)
self.amortized_posterior = amortized_posterior
self.amortized_likelihood = amortized_likelihood
[docs]
def call(self, input_dict, **kwargs):
"""Performs a forward pass through both amortizers.
Parameters
----------
input_dict : dict
Input dictionary containing the following mandatory keys:
`posterior_inputs` - The input dictionary for the amortized posterior
`likelihood_inputs` - The input dictionary for the amortized likelihood
Returns
-------
(post_out, lik_out) : tuple
The outputs of the posterior and likelihood networks given input variables.
"""
post_out = self.amortized_posterior(input_dict["posterior_inputs"], **kwargs)
lik_out = self.amortized_likelihood(input_dict["likelihood_inputs"], **kwargs)
return post_out, lik_out
[docs]
def compute_loss(self, input_dict, **kwargs):
"""Computes the loss of the join amortizer by summing the corresponding amortized posterior
and likelihood losses.
Parameters
----------
input_dict : dict
Nested input dictionary containing the following mandatory keys, if DEFAULT_KEYS unchanged::
`posterior_inputs` - The input dictionary for the amortized posterior
`likelihood_inputs` - The input dictionary for the amortized likelihood
Returns
-------
total_losses : dict
A dictionary with keys `Post.Loss` and `Lik.Loss` containing the individual losses for the
two amortizers.
"""
loss_post = self.amortized_posterior.compute_loss(input_dict[DEFAULT_KEYS["posterior_inputs"]], **kwargs)
loss_lik = self.amortized_likelihood.compute_loss(input_dict[DEFAULT_KEYS["likelihood_inputs"]], **kwargs)
return {"Post.Loss": loss_post, "Lik.Loss": loss_lik}
[docs]
def log_likelihood(self, input_dict, to_numpy=True, **kwargs):
"""Calculates the approximate log-likelihood of data given conditional variables via
the change-of-variable formula for conditional normalizing flows.
Parameters
----------
input_dict : dict
Input dictionary containing the following mandatory keys, if DEFAULT_KEYS unchanged:
`observables` - the variables over which a condition density is learned (i.e., the observables)
`conditions` - the conditioning variables that are directly passed to the inference network
OR a nested dictionary with key `likelihood_inputs` containing the above input dictionary
to_numpy : bool, optional, default: True
Flag indicating whether to return the samples as a `np.array` or a `tf.Tensor`
Returns
-------
log_lik : tf.Tensor of shape (batch_size, n_obs)
the approximate log-likelihood of each data point in each data set
"""
if input_dict.get(DEFAULT_KEYS["likelihood_inputs"]) is not None:
return self.amortized_likelihood.log_likelihood(
input_dict[DEFAULT_KEYS["likelihood_inputs"]], to_numpy=to_numpy, **kwargs
)
return self.amortized_likelihood.log_likelihood(input_dict, to_numpy=to_numpy, **kwargs)
[docs]
def log_posterior(self, input_dict, to_numpy=True, **kwargs):
"""Calculates the approximate log-posterior of targets given conditional variables via
the change-of-variable formula for conditional normalizing flows.
Parameters
----------
input_dict : dict
Input dictionary containing the following mandatory keys, if DEFAULT_KEYS unchanged:
`parameters` - the latent generative model parameters over which a condition density is learned
`summary_conditions` - the conditioning variables that are first passed through a summary network
`direct_conditions` - the conditioning variables that the directly passed to the inference network
OR a nested dictionary with key `posterior_inputs` containing the above input dictionary
Returns
-------
log_post : tf.Tensor of shape (batch_size, n_obs)
the approximate log-likelihood of each data point in each data set
"""
if input_dict.get(DEFAULT_KEYS["posterior_inputs"]) is not None:
return self.amortized_posterior.log_posterior(
input_dict[DEFAULT_KEYS["posterior_inputs"]], to_numpy=to_numpy, **kwargs
)
return self.amortized_posterior.log_posterior(input_dict, to_numpy=to_numpy, **kwargs)
[docs]
def log_prob(self, input_dict, to_numpy=True, **kwargs):
"""Identical to calling separate `log_likelihood()` and `log_posterior()`.
Returns
-------
out_dict : dict with keys `log_posterior` and `log_likelihood` corresponding
to the computed log_pdfs of the approximate posterior and likelihood.
"""
log_post = self.log_posterior(input_dict, to_numpy=to_numpy, **kwargs)
log_lik = self.log_likelihood(input_dict, to_numpy=to_numpy, **kwargs)
out_dict = {"log_posterior": log_post, "log_likelihood": log_lik}
return out_dict
[docs]
def sample_data(self, input_dict, n_samples, to_numpy=True, **kwargs):
"""Generates `n_samples` random draws from the surrogate likelihood given input conditions.
Parameters
----------
input_dict : dict
Input dictionary containing the following mandatory keys, if DEFAULT_KEYS unchanged:
`conditions` - the conditioning variables that the directly passed to the inference network
OR a nested dictionary with key `likelihood_inputs` containing the above input dictionary
n_samples : int
The number of posterior samples to obtain from the approximate posterior
to_numpy : bool, optional, default: True
Flag indicating whether to return the samples as a `np.array` or a `tf.Tensor`
Returns
-------
lik_samples : tf.Tensor or np.ndarray of shape (n_datasets, n_samples, None)
Simulated observables from the surrogate likelihood.
"""
if input_dict.get(DEFAULT_KEYS["likelihood_inputs"]) is not None:
return self.amortized_likelihood.sample(
input_dict[DEFAULT_KEYS["likelihood_inputs"]], n_samples, to_numpy=to_numpy, **kwargs
)
return self.amortized_likelihood.sample(input_dict, n_samples, to_numpy=to_numpy, **kwargs)
[docs]
def sample_parameters(self, input_dict, n_samples, to_numpy=True, **kwargs):
"""Generates random draws from the approximate posterior given conditonal variables.
Parameters
----------
input_dict : dict
Input dictionary containing the following mandatory keys, if DEFAULT KEYS unchanged:
`summary_conditions` : the conditioning variables (including data) that are first passed through a summary network
`direct_conditions` : the conditioning variables that the directly passed to the inference network
OR a nested dictionary with key `posterior_inputs` containing the above input dictionary
n_samples : int
The number of posterior samples to obtain from the approximate posterior
to_numpy : bool, optional, default: True
Boolean flag indicating whether to return the samples as a `np.array` or a `tf.Tensor`
Returns
-------
post_samples : tf.Tensor or np.ndarray of shape (n_datasets, n_samples, n_params)
the sampled parameters per data set
"""
if input_dict.get(DEFAULT_KEYS["posterior_inputs"]) is not None:
return self.amortized_posterior.sample(
input_dict[DEFAULT_KEYS["posterior_inputs"]], n_samples, to_numpy=to_numpy, **kwargs
)
return self.amortized_posterior.sample(input_dict, n_samples, to_numpy=to_numpy, **kwargs)
[docs]
def sample(self, input_dict, n_post_samples, n_lik_samples, to_numpy=True, **kwargs):
"""Identical to calling `sample_parameters()` and `sample_data()` separately.
Returns
-------
out_dict : dict with keys `posterior_samples` and `likelihood_samples` corresponding
to the `n_samples` from the approximate posterior and likelihood, respectively
"""
post_samples = self.sample_parameters(input_dict, n_post_samples, to_numpy=to_numpy, **kwargs)
lik_samples = self.sample_data(input_dict, n_lik_samples, to_numpy=to_numpy, **kwargs)
out_dict = {"posterior_samples": post_samples, "likelihood_samples": lik_samples}
return out_dict
[docs]
class AmortizedModelComparison(tf.keras.Model):
"""An interface to connect an evidential network for Bayesian model comparison with an optional summary network,
as described in the original paper on evidential neural networks for model comparison according to [1, 2]:
[1] Radev, S. T., D'Alessandro, M., Mertens, U. K., Voss, A., Köthe, U., & Bürkner, P. C. (2021).
Amortized bayesian model comparison with evidential deep learning.
IEEE Transactions on Neural Networks and Learning Systems.
[2] Elsemüller, L., Schnuerch, M., Bürkner, P. C., & Radev, S. T. (2023).
A Deep Learning Method for Comparing Bayesian Hierarchical Models.
arXiv preprint arXiv:2301.11873.
Note: the original paper [1] does not distinguish between the summary and the evidential networks, but
treats them as a whole, with the appropriate architecture dictated by the model application. For the
sake of consistency and modularity, the BayesFlow library separates the two constructs.
"""
[docs]
def __init__(self, inference_net, summary_net=None, loss_fun=None):
"""Initializes a composite neural architecture for amortized bayesian model comparison.
Parameters
----------
inference_net : tf.keras.Model
A neural network which outputs model evidences.
summary_net : tf.keras.Model or None, optional, default: None
An optional summary network
loss_fun : callable or None, optional, default: None
The loss function which accepts the outputs of the amortizer. If None, the loss will be the log-loss.
Important
----------
- If no ``summary_net`` is provided, then the output dictionary of your generative model should not contain
any `sumamry_conditions`, i.e., ``summary_conditions`` should be set to None, otherwise these will be ignored.
- If no custom ``loss_fun`` is provided, the loss function will be the log loss for the means of a Dirichlet
distribution or softmax outputs.
"""
super().__init__()
self.inference_net = inference_net
self.summary_net = summary_net
self.loss = self._determine_loss(loss_fun)
self.num_models = self.inference_net.num_models
[docs]
def call(self, input_dict, return_summary=False, **kwargs):
"""Performs a forward pass through both networks.
Parameters
----------
input_dict : dict
Input dictionary containing the following mandatory keys, if DEFAULT_KEYS unchanged
`summary_conditions` - the conditioning variables that are first passed through a summary network
`direct_conditions` - the conditioning variables that the directly passed to the evidential network
`model_indices` - the ground-truth, one-hot encoded model indices sampled from the model prior
return_summary : bool, optional, default: False
Indicates whether the summary network outputs are returned along the estimated evidences.
Returns
-------
net_out : tf.Tensor of shape (batch_size, num_models) or tuple of (net_out (batch_size, num_models),
summary_out (batch_size, summary_dim)), the latter being the summary network outputs, if
``return_summary is True``.
"""
summary_out, full_cond = self._compute_summary_condition(
input_dict.get(DEFAULT_KEYS["summary_conditions"]),
input_dict.get(DEFAULT_KEYS["direct_conditions"]),
**kwargs,
)
net_out = self.inference_net(full_cond, **kwargs)
if not return_summary:
return net_out
return net_out, summary_out
[docs]
def posterior_probs(self, input_dict, to_numpy=True, **kwargs):
"""Compute posterior model probabilities (PMPs) given a dictionary with observed or
simulated data.
Parameters
----------
input_dict : dict
Input dictionary containing at least one of the following mandatory keys, if DEFAULT_KEYS unchanged
`summary_conditions` - the conditioning variables that are first passed through a summary network
`direct_conditions` - the conditioning variables that the directly passed to the evidential network
to_numpy : bool, optional, default: True
Flag indicating whether to return the PMPs a ``np.ndarray`` or a ``tf.Tensor``
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the networks
Returns
-------
out : tf.Tensor of shape (batch_size, ..., num_models)
The approximated PMPs
"""
_, full_cond = self._compute_summary_condition(
input_dict.get(DEFAULT_KEYS["summary_conditions"]),
input_dict.get(DEFAULT_KEYS["direct_conditions"]),
**kwargs,
)
pmps = self.inference_net(full_cond, **kwargs)
if to_numpy:
return pmps.numpy()
return pmps
[docs]
def compute_loss(self, input_dict, **kwargs):
"""Computes the loss of the amortized model comparison instance.
Parameters
----------
input_dict : dict
Input dictionary containing the following mandatory keys, if DEFAULT_KEYS unchanged::
`summary_conditions` - the conditioning variables that are first passed through a summary network
`direct_conditions` - the conditioning variables that the directly passed to the evidence network
`model_indices` - the ground-truth, one-hot encoded model indices sampled from the model prior
Returns
-------
loss : tf.Tensor of shape (1,) - the total computed loss given input variables
"""
preds = self(input_dict, **kwargs)
loss = self.loss(input_dict[DEFAULT_KEYS["model_indices"]], preds)
return loss
def _compute_summary_condition(self, summary_conditions, direct_conditions, **kwargs):
"""Helper method to determines how to concatenate the provided conditions."""
# Compute learnable summaries, if given
if self.summary_net is not None:
sum_condition = self.summary_net(summary_conditions, **kwargs)
else:
sum_condition = None
# Concatenate learnable summaries with fixed summaries, this
if sum_condition is not None and direct_conditions is not None:
full_cond = tf.concat([sum_condition, direct_conditions], axis=-1)
elif sum_condition is not None:
full_cond = sum_condition
elif direct_conditions is not None:
full_cond = direct_conditions
else:
raise SummaryStatsError("Could not concatenarte or determine conditioning inputs...")
return sum_condition, full_cond
def _determine_loss(self, loss_fun):
"""Helper method to determine loss function to use."""
if loss_fun is None:
return partial(log_loss, evidential=isinstance(self.inference_net, EvidentialNetwork))
elif callable(loss_fun):
return loss_fun
else:
raise ConfigurationError(
"Loss function is neither default (`None`) not callable. Please provide a valid loss function!"
)
[docs]
class TwoLevelAmortizedPosterior(tf.keras.Model, AmortizedTarget):
"""An interface for estimating arbitrary two level hierarchical Bayesian models."""
[docs]
def __init__(self, local_amortizer, global_amortizer, summary_net=None, **kwargs):
"""Creates an wrapper for estimating two-level hierarchical Bayesian models.
Parameters
----------
local_amortizer : bayesflow.amortizers.AmortizedPosterior
A posterior amortizer without a summary network which will estimate
the full conditional of the (varying numbers of) local parameter vectors.
global_amortizer : bayesflow.amortizers.AmortizedPosterior
A posterior amortizer without a summary network which will estimate the joint
posterior of hyperparameters and optional shared parameters given a representation
of an entire hierarchical data set. If both hyper- and shared parameters are present,
the first dimensions correspond to the hyperparameters and the remaining ones correspond
to the shared parameters.
summary_net : tf.keras.Model or None, optional, default: None
An optional summary network to compress non-vector data structures.
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the ``__init__`` method of a ``tf.keras.Model`` instance.
"""
super().__init__(**kwargs)
self.local_amortizer = local_amortizer
self.global_amortizer = global_amortizer
self.summary_net = summary_net
[docs]
def call(self, input_dict, **kwargs):
"""Forward pass through the hierarchical amortized posterior."""
local_summaries, global_summaries = self._compute_condition(input_dict, **kwargs)
local_inputs, global_inputs = self._prepare_inputs(input_dict, local_summaries, global_summaries)
local_out = self.local_amortizer(local_inputs, **kwargs)
global_out = self.global_amortizer(global_inputs, **kwargs)
return local_out, global_out
[docs]
def compute_loss(self, input_dict, **kwargs):
"""Compute loss of all amortizers."""
local_summaries, global_summaries = self._compute_condition(input_dict, **kwargs)
local_inputs, global_inputs = self._prepare_inputs(input_dict, local_summaries, global_summaries)
local_loss = self.local_amortizer.compute_loss(local_inputs, **kwargs)
global_loss = self.global_amortizer.compute_loss(global_inputs, **kwargs)
return {"Local.Loss": local_loss, "Global.Loss": global_loss}
[docs]
def sample(self, input_dict, n_samples, to_numpy=True, **kwargs):
"""Obtains samples from the joint hierarchical posterior given observations.
Important: Currently works only for single hierarchical data sets!
Parameters
----------
input_dict : dict
Input dictionary containing the following mandatory keys, if DEFAULT_KEYS unchanged:
`summary_conditions` - the hierarchical data set (to be embedded by the summary net)
As well as optional keys:
`direct_local_conditions` - (Context) variables used to condition the local posterior
`direct_global_conditions` - (Context) variables used to condition the global posterior
n_samples : int
The number of posterior draws (samples) to obtain from the approximate posterior
to_numpy : bool, optional, default: True
Flag indicating whether to return the samples as a `np.array` or a `tf.Tensor`
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the summary network as the amortizers
Returns
-------
samples_dict : dict
A dictionary with keys `global_samples` and `local_samples`
Local samples will hold an array-like of shape (num_replicas, num_samples, num_local)
and local samples will hold an array-like of shape (num_samples, num_hyper + num_shared),
if optional shared patameters are present, otherwise (num_samples, num_hyper),
"""
# Returned shapes will be :
# local_summaries.shape = (1, num_groups, summary_dim_local)
# global_summaries.shape = (1, summary_dim_global)
local_summaries, global_summaries = self._get_local_global(input_dict, **kwargs)
num_groups = local_summaries.shape[1]
if local_summaries.shape[0] != 1 or global_summaries.shape[0] != 1:
raise NotImplementedError("Method currently supports only single hierarchical data sets!")
# Obtain samples from p(global | all_data)
inp_global = {DEFAULT_KEYS["direct_conditions"]: global_summaries}
# New, shape will be (n_samples, num_globals)
global_samples = self.global_amortizer.sample(inp_global, n_samples, **kwargs, to_numpy=False)
# Repeat local conditions for n_samples
# New shape -> (num_groups, n_samples, summary_dim_local)
local_summaries = tf.stack([tf.squeeze(local_summaries, axis=0)] * n_samples, axis=1)
# Repeat global samples for num_groups
# New shape -> (num_groups, n_samples, num_globals)
global_samples_rep = tf.stack([global_samples] * num_groups, axis=0)
# Concatenate local summaries with global samples
# New shape -> (num_groups, num_samples, summary_dim_local + num_globals)
local_summaries = tf.concat([local_summaries, global_samples_rep], axis=-1)
# Obtain samples from p(local_i | data_i, global_i)
inp_local = {DEFAULT_KEYS["direct_conditions"]: local_summaries}
local_samples = self.local_amortizer.sample(inp_local, n_samples, to_numpy=False, **kwargs)
if to_numpy:
global_samples = global_samples.numpy()
local_samples = local_samples.numpy()
return {"global_samples": global_samples, "local_samples": local_samples}
[docs]
def log_prob(self, input_dict):
"""Compute normalized log density."""
raise NotImplementedError
def _prepare_inputs(self, input_dict, local_summaries, global_summaries):
"""Prepare input dictionaries for both amortizers."""
# Prepare inputs for local amortizer
local_inputs = {"direct_conditions": local_summaries, "parameters": input_dict["local_parameters"]}
# Prepare inputs for global amortizer
_parameters = input_dict["hyper_parameters"]
if input_dict.get("shared_parameters") is not None:
_parameters = tf.concat([_parameters, input_dict.get("shared_parameters")], axis=-1)
global_inputs = {"direct_conditions": global_summaries, "parameters": _parameters}
return local_inputs, global_inputs
def _compute_condition(self, input_dict, **kwargs):
"""Determines conditionining variables for both amortizers."""
# Obtain needed summaries
local_summaries, global_summaries = self._get_local_global(input_dict, **kwargs)
# At this point, add globals as conditions
num_locals = tf.shape(local_summaries)[1]
# Add hyper parameters as conditions:
# p(local_n | data_n, hyper)
if input_dict.get("hyper_parameters") is not None:
_params = input_dict.get("hyper_parameters")
_params = tf.expand_dims(_params, 1)
_conds = tf.tile(_params, [1, num_locals, 1])
local_summaries = tf.concat([local_summaries, _conds], axis=-1)
# Add shared parameters as conditions:
# p(local_n | data_n, hyper, shared)
if input_dict.get("shared_parameters") is not None:
_params = input_dict.get("shared_parameters")
_params = tf.expand_dims(_params, 1)
_conds = tf.tile(_params, [1, num_locals, 1])
local_summaries = tf.concat([local_summaries, _conds], axis=-1)
return local_summaries, global_summaries
def _get_local_global(self, input_dict, **kwargs):
"""Helper function to obtain local and global condition tensors."""
# Obtain summary conditions
if self.summary_net is not None:
local_summaries, global_summaries = self.summary_net(
input_dict["summary_conditions"], return_all=True, **kwargs
)
if input_dict.get("direct_local_conditions") is not None:
local_summaries = tf.concat([local_summaries, input_dict.get("direct_local_conditions")], axis=-1)
if input_dict.get("direct_global_conditions") is not None:
global_summaries = tf.concat([global_summaries, input_dict.get("direct_global_conditions")], axis=-1)
# If no summary net provided, assume direct conditions exist or fail
else:
local_summaries = input_dict.get("direct_local_conditions")
global_summaries = input_dict.get("direct_global_conditions")
return local_summaries, global_summaries
[docs]
class AmortizedPointEstimator(tf.keras.Model):
"""An interface to connect a neural point estimator for Bayesian estimation with an optional summary network [1].
[1] Sainsbury-Dale, M., Zammit-Mangion, A., & Huser, R. (2024).
Likelihood-free parameter estimation with neural Bayes estimators.
The American Statistician, 78(1), 1-14.
"""
[docs]
def __init__(self, inference_net, summary_net=None, norm_ord=2, loss_fun=None):
"""Initializes a composite neural architecture for amortized bayesian model comparison.
Parameters
----------
inference_net : tf.keras.Model
A neural network whose final output dimension equals that of the target quantities.
summary_net : tf.keras.Model or None, optional, default: None
An optional summary network
norm_ord : int or np.inf, optional, default: 2
The order of the norm used as a loss function for the point estimator. Should be in ``[1, 2, np.inf]``.
loss_fun : callable or None, optional, default: None
If not None, it overrides the norm keyword argument.
Important
----------
- If no ``summary_net`` is provided, then the output dictionary of your generative model should not contain
any `sumamry_conditions`, i.e., ``summary_conditions`` should be set to None, otherwise these will be ignored.
- If no custom ``loss_fun`` is provided, the loss function will be the log loss for the means of a Dirichlet
distribution or softmax outputs.
"""
super().__init__()
self.inference_net = inference_net
self.summary_net = summary_net
self.loss_fn = self._determine_loss(loss_fun, norm_ord)
[docs]
def call(self, input_dict, return_summary=False, **kwargs):
"""Performs a forward pass through the summary and inference network given an input dictionary.
Parameters
----------
input_dict : dict
Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
``parameters`` - the latent model parameters over which a condition density is learned
``summary_conditions`` - the conditioning variables (including data) that are first passed through a summary network
``direct_conditions`` - the conditioning variables that the directly passed to the inference network
return_summary : bool, optional, default: False
A flag which determines whether the learnable data summaries (representations) are returned or not.
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the networks
For instance, ``kwargs={'training': True}`` is passed automatically during training.
Returns
-------
net_out or (net_out, summary_out) : tuple of tf.Tensor
The outputs of ``inference_net(summary_net(x, c_s), c_d)``, usually a batch of point estimates,
that is, a tensor ``estimates`` or ``(sum_outputs, estimates)`` if ``return_summary`` is set
to True and a summary network is defined.
"""
# Concatenate conditions, if given
summary_out, full_cond = self._compute_summary_condition(
input_dict.get(DEFAULT_KEYS["summary_conditions"]),
input_dict.get(DEFAULT_KEYS["direct_conditions"]),
**kwargs,
)
# Compute output of inference net
net_out = self.inference_net(full_cond, **kwargs)
# Return summary outputs or not, depending on parameter
if return_summary:
return net_out, summary_out
return net_out
[docs]
def estimate(self, input_dict, to_numpy=True, **kwargs):
"""Obtains Bayesian point estimates given the data in input_dict.
Parameters
----------
input_dict : dict
Input dictionary containing at least one of the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
``summary_conditions`` : the conditioning variables (including data) that are first passed through a summary network
``direct_conditions`` : the conditioning variables that the directly passed to the inference network
to_numpy : bool, optional, default: True
Flag indicating whether to return the samples as a ``np.ndarray`` or a ``tf.Tensor``.
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the networks.
Returns
-------
estimates : tf.Tensor or np.ndarray of shape (num_data_sets, num_params)
The point estimates of the parameters for each data set.
"""
estimates = self(input_dict, **kwargs)
if to_numpy:
return estimates.numpy()
return estimates
[docs]
def compute_loss(self, input_dict, **kwargs):
"""Computes the loss of the posterior amortizer given an input dictionary, which will
typically be the output of a Bayesian ``GenerativeModel`` instance.
Parameters
----------
input_dict : dict
Input dictionary containing the following mandatory keys, if ``DEFAULT_KEYS`` unchanged:
``parameters`` - the latent model parameters over which a condition density is learned
``summary_conditions`` - the conditioning variables that are first passed through a summary network
``direct_conditions`` - the conditioning variables that the directly passed to the inference network
**kwargs : dict, optional, default: {}
Additional keyword arguments passed to the networks
For instance, ``kwargs={'training': True}`` is passed automatically during training.
Returns
-------
total_loss : tf.Tensor of shape (1,) - the total computed loss given input variables
"""
net_out = self(input_dict, **kwargs)
loss = tf.reduce_mean(self.loss_fn(net_out, input_dict[DEFAULT_KEYS["parameters"]]))
return loss
def _compute_summary_condition(self, summary_conditions, direct_conditions, **kwargs):
"""Determines how to concatenate the provided conditions."""
# Compute learnable summaries, if given
if self.summary_net is not None:
sum_condition = self.summary_net(summary_conditions, **kwargs)
else:
sum_condition = None
# Concatenate learnable summaries with fixed summaries
if sum_condition is not None and direct_conditions is not None:
full_cond = tf.concat([sum_condition, direct_conditions], axis=-1)
elif sum_condition is not None:
full_cond = sum_condition
elif direct_conditions is not None:
full_cond = direct_conditions
else:
raise SummaryStatsError("Could not concatenarte or determine conditioning inputs...")
return sum_condition, full_cond
def _determine_loss(self, loss_fun, norm_ord):
"""Determines which loss function to use and defaults to the norm_ord=2 as specified by the ``__init__`` method."""
# In case of user-provided loss, override norm order
if loss_fun is not None:
return loss_fun
return partial(norm_diff, ord=norm_ord, axis=-1)