import keras
from bayesflow.types import Shape, Tensor
from bayesflow.utils import find_distribution, keras_kwargs
from bayesflow.utils.decorators import allow_batch_size
[docs]
class InferenceNetwork(keras.Layer):
MLP_DEFAULT_CONFIG = {}
def __init__(self, base_distribution: str = "normal", **kwargs):
super().__init__(**keras_kwargs(kwargs))
self.base_distribution = find_distribution(base_distribution)
[docs]
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
self.base_distribution.build(xz_shape)
[docs]
def call(
self,
xz: Tensor,
conditions: Tensor = None,
inverse: bool = False,
density: bool = False,
training: bool = False,
**kwargs,
) -> Tensor | tuple[Tensor, Tensor]:
if inverse:
return self._inverse(xz, conditions=conditions, density=density, training=training, **kwargs)
return self._forward(xz, conditions=conditions, density=density, training=training, **kwargs)
def _forward(
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
raise NotImplementedError
def _inverse(
self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
) -> Tensor | tuple[Tensor, Tensor]:
raise NotImplementedError
[docs]
@allow_batch_size
def sample(self, batch_shape: Shape, conditions: Tensor = None, **kwargs) -> Tensor:
samples = self.base_distribution.sample(batch_shape)
samples = self(samples, conditions=conditions, inverse=True, density=False, **kwargs)
return samples
[docs]
def log_prob(self, samples: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
_, log_density = self(samples, conditions=conditions, inverse=False, density=True, **kwargs)
return log_density
[docs]
def compute_metrics(
self, x: Tensor, conditions: Tensor = None, sample_weight: Tensor = None, stage: str = "training"
) -> dict[str, Tensor]:
if not self.built:
xz_shape = keras.ops.shape(x)
conditions_shape = None if conditions is None else keras.ops.shape(conditions)
self.build(xz_shape, conditions_shape=conditions_shape)
metrics = {}
if stage != "training" and any(self.metrics):
# compute sample-based metrics
samples = self.sample((keras.ops.shape(x)[0],), conditions=conditions)
for metric in self.metrics:
metrics[metric.name] = metric(samples, x)
return metrics