BasicWorkflow#
- class bayesflow.workflows.BasicWorkflow(simulator: Simulator = None, adapter: Adapter = None, inference_network: InferenceNetwork | str = 'coupling_flow', summary_network: SummaryNetwork | str = None, initial_learning_rate: float = 0.0005, optimizer: Optimizer | type = None, checkpoint_filepath: str = None, checkpoint_name: str = 'model', save_weights_only: bool = False, save_best_only: bool = False, inference_variables: Sequence[str] | str = None, inference_conditions: Sequence[str] | str = None, summary_variables: Sequence[str] | str = None, standardize: Sequence[str] | str | None = 'inference_variables', **kwargs)[source]#
 Bases:
WorkflowThis class provides methods to set up, simulate, and fit and validate models using amortized Bayesian inference. It allows for both online and offline amortized workflows.
- Parameters:
 - simulatorSimulator, optional
 A Simulator object to generate synthetic data for inference (default is None).
- adapterAdapter, optional
 Adapter for data processing. If not provided, a default adapter will be used (default is None), but you need to make sure to provide the correct names for inference_variables and/or inference_conditions and/or summary_variables.
- inference_networkInferenceNetwork or str, optional
 The inference network used for posterior approximation, specified as an instance or by name (default is “coupling_flow”).
- summary_networkSummaryNetwork or str, optional
 The summary network used for data summarization, specified as an instance or by name (default is None).
- initial_learning_ratefloat, optional
 Initial learning rate for the optimizer (default is 5e-4).
- optimizertype, optional
 The optimizer to be used for training. If None, a default Adam optimizer will be selected (default is None).
- checkpoint_filepathstr, optional
 Directory path where model checkpoints will be saved (default is None).
- checkpoint_namestr, optional
 Name of the checkpoint file (default is “model”).
- save_weights_onlybool, optional
 If True, only the model weights will be saved during checkpointing (default is False).
- save_best_only: bool, optional
 If only the best model according to the quantity monitored (loss or validation) at the end of each epoch will be saved instead of the last model (default is False). Use with caution, as some losses (e.g. flow matching) do not reliably reflect model performance, and outliers in the validation data can cause unwanted effects.
- inference_variablesSequence[str] or str, optional
 Variables for inference as a sequence of strings or a single string (default is None). Important for automating diagnostics!
- inference_conditionsSequence[str] or str, optional
 Variables used as direct conditions for inference (default is None).
- summary_variablesSequence[str] or str, optional
 Variables to be summarized through the summary network before being used as conditions (default is None).
- standardizeSequence[str] or str, optional
 Variables to standardize during preprocessing (default is “inference_variables”). These will be passed to the corresponding approximator constructor and can be either “all” or any subset of [“inference_variables”, “summary_variables”, “inference_conditions”].
- **kwargsdict, optional
 Additional arguments for configuring networks, adapters, optimizers, etc.
- property adapter#
 
- static samples_to_data_frame(samples: Mapping[str, ndarray]) DataFrame[source]#
 Convert a dictionary of samples into a pandas DataFrame.
- Parameters:
 - samplesMapping[str, np.ndarray]
 A dictionary where keys represent variable names and values are arrays containing sampled data.
- Returns:
 - pd.DataFrame
 A pandas DataFrame where each column corresponds to a variable, and rows represent individual samples.
- static default_adapter(inference_variables: Sequence[str] | str, inference_conditions: Sequence[str] | str, summary_variables: Sequence[str] | str) Adapter[source]#
 Create a default adapter for processing inference variables, conditions, summaries, and standardization.
Converts all float64 values to float32 for computational efficiency.
- Parameters:
 - inference_variablesSequence[str] or str
 The variables to be treated as inference targets.
- inference_conditionsSequence[str] or str
 The variables used as conditions for inference.
- summary_variablesSequence[str] or str
 The variables used for summarization.
- Returns:
 - Adapter
 A configured Adapter instance that applies dtype conversion, concatenation, and optional standardization.
- simulate(batch_shape: tuple[int, ...], **kwargs) dict[str, ndarray][source]#
 Generates a batch of simulations using the provided simulator.
- Parameters:
 - batch_shapeShape
 The shape of the batch to be simulated. Typically an integer for simple simulators.
