correlation#

bayesflow.diagnostics.correlation(estimates: Mapping[str, ~numpy.ndarray] | ~numpy.ndarray, targets: Mapping[str, ~numpy.ndarray] | ~numpy.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, test_quantities: dict[str, ~collections.abc.Callable]=None, aggregation: Callable = <function median>) dict[str, Any][source]#

Computes the Pearson correlation between estimates and targets for each random draw from the posterior distribution across datasets, separately for each variable.

Parameters:
estimatesnp.ndarray or dict[str, np.ndarray]

Posterior samples, either as a NumPy array of shape (num_datasets, num_draws_post, num_variables) or as a dictionary mapping variable names to arrays. Comprises num_draws_post random draws from the posterior distribution for each data set from num_datasets.

targetsnp.ndarray or dict[str, np.ndarray]

Prior samples, either as a NumPy array of shape (num_datasets, num_variables) or as a dictionary mapping variable names to arrays. Comprises num_datasets ground truths.

variable_keysSequence[str], optional (default = None)

Select keys from the dictionaries provided in estimates and targets. By default, select all keys.

variable_namesSequence[str], optional (default = None)

Optional variable names to show in the output.

test_quantitiesdict or None, optional, default: None

A dict that maps plot titles to functions that compute test quantities based on estimate/target draws.

The dict keys are automatically added to variable_keys and variable_names. Test quantity functions are expected to accept a dict of draws with shape (batch_size, ...) as the first positional argument and return a NumPy array of shape (batch_size,).

aggregationcallable, optional (default = np.median)

Function to aggregate the correlations across posterior draws. Typically np.mean or np.median.

Returns:
resultdict

Dictionary containing:

  • “values”np.ndarray

    The aggregated Pearson correlation for each variable.

  • “metric_name”str

    The name of the metric (“Correlation”).

  • “variable_names”str

    The (inferred) variable names.