# Copyright (c) 2022 The BayesFlow Developers

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.


import os
import re
from copy import deepcopy

import numpy as np
import pandas as pd
import tensorflow as tf

    import cPickle as pickle
    import pickle

import logging


from sklearn.linear_model import HuberRegressor

from bayesflow.default_settings import DEFAULT_KEYS

[docs] class SimulationDataset: """Helper class to create a which parses simulation dictionaries and returns simulation dictionaries as expected by BayesFlow amortizers. """
[docs] def __init__(self, forward_dict, batch_size, buffer_size=1024): """Creates a wrapper holding a ```` instance for offline training in an amortized estimation context. Parameters ---------- forward_dict : dict The outputs from a ``GenerativeModel`` or a custom function, stored in a dictionary with at least the following keys: ``sim_data`` - an array representing the batched output of the model ``prior_draws`` - an array with prior generated from the model's prior batch_size : int The total number of simulations from all models in a given batch. The batch size per model will be calculated as ``batch_size // num_models`` buffer_size : int, optional, default: 1024 The buffer size for shuffling elements in a ```` """ slices, keys_used, keys_none, n_sim = self._determine_slices(forward_dict) = self.keys_used = keys_used self.keys_none = keys_none self.n_sim = n_sim self.num_batches = len(
def _determine_slices(self, forward_dict): """Determine slices for a tensorflow Dataset.""" keys_used = [] keys_none = [] slices = [] for k, v in forward_dict.items(): if forward_dict[k] is not None: slices.append(v) keys_used.append(k) else: keys_none.append(k) n_sim = forward_dict[DEFAULT_KEYS["sim_data"]].shape[0] return slices, keys_used, keys_none, n_sim
[docs] def __call__(self, batch_in): """Convert output of to dict.""" forward_dict = {} for key_used, batch_stuff in zip(self.keys_used, batch_in): forward_dict[key_used] = batch_stuff.numpy() for key_none in zip(self.keys_none): forward_dict[key_none] = None return forward_dict
def __len__(self): return len( def __iter__(self): return map(self,
[docs] class MultiSimulationDataset: """Helper class for model comparison training with multiple generative models. Will create multiple ``SimulationDataset`` instances, each parsing their own simulation dictionaries and returning these as expected by BayesFlow amortizers. """
[docs] def __init__(self, forward_dict, batch_size, buffer_size=1024): """Creates a wrapper holding multiple ```` instances for offline training in an amortized model comparison context. Parameters ---------- forward_dict : dict The outputs from a ``MultiGenerativeModel`` or a custom function, stored in a dictionary with at least the following keys: ``model_outputs`` - a list with length equal to the number of models, each element representing a batched output of a single model ``model_indices`` - a list with integer model indices, which will later be one-hot-encoded for the model comparison learning problem. batch_size : int The total number of simulations from all models in a given batch. The batch size per model will be calculated as ``batch_size // num_models`` buffer_size : int, optional, default: 1024 The buffer size for shuffling elements in a ```` """ self.model_indices = forward_dict[DEFAULT_KEYS["model_indices"]] self.num_models = len(self.model_indices) self.per_model_batch_size = batch_size // self.num_models self.datasets = [ SimulationDataset(out, self.per_model_batch_size, buffer_size) for out in forward_dict[DEFAULT_KEYS["model_outputs"]] ] self.current_it = 0 self.num_batches = min([d.num_batches for d in self.datasets]) self.iters = [iter(d) for d in self.datasets] self.batch_size = batch_size # Include further keys (= shared context) from forward_dict self.further_keys = {} for key, value in forward_dict.items(): if key not in [DEFAULT_KEYS["model_outputs"], DEFAULT_KEYS["model_indices"]]: self.further_keys[key] = value
def __next__(self): if self.current_it < self.num_batches: outputs = [next(d) for d in self.iters] output_dict = {DEFAULT_KEYS["model_outputs"]: outputs, DEFAULT_KEYS["model_indices"]: self.model_indices} if self.further_keys: output_dict.update(self.further_keys) self.current_it += 1 return output_dict self.current_it = 0 self.iters = [iter(d) for d in self.datasets] raise StopIteration def __iter__(self): return self
[docs] class EarlyStopper: """This class will track the total validation loss and trigger an early stopping recommendation based on its hyperparameters."""
[docs] def __init__(self, patience=5, tolerance=0.05): """ patience : int, optional, default: 5 How many successive times the tolerance value is reached before triggering an early stopping recommendation. tolerance : float, optional, default: 0.05 The minimum reduction of validation loss to be considered significant. """ self.history = [] self.patience = patience self.tolerance = tolerance self._patience_counter = 0
[docs] def update_and_recommend(self, current_val_loss): """Adds loss to history and check difference between sequential losses.""" self.history.append(current_val_loss) rec = self._check_patience() return rec
def _check_patience(self): """Check whether the patience has been surpassed or not. Assumes current_val_loss has previously been added to the internal history, so it has at least one element. """ # Still not enough history, no recommendation if len(self.history) <= 1: return False # Significant increase according to tolerance, reset patience if (self.history[-2] - self.history[-1]) >= self.tolerance: self._patience_counter = 0 return False # Not a signifcant increase, check counter else: # Still no stop recommendation, but increase counter if self._patience_counter < self.patience: self._patience_counter += 1 return False # Reset counter and recommend stop else: self._patience_counter = 0 return True
[docs] class RegressionLRAdjuster: """This class will compute the slope of the loss trajectory and inform learning rate decay.""" file_name = "lr_adjuster"
[docs] def __init__( self, optimizer, period=1000, wait_between_fits=10, patience=10, tolerance=-0.05, reduction_factor=0.25, cooldown_factor=2, num_resets=3, **kwargs, ): """Creates an instance with given hyperparameters which will track the slope of the loss trajectory according to specified hyperparameters and then issue an optional stopping suggestion. Parameters ---------- optimizer : tf.keras.optimizers.Optimizer instance An optimizer implementing a lr() method period : int, optional, default: 1000 How much loss values to consider from the past wait_between_fits : int, optional, default: 10 How many backpropagation updates to wait between two successive fits patience : int, optional, default: 10 How many successive times the tolerance value is reached before lr update. tolerance : float, optional, default: -0.05 The minimum slope to be considered substantial for training. reduction_factor : float in [0, 1], optional, default: 0.25 The factor by which the learning rate is reduced upon hitting the `tolerance` threshold for `patience` number of times cooldown_factor : float, optional, default: 2 The factor by which the `period` is multiplied to arrive at a cooldown period. num_resets : int, optional, default: 3 How many times to reduce the learning rate before issuing an optional stopping **kwargs : dict, optional, default {} Additional keyword arguments passed to the `HuberRegression` class. """ self.optimizer = optimizer self.period = period self.wait_between_periods = wait_between_fits self.regressor = HuberRegressor(**kwargs) self.t_vector = np.linspace(0, 1, self.period)[:, np.newaxis] self.patience = patience self.tolerance = tolerance self.num_resets = num_resets self.reduction_factor = reduction_factor self.stopping_issued = False self.cooldown_factor = cooldown_factor self._history = {"iteration": [], "learning_rate": []} self._reset_counter = 0 self._patience_counter = 0 self._cooldown_counter = 0 self._wait_counter = 0 self._slope = None self._is_waiting = False self._in_cooldown = False
[docs] def get_slope(self, losses): """Fits a Huber regression on the provided loss trajectory or returns `None` if not enough data points present. """ # Return None if not enough loss values present if losses.shape[0] < self.period: return None # Increment counter if self._in_cooldown: self._cooldown_counter += 1 # Check if still in a waiting phase and return old slope # if still waiting, otherwise refit Huber regression wait = self._check_waiting() if wait: return self._slope else:, losses[-self.period :]) self._slope = self.regressor.coef_[0] self._check_patience() return self._slope
[docs] def reset(self): """Resets all stateful variables in preparation for a new start.""" self._reset_counter = 0 self._patience_counter = 0 self._cooldown_counter = 0 self._wait_counter = 0 self._in_cooldown = False self._is_waiting = False self.stopping_issued = False
[docs] def save_to_file(self, file_path): """Saves the state parameters of a RegressionLRAdjuster object to a pickled dictionary in file_path.""" # Create path to memory memory_path = os.path.join(file_path, f"{RegressionLRAdjuster.file_name}.pkl") # Prepare attributes states_dict = {} states_dict["_history"] = self._history states_dict["_reset_counter"] = self._reset_counter states_dict["_patience_counter"] = self._patience_counter states_dict["_cooldown_counter"] = self._cooldown_counter states_dict["_wait_counter"] = self._wait_counter states_dict["_slope"] = self._slope states_dict["_is_waiting"] = self._is_waiting states_dict["_in_cooldown"] = self._in_cooldown # Dump as pickle object with open(memory_path, "wb") as f: pickle.dump(states_dict, f)
[docs] def load_from_file(self, file_path): """Loads the saved LRAdjuster object from file_path.""" # Logger init logger = logging.getLogger() logger.setLevel(logging.INFO) # Create path to memory memory_path = os.path.join(file_path, f"{RegressionLRAdjuster.file_name}.pkl") # Case memory file exists if os.path.exists(memory_path): # Load pickle and fill in attributes with open(memory_path, "rb") as f: states_dict = pickle.load(f) self._history = states_dict["_history"] self._reset_counter = states_dict["_reset_counter"] self._patience_counter = states_dict["_patience_counter"] self._cooldown_counter = states_dict["_cooldown_counter"] self._wait_counter = states_dict["_wait_counter"] self._slope = states_dict["_slope"] self._is_waiting = states_dict["_is_waiting"] self._in_cooldown = states_dict["_in_cooldown"]"Loaded RegressionLRAdjuster from {memory_path}") # Case memory file does not exist else:"Initialized a new RegressionLRAdjuster.")
def _check_patience(self): """Determines whether to reduce learning rate or be patient.""" # Do nothing, if still in cooldown period if self._in_cooldown and self._cooldown_counter < int(self.cooldown_factor * self.period): return # Otherwise set cooldown flag to False and reset counter else: self._in_cooldown = False self._cooldown_counter = 0 # Check if negetaive slope too small if self._slope > self.tolerance: self._patience_counter += 1 else: self._patience_counter = max(0, self._patience_counter - 1) # Check if patience surpassed and issue a reduction in learning rate if self._patience_counter >= self.patience: self._reduce_learning_rate() self._patience_counter = 0 def _reduce_learning_rate(self): """Reduces the learning rate by a given factor.""" # Logger init logger = logging.getLogger() logger.setLevel(logging.INFO) if self._reset_counter >= self.num_resets: self.stopping_issued = True else: # Take care of updating learning rate old_lr = new_lr = round(self.reduction_factor * old_lr, 8) self._reset_counter += 1 # Store iteration and learning rate self._history["iteration"].append(self.optimizer.iterations.numpy()) self._history["learning_rate"].append(old_lr) # Verbose info to user"Reducing learning rate from {old_lr:.8f} to: {new_lr:.8f} and entering cooldown...") # Set cooldown flag to avoid reset for some time given by self.period self._in_cooldown = True def _check_waiting(self): """Determines whether to compute a new slope or wait.""" # Case currently waiting if self._is_waiting: # Case currently waiting but period is over if self._wait_counter >= self.wait_between_periods - 1: self._wait_counter = 0 self._is_waiting = False # Case currently waiting and period not over else: self._wait_counter += 1 return True # Case not waiting else: self._is_waiting = True self._wait_counter += 1 return False
[docs] class LossHistory: """Helper class to keep track of losses during training.""" file_name = "history"
[docs] def __init__(self): self.latest = 0 self.history = {} self.val_history = {} self.loss_names = [] self.val_loss_names = [] self._current_run = 0 self._total_loss = [] self._total_val_loss = []
@property def total_loss(self): return np.array(self._total_loss) @property def total_val_loss(self): return np.array(self._total_val_loss)
[docs] def last_total_loss(self): return self._total_loss[-1]
[docs] def last_total_val_loss(self): return self._total_val_loss[-1]
[docs] def start_new_run(self): self._current_run += 1 self.history[f"Run {self._current_run}"] = {} self.val_history[f"Run {self._current_run}"] = {}
[docs] def add_val_entry(self, epoch, val_loss): """Add validation entry to loss structure. Assume ``loss_names`` already exists as an attribute, so no attempt will be made to create names. """ # Add epoch key, if specified if self.val_history[f"Run {self._current_run}"].get(f"Epoch {epoch}") is None: self.val_history[f"Run {self._current_run}"][f"Epoch {epoch}"] = [] # Handle dict loss output if type(val_loss) is dict: # Store keys, if none existing if self.val_loss_names == []: self.val_loss_names = ["Val." + k for k in val_loss.keys()] # Create and store entry entry = [v.numpy() if type(v) is not np.ndarray else v for v in val_loss.values()] self.val_history[f"Run {self._current_run}"][f"Epoch {epoch}"].append(entry) # Add entry to total loss self._total_val_loss.append(sum(entry)) # Handle tuple or list loss output elif type(val_loss) is tuple or type(val_loss) is list: entry = [v.numpy() if type(v) is not np.ndarray else v for v in val_loss] self.val_history[f"Run {self._current_run}"][f"Epoch {epoch}"].append(entry) # Store keys, if none existing if self.val_loss_names == []: self.val_loss_names = [f"Val.Loss.{l}" for l in range(1, len(entry) + 1)] # Add entry to total loss self._total_val_loss.append(sum(entry)) # Assume scalar loss output else: self.val_history[f"Run {self._current_run}"][f"Epoch {epoch}"].append(val_loss.numpy()) # Store keys, if none existing if self.val_loss_names == []: self.val_loss_names.append("Default.Val.Loss") # Add entry to total loss self._total_val_loss.append(val_loss.numpy())
[docs] def add_entry(self, epoch, current_loss): """Adds loss entry for current epoch into internal memory data structure.""" # Add epoch key, if specified if self.history[f"Run {self._current_run}"].get(f"Epoch {epoch}") is None: self.history[f"Run {self._current_run}"][f"Epoch {epoch}"] = [] # Handle dict loss output if type(current_loss) is dict: # Store keys, if none existing if self.loss_names == []: self.loss_names = [k for k in current_loss.keys()] # Create and store entry entry = [v.numpy() if type(v) is not np.ndarray else v for v in current_loss.values()] self.history[f"Run {self._current_run}"][f"Epoch {epoch}"].append(entry) # Add entry to total loss self._total_loss.append(sum(entry)) # Handle tuple or list loss output elif type(current_loss) is tuple or type(current_loss) is list: entry = [v.numpy() if type(v) is not np.ndarray else v for v in current_loss] self.history[f"Run {self._current_run}"][f"Epoch {epoch}"].append(entry) # Store keys, if none existing if self.loss_names == []: self.loss_names = [f"Loss.{l}" for l in range(1, len(entry) + 1)] # Add entry to total loss self._total_loss.append(sum(entry)) # Assume scalar loss output else: self.history[f"Run {self._current_run}"][f"Epoch {epoch}"].append(current_loss.numpy()) # Store keys, if none existing if self.loss_names == []: self.loss_names.append("Default.Loss") # Add entry to total loss self._total_loss.append(current_loss.numpy())
[docs] def get_running_losses(self, epoch): """Compute and return running means of the losses for current epoch.""" means = np.atleast_1d(np.mean(self.history[f"Run {self._current_run}"][f"Epoch {epoch}"], axis=0)) if means.shape[0] == 1: return {"Avg.Loss": means[0]} else: return {"Avg." + k: v for k, v in zip(self.loss_names, means)}
[docs] def get_plottable(self): """Returns the losses as a nicely formatted pandas DataFrame, in case only train losses were collected, otherwise a dict of data frames. """ # Assume equal lengths per epoch and run try: losses_df = self._to_data_frame(self.history, self.loss_names) if any([v for v in self.val_history.values()]): # Rremove decay names = [name for name in self.loss_names if "Decay" not in name] val_losses_df = self._to_data_frame(self.val_history, names) return {"train_losses": losses_df, "val_losses": val_losses_df} return losses_df # Handle unequal lengths or problems when user kills training with an interrupt except ValueError as ve: if any([v for v in self.val_history.values()]): return {"train_losses": self.history, "val_losses": self.val_history} return self.history except TypeError as te: if any([v for v in self.val_history.values()]): return {"train_losses": self.history, "val_losses": self.val_history} return self.history
[docs] def flush(self): """Returns current history and removes all existing loss history, but keeps loss names.""" history = self.history val_history = self.val_history self.history = {} self.val_history = {} self._total_loss = [] self._total_val_loss = [] self._current_run = 0 return history, val_history
[docs] def save_to_file(self, file_path, max_to_keep): """Saves a `LossHistory` object to a pickled dictionary in file_path. If max_to_keep saved loss history files are found in file_path, the oldest is deleted before a new one is saved. """ # Increment history index self.latest += 1 # Path to history history_path = os.path.join(file_path, f"{LossHistory.file_name}_{self.latest}.pkl") # Prepare full history dict pickle_dict = { "history": self.history, "val_history": self.val_history, "loss_names": self.loss_names, "val_loss_names": self.val_loss_names, "_current_run": self._current_run, "_total_loss": self._total_loss, "_total_val_loss": self._total_val_loss, } # Pickle current with open(history_path, "wb") as f: pickle.dump(pickle_dict, f) # Get list of history checkpoints history_checkpoints_list = [l for l in os.listdir(file_path) if "history" in l] # Determine the oldest saved loss history and remove it if len(history_checkpoints_list) > max_to_keep: oldest_history_path = os.path.join(file_path, f"history_{self.latest-max_to_keep}.pkl") os.remove(oldest_history_path)
[docs] def load_from_file(self, file_path): """Loads the most recent saved `LossHistory` object from `file_path`.""" # Logger init logger = logging.getLogger() logger.setLevel(logging.INFO) # Get list of histories if os.path.exists(file_path): history_checkpoints_list = [l for l in os.listdir(file_path) if LossHistory.file_name in l] else: history_checkpoints_list = [] # Case history list is not empty if len(history_checkpoints_list) > 0: # Determine which file contains the latest LossHistory and load it file_numbers = [int(re.findall(r"\d+", h)[0]) for h in history_checkpoints_list] latest_file = history_checkpoints_list[np.argmax(file_numbers)] latest_number = np.max(file_numbers) latest_path = os.path.join(file_path, latest_file) # Load dictionary with open(latest_path, "rb") as f: loaded_history_dict = pickle.load(f) # Fill public entries self.latest = latest_number self.history = loaded_history_dict.get("history", {}) self.val_history = loaded_history_dict.get("val_history", {}) self.loss_names = loaded_history_dict.get("loss_names", []) self.val_loss_names = loaded_history_dict.get("val_loss_names", []) # Fill private entries self._current_run = loaded_history_dict.get("_current_run", 0) self._total_loss = loaded_history_dict.get("_total_loss", []) self._total_val_loss = loaded_history_dict.get("_total_val_loss", []) # Verbose"Loaded loss history from {latest_path}.") # Case history list is empty else:"Initialized empty loss history.")
def _to_data_frame(self, history, names): """Helper function to convert a history dict into a DataFrame.""" losses_list = [pd.melt(pd.DataFrame.from_dict(history[r], orient="index").T) for r in history] losses_list = pd.concat(losses_list, axis=0).value.to_list() losses_list = [l for l in losses_list if l is not None] losses_df = pd.DataFrame(losses_list, columns=names) return losses_df
[docs] class SimulationMemory: """Helper class to keep track of a pre-determined number of simulations during training.""" file_name = "memory"
[docs] def __init__(self, stores_raw=True, capacity_in_batches=50): self.stores_raw = stores_raw self._capacity = capacity_in_batches self._buffer = [None] * self._capacity self._idx = 0 self.size_in_batches = 0
[docs] def store(self, forward_dict): """Stores simulation outputs in `forward_dict`, if internal buffer is not full. Parameters ---------- forward_dict : dict The configured outputs of the forward model. """ # If full, overwrite at index if not self.is_full(): self._buffer[self._idx] = forward_dict self._idx += 1 self.size_in_batches += 1
[docs] def get_memory(self): return deepcopy(self._buffer)
[docs] def is_full(self): """Returns True if the buffer is full, otherwise False.""" if self._idx >= self._capacity: return True return False
[docs] def save_to_file(self, file_path): """Saves a `SimulationMemory` object to a pickled dictionary in file_path.""" # Create path to memory memory_path = os.path.join(file_path, f"{SimulationMemory.file_name}.pkl") # Prepare attributes full_memory_dict = {} full_memory_dict["stores_raw"] = self.stores_raw full_memory_dict["_capacity"] = self._capacity full_memory_dict["_buffer"] = self._buffer full_memory_dict["_idx"] = self._idx full_memory_dict["_size_in_batches"] = self.size_in_batches # Dump as pickle object with open(memory_path, "wb") as f: pickle.dump(full_memory_dict, f)
[docs] def load_from_file(self, file_path): """Loads the saved `SimulationMemory` object from file_path.""" # Logger init logger = logging.getLogger() logger.setLevel(logging.INFO) # Create path to memory memory_path = os.path.join(file_path, f"{SimulationMemory.file_name}.pkl") # Case memory file exists if os.path.exists(memory_path): # Load pickle and fill in attributes with open(memory_path, "rb") as f: full_memory_dict = pickle.load(f) self.stores_raw = full_memory_dict["stores_raw"] self._capacity = full_memory_dict["_capacity"] self._buffer = full_memory_dict["_buffer"] self._idx = full_memory_dict["_idx"] self.size_in_batches = full_memory_dict["_size_in_batches"]"Loaded simulation memory from {memory_path}") # Case memory file does not exist else:"Initialized empty simulation memory.")
[docs] class MemoryReplayBuffer: """Implements a memory replay buffer for simulation-based inference."""
[docs] def __init__(self, capacity_in_batches=500): """Creates a circular buffer following the logic of experience replay. Parameters ---------- capacity_in_batches : int, optional, default: 500 The capacity of the buffer in batches of simulations. Could potentially grow very large, so make sure you pick a reasonable number! """ self._capacity = capacity_in_batches self._buffer = [None] * self._capacity self._idx = 0 self._size_in_batches = 0 self._is_full = False
[docs] def store(self, forward_dict): """Stores simulation outputs, if internal buffer is not full. Parameters ---------- forward_dict : dict The confogired outputs of the forward model. """ # If full, overwrite at index if self._is_full: self._overwrite(forward_dict) # Otherwise still capacity to append else: # Add to internal list self._buffer[self._idx] = forward_dict # Increment index and # of batches currently stored self._idx += 1 self._size_in_batches += 1 # Check whether buffer is full and set flag if thats the case if self._idx == self._capacity: self._is_full = True
[docs] def sample(self): """Samples `batch_size` number of parameter vectors and simulations from buffer. Returns ------- forward_dict : dict The (raw or configured) outputs of the forward model. """ rand_idx = np.random.default_rng().integers(low=0, high=self._size_in_batches) return self._buffer[rand_idx]
def _overwrite(self, forward_dict): """Overwrites a simulated batch at current position. Only called when the internal buffer is full.""" # Reset index, if at the end of buffer if self._idx == self._capacity: self._idx = 0 # Overwrite params and data at index self._buffer[self._idx] = forward_dict # Increment index self._idx += 1