Source code for bayesflow.helper_classes

# Copyright (c) 2022 The BayesFlow Developers

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import os
import re
from copy import deepcopy

import numpy as np
import pandas as pd
import tensorflow as tf

try:
    import cPickle as pickle
except:
    import pickle

import logging

logging.basicConfig()

from sklearn.linear_model import HuberRegressor

from bayesflow.default_settings import DEFAULT_KEYS


[docs] class SimulationDataset: """Helper class to create a tensorflow.data.Dataset which parses simulation dictionaries and returns simulation dictionaries as expected by BayesFlow amortizers. """
[docs] def __init__(self, forward_dict, batch_size, buffer_size=1024): """Creates a wrapper holding a ``tf.data.Dataset`` instance for offline training in an amortized estimation context. Parameters ---------- forward_dict : dict The outputs from a ``GenerativeModel`` or a custom function, stored in a dictionary with at least the following keys: ``sim_data`` - an array representing the batched output of the model ``prior_draws`` - an array with prior generated from the model's prior batch_size : int The total number of simulations from all models in a given batch. The batch size per model will be calculated as ``batch_size // num_models`` buffer_size : int, optional, default: 1024 The buffer size for shuffling elements in a ``tf.data.Dataset`` """ slices, keys_used, keys_none, n_sim = self._determine_slices(forward_dict) self.data = tf.data.Dataset.from_tensor_slices(tuple(slices)).shuffle(buffer_size).batch(batch_size) self.keys_used = keys_used self.keys_none = keys_none self.n_sim = n_sim self.num_batches = len(self.data)
def _determine_slices(self, forward_dict): """Determine slices for a tensorflow Dataset.""" keys_used = [] keys_none = [] slices = [] for k, v in forward_dict.items(): if forward_dict[k] is not None: slices.append(v) keys_used.append(k) else: keys_none.append(k) n_sim = forward_dict[DEFAULT_KEYS["sim_data"]].shape[0] return slices, keys_used, keys_none, n_sim
[docs] def __call__(self, batch_in): """Convert output of tensorflow.data.Dataset to dict.""" forward_dict = {} for key_used, batch_stuff in zip(self.keys_used, batch_in): forward_dict[key_used] = batch_stuff.numpy() for key_none in zip(self.keys_none): forward_dict[key_none] = None return forward_dict
def __len__(self): return len(self.data) def __iter__(self): return map(self, self.data)
[docs] class MultiSimulationDataset: """Helper class for model comparison training with multiple generative models. Will create multiple ``SimulationDataset`` instances, each parsing their own simulation dictionaries and returning these as expected by BayesFlow amortizers. """
[docs] def __init__(self, forward_dict, batch_size, buffer_size=1024): """Creates a wrapper holding multiple ``tf.data.Dataset`` instances for offline training in an amortized model comparison context. Parameters ---------- forward_dict : dict The outputs from a ``MultiGenerativeModel`` or a custom function, stored in a dictionary with at least the following keys: ``model_outputs`` - a list with length equal to the number of models, each element representing a batched output of a single model ``model_indices`` - a list with integer model indices, which will later be one-hot-encoded for the model comparison learning problem. batch_size : int The total number of simulations from all models in a given batch. The batch size per model will be calculated as ``batch_size // num_models`` buffer_size : int, optional, default: 1024 The buffer size for shuffling elements in a ``tf.data.Dataset`` """ self.model_indices = forward_dict[DEFAULT_KEYS["model_indices"]] self.num_models = len(self.model_indices) self.per_model_batch_size = batch_size // self.num_models self.datasets = [ SimulationDataset(out, self.per_model_batch_size, buffer_size) for out in forward_dict[DEFAULT_KEYS["model_outputs"]] ] self.current_it = 0 self.num_batches = min([d.num_batches for d in self.datasets]) self.iters = [iter(d) for d in self.datasets] self.batch_size = batch_size # Include further keys (= shared context) from forward_dict self.further_keys = {} for key, value in forward_dict.items(): if key not in [DEFAULT_KEYS["model_outputs"], DEFAULT_KEYS["model_indices"]]: self.further_keys[key] = value
def __next__(self): if self.current_it < self.num_batches: outputs = [next(d) for d in self.iters] output_dict = {DEFAULT_KEYS["model_outputs"]: outputs, DEFAULT_KEYS["model_indices"]: self.model_indices} if self.further_keys: output_dict.update(self.further_keys) self.current_it += 1 return output_dict self.current_it = 0 self.iters = [iter(d) for d in self.datasets] raise StopIteration def __iter__(self): return self
[docs] class EarlyStopper: """This class will track the total validation loss and trigger an early stopping recommendation based on its hyperparameters."""
[docs] def __init__(self, patience=5, tolerance=0.05): """ patience : int, optional, default: 5 How many successive times the tolerance value is reached before triggering an early stopping recommendation. tolerance : float, optional, default: 0.05 The minimum reduction of validation loss to be considered significant. """ self.history = [] self.patience = patience self.tolerance = tolerance self._patience_counter = 0
[docs] def update_and_recommend(self, current_val_loss): """Adds loss to history and check difference between sequential losses.""" self.history.append(current_val_loss) rec = self._check_patience() return rec
def _check_patience(self): """Check whether the patience has been surpassed or not. Assumes current_val_loss has previously been added to the internal history, so it has at least one element. """ # Still not enough history, no recommendation if len(self.history) <= 1: return False # Significant increase according to tolerance, reset patience if (self.history[-2] - self.history[-1]) >= self.tolerance: self._patience_counter = 0 return False # Not a signifcant increase, check counter else: # Still no stop recommendation, but increase counter if self._patience_counter < self.patience: self._patience_counter += 1 return False # Reset counter and recommend stop else: self._patience_counter = 0 return True
[docs] class RegressionLRAdjuster: """This class will compute the slope of the loss trajectory and inform learning rate decay.""" file_name = "lr_adjuster"
[docs] def __init__( self, optimizer, period=1000, wait_between_fits=10, patience=10, tolerance=-0.05, reduction_factor=0.25, cooldown_factor=2, num_resets=3, **kwargs, ): """Creates an instance with given hyperparameters which will track the slope of the loss trajectory according to specified hyperparameters and then issue an optional stopping suggestion. Parameters ---------- optimizer : tf.keras.optimizers.Optimizer instance An optimizer implementing a lr() method period : int, optional, default: 1000 How much loss values to consider from the past wait_between_fits : int, optional, default: 10 How many backpropagation updates to wait between two successive fits patience : int, optional, default: 10 How many successive times the tolerance value is reached before lr update. tolerance : float, optional, default: -0.05 The minimum slope to be considered substantial for training. reduction_factor : float in [0, 1], optional, default: 0.25 The factor by which the learning rate is reduced upon hitting the `tolerance` threshold for `patience` number of times cooldown_factor : float, optional, default: 2 The factor by which the `period` is multiplied to arrive at a cooldown period. num_resets : int, optional, default: 3 How many times to reduce the learning rate before issuing an optional stopping **kwargs : dict, optional, default {} Additional keyword arguments passed to the `HuberRegression` class. """ self.optimizer = optimizer self.period = period self.wait_between_periods = wait_between_fits self.regressor = HuberRegressor(**kwargs) self.t_vector = np.linspace(0, 1, self.period)[:, np.newaxis] self.patience = patience self.tolerance = tolerance self.num_resets = num_resets self.reduction_factor = reduction_factor self.stopping_issued = False self.cooldown_factor = cooldown_factor self._history = {"iteration": [], "learning_rate": []} self._reset_counter = 0 self._patience_counter = 0 self._cooldown_counter = 0 self._wait_counter = 0 self._slope = None self._is_waiting = False self._in_cooldown = False
[docs] def get_slope(self, losses): """Fits a Huber regression on the provided loss trajectory or returns `None` if not enough data points present. """ # Return None if not enough loss values present if losses.shape[0] < self.period: return None # Increment counter if self._in_cooldown: self._cooldown_counter += 1 # Check if still in a waiting phase and return old slope # if still waiting, otherwise refit Huber regression wait = self._check_waiting() if wait: return self._slope else: self.regressor.fit(self.t_vector, losses[-self.period :]) self._slope = self.regressor.coef_[0] self._check_patience() return self._slope
[docs] def reset(self): """Resets all stateful variables in preparation for a new start.""" self._reset_counter = 0 self._patience_counter = 0 self._cooldown_counter = 0 self._wait_counter = 0 self._in_cooldown = False self._is_waiting = False self.stopping_issued = False
[docs] def save_to_file(self, file_path): """Saves the state parameters of a RegressionLRAdjuster object to a pickled dictionary in file_path.""" # Create path to memory memory_path = os.path.join(file_path, f"{RegressionLRAdjuster.file_name}.pkl") # Prepare attributes states_dict = {} states_dict["_history"] = self._history states_dict["_reset_counter"] = self._reset_counter states_dict["_patience_counter"] = self._patience_counter states_dict["_cooldown_counter"] = self._cooldown_counter states_dict["_wait_counter"] = self._wait_counter states_dict["_slope"] = self._slope states_dict["_is_waiting"] = self._is_waiting states_dict["_in_cooldown"] = self._in_cooldown # Dump as pickle object with open(memory_path, "wb") as f: pickle.dump(states_dict, f)
[docs] def load_from_file(self, file_path): """Loads the saved LRAdjuster object from file_path.""" # Logger init logger = logging.getLogger() logger.setLevel(logging.INFO) # Create path to memory memory_path = os.path.join(file_path, f"{RegressionLRAdjuster.file_name}.pkl") # Case memory file exists if os.path.exists(memory_path): # Load pickle and fill in attributes with open(memory_path, "rb") as f: states_dict = pickle.load(f) self._history = states_dict["_history"] self._reset_counter = states_dict["_reset_counter"] self._patience_counter = states_dict["_patience_counter"] self._cooldown_counter = states_dict["_cooldown_counter"] self._wait_counter = states_dict["_wait_counter"] self._slope = states_dict["_slope"] self._is_waiting = states_dict["_is_waiting"] self._in_cooldown = states_dict["_in_cooldown"] logger.info(f"Loaded RegressionLRAdjuster from {memory_path}") # Case memory file does not exist else: logger.info("Initialized a new RegressionLRAdjuster.")
def _check_patience(self): """Determines whether to reduce learning rate or be patient.""" # Do nothing, if still in cooldown period if self._in_cooldown and self._cooldown_counter < int(self.cooldown_factor * self.period): return # Otherwise set cooldown flag to False and reset counter else: self._in_cooldown = False self._cooldown_counter = 0 # Check if negetaive slope too small if self._slope > self.tolerance: self._patience_counter += 1 else: self._patience_counter = max(0, self._patience_counter - 1) # Check if patience surpassed and issue a reduction in learning rate if self._patience_counter >= self.patience: self._reduce_learning_rate() self._patience_counter = 0 def _reduce_learning_rate(self): """Reduces the learning rate by a given factor.""" # Logger init logger = logging.getLogger() logger.setLevel(logging.INFO) if self._reset_counter >= self.num_resets: self.stopping_issued = True else: # Take care of updating learning rate old_lr = self.optimizer.lr.numpy() new_lr = round(self.reduction_factor * old_lr, 8) self.optimizer.lr.assign(new_lr) self._reset_counter += 1 # Store iteration and learning rate self._history["iteration"].append(self.optimizer.iterations.numpy()) self._history["learning_rate"].append(old_lr) # Verbose info to user logger.info(f"Reducing learning rate from {old_lr:.8f} to: {new_lr:.8f} and entering cooldown...") # Set cooldown flag to avoid reset for some time given by self.period self._in_cooldown = True def _check_waiting(self): """Determines whether to compute a new slope or wait.""" # Case currently waiting if self._is_waiting: # Case currently waiting but period is over if self._wait_counter >= self.wait_between_periods - 1: self._wait_counter = 0 self._is_waiting = False # Case currently waiting and period not over else: self._wait_counter += 1 return True # Case not waiting else: self._is_waiting = True self._wait_counter += 1 return False
[docs] class LossHistory: """Helper class to keep track of losses during training.""" file_name = "history"
[docs] def __init__(self): self.latest = 0 self.history = {} self.val_history = {} self.loss_names = [] self.val_loss_names = [] self._current_run = 0 self._total_loss = [] self._total_val_loss = []
@property def total_loss(self): return np.array(self._total_loss) @property def total_val_loss(self): return np.array(self._total_val_loss)
[docs] def last_total_loss(self): return self._total_loss[-1]
[docs] def last_total_val_loss(self): return self._total_val_loss[-1]
[docs] def start_new_run(self): self._current_run += 1 self.history[f"Run {self._current_run}"] = {} self.val_history[f"Run {self._current_run}"] = {}
[docs] def add_val_entry(self, epoch, val_loss): """Add validation entry to loss structure. Assume ``loss_names`` already exists as an attribute, so no attempt will be made to create names. """ # Add epoch key, if specified if self.val_history[f"Run {self._current_run}"].get(f"Epoch {epoch}") is None: self.val_history[f"Run {self._current_run}"][f"Epoch {epoch}"] = [] # Handle dict loss output if type(val_loss) is dict: # Store keys, if none existing if self.val_loss_names == []: self.val_loss_names = ["Val." + k for k in val_loss.keys()] # Create and store entry entry = [v.numpy() if type(v) is not np.ndarray else v for v in val_loss.values()] self.val_history[f"Run {self._current_run}"][f"Epoch {epoch}"].append(entry) # Add entry to total loss self._total_val_loss.append(sum(entry)) # Handle tuple or list loss output elif type(val_loss) is tuple or type(val_loss) is list: entry = [v.numpy() if type(v) is not np.ndarray else v for v in val_loss] self.val_history[f"Run {self._current_run}"][f"Epoch {epoch}"].append(entry) # Store keys, if none existing if self.val_loss_names == []: self.val_loss_names = [f"Val.Loss.{l}" for l in range(1, len(entry) + 1)] # Add entry to total loss self._total_val_loss.append(sum(entry)) # Assume scalar loss output else: self.val_history[f"Run {self._current_run}"][f"Epoch {epoch}"].append(val_loss.numpy()) # Store keys, if none existing if self.val_loss_names == []: self.val_loss_names.append("Default.Val.Loss") # Add entry to total loss self._total_val_loss.append(val_loss.numpy())
[docs] def add_entry(self, epoch, current_loss): """Adds loss entry for current epoch into internal memory data structure.""" # Add epoch key, if specified if self.history[f"Run {self._current_run}"].get(f"Epoch {epoch}") is None: self.history[f"Run {self._current_run}"][f"Epoch {epoch}"] = [] # Handle dict loss output if type(current_loss) is dict: # Store keys, if none existing if self.loss_names == []: self.loss_names = [k for k in current_loss.keys()] # Create and store entry entry = [v.numpy() if type(v) is not np.ndarray else v for v in current_loss.values()] self.history[f"Run {self._current_run}"][f"Epoch {epoch}"].append(entry) # Add entry to total loss self._total_loss.append(sum(entry)) # Handle tuple or list loss output elif type(current_loss) is tuple or type(current_loss) is list: entry = [v.numpy() if type(v) is not np.ndarray else v for v in current_loss] self.history[f"Run {self._current_run}"][f"Epoch {epoch}"].append(entry) # Store keys, if none existing if self.loss_names == []: self.loss_names = [f"Loss.{l}" for l in range(1, len(entry) + 1)] # Add entry to total loss self._total_loss.append(sum(entry)) # Assume scalar loss output else: self.history[f"Run {self._current_run}"][f"Epoch {epoch}"].append(current_loss.numpy()) # Store keys, if none existing if self.loss_names == []: self.loss_names.append("Default.Loss") # Add entry to total loss self._total_loss.append(current_loss.numpy())
[docs] def get_running_losses(self, epoch): """Compute and return running means of the losses for current epoch.""" means = np.atleast_1d(np.mean(self.history[f"Run {self._current_run}"][f"Epoch {epoch}"], axis=0)) if means.shape[0] == 1: return {"Avg.Loss": means[0]} else: return {"Avg." + k: v for k, v in zip(self.loss_names, means)}
[docs] def get_plottable(self): """Returns the losses as a nicely formatted pandas DataFrame, in case only train losses were collected, otherwise a dict of data frames. """ # Assume equal lengths per epoch and run try: losses_df = self._to_data_frame(self.history, self.loss_names) if any([v for v in self.val_history.values()]): # Rremove decay names = [name for name in self.loss_names if "Decay" not in name] val_losses_df = self._to_data_frame(self.val_history, names) return {"train_losses": losses_df, "val_losses": val_losses_df} return losses_df # Handle unequal lengths or problems when user kills training with an interrupt except ValueError as ve: if any([v for v in self.val_history.values()]): return {"train_losses": self.history, "val_losses": self.val_history} return self.history except TypeError as te: if any([v for v in self.val_history.values()]): return {"train_losses": self.history, "val_losses": self.val_history} return self.history
[docs] def flush(self): """Returns current history and removes all existing loss history, but keeps loss names.""" history = self.history val_history = self.val_history self.history = {} self.val_history = {} self._total_loss = [] self._total_val_loss = [] self._current_run = 0 return history, val_history
[docs] def save_to_file(self, file_path, max_to_keep): """Saves a `LossHistory` object to a pickled dictionary in file_path. If max_to_keep saved loss history files are found in file_path, the oldest is deleted before a new one is saved. """ # Increment history index self.latest += 1 # Path to history history_path = os.path.join(file_path, f"{LossHistory.file_name}_{self.latest}.pkl") # Prepare full history dict pickle_dict = { "history": self.history, "val_history": self.val_history, "loss_names": self.loss_names, "val_loss_names": self.val_loss_names, "_current_run": self._current_run, "_total_loss": self._total_loss, "_total_val_loss": self._total_val_loss, } # Pickle current with open(history_path, "wb") as f: pickle.dump(pickle_dict, f) # Get list of history checkpoints history_checkpoints_list = [l for l in os.listdir(file_path) if "history" in l] # Determine the oldest saved loss history and remove it if len(history_checkpoints_list) > max_to_keep: oldest_history_path = os.path.join(file_path, f"history_{self.latest-max_to_keep}.pkl") os.remove(oldest_history_path)
[docs] def load_from_file(self, file_path): """Loads the most recent saved `LossHistory` object from `file_path`.""" # Logger init logger = logging.getLogger() logger.setLevel(logging.INFO) # Get list of histories if os.path.exists(file_path): history_checkpoints_list = [l for l in os.listdir(file_path) if LossHistory.file_name in l] else: history_checkpoints_list = [] # Case history list is not empty if len(history_checkpoints_list) > 0: # Determine which file contains the latest LossHistory and load it file_numbers = [int(re.findall(r"\d+", h)[0]) for h in history_checkpoints_list] latest_file = history_checkpoints_list[np.argmax(file_numbers)] latest_number = np.max(file_numbers) latest_path = os.path.join(file_path, latest_file) # Load dictionary with open(latest_path, "rb") as f: loaded_history_dict = pickle.load(f) # Fill public entries self.latest = latest_number self.history = loaded_history_dict.get("history", {}) self.val_history = loaded_history_dict.get("val_history", {}) self.loss_names = loaded_history_dict.get("loss_names", []) self.val_loss_names = loaded_history_dict.get("val_loss_names", []) # Fill private entries self._current_run = loaded_history_dict.get("_current_run", 0) self._total_loss = loaded_history_dict.get("_total_loss", []) self._total_val_loss = loaded_history_dict.get("_total_val_loss", []) # Verbose logger.info(f"Loaded loss history from {latest_path}.") # Case history list is empty else: logger.info("Initialized empty loss history.")
def _to_data_frame(self, history, names): """Helper function to convert a history dict into a DataFrame.""" losses_list = [pd.melt(pd.DataFrame.from_dict(history[r], orient="index").T) for r in history] losses_list = pd.concat(losses_list, axis=0).value.to_list() losses_list = [l for l in losses_list if l is not None] losses_df = pd.DataFrame(losses_list, columns=names) return losses_df
[docs] class SimulationMemory: """Helper class to keep track of a pre-determined number of simulations during training.""" file_name = "memory"
[docs] def __init__(self, stores_raw=True, capacity_in_batches=50): self.stores_raw = stores_raw self._capacity = capacity_in_batches self._buffer = [None] * self._capacity self._idx = 0 self.size_in_batches = 0
[docs] def store(self, forward_dict): """Stores simulation outputs in `forward_dict`, if internal buffer is not full. Parameters ---------- forward_dict : dict The configured outputs of the forward model. """ # If full, overwrite at index if not self.is_full(): self._buffer[self._idx] = forward_dict self._idx += 1 self.size_in_batches += 1
[docs] def get_memory(self): return deepcopy(self._buffer)
[docs] def is_full(self): """Returns True if the buffer is full, otherwise False.""" if self._idx >= self._capacity: return True return False
[docs] def save_to_file(self, file_path): """Saves a `SimulationMemory` object to a pickled dictionary in file_path.""" # Create path to memory memory_path = os.path.join(file_path, f"{SimulationMemory.file_name}.pkl") # Prepare attributes full_memory_dict = {} full_memory_dict["stores_raw"] = self.stores_raw full_memory_dict["_capacity"] = self._capacity full_memory_dict["_buffer"] = self._buffer full_memory_dict["_idx"] = self._idx full_memory_dict["_size_in_batches"] = self.size_in_batches # Dump as pickle object with open(memory_path, "wb") as f: pickle.dump(full_memory_dict, f)
[docs] def load_from_file(self, file_path): """Loads the saved `SimulationMemory` object from file_path.""" # Logger init logger = logging.getLogger() logger.setLevel(logging.INFO) # Create path to memory memory_path = os.path.join(file_path, f"{SimulationMemory.file_name}.pkl") # Case memory file exists if os.path.exists(memory_path): # Load pickle and fill in attributes with open(memory_path, "rb") as f: full_memory_dict = pickle.load(f) self.stores_raw = full_memory_dict["stores_raw"] self._capacity = full_memory_dict["_capacity"] self._buffer = full_memory_dict["_buffer"] self._idx = full_memory_dict["_idx"] self.size_in_batches = full_memory_dict["_size_in_batches"] logger.info(f"Loaded simulation memory from {memory_path}") # Case memory file does not exist else: logger.info("Initialized empty simulation memory.")
[docs] class MemoryReplayBuffer: """Implements a memory replay buffer for simulation-based inference."""
[docs] def __init__(self, capacity_in_batches=500): """Creates a circular buffer following the logic of experience replay. Parameters ---------- capacity_in_batches : int, optional, default: 500 The capacity of the buffer in batches of simulations. Could potentially grow very large, so make sure you pick a reasonable number! """ self._capacity = capacity_in_batches self._buffer = [None] * self._capacity self._idx = 0 self._size_in_batches = 0 self._is_full = False
[docs] def store(self, forward_dict): """Stores simulation outputs, if internal buffer is not full. Parameters ---------- forward_dict : dict The confogired outputs of the forward model. """ # If full, overwrite at index if self._is_full: self._overwrite(forward_dict) # Otherwise still capacity to append else: # Add to internal list self._buffer[self._idx] = forward_dict # Increment index and # of batches currently stored self._idx += 1 self._size_in_batches += 1 # Check whether buffer is full and set flag if thats the case if self._idx == self._capacity: self._is_full = True
[docs] def sample(self): """Samples `batch_size` number of parameter vectors and simulations from buffer. Returns ------- forward_dict : dict The (raw or configured) outputs of the forward model. """ rand_idx = np.random.default_rng().integers(low=0, high=self._size_in_batches) return self._buffer[rand_idx]
def _overwrite(self, forward_dict): """Overwrites a simulated batch at current position. Only called when the internal buffer is full.""" # Reset index, if at the end of buffer if self._idx == self._capacity: self._idx = 0 # Overwrite params and data at index self._buffer[self._idx] = forward_dict # Increment index self._idx += 1