Source code for bayesflow.configuration

# Copyright (c) 2022 The BayesFlow Developers

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import numpy as np
from tensorflow.keras.utils import to_categorical

from bayesflow.default_settings import DEFAULT_KEYS
from bayesflow.exceptions import ConfigurationError


[docs] class DefaultJointConfigurator: """Fallback class for a generic configurator for joint posterior and likelihood approximation."""
[docs] def __init__(self, default_float_type=np.float32): self.posterior_config = DefaultPosteriorConfigurator(default_float_type=default_float_type) self.likelihood_config = DefaultLikelihoodConfigurator(default_float_type=default_float_type) self.default_float_type = default_float_type
[docs] def __call__(self, forward_dict): """Configures the outputs of a generative model for joint learning.""" input_dict = {} input_dict[DEFAULT_KEYS["posterior_inputs"]] = self.posterior_config(forward_dict) input_dict[DEFAULT_KEYS["likelihood_inputs"]] = self.likelihood_config(forward_dict) return input_dict
[docs] class DefaultLikelihoodConfigurator: """Fallback class for a generic configrator for amortized likelihood approximation."""
[docs] def __init__(self, default_float_type=np.float32): self.default_float_type = default_float_type
[docs] def __call__(self, forward_dict): """Configures the output of a generative model for likelihood estimation.""" # Attempt to combine inputs input_dict = self._combine(forward_dict) # Convert everything to default type or fail gently input_dict = {k: v.astype(self.default_float_type) if v is not None else v for k, v in input_dict.items()} return input_dict
def _combine(self, forward_dict): """Default combination for entries in forward_dict.""" out_dict = {DEFAULT_KEYS["observables"]: None, DEFAULT_KEYS["conditions"]: None} # Determine whether simulated or observed data available, throw if None present if forward_dict.get(DEFAULT_KEYS["sim_data"]) is None and forward_dict.get(DEFAULT_KEYS["obs_data"]) is None: raise ConfigurationError( f"Either {DEFAULT_KEYS['sim_data']} or {DEFAULT_KEYS['obs_data']}" + " should be present as keys in the forward_dict." ) # If only simulated or observed data present, all good elif forward_dict.get(DEFAULT_KEYS["sim_data"]) is not None: data = forward_dict.get(DEFAULT_KEYS["sim_data"]) elif forward_dict.get(DEFAULT_KEYS["obs_data"]) is not None: data = forward_dict.get(DEFAULT_KEYS["obs_data"]) # Else if neither 'sim_data' nor 'obs_data' present, throw again else: raise ConfigurationError( f"Either {DEFAULT_KEYS['sim_data']} or {DEFAULT_KEYS['obs_data']}" + " should be present as keys in the forward_dict." ) # Extract targets and conditions out_dict[DEFAULT_KEYS["observables"]] = data out_dict[DEFAULT_KEYS["conditions"]] = forward_dict[DEFAULT_KEYS["prior_draws"]] return out_dict
[docs] class DefaultCombiner: """Fallback class for a generic combiner of conditions."""
[docs] def __call__(self, forward_dict): """Converts all condition-related variables or fails.""" out_dict = { DEFAULT_KEYS["summary_conditions"]: None, DEFAULT_KEYS["direct_conditions"]: None, } # Determine whether simulated or observed data available, throw if None present if forward_dict.get(DEFAULT_KEYS["sim_data"]) is None and forward_dict.get(DEFAULT_KEYS["obs_data"]) is None: raise ConfigurationError( f"Either {DEFAULT_KEYS['sim_data']} or {DEFAULT_KEYS['obs_data']}" + " should be present as keys in the forward_dict, but not both!" ) # If only simulated or observed data present, all good elif forward_dict.get(DEFAULT_KEYS["sim_data"]) is not None: data = forward_dict.get(DEFAULT_KEYS["sim_data"]) elif forward_dict.get(DEFAULT_KEYS["obs_data"]) is not None: data = forward_dict.get(DEFAULT_KEYS["obs_data"]) # Else if neither 'sim_data' nor 'obs_data' present, throw again else: raise ConfigurationError( f"Either {DEFAULT_KEYS['sim_data']} or {DEFAULT_KEYS['obs_data']}" + " should be present as keys in the forward_dict." ) # Handle simulated or observed data or throw if the data could not be converted to an array try: if type(data) is not np.ndarray: summary_conditions = np.array(data) else: summary_conditions = data except Exception as _: raise ConfigurationError("Could not convert input data to array...") # Handle prior batchable context or throw if error encountered if forward_dict.get(DEFAULT_KEYS["prior_batchable_context"]) is not None: try: if type(forward_dict[DEFAULT_KEYS["prior_batchable_context"]]) is not np.ndarray: pbc_as_array = np.array(forward_dict[DEFAULT_KEYS["prior_batchable_context"]]) else: pbc_as_array = forward_dict[DEFAULT_KEYS["prior_batchable_context"]] except Exception as _: raise ConfigurationError("Could not convert prior batchable context to array.") try: summary_conditions = np.concatenate([summary_conditions, pbc_as_array], axis=-1) except Exception as _: raise ConfigurationError( f"Could not concatenate data and prior batchable context. Shape mismatch: " + "data - {summary_conditions.shape}, prior_batchable_context - {pbc_as_array.shape}." ) # Handle simulation batchable context, or throw if error encountered if forward_dict.get(DEFAULT_KEYS["sim_batchable_context"]) is not None: try: if type(forward_dict[DEFAULT_KEYS["sim_batchable_context"]]) is not np.ndarray: sbc_as_array = np.array(forward_dict[DEFAULT_KEYS["sim_batchable_context"]]) else: sbc_as_array = forward_dict[DEFAULT_KEYS["sim_batchable_context"]] except Exception as _: raise ConfigurationError("Could not convert simulation batchable context to array!") try: summary_conditions = np.concatenate([summary_conditions, sbc_as_array], axis=-1) except Exception as _: raise ConfigurationError( f"Could not concatenate data (+optional prior context) and" + f" simulation batchable context. Shape mismatch:" + f" data - {summary_conditions.shape}, prior_batchable_context - {sbc_as_array.shape}" ) # Add summary conditions to output dict out_dict[DEFAULT_KEYS["summary_conditions"]] = summary_conditions # Handle non-batchable contexts if ( forward_dict.get(DEFAULT_KEYS["prior_non_batchable_context"]) is None and forward_dict.get(DEFAULT_KEYS["sim_non_batchable_context"]) is None ): return out_dict # Handle prior non-batchable context direct_conditions = None if forward_dict.get(DEFAULT_KEYS["prior_non_batchable_context"]) is not None: try: if type(forward_dict[DEFAULT_KEYS["prior_non_batchable_context"]]) is not np.ndarray: pnbc_conditions = np.array(forward_dict[DEFAULT_KEYS["prior_non_batchable_context"]]) else: pnbc_conditions = forward_dict[DEFAULT_KEYS["prior_non_batchable_context"]] except Exception as _: raise ConfigurationError("Could not convert prior non_batchable_context to an array!") direct_conditions = pnbc_conditions # Handle simulation non-batchable context if forward_dict.get(DEFAULT_KEYS["sim_non_batchable_context"]) is not None: try: if type(forward_dict[DEFAULT_KEYS["sim_non_batchable_context"]]) is not np.ndarray: snbc_conditions = np.array(forward_dict[DEFAULT_KEYS["sim_non_batchable_context"]]) else: snbc_conditions = forward_dict[DEFAULT_KEYS["sim_non_batchable_context"]] except Exception as _: raise ConfigurationError("Could not convert sim_non_batchable_context to array!") try: if direct_conditions is not None: direct_conditions = np.concatenate([direct_conditions, snbc_conditions], axis=-1) else: direct_conditions = snbc_conditions except Exception as _: raise ConfigurationError( f"Could not concatenate prior non-batchable context and \ simulation non-batchable context. Shape mismatch: \ - {direct_conditions.shape} vs. {snbc_conditions.shape}" ) out_dict[DEFAULT_KEYS["direct_conditions"]] = direct_conditions return out_dict
[docs] class DefaultPosteriorConfigurator: """Fallback class for a generic configrator for amortized posterior approximation."""
[docs] def __init__(self, default_float_type=np.float32): self.default_float_type = default_float_type self.combiner = DefaultCombiner()
[docs] def __call__(self, forward_dict): """Processes the forward dict to configure the input to an amortizer.""" # Combine inputs (conditionals) input_dict = self.combiner(forward_dict) input_dict[DEFAULT_KEYS["parameters"]] = forward_dict[DEFAULT_KEYS["prior_draws"]] # Convert everything to default type or fail gently input_dict = {k: v.astype(self.default_float_type) if v is not None else v for k, v in input_dict.items()} return input_dict
[docs] class DefaultModelComparisonConfigurator: """Fallback class for a default configurator for amortized model comparison."""
[docs] def __init__(self, num_models, combiner=None, default_float_type=np.float32): self.num_models = num_models if combiner is None: self.combiner = DefaultCombiner() else: self.combiner = combiner self.default_float_type = default_float_type
[docs] def __call__(self, forward_dict): """Convert all variables to arrays and combines them for inference into a dictionary with the following keys, if DEFAULT_KEYS dictionary unchanged: `model_indices` - a list of model indices, e.g., if two models, then [0, 1] `model_outputs` - a list of dictionaries, e.g., if two models, then [dict0, dict1] """ # Prepare placeholders input_dict = { DEFAULT_KEYS["summary_conditions"]: None, DEFAULT_KEYS["direct_conditions"]: None, DEFAULT_KEYS["model_indices"]: None, } summary_conditions = [] direct_conditions = [] model_indices = [] # Loop through outputs of individual models for m_idx, dict_m in zip( forward_dict[DEFAULT_KEYS["model_indices"]], forward_dict[DEFAULT_KEYS["model_outputs"]] ): # Configure individual model outputs conf_out = self.combiner(dict_m) # Extract summary conditions if conf_out.get(DEFAULT_KEYS["summary_conditions"]) is not None: summary_conditions.append(conf_out[DEFAULT_KEYS["summary_conditions"]]) num_draws_m = conf_out[DEFAULT_KEYS["summary_conditions"]].shape[0] # Extract direct conditions if conf_out.get(DEFAULT_KEYS["direct_conditions"]) is not None: direct_conditions.append(conf_out[DEFAULT_KEYS["direct_conditions"]]) num_draws_m = conf_out[DEFAULT_KEYS["direct_conditions"]].shape[0] model_indices.append(to_categorical([m_idx] * num_draws_m, self.num_models)) # At this point, all elements of the input_dicts should be arrays with identical keys input_dict[DEFAULT_KEYS["summary_conditions"]] = ( np.concatenate(summary_conditions) if summary_conditions else None ) input_dict[DEFAULT_KEYS["direct_conditions"]] = np.concatenate(direct_conditions) if direct_conditions else None input_dict[DEFAULT_KEYS["model_indices"]] = np.concatenate(model_indices) # Convert to default types input_dict = {k: v.astype(self.default_float_type) if v is not None else v for k, v in input_dict.items()} return input_dict