from collections.abc import Callable, Sequence, Mapping
import matplotlib.pyplot as plt
import numpy as np
from bayesflow.utils.dict_utils import make_variable_array, dicts_to_arrays, filter_kwargs, compute_test_quantities
from bayesflow.utils.plot_utils import (
add_titles_and_labels,
make_figure,
set_layout,
prettify_subplots,
)
from bayesflow.utils.validators import check_estimates_prior_shapes
[docs]
def plot_quantity(
values: Mapping[str, np.ndarray] | np.ndarray | Callable,
targets: Mapping[str, np.ndarray] | np.ndarray,
*,
variable_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
estimates: Mapping[str, np.ndarray] | np.ndarray | None = None,
test_quantities: dict[str, Callable] = None,
figsize: Sequence[int] = None,
label_fontsize: int = 16,
title_fontsize: int = 18,
tick_fontsize: int = 12,
color: str = "#132a70",
markersize: float = 25.0,
marker: str = "o",
alpha: float = 0.5,
xlabel: str = "Ground truth",
ylabel: str = "",
num_col: int = None,
num_row: int = None,
default_name: str = "v",
) -> plt.Figure:
"""
Plot a quantity as a function of a variable for each variable key.
The function supports the following different combinations to pass
or compute the values:
1. pass `values` as an array of shape (num_datasets,) or (num_datasets, num_variables)
2. pass `values` as a dictionary with the keys 'values', 'metric_name' and 'variable_names'
as provided by the metrics functions. Note that the functions have to be called
without aggregation to obtain value per dataset.
3. pass a function to `values`, as well as `estimates`. The function should have the
signature fn(estimates, targets, [aggregation]) and return an object like the
`values` described in the previous options.
Parameters
----------
values : dict[str, np.ndarray] | np.ndarray | Callable,
The value of the quantity to plot. One of the following:
1. an array of shape (num_datasets,) or (num_datasets, num_variables)
2. a dictionary with the keys 'values', 'metric_name' and 'variable_names'
as provided by the metrics functions. Note that the functions have to be called
without aggregation to obtain value per dataset.
3. a callable, requires passing `estimates` as well. The function should have the
signature fn(estimates, targets, [aggregation]) and return an object like the
ones described in the previous options.
targets : dict[str, np.ndarray] | np.ndarray,
The parameter values plotted on the axis.
variable_keys : list or None, optional, default: None
Select keys from the dictionary provided in samples.
By default, select all keys.
variable_names : list or None, optional, default: None
The parameter names for nice plot titles. Inferred if None
estimates : np.ndarray of shape (n_data_sets, n_post_draws, n_params), optional, default: None
The posterior draws obtained from n_data_sets. Can only be supplied if
`values` is of type Callable.
test_quantities : dict or None, optional, default: None
A dict that maps plot titles to functions that compute
test quantities based on estimate/target draws.
Can only be supplied if `values` is a function.
The dict keys are automatically added to ``variable_keys``
and ``variable_names``.
Test quantity functions are expected to accept a dict of draws with
shape ``(batch_size, ...)`` as the first (typically only)
positional argument and return an NumPy array of shape
``(batch_size,)``.
The functions do not have to deal with an additional
sample dimension, as appropriate reshaping is done internally.
figsize : tuple or None, optional, default : None
The figure size passed to the matplotlib constructor. Inferred if None.
label_fontsize : int, optional, default: 16
The font size of the y-label text
title_fontsize : int, optional, default: 18
The font size of the title text
tick_fontsize : int, optional, default: 12
The font size of the axis ticklabels
color : str, optional, default: '#8f2727'
The color for the true vs. estimated scatter points and error bars
markersize : float, optional, default: 25.0
The marker size in points**2 for the scatter plot.
marker : str, optional, default: 'o'
The marker for the scatter plot.
alpha : float, default: 0.5
The opacity for the scatter plot
num_row : int, optional, default: None
The number of rows for the subplots. Dynamically determined if None.
num_col : int, optional, default: None
The number of columns for the subplots. Dynamically determined if None.
default_name : str, optional (default = "v")
The default name to use for estimates if None provided
Returns
-------
f : plt.Figure - the figure instance for optional saving
Raises
------
ShapeError
If there is a deviation from the expected shapes of ``estimates`` and ``targets``.
"""
if isinstance(values, Callable) and estimates is None:
raise ValueError("Supplied a callable as `values`, but no `estimates`.")
if not isinstance(values, Callable) and test_quantities is not None:
raise ValueError(
"Supplied `test_quantities`, but `values` is not a function. "
"As the values have to be calculated for the test quantities, "
"passing a function is required."
)
d = _prepare_values(
values=values,
targets=targets,
estimates=estimates,
variable_keys=variable_keys,
variable_names=variable_names,
test_quantities=test_quantities,
label=None,
default_name=default_name,
)
(values, targets, variable_keys, variable_names, test_quantities, _) = (
d["values"],
d["targets"],
d["variable_keys"],
d["variable_names"],
d["test_quantities"],
d["label"],
)
# store variable information at the top level for easy access
num_variables = len(variable_names)
# Configure layout
num_row, num_col = set_layout(num_variables, num_row, num_col)
# Initialize figure
fig, axes = make_figure(num_row, num_col, figsize=figsize)
# Loop and plot
for i, ax in enumerate(axes.flat):
if i >= num_variables:
break
ax.scatter(targets[:, i], values[:, i], color=color, alpha=alpha, s=markersize, marker=marker)
prettify_subplots(axes, num_subplots=num_variables, tick_fontsize=tick_fontsize)
# Add labels, titles, and set font sizes
add_titles_and_labels(
axes=axes,
num_row=num_row,
num_col=num_col,
title=variable_names,
xlabel=xlabel,
ylabel=ylabel,
title_fontsize=title_fontsize,
label_fontsize=label_fontsize,
)
fig.tight_layout()
return fig
def _prepare_values(
*,
values: Mapping[str, np.ndarray] | np.ndarray | Callable,
targets: Mapping[str, np.ndarray] | np.ndarray,
estimates: Mapping[str, np.ndarray] | np.ndarray | None,
variable_keys: Sequence[str],
variable_names: Sequence[str],
test_quantities: dict[str, Callable],
label: str | None,
default_name: str,
):
"""
Private helper function to compute/extract the values required for plotting
a quantity.
Refer to pairs_quantity and plot_quantity for details.
"""
is_values_callable = isinstance(values, Callable)
# Optionally, compute and prepend test quantities from draws
if test_quantities is not None:
updated_data = compute_test_quantities(
targets=targets,
estimates=estimates,
variable_keys=variable_keys,
variable_names=variable_names,
test_quantities=test_quantities,
)
variable_names = updated_data["variable_names"]
variable_keys = updated_data["variable_keys"]
estimates = updated_data["estimates"]
targets = updated_data["targets"]
if estimates is not None:
if is_values_callable:
values = values(
estimates=estimates,
targets=targets,
variable_keys=variable_keys,
**filter_kwargs({"aggregation": None}, values),
)
data = dicts_to_arrays(
estimates=estimates,
targets=targets,
variable_keys=variable_keys,
variable_names=variable_names,
default_name=default_name,
)
check_estimates_prior_shapes(data["estimates"], data["targets"])
estimates = data["estimates"]
targets = data["targets"]
variable_keys = variable_keys or estimates.variable_keys
if test_quantities is None:
variable_names = variable_names or estimates.variable_names
if all([key in values for key in ["values", "metric_name", "variable_names"]]):
# output of a metric function
label = values["metric_name"] if label is None else label
variable_names = variable_names or values["variable_names"]
values = values["values"]
if hasattr(values, "variable_keys"):
variable_keys = variable_keys or values.variable_keys
if hasattr(values, "variable_names") and test_quantities is None:
variable_names = variable_names or values.variable_names
try:
targets = make_variable_array(
targets,
variable_keys=variable_keys,
variable_names=variable_names,
default_name=default_name,
)
except ValueError:
raise ValueError(
"Length of 'variable_names' and number of variables do not match. "
"Did you forget to specify `variable_keys`?"
)
variable_names = targets.variable_names
variable_keys = targets.variable_keys
if values.ndim == 1:
values = values[:, None].repeat(len(variable_names), axis=-1)
try:
values = make_variable_array(
values,
variable_keys=variable_keys,
variable_names=variable_names,
default_name=default_name,
)
except ValueError:
raise ValueError(
"Length of 'variable_names' and number of variables do not match. "
"Did you forget to specify `variable_keys`?"
)
return {
"values": values,
"targets": targets,
"variable_keys": variable_keys,
"variable_names": variable_names,
"test_quantities": test_quantities,
"label": label,
}