- **kwargsdict, optional
 Additional keyword arguments passed to the simulator’s sample method.
- Returns:
 - dict[str, np.ndarray]
 A dictionary where keys represent variable names and values are NumPy arrays containing the simulated variables.
- Raises:
 - RuntimeError
 If no simulator is provided.
- simulate_adapted(batch_shape: tuple[int, ...], **kwargs) dict[str, ndarray][source]#
 Generates a batch of simulations and applies the adapter to the result.
- Parameters:
 - batch_shapeShape
 The shape of the batch to be simulated. Typically an integer for simple simulators.
- **kwargsdict, optional
 Additional keyword arguments passed to the simulator’s sample method.
- Returns:
 - dict[str, np.ndarray]
 A dictionary where keys represent variable names and values are NumPy arrays containing the adapted simulated variables.
- Raises:
 - RuntimeError
 If no simulator is provided.
- sample(*, num_samples: int, conditions: Mapping[str, ndarray], **kwargs) dict[str, ndarray][source]#
 Draws num_samples samples from the approximator given specified conditions.
- Parameters:
 - num_samplesint
 The number of samples to generate.
- conditionsdict[str, np.ndarray]
 A dictionary where keys represent variable names and values are NumPy arrays containing the adapted simulated variables. Keys used as summary or inference conditions during training should be present.
- **kwargsdict, optional
 Additional keyword arguments passed to the approximator’s sampling function.
- Returns:
 - dict[str, np.ndarray]
 A dictionary where keys correspond to variable names and values are arrays containing the generated samples.
- estimate(*, conditions: Mapping[str, ndarray], **kwargs) dict[str, dict[str, ndarray | dict[str, ndarray]]][source]#
 Estimates point summaries of inference variables based on specified conditions.
- Parameters:
 - conditionsMapping[str, np.ndarray]
 A dictionary mapping variable names to arrays representing the conditions for the estimation process.
- **kwargs
 Additional keyword arguments passed to underlying processing functions.
- Returns:
 - estimatesdict[str, dict[str, np.ndarray or dict[str, np.ndarray]]]
 The estimates of inference variables in a nested dictionary.
Each first-level key is the name of an inference variable.
Each second-level key is the name of a scoring rule.
(If the scoring rule comprises multiple estimators, each third-level key is the name of an estimator.)
Each estimator output (i.e., dictionary value that is not itself a dictionary) is an array of shape (num_datasets, point_estimate_size, variable_block_size).
- log_prob(data: Mapping[str, ndarray], **kwargs) ndarray[source]#
 Compute the log probability of given variables under the approximator.
- Parameters:
 - dataMapping[str, np.ndarray]
 A dictionary where keys represent variable names and values are arrays corresponding to the variables’ realizations.
- **kwargsdict, optional
 Additional keyword arguments passed to the approximator’s log probability function.
- Returns:
 - np.ndarray
 An array containing the log probabilities computed from the provided variables.
- plot_default_diagnostics(test_data: Mapping[str, ndarray] | int, num_samples: int = 1000, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, **kwargs) dict[str, Figure][source]#
 Generates default diagnostic plots to evaluate the quality of inference. The function produces several diagnostic plots, including - Loss history (if training history is available). - Parameter recovery plots. - Calibration ECDF plots. - Z-score contraction plots.
- Parameters:
 - test_dataMapping[str, np.ndarray] or int
 A dictionary containing test data where keys represent variable names and values are corresponding data arrays. If an integer is provided, that number of test data sets will be generated using the simulator (if available).
- num_samplesint, optional
 The number of samples to draw from the approximator for diagnostics, by default 1000.
- variable_keyslist or None, optional, default: None
 Select keys from the dictionaries provided in estimates and targets. By default, select all keys.
- variable_nameslist or None, optional, default: None
 The variable names for nice table plot titles.
- **kwargsdict, optional
 Additional keyword arguments:
- test_data_kwargs: dict, optional
 Arguments to pass to the simulator when generating test data.
- approximator_kwargs: dict, optional
 Arguments to pass to the approximator’s sampling function.
- loss_kwargs: dict, optional
 Arguments for customizing the loss plot.
- recovery_kwargs: dict, optional
 Arguments for customizing the parameter recovery plot.
