from collections.abc import Mapping
import multiprocessing as mp
import keras
from bayesflow.adapters import Adapter
from bayesflow.datasets import OnlineDataset
from bayesflow.simulators import Simulator
from bayesflow.utils import find_batch_size, filter_kwargs, logging
from .backend_approximators import BackendApproximator
[docs]
class Approximator(BackendApproximator):
[docs]
def build(self, data_shapes: any) -> None:
mock_data = keras.tree.map_structure(keras.ops.zeros, data_shapes)
self.build_from_data(mock_data)
[docs]
@classmethod
def build_adapter(cls, **kwargs) -> Adapter:
# implemented by each respective architecture
raise NotImplementedError
[docs]
def build_from_data(self, data: Mapping[str, any]) -> None:
self.compute_metrics(**data, stage="training")
self.built = True
[docs]
@classmethod
def build_dataset(
cls,
*,
batch_size: int = "auto",
num_batches: int,
adapter: Adapter = "auto",
memory_budget: str | int = "auto",
simulator: Simulator,
workers: int = "auto",
use_multiprocessing: bool = False,
max_queue_size: int = 32,
**kwargs,
) -> OnlineDataset:
if batch_size == "auto":
batch_size = find_batch_size(memory_budget=memory_budget, sample=simulator.sample((1,)))
logging.info(f"Using a batch size of {batch_size}.")
if adapter == "auto":
adapter = cls.build_adapter(**filter_kwargs(kwargs, cls.build_adapter))
if workers == "auto":
workers = mp.cpu_count()
logging.info(f"Using {workers} data loading workers.")
workers = workers or 1
return OnlineDataset(
simulator=simulator,
batch_size=batch_size,
num_batches=num_batches,
adapter=adapter,
workers=workers,
use_multiprocessing=use_multiprocessing,
max_queue_size=max_queue_size,
)
[docs]
def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = None, **kwargs):
"""
Trains the approximator on the provided dataset or on-demand data generated from the given simulator.
If `dataset` is not provided, a dataset is built from the `simulator`.
If the model has not been built, it will be built using a batch from the dataset.
Parameters
----------
dataset : keras.utils.PyDataset, optional
A dataset containing simulations for training. If provided, `simulator` must be None.
simulator : Simulator, optional
A simulator used to generate a dataset. If provided, `dataset` must be None.
**kwargs
Additional keyword arguments passed to `keras.Model.fit()`, including (see also `build_dataset`):
batch_size : int or None, default='auto'
Number of samples per gradient update. Do not specify if `dataset` is provided as a
`keras.utils.PyDataset`, `tf.data.Dataset`, `torch.utils.data.DataLoader`, or a generator function.
epochs : int, default=1
Number of epochs to train the model.
verbose : {"auto", 0, 1, 2}, default="auto"
Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch.
callbacks : list of keras.callbacks.Callback, optional
List of callbacks to apply during training.
validation_split : float, optional
Fraction of training data to use for validation (only supported if `dataset` consists of NumPy arrays
or tensors).
validation_data : tuple or dataset, optional
Data for validation, overriding `validation_split`.
shuffle : bool, default=True
Whether to shuffle the training data before each epoch (ignored for dataset generators).
initial_epoch : int, default=0
Epoch at which to start training (useful for resuming training).
steps_per_epoch : int or None, optional
Number of steps (batches) before declaring an epoch finished.
validation_steps : int or None, optional
Number of validation steps per validation epoch.
validation_batch_size : int or None, optional
Number of samples per validation batch (defaults to `batch_size`).
validation_freq : int, default=1
Specifies how many training epochs to run before performing validation.
Returns
-------
keras.callbacks.History
A history object containing the training loss and metrics values.
Raises
------
ValueError
If both `dataset` and `simulator` are provided or neither is provided.
"""
if dataset is None:
if simulator is None:
raise ValueError("Received no data to fit on. Please provide either a dataset or a simulator.")
logging.info(f"Building dataset from simulator instance of {simulator.__class__.__name__}.")
dataset = self.build_dataset(simulator=simulator, **filter_kwargs(kwargs, self.build_dataset))
else:
if simulator is not None:
raise ValueError(
"Received conflicting arguments. Please provide either a dataset or a simulator, but not both."
)
logging.info(f"Fitting on dataset instance of {dataset.__class__.__name__}.")
if not self.built:
logging.info("Building on a test batch.")
mock_data = dataset[0]
mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data)
self.build_from_data(mock_data)
return super().fit(dataset=dataset, **kwargs)