correlation#
- bayesflow.diagnostics.correlation(estimates: Mapping[str, ~numpy.ndarray] | ~numpy.ndarray, targets: Mapping[str, ~numpy.ndarray] | ~numpy.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, test_quantities: dict[str, ~collections.abc.Callable]=None, aggregation: Callable = <function median>) dict[str, Any][source]#
Computes the Pearson correlation between estimates and targets for each random draw from the posterior distribution across datasets, separately for each variable.
- Parameters:
- estimatesnp.ndarray or dict[str, np.ndarray]
Posterior samples, either as a NumPy array of shape (num_datasets, num_draws_post, num_variables) or as a dictionary mapping variable names to arrays. Comprises num_draws_post random draws from the posterior distribution for each data set from num_datasets.
- targetsnp.ndarray or dict[str, np.ndarray]
Prior samples, either as a NumPy array of shape (num_datasets, num_variables) or as a dictionary mapping variable names to arrays. Comprises num_datasets ground truths.
- variable_keysSequence[str], optional (default = None)
Select keys from the dictionaries provided in estimates and targets. By default, select all keys.
- variable_namesSequence[str], optional (default = None)
Optional variable names to show in the output.
- test_quantitiesdict or None, optional, default: None
A dict that maps plot titles to functions that compute test quantities based on estimate/target draws.
The dict keys are automatically added to
variable_keysandvariable_names. Test quantity functions are expected to accept a dict of draws with shape(batch_size, ...)as the first positional argument and return a NumPy array of shape(batch_size,).- aggregationcallable, optional (default = np.median)
Function to aggregate the correlations across posterior draws. Typically np.mean or np.median.
- Returns:
- resultdict
Dictionary containing:
- “values”np.ndarray
The aggregated Pearson correlation for each variable.
- “metric_name”str
The name of the metric (“Correlation”).
- “variable_names”str
The (inferred) variable names.