Source code for bayesflow.diagnostics.plots.loss

from collections.abc import Sequence

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import keras.src.callbacks

from ...utils.plot_utils import make_figure, add_titles_and_labels


[docs] def loss( history: keras.callbacks.History, train_key: str = "loss", val_key: str = "val_loss", moving_average: bool = False, per_training_step: bool = False, ma_window_fraction: float = 0.01, figsize: Sequence[float] = None, train_color: str = "#132a70", val_color: str = "black", lw_train: float = 2.0, lw_val: float = 3.0, legend_fontsize: int = 14, label_fontsize: int = 14, title_fontsize: int = 16, ) -> plt.Figure: """ A generic helper function to plot the losses of a series of training epochs and runs. Parameters ---------- history : keras.src.callbacks.History History object as returned by `keras.Model.fit`. train_key : str, optional, default: "loss" The training loss key to look for in the history val_key : str, optional, default: "val_loss" The validation loss key to look for in the history moving_average : bool, optional, default: False A flag for adding a moving average line of the train_losses. per_training_step : bool, optional, default: False A flag for making loss trajectory detailed (to training steps) rather than per epoch. ma_window_fraction : int, optional, default: 0.01 Window size for the moving average as a fraction of total training steps. figsize : tuple or None, optional, default: None The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` train_color : str, optional, default: '#8f2727' The color for the train loss trajectory val_color : str, optional, default: black The color for the optional validation loss trajectory lw_train : int, optional, default: 2 The linewidth for the training loss curve lw_val : int, optional, default: 3 The linewidth for the validation loss curve legend_fontsize : int, optional, default: 14 The font size of the legend text label_fontsize : int, optional, default: 14 The font size of the y-label text title_fontsize : int, optional, default: 16 The font size of the title text Returns ------- f : plt.Figure - the figure instance for optional saving Raises ------ AssertionError If the number of columns in ``train_losses`` does not match the number of columns in ``val_losses``. """ train_losses = history.history.get(train_key) val_losses = history.history.get(val_key) train_losses = pd.DataFrame(np.array(train_losses)) val_losses = pd.DataFrame(np.array(val_losses)) if val_losses is not None else None # Determine the number of rows for plot num_row = len(train_losses.columns) # Initialize figure fig, axes = make_figure(num_row=num_row, num_col=1, figsize=(16, int(4 * num_row)) if figsize is None else figsize) # Get the number of steps as an array train_step_index = np.arange(1, len(train_losses) + 1) if val_losses is not None: val_step = int(np.floor(len(train_losses) / len(val_losses))) val_step_index = train_step_index[(val_step - 1) :: val_step] # If unequal length due to some reason, attempt a fix if val_step_index.shape[0] > val_losses.shape[0]: val_step_index = val_step_index[: val_losses.shape[0]] # Loop through loss entries and populate plot for i, ax in enumerate(axes.flat): # Plot train curve ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training") if moving_average and train_losses.columns[i] == "Loss": moving_average_window = int(train_losses.shape[0] * ma_window_fraction) smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean() ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)") # Plot optional val curve if val_losses is not None: if i < val_losses.shape[1]: ax.plot( val_step_index, val_losses.iloc[:, i], linestyle="--", marker="o", color=val_color, lw=lw_val, label="Validation", ) sns.despine(ax=ax) ax.grid(alpha=0.5) # Only add legend if there is a validation curve if val_losses is not None or moving_average: ax.legend(fontsize=legend_fontsize) # Add labels, titles, and set font sizes add_titles_and_labels( axes=axes, num_row=num_row, num_col=1, title=["Loss Trajectory"], xlabel="Training step #" if per_training_step else "Training epoch #", ylabel="Value", title_fontsize=title_fontsize, label_fontsize=label_fontsize, ) fig.tight_layout() return fig