# Copyright (c) 2022 The BayesFlow Developers
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
import tensorflow as tf
from numpy import e as EULER_CONST
from numpy import pi as PI_CONST
from bayesflow import default_settings
from bayesflow.exceptions import ConfigurationError
from bayesflow.helper_functions import build_meta_dict
from bayesflow.helper_networks import ActNorm, DenseCouplingNet, Orthogonal, Permutation
class AffineCoupling(tf.keras.Model):
"""Implements a conditional affine coupling block according to [1, 2], with additional
options, such as residual blocks or Monte Carlo Dropout.
[1] Kingma, D. P., & Dhariwal, P. (2018).
Glow: Generative flow with invertible 1x1 convolutions.
Advances in neural information processing systems, 31.
[2] Ardizzone, L., Lüth, C., Kruse, J., Rother, C., & Köthe, U. (2019).
Guided image generation with conditional invertible neural networks.
arXiv preprint arXiv:1907.02392.
def __init__(self, dim_out, settings_dict, **kwargs):
"""Creates one half of an affine coupling layer to be used as part of a ``CouplingLayer`` in
an ``InvertibleNetwork`` instance.
dim_out : int
The output dimensionality of the affine coupling layer.
settings_dict : dict
The settings for the inner networks. Defaults will use:
"dense_args" : dict(units=128, activation="relu"),
"num_dense" : 2,
"spec_norm" : False,
"mc_dropout" : False,
"dropout" : True,
"residual" : False,
"dropout_prob" : 0.01,
"soft_clamping" : 1.9
self.dim_out = dim_out
self.soft_clamp = settings_dict["soft_clamping"]
# Check if separate settings for s and t are provided and adjust accordingly
if settings_dict.get("s_args") is not None and settings_dict.get("t_args") is not None:
s_settings, t_settings = settings_dict.get("s_args"), settings_dict.get("t_args")
elif settings_dict.get("s_args") is not None and settings_dict.get("t_args") is None:
raise ConfigurationError("s_args were provided, but you also need to provide t_args!")
elif settings_dict.get("s_args") is None and settings_dict.get("t_args") is not None:
raise ConfigurationError("t_args were provided, but you also need to provide s_args!")
s_settings, t_settings = settings_dict, settings_dict
# Internal network (learnable scale and translation)
self.scale = DenseCouplingNet(s_settings, dim_out)
self.translate = DenseCouplingNet(t_settings, dim_out)
def call(self, split1, split2, condition, inverse=False, **kwargs):
"""Performs one pass through an affine coupling layer (either inverse or forward).
split1 : tf.Tensor of shape (batch_size, ..., input_dim//2)
The first partition of the input vector(s)
split2 : tf.Tensor of shape (batch_size, ..., ceil[input_dim//2])
The second partition of the input vector(s)
condition : tf.Tensor or None
The conditioning data of interest, for instance, x = summary_fun(x), shape (batch_size, ...).
If ``condition is None``, then the layer reduces to an unconditional coupling.
inverse : bool, optional, default: False
Flag indicating whether to run the block forward or backward.
(z, log_det_J) : tuple(tf.Tensor, tf.Tensor)
If inverse=False: The transformed input and the corresponding Jacobian of the transformation,
z shape: (batch_size, ..., input_dim//2), log_det_J shape: (batch_size, ...)
target : tf.Tensor
If inverse=True: The back-transformed z, shape (batch_size, ..., inp_dim//2)
if not inverse:
return self._forward(split1, split2, condition, **kwargs)
return self._inverse(split1, split2, condition, **kwargs)
def _forward(self, u1, u2, condition, **kwargs):
"""Performs a forward pass through the coupling layer. Used internally by the instance.
v1 : tf.Tensor of shape (batch_size, ..., dim_1)
The first partition of the input
v2 : tf.Tensor of shape (batch_size, ..., dim_2)
The second partition of the input
condition : tf.Tensor of shape (batch_size, ..., dim_condition) or None
The optional conditioning vector. Batch size must match the batch size
of the partitions.
(v, log_det_J) : tuple(tf.Tensor, tf.Tensor)
The transformed input and the corresponding Jacobian of the transformation.
s = self.scale(u2, condition, **kwargs)
if self.soft_clamp is not None:
s = (2.0 * self.soft_clamp / PI_CONST) * tf.math.atan(s / self.soft_clamp)
t = self.translate(u2, condition, **kwargs)
v = u1 * tf.math.exp(s) + t
log_det_J = tf.reduce_sum(s, axis=-1)
return v, log_det_J
def _inverse(self, v1, v2, condition, **kwargs):
"""Performs an inverse pass through the affine coupling block. Used internally by the instance.
v1 : tf.Tensor of shape (batch_size, ..., dim_1)
The first partition of the latent vector
v2 : tf.Tensor of shape (batch_size, ..., dim_2)
The second partition of the latent vector
condition : tf.Tensor of shape (batch_size, ..., dim_condition)
The optional conditioning vector. Batch size must match the batch size
of the partitions.
u : tf.Tensor of shape (batch_size, ..., dim_1)
The back-transformed input.
s = self.scale(v1, condition, **kwargs)
if self.soft_clamp is not None:
s = (2.0 * self.soft_clamp / PI_CONST) * tf.math.atan(s / self.soft_clamp)
t = self.translate(v1, condition, **kwargs)
u = (v2 - t) * tf.math.exp(-s)
return u
class SplineCoupling(tf.keras.Model):
"""Implements a conditional spline coupling block according to [1, 2], with additional
options, such as residual blocks or Monte Carlo Dropout.
[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G. (2019).
Neural spline flows. Advances in Neural Information Processing Systems, 32.
[2] Ardizzone, L., Lüth, C., Kruse, J., Rother, C., & Köthe, U. (2019).
Guided image generation with conditional invertible neural networks.
arXiv preprint arXiv:1907.02392.
Implement only rational quadratic splines (RQS), since these appear to work
best in practice and lead to stable training.
def __init__(self, dim_out, settings_dict, **kwargs):
"""Creates one half of a spline coupling layer to be used as part of a ``CouplingLayer`` in
an ``InvertibleNetwork`` instance.
dim_out : int
The output dimensionality of the coupling layer.
settings_dict : dict
The settings for the inner networks. Defaults will use:
"dense_args" : dict(units=128, activation="relu"),
"num_dense" : 2,
"spec_norm" : False,
"mc_dropout" : False,
"dropout" : True,
"residual" : False,
"dropout_prob" : 0.05,
"bins" : 16,
"default_domain" : (-5., 5., -5., 5.)
self.dim_out = dim_out
self.bins = settings_dict["bins"]
self.default_domain = settings_dict["default_domain"]
self.spline_params_counts = {
"left_edge": 1,
"bottom_edge": 1,
"widths": self.bins,
"heights": self.bins,
"derivatives": self.bins - 1,
self.num_total_spline_params = sum(self.spline_params_counts.values()) * self.dim_out
# Internal network (learnable spline parameters)
self.net = DenseCouplingNet(settings_dict, self.num_total_spline_params)
def call(self, split1, split2, condition, inverse=False, **kwargs):
"""Performs one pass through a spline coupling layer (either inverse or forward).
split1 : tf.Tensor of shape (batch_size, ..., input_dim//2)
The first partition of the input vector(s)
split2 : tf.Tensor of shape (batch_size, ..., ceil[input_dim//2])
The second partition of the input vector(s)
condition : tf.Tensor or None
The conditioning data of interest, for instance, x = summary_fun(x), shape (batch_size, ...).
If ``condition is None``, then the layer recuces to an unconditional coupling.
inverse : bool, optional, default: False
Flag indicating whether to run the block forward or backward.
(z, log_det_J) : tuple(tf.Tensor, tf.Tensor)
If inverse=False: The transformed input and the corresponding Jacobian of the transformation,
z shape: (batch_size, ..., input_dim//2), log_det_J shape: (batch_size, ...)
target : tf.Tensor
If inverse=True: The back-transformed z, shape (batch_size, ..., inp_dim//2)
if not inverse:
return self._forward(split1, split2, condition, **kwargs)
return self._inverse(split1, split2, condition, **kwargs)
def _forward(self, u1, u2, condition, **kwargs):
"""Performs a forward pass through the spline coupling layer. Used internally by the instance.
v1 : tf.Tensor of shape (batch_size, ..., dim_1)
The first partition of the input
v2 : tf.Tensor of shape (batch_size, ..., dim_2)
The second partition of the input
condition : tf.Tensor of shape (batch_size, ..., dim_condition) or None
The optional conditioning vector. Batch size must match the batch size
of the partitions.
(v, log_det_J) : tuple(tf.Tensor, tf.Tensor)
The transformed input and the corresponding Jacobian of the transformation.
spline_params = self.net(u2, condition, **kwargs)
spline_params = self._semantic_spline_parameters(spline_params)
spline_params = self._constrain_parameters(spline_params)
v, log_det_J = self._calculate_spline(u1, spline_params, inverse=False)
return v, log_det_J
def _inverse(self, v1, v2, condition, **kwargs):
"""Performs an inverse pass through the coupling block. Used internally by the instance.
v1 : tf.Tensor of shape (batch_size, ..., dim_1)
The first partition of the latent vector
v2 : tf.Tensor of shape (batch_size, ..., dim_2)
The second partition of the latent vector
condition : tf.Tensor of shape (batch_size, ..., dim_condition)
The optional conditioning vector. Batch size must match the batch size
of the partitions.
u : tf.Tensor of shape (batch_size, ..., dim_1)
The back-transformed input.
spline_params = self.net(v1, condition, **kwargs)
spline_params = self._semantic_spline_parameters(spline_params)
spline_params = self._constrain_parameters(spline_params)
u = self._calculate_spline(v2, spline_params, inverse=True)
return u
def _calculate_spline(self, target, spline_params, inverse=False):
"""Computes both directions of a rational quadratic spline (RQS) as in:
At this point, ``spline_params`` represents a tuple with the parameters of the RQS learned
by the internal neural network (given optional conditional information).
target : tf.Tensor of shape (batch_size, ..., dim_2)
The target partition of the input vector to transform.
spline_params : tuple(tf.Tensor,...)
A tuple with tensors corresponding to the learnable spline features:
(left_edge, bottom_edge, widths, heights, derivatives)
inverse : bool, optional, default: False
Flag indicating whether to run the block forward or backward.
(result, log_det_J) : tuple(tf.Tensor, tf.Tensor)
If inverse=False: The transformed input and the corresponding Jacobian of the transformation,
result shape: (batch_size, ..., dim_2), log_det_J shape: (batch_size, ...)
result : tf.Tensor
If inverse=True: The back-transformed latent, shape (batch_size, ..., dim_2)
# Extract all learnable parameters
left_edge, bottom_edge, widths, heights, derivatives = spline_params
# Placeholders for results
result = tf.zeros_like(target)
log_jac = tf.zeros_like(target)
total_width = tf.reduce_sum(widths, axis=-1, keepdims=True)
total_height = tf.reduce_sum(heights, axis=-1, keepdims=True)
knots_x = tf.concat([left_edge, left_edge + tf.math.cumsum(widths, axis=-1)], axis=-1)
knots_y = tf.concat([bottom_edge, bottom_edge + tf.math.cumsum(heights, axis=-1)], axis=-1)
# Determine which targets are in domain and which are not
if not inverse:
target_in_domain = tf.logical_and(knots_x[..., 0] < target, target <= knots_x[..., -1])
higher_indices = tf.searchsorted(knots_x, target[..., None])
target_in_domain = tf.logical_and(knots_y[..., 0] < target, target <= knots_y[..., -1])
higher_indices = tf.searchsorted(knots_y, target[..., None])
target_in = target[target_in_domain]
target_in_idx = tf.where(target_in_domain)
target_out = target[~target_in_domain]
target_out_idx = tf.where(~target_in_domain)
# In-domain computation
if tf.size(target_in_idx) > 0:
# Index crunching
higher_indices = tf.gather_nd(higher_indices, target_in_idx)
higher_indices = tf.cast(higher_indices, tf.int32)
lower_indices = higher_indices - 1
lower_idx_tuples = tf.concat([tf.cast(target_in_idx, tf.int32), lower_indices], axis=-1)
higher_idx_tuples = tf.concat([tf.cast(target_in_idx, tf.int32), higher_indices], axis=-1)
# Spline computation
dk = tf.gather_nd(derivatives, lower_idx_tuples)
dkp = tf.gather_nd(derivatives, higher_idx_tuples)
xk = tf.gather_nd(knots_x, lower_idx_tuples)
xkp = tf.gather_nd(knots_x, higher_idx_tuples)
yk = tf.gather_nd(knots_y, lower_idx_tuples)
ykp = tf.gather_nd(knots_y, higher_idx_tuples)
x = target_in
dx = xkp - xk
dy = ykp - yk
sk = dy / dx
xi = (x - xk) / dx
# Forward pass
if not inverse:
numerator = dy * (sk * xi**2 + dk * xi * (1 - xi))
denominator = sk + (dkp + dk - 2 * sk) * xi * (1 - xi)
result_in = yk + numerator / denominator
# Log Jacobian for in-domain
numerator = sk**2 * (dkp * xi**2 + 2 * sk * xi * (1 - xi) + dk * (1 - xi) ** 2)
denominator = (sk + (dkp + dk - 2 * sk) * xi * (1 - xi)) ** 2
log_jac_in = tf.math.log(numerator + 1e-10) - tf.math.log(denominator + 1e-10)
log_jac = tf.tensor_scatter_nd_update(log_jac, target_in_idx, log_jac_in)
# Inverse pass
y = x
a = dy * (sk - dk) + (y - yk) * (dkp + dk - 2 * sk)
b = dy * dk - (y - yk) * (dkp + dk - 2 * sk)
c = -sk * (y - yk)
discriminant = tf.maximum(b**2 - 4 * a * c, 0.0)
xi = 2 * c / (-b - tf.math.sqrt(discriminant))
result_in = xi * dx + xk
result = tf.tensor_scatter_nd_update(result, target_in_idx, result_in)
# Out-of-domain
if tf.size(target_out_idx) > 1:
scale = total_height / total_width
shift = bottom_edge - scale * left_edge
scale_out = tf.gather_nd(scale, target_out_idx)
shift_out = tf.gather_nd(shift, target_out_idx)
if not inverse:
result_out = scale_out * target_out[..., None] + shift_out
# Log Jacobian for out-of-domain points
log_jac_out = tf.math.log(scale_out + 1e-10)
log_jac_out = tf.squeeze(log_jac_out, axis=-1)
log_jac = tf.tensor_scatter_nd_update(log_jac, target_out_idx, log_jac_out)
result_out = (target_out[..., None] - shift_out) / scale_out
result_out = tf.squeeze(result_out, axis=-1)
result = tf.tensor_scatter_nd_update(result, target_out_idx, result_out)
if not inverse:
return result, tf.reduce_sum(log_jac, axis=-1)
return result
def _semantic_spline_parameters(self, parameters):
"""Builds a tuple of tensors from the output of the coupling net.
parameters : tf.Tensor of shape (batch_size, ..., num_spline_parameters)
All learnable spline parameters packed in a single tensor, which will be
partitioned according to the role of each spline parameter.
parameters : tuple(tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor)
The partitioned spline parameters according to their role in the spline computation.
shape = tf.shape(parameters)
rank = len(shape)
if rank == 2:
new_shape = (shape[0], self.dim_out, -1)
elif rank == 3:
new_shape = (shape[0], shape[1], self.dim_out, -1)
raise NotImplementedError("Spline flows can currently only operate on 2D and 3D inputs!")
parameters = tf.reshape(parameters, new_shape)
parameters = tf.split(parameters, list(self.spline_params_counts.values()), axis=-1)
return parameters
def _constrain_parameters(self, parameters):
"""Takes care of zero spline parameters due to zeros kernel initializer and
applies parameter constraints for stability.
parameters : tuple(tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor)
The unconstrained spline parameters.
parameters : tuple(tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor)
The constrained spline parameters.
left_edge, bottom_edge, widths, heights, derivatives = parameters
# Set lower corners of domain relative to default domain
left_edge = left_edge + self.default_domain[0]
bottom_edge = bottom_edge + self.default_domain[2]
# Compute default widths and heights
default_width = (self.default_domain[1] - self.default_domain[0]) / self.bins
default_height = (self.default_domain[3] - self.default_domain[2]) / self.bins
# Compute shifts for softplus function
xshift = tf.math.log(tf.math.exp(default_width) - 1)
yshift = tf.math.log(tf.math.exp(default_height) - 1)
# Constrain widths and heights to be positive
widths = tf.math.softplus(widths + xshift)
heights = tf.math.softplus(heights + yshift)
# Compute spline derivatives
shift = tf.math.log(EULER_CONST - 1.0)
derivatives = tf.nn.softplus(derivatives + shift)
# Add in edge derivatives
total_height = tf.reduce_sum(heights, axis=-1, keepdims=True)
total_width = tf.reduce_sum(widths, axis=-1, keepdims=True)
scale = total_height / total_width
derivatives = tf.concat([scale, derivatives, scale], axis=-1)
return left_edge, bottom_edge, widths, heights, derivatives
class CouplingLayer(tf.keras.Model):
"""General wrapper for a coupling layer (either affine or spline) with different settings."""
def __init__(
"""Creates an invertible coupling layers instance with the provided hyperparameters.
latent_dim : int
The dimensionality of the latent space (equal to the dimensionality of the target variable)
coupling_settings : dict or None, optional, default: None
The coupling network settings to pass to the internal coupling layers. See ``default_settings``
for the required entries.
coupling_design : str or callable, optional, default: 'affine'
The type of internal coupling network to use. Must be in ['affine', 'spline'].
In general, spline couplings run slower than affine couplings, but requires fewer coupling
layers. Spline couplings may work best with complex (e.g., multimodal) low-dimensional
problems. The difference will become less and less pronounced as we move to higher dimensions.
permutation : str or None, optional, default: 'fixed'
Whether to use permutations between coupling layers. Highly recommended if ``num_coupling_layers > 1``
Important: Must be in ['fixed', 'learnable', None]
use_act_norm : bool, optional, default: True
Whether to use activation normalization after each coupling layer. Recommended to keep default.
act_norm_init : np.ndarray of shape (num_simulations, num_params) or None, optional, default: None
Optional data-dependent initialization for the internal ``ActNorm`` layers.
**kwargs : dict
Optional keyword arguments (e.g., name) passed to the tf.keras.Model __init__ method.
# Set dimensionality attributes
self.latent_dim = latent_dim
self.dim_out1 = self.latent_dim // 2
self.dim_out2 = self.latent_dim // 2 if self.latent_dim % 2 == 0 else self.latent_dim // 2 + 1
# Determine coupling net settings
if coupling_settings is None:
user_dict = dict()
elif isinstance(coupling_settings, dict):
user_dict = coupling_settings
raise ConfigurationError("coupling_net_settings argument must be None or a dict!")
# Determine type of coupling (affine or spline) and build settings
if coupling_design == "affine":
coupling_type = AffineCoupling
coupling_settings = build_meta_dict(
user_dict=user_dict, default_setting=default_settings.DEFAULT_SETTING_AFFINE_COUPLING
elif coupling_design == "spline":
coupling_type = SplineCoupling
coupling_settings = build_meta_dict(
user_dict=user_dict, default_setting=default_settings.DEFAULT_SETTING_SPLINE_COUPLING
raise NotImplementedError('coupling_design must be in ["affine", "spline"]')
# Two-in-one coupling block (i.e., no inactive part after a forward pass)
self.net1 = coupling_type(self.dim_out1, coupling_settings)
self.net2 = coupling_type(self.dim_out2, coupling_settings)
# Optional (learnable or fixed) permutation
if permutation not in ["fixed", "learnable", None]:
raise ConfigurationError('Argument permutation should be in ["fixed", "learnable", None]')
if permutation == "fixed":
self.permutation = Permutation(self.latent_dim)
self.permutation.trainable = False
elif permutation == "learnable":
self.permutation = Orthogonal(self.latent_dim)
self.permutation = None
# Optional learnable activation normalization
if use_act_norm:
self.act_norm = ActNorm(latent_dim, act_norm_init)
self.act_norm = None
def call(self, target_or_z, condition, inverse=False, **kwargs):
"""Performs one pass through the affine coupling layer (either inverse or forward).
target_or_z : tf.Tensor
The estimation quantities of interest or latent representations z ~ p(z), shape (batch_size, ...)
condition : tf.Tensor or None
The conditioning data of interest, for instance, x = summary_fun(x), shape (batch_size, ...).
If `condition is None`, then the layer reduces to an unconditional ACL.
inverse : bool, optional, default: False
Flag indicating whether to run the block forward or backward.
(z, log_det_J) : tuple(tf.Tensor, tf.Tensor)
If inverse=False: The transformed input and the corresponding Jacobian of the transformation,
z shape: (batch_size, inp_dim), log_det_J shape: (batch_size, )
target : tf.Tensor
If inverse=True: The back-transformed z, shape (batch_size, inp_dim)
If ``inverse=False``, the return is ``(z, log_det_J)``.\n
If ``inverse=True``, the return is ``target``
if not inverse:
return self.forward(target_or_z, condition, **kwargs)
return self.inverse(target_or_z, condition, **kwargs)
def forward(self, target, condition, **kwargs):
"""Performs a forward pass through a coupling layer with an optinal `Permutation` and `ActNorm` layers.
target : tf.Tensor
The estimation quantities of interest, for instance, parameter vector of shape (batch_size, theta_dim)
condition : tf.Tensor or None
The conditioning vector of interest, for instance, x = summary(x), shape (batch_size, summary_dim)
If `None`, transformation amounts to unconditional estimation.
(z, log_det_J) : tuple(tf.Tensor, tf.Tensor)
The transformed input and the corresponding Jacobian of the transformation.
# Initialize log_det_Js accumulator
log_det_Js = tf.zeros(1)
# Normalize activation, if specified
if self.act_norm is not None:
target, log_det_J_act = self.act_norm(target)
log_det_Js += log_det_J_act
# Permute, if indicated
if self.permutation is not None:
target = self.permutation(target)
if self.permutation.trainable:
target, log_det_J_p = target
log_det_Js += log_det_J_p
# Pass through coupling layer
latent, log_det_J_c = self._forward(target, condition, **kwargs)
log_det_Js += log_det_J_c
return latent, log_det_Js
def inverse(self, latent, condition, **kwargs):
"""Performs an inverse pass through a coupling layer with an optinal `Permutation` and `ActNorm` layers.
z : tf.Tensor
latent variables z ~ p(z), shape (batch_size, theta_dim)
condition : tf.Tensor or None
The conditioning vector of interest, for instance, x = summary(x), shape (batch_size, summary_dim).
If `None`, transformation amounts to unconditional estimation.
target : tf.Tensor
The back-transformed latent variable z.
target = self._inverse(latent, condition, **kwargs)
if self.permutation is not None:
target = self.permutation(target, inverse=True)
if self.act_norm is not None:
target = self.act_norm(target, inverse=True)
return target
def _forward(self, target, condition, **kwargs):
"""Performs a forward pass through the coupling layer. Used internally by the instance.
target : tf.Tensor
The estimation quantities of interest, for instance, parameter vector of shape (batch_size, theta_dim)
condition : tf.Tensor or None
The conditioning vector of interest, for instance, x = summary(x), shape (batch_size, summary_dim)
If `None`, transformation amounts to unconditional estimation.
(v, log_det_J) : tuple(tf.Tensor, tf.Tensor)
The transformed input and the corresponding Jacobian of the transformation.
# Split input along last axis and perform forward coupling
u1, u2 = tf.split(target, [self.dim_out1, self.dim_out2], axis=-1)
v1, log_det_J1 = self.net1(u1, u2, condition, inverse=False, **kwargs)
v2, log_det_J2 = self.net2(u2, v1, condition, inverse=False, **kwargs)
v = tf.concat((v1, v2), axis=-1)
# Compute log determinat of the Jacobians from both splits
log_det_J = log_det_J1 + log_det_J2
return v, log_det_J
def _inverse(self, latent, condition, **kwargs):
"""Performs an inverse pass through the coupling block. Used internally by the instance.
latent : tf.Tensor
latent variables z ~ p(z), shape (batch_size, theta_dim)
condition : tf.Tensor or None
The conditioning vector of interest, for instance, x = summary(x), shape (batch_size, summary_dim).
If `None`, transformation amounts to unconditional estimation.
u : tf.Tensor
The back-transformed input.
# Split input along last axis and perform inverse coupling
v1, v2 = tf.split(latent, [self.dim_out1, self.dim_out2], axis=-1)
u2 = self.net2(v1, v2, condition, inverse=True, **kwargs)
u1 = self.net1(u2, v1, condition, inverse=True, **kwargs)
u = tf.concat((u1, u2), axis=-1)
return u