- calibration_ecdf_kwargs: dict, optional
 Arguments for customizing the empirical cumulative distribution function (ECDF) calibration plot.
- z_score_contraction_kwargs: dict, optional
 Arguments for customizing the z-score contraction plot.
- Returns:
 - dict[str, plt.Figure]
 A dictionary where keys correspond to different diagnostic plot types, and values are the respective matplotlib Figure objects.
- plot_custom_diagnostics(test_data: Mapping[str, ndarray] | int, plot_fns: Mapping[str, Callable], num_samples: int = 1000, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, **kwargs) dict[str, Figure][source]#
 Generates custom diagnostic plots to evaluate the quality of inference. The functions passed should have the following signature: - fn(samples, inference_variables, variable_names)
They should also return a single matplotlib Figure object.
- Parameters:
 - test_dataMapping[str, np.ndarray] or int
 A dictionary containing test data where keys represent variable names and values are corresponding data arrays. If an integer is provided, that number of test data sets will be generated using the simulator (if available).
- plot_fns: Mapping[str, Callable]
 A dictionary containing custom plotting functions where keys represent the function names and values correspond to the functions themselves. The functions should have a signature of fn(samples, inference_variables, variable_names)
- num_samplesint, optional
 The number of samples to draw from the approximator for diagnostics, by default 1000.
- variable_keyslist or None, optional, default: None
 Select keys from the dictionaries provided in estimates and targets. By default, select all keys.
- variable_nameslist or None, optional, default: None
 The variable names for nice table plot titles.
- **kwargsdict, optional
 Additional keyword arguments:
- test_data_kwargs: dict, optional
 Arguments to pass to the simulator when generating test data.
- approximator_kwargs: dict, optional
 Arguments to pass to the approximator’s sampling function.
- loss_kwargs: dict, optional
 Arguments for customizing the loss plot.
- recovery_kwargs: dict, optional
 Arguments for customizing the parameter recovery plot.
- calibration_ecdf_kwargs: dict, optional
 Arguments for customizing the empirical cumulative distribution function (ECDF) calibration plot.
- z_score_contraction_kwargs: dict, optional
 Arguments for customizing the z-score contraction plot.
- Returns:
 - dict[str, plt.Figure]
 A dictionary where keys correspond to different diagnostic plot types, and values are the respective matplotlib Figure objects.
- compute_default_diagnostics(test_data: Mapping[str, ndarray] | int, num_samples: int = 1000, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, as_data_frame: bool = True, **kwargs) Sequence[dict] | DataFrame[source]#
 Computes default diagnostic metrics to evaluate the quality of inference. The function computes several diagnostic metrics, including: - Root Mean Squared Error (RMSE) - Posterior contraction - Calibration error
- Parameters:
 - test_dataMapping[str, np.ndarray] or int
 A dictionary containing test data where keys represent variable names and values are corresponding realizations. If an integer is provided, that number of test data sets will be generated using the simulator (if available).
- num_samplesint, optional
 The number of samples to draw from the approximator for diagnostics, by default 1000.
- variable_keyslist or None, optional, default: None
 Select keys from the dictionaries provided in estimates and targets. By default, select all keys.
- variable_nameslist or None, optional, default: None
 The parameter names for nice table plot titles.
- as_data_framebool, optional
 Whether to return the results as a pandas DataFrame (default: True). If False, a sequence of dictionaries with metric values is returned.
- **kwargsdict, optional
 Additional keyword arguments:
- test_data_kwargs: dict, optional
 Arguments to pass to the simulator when generating test data.
- approximator_kwargs: dict, optional
 Arguments to pass to the approximator’s sampling function.
- root_mean_squared_error_kwargs: dict, optional
 Arguments for customizing the RMSE computation.
- posterior_contraction_kwargs: dict, optional
 Arguments for customizing the posterior contraction computation.
- calibration_error_kwargs: dict, optional
 Arguments for customizing the calibration error computation.
- Returns:
 - Sequence[dict] or pd.DataFrame
 If as_data_frame is True, returns a pandas DataFrame containing the computed diagnostic metrics for each variable. Otherwise, returns a sequence of dictionaries with metric values.
- compute_custom_diagnostics(test_data: Mapping[str, ndarray] | int, metrics: Mapping[str, Callable], num_samples: int = 1000, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, as_data_frame: bool = True, **kwargs) Sequence[Mapping] | DataFrame[source]#
 Computes custom diagnostic metrics to evaluate the quality of inference. The metric functions should have a signature of:
metric_fn(samples, inference_variables, variable_names, variable_keys) or
metric_fn(samples, inference_variables, **kwargs)
And return a dictionary containing the metric name in ‘name’ key and the metric values in a ‘values’ key.
- Parameters:
 - test_dataMapping[str, np.ndarray] or int
 A dictionary containing test data where keys represent variable names and values are corresponding realizations. If an integer is provided, that number of test data sets will be generated using the simulator (if available).
- metrics: Mapping[str, Callable]
 A dictionary containing custom metric functions where keys represent the function names and values correspond to the functions themselves. The functions should have a signature of fn(samples, inference_variables, variable_names)
- num_samplesint, optional
 The number of samples to draw from the approximator for diagnostics, by default 1000.
- variable_keyslist or None, optional, default: None
 Select keys from the dictionaries provided in estimates and targets. By default, select all keys.
- variable_nameslist or None, optional, default: None
 The variable names for nice plot titles.
- as_data_framebool, optional
 Whether to return the results as a pandas DataFrame (default: True). If False, a sequence of dictionaries with metric values is returned.
- **kwargsdict, optional
 Additional keyword arguments:
- test_data_kwargs: dict, optional
 Arguments to pass to the simulator when generating test data.
- approximator_kwargs: dict, optional
 Arguments to pass to the approximator’s sampling function.
- root_mean_squared_error_kwargs: dict, optional
 Arguments for customizing the RMSE computation.
- posterior_contraction_kwargs: dict, optional
 Arguments for customizing the posterior contraction computation.
- calibration_error_kwargs: dict, optional
 Arguments for customizing the calibration error computation.
- Returns:
 - Sequence[dict] or pd.DataFrame
 If as_data_frame is True, returns a pandas DataFrame containing the computed diagnostic metrics for each variable. Otherwise, returns a sequence of dictionaries with metric values.
- fit_offline(data: Mapping[str, ndarray], epochs: int = 100, batch_size: int = 32, keep_optimizer: bool = False, validation_data: Mapping[str, ndarray] | int = None, augmentations: Mapping[str, Callable] | Callable = None, **kwargs) History[source]#
 Train the approximator offline using a fixed dataset. This approach will be faster than online training, since no computation time is spent in generating new data for each batch, but it assumes that simulations can fit in memory.
- Parameters:
 - dataMapping[str, np.ndarray]
 A dictionary containing training data where keys represent variable names and values are corresponding realizations.
- epochsint, optional
 The number of training epochs, by default 100. Consider increasing this number for free-form inference networks, such as FlowMatching or ConsistencyModel, which generally need more epochs than CouplingFlows.
- batch_sizeint, optional
 The batch size used for training, by default 32.
- keep_optimizerbool, optional
 Whether to retain the current state of the optimizer after training, by default False.
- validation_dataMapping[str, np.ndarray] or int, optional
 A dictionary containing validation data. If an integer is provided, that number of validation samples will be generated (if supported). By default, no validation data is used.
- augmentationsdict of str to Callable or Callable, optional
 Dictionary of augmentation functions to apply to each corresponding key in the batch or a function to apply to the entire batch (possibly adding new keys).
If you provide a dictionary of functions, each function should accept one element of your output batch and return the corresponding transformed element. Otherwise, your function should accept the entire dictionary output and return a dictionary.
Note - augmentations are applied before the adapter is called and are generally transforms that you only want to apply during training.
- **kwargsdict, optional
 Additional keyword arguments passed to the underlying _fit method.
- Returns:
 - historykeras.callbacks.History
 A history object containing training history, where keys correspond to logged metrics (e.g., loss values) and values are arrays tracking metric evolution over epochs.
- fit_online(epochs: int = 100, num_batches_per_epoch: int = 100, batch_size: int = 32, keep_optimizer: bool = False, validation_data: Mapping[str, ndarray] | int = None, augmentations: Mapping[str, Callable] | Callable = None, **kwargs) History[source]#
 Train the approximator using an online data-generating process. The dataset is dynamically generated during training, making this approach suitable for scenarios where generating new simulations is computationally cheap.
- Parameters:
 - epochsint, optional
 The number of training epochs, by default 100.
- num_batches_per_epochint, optional
 The number of batches generated per epoch, by default 100.
- batch_sizeint, optional
 The batch size used for training, by default 32.
- keep_optimizerbool, optional
 Whether to retain the current state of the optimizer after training, by default False.
- validation_dataMapping[str, np.ndarray] or int, optional
 A dictionary containing validation data. If an integer is provided, that number of validation samples will be generated (if supported). By default, no validation data is used.
- augmentationsdict of str to Callable or Callable, optional
 Dictionary of augmentation functions to apply to each corresponding key in the batch or a function to apply to the entire batch (possibly adding new keys).
If you provide a dictionary of functions, each function should accept one element of your output batch and return the corresponding transformed element. Otherwise, your function should accept the entire dictionary output and return a dictionary.
Note - augmentations are applied before the adapter is called and are generally transforms that you only want to apply during training.
- **kwargsdict, optional
 Additional keyword arguments passed to the underlying _fit method.
- Returns:
 - historykeras.callbacks.History
 A history object containing training history, where keys correspond to logged metrics (e.g., loss values) and values are arrays tracking metric evolution over epochs.
- fit_disk(root: PathLike, pattern: str = '*.pkl', batch_size: int = 32, load_fn: callable = None, epochs: int = 100, keep_optimizer: bool = False, validation_data: Mapping[str, ndarray] | int = None, augmentations: Mapping[str, Callable] | Callable = None, **kwargs) History[source]#
 Train the approximator using data stored on disk. This approach is suitable for large sets of simulations that don’t fit in memory.
- Parameters:
 - rootos.PathLike
 The root directory containing the dataset files.
- patternstr, optional
 A filename pattern to match dataset files, by default
"*.pkl".- batch_sizeint, optional
 The batch size used for training, by default 32.
- load_fncallable, optional
 A function to load dataset files. If None, a default loading function is used.
- epochsint, optional
 The number of training epochs, by default 100. Consider increasing this number for free-form inference networks, such as FlowMatching or ConsistencyModel, which generally need more epochs than CouplingFlows.
- keep_optimizerbool, optional
 Whether to retain the current state of the optimizer after training, by default False.
- validation_dataMapping[str, np.ndarray] or int, optional
 A dictionary containing validation data. If an integer is provided, that number of validation samples will be generated (if supported). By default, no validation data is used.
- augmentationsdict of str to Callable or Callable, optional
 Dictionary of augmentation functions to apply to each corresponding key in the batch or a function to apply to the entire batch (possibly adding new keys).
If you provide a dictionary of functions, each function should accept one element of your output batch and return the corresponding transformed element. Otherwise, your function should accept the entire dictionary output and return a dictionary.
Note - augmentations are applied before the adapter is called and are generally transforms that you only want to apply during training.
- **kwargsdict, optional
 Additional keyword arguments passed to the underlying _fit method.
- Returns:
 - historykeras.callbacks.History
 A history object containing training history, where keys correspond to logged metrics (e.g., loss values) and values are arrays tracking metric evolution over epochs.
- build_optimizer(epochs: int, num_batches: int, strategy: str) Optimizer | None[source]#
 Build and initialize the optimizer based on the training strategy. Uses a cosine decay learning rate schedule, where the final learning rate is proportional to the square of the initial learning rate, as found to work best in SBI.
The default optimizer will use 5% of the epochs as warmup; during the warmup phase, the learning rate will be increased from 10% of the initial learning rate to initial learning rate supplied to the workflow.
- Parameters:
 - epochsint
 The total number of training epochs.
- num_batchesint
 The number of batches per epoch.
- strategystr
 The training strategy, either “online” or another mode that applies weight decay. For “online” training, an Adam optimizer with gradient clipping is used. For other strategies, AdamW is used with weight decay to encourage regularization.
- Returns:
 - keras.Optimizer or None
 The initialized optimizer if it was not already set. Returns None if the optimizer was already defined.
- build_graph(*args, **kwargs)#
 
- fit(**kwargs)#