Source code for bayesflow.utils.dict_utils

import inspect
from collections.abc import Callable, Mapping, Sequence
from typing import TypeVar, Any

import keras
import numpy as np

from bayesflow.types import Tensor
from . import logging

T = TypeVar("T")


[docs] def convert_args(f: Callable, *args: any, **kwargs: any) -> tuple[any, ...]: """Convert positional and keyword arguments to just positional arguments for f""" if not kwargs: return args signature = inspect.signature(f) # convert to just kwargs first kwargs = convert_kwargs(f, *args, **kwargs) parameters = [] for name, param in signature.parameters.items(): if param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]: continue parameters.append(kwargs.get(name, param.default)) return tuple(parameters)
[docs] def convert_kwargs(f: Callable, *args: any, **kwargs: any) -> dict[str, any]: """Convert positional and keyword arguments qto just keyword arguments for f""" if not args: return kwargs signature = inspect.signature(f) parameters = dict(zip(signature.parameters, args)) for name, value in kwargs.items(): if name in parameters: raise TypeError(f"{f.__name__}() got multiple arguments for argument '{name}'") parameters[name] = value return parameters
[docs] def filter_kwargs(kwargs: dict[str, T], f: Callable) -> dict[str, T]: """Filter keyword arguments for f""" signature = inspect.signature(f) for parameter in signature.parameters.values(): if parameter.kind == inspect.Parameter.VAR_KEYWORD: # there is a **kwargs parameter, so anything is valid return kwargs kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters} return kwargs
[docs] def keras_kwargs(kwargs: dict[str, T]) -> dict[str, T]: """Filter keyword arguments for keras.Layer""" valid_keys = ["dtype", "name", "trainable"] return {key: value for key, value in kwargs.items() if key in valid_keys}
# TODO: rename and streamline and make protected def check_output(outputs: T) -> None: # Warn if any NaNs present in output for k, v in outputs.items(): nan_mask = keras.ops.isnan(v) if keras.ops.any(nan_mask): logging.warning("Found a total of {n:d} nan values for output {k}.", n=int(keras.ops.sum(nan_mask)), k=k) # Warn if any inf present in output for k, v in outputs.items(): inf_mask = keras.ops.isinf(v) if keras.ops.any(inf_mask): logging.warning("Found a total of {n:d} inf values for output {k}.", n=int(keras.ops.sum(inf_mask)), k=k)
[docs] def split_tensors(data: Mapping[any, Tensor], axis: int = -1) -> Mapping[any, Tensor]: """Split tensors in the dictionary along the given axis.""" result = {} for key, value in data.items(): if keras.ops.shape(value)[axis] == 1: result[key] = keras.ops.squeeze(value, axis=axis) continue splits = keras.ops.split(value, keras.ops.shape(value)[axis], axis=axis) splits = [keras.ops.squeeze(split, axis=axis) for split in splits] for i, split in enumerate(splits): result[f"{key}_{i + 1}"] = split return result
[docs] def split_arrays(data: Mapping[str, np.ndarray], axis: int = -1) -> Mapping[str, np.ndarray]: """Split tensors in the dictionary along the given axis.""" result = {} for key, value in data.items(): if not hasattr(value, "shape"): result[key] = np.array([value]) continue if len(value.shape) == 1: result[key] = value continue if value.shape[axis] == 1: result[key] = np.squeeze(value, axis=axis) continue splits = np.split(value, value.shape[axis], axis=axis) splits = [np.squeeze(split, axis=axis) for split in splits] for i, split in enumerate(splits): result[f"{key}_{i}"] = split return result
class VariableArray(np.ndarray): """ An enriched numpy array with information on variable keys and names to be used in post-processing, specifically the diagnostics module. The current implemention is very basic and we may want to extend it in the future should this general structure prove useful. Design according to https://numpy.org/doc/stable/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray """ def __new__(cls, input_array, variable_keys=None, variable_names=None): obj = np.asarray(input_array).view(cls) obj.variable_keys = variable_keys obj.variable_names = variable_names return obj def __array_finalize__(self, obj): if obj is None: return self.variable_keys = getattr(obj, "variable_keys", None) self.variable_names = getattr(obj, "variable_names", None) def make_variable_array( x: Mapping[str, np.ndarray] | np.ndarray, dataset_ids: Sequence[int] | int = None, variable_keys: Sequence[str] | str = None, variable_names: Sequence[str] | str = None, default_name: str = "v", ) -> VariableArray: """ Helper function to validate arrays for use in the diagnostics module. Parameters ---------- x : dict[str, ndarray] or ndarray. Dict of arrays or array to be validated. See dicts_to_arrays dataset_ids : Sequence of integers indexing the datasets to select (default = None). By default, use all datasets. variable_names : Sequence[str], optional (default = None) Optional variable names to act as a filter if dicts provided or actual variable names in case of array inputs. default_name : str, optional (default = "v") The default variable name to use if array arguments and no variable names are provided. """ if isinstance(variable_keys, str): variable_keys = [variable_keys] if isinstance(variable_names, str): variable_names = [variable_names] if isinstance(x, dict): if variable_keys is not None: x = {k: x[k] for k in variable_keys} variable_keys = x.keys() if dataset_ids is not None: if isinstance(dataset_ids, int): # dataset_ids needs to be a sequence so that np.stack works correctly dataset_ids = [dataset_ids] x = {k: v[dataset_ids] for k, v in x.items()} x = split_arrays(x) if variable_names is None: variable_names = list(x.keys()) x = np.stack(list(x.values()), axis=-1) # Case arrays provided elif isinstance(x, np.ndarray): if isinstance(x, VariableArray): # reuse existing variable keys and names if contained in x if variable_names is None: variable_names = x.variable_names if variable_keys in None: variable_keys = x.variable_keys # use default names if not otherwise specified if variable_names is None: variable_names = [f"{default_name}_{i}" for i in range(x.shape[-1])] if dataset_ids is not None: x = x[dataset_ids] # Throw if unknown type else: raise TypeError(f"Only dicts and tensors are supported as arguments, but your estimates are of type {type(x)}") if len(variable_names) is not x.shape[-1]: raise ValueError("Length of 'variable_names' should be the same as the number of variables.") if variable_keys is None: # every variable will count as its own key if not otherwise specified variable_keys = variable_names x = VariableArray(x, variable_keys=variable_keys, variable_names=variable_names) return x def dicts_to_arrays( estimates: Mapping[str, np.ndarray] | np.ndarray, targets: Mapping[str, np.ndarray] | np.ndarray = None, priors: Mapping[str, np.ndarray] | np.ndarray = None, dataset_ids: Sequence[int] | int = None, variable_keys: Sequence[str] | str = None, variable_names: Sequence[str] | str = None, default_name: str = "v", ) -> dict[str, Any]: """Helper function that prepares estimates and optional ground truths for diagnostics (plotting or computation of metrics). The function operates on both arrays and dictionaries and assumes either a dictionary where each key contains a 1D or a 2D array (i.e., a univariate quantity or samples thereof) or a 2D or 3D array where the last axis represents all quantities of interest. If a `ground_truths` array is provided, it must correspond to estimates in terms of type and structure of the first and last axis. If a dictionary is provided, `variable_names` acts as a filter to select variables from estimates. If an array is provided, `variable_names` can be used to override the `default_name`. Parameters ---------- estimates : dict[str, ndarray] or ndarray The model-generated predictions or estimates, which can take the following forms: - ndarray of shape (num_datasets, num_variables) Point estimates for each dataset, where `num_datasets` is the number of datasets and `num_variables` is the number of variables per dataset. - ndarray of shape (num_datasets, num_draws, num_variables) Posterior samples for each dataset, where `num_datasets` is the number of datasets, `num_draws` is the number of posterior draws, and `num_variables` is the number of variables. targets : dict[str, ndarray] or ndarray, optional (default = None) Ground-truth values corresponding to the estimates. Must match the structure and dimensionality of `estimates` in terms of first and last axis. dataset_ids : Sequence of integers indexing the datasets to select (default = None). By default, use all datasets. variable_keys : list or None, optional, default: None Select keys from the dictionary provided in samples. By default, select all keys. variable_names : Sequence[str], optional (default = None) Optional variable names to act as a filter if dicts provided or actual variable names in case of array inputs. default_name : str, optional (default = "v") The default variable name to use if array arguments and no variable names are provided. """ # other to be validated arrays (see below) will take use # the variable_keys and variable_names implied by estimates estimates = make_variable_array( estimates, dataset_ids=dataset_ids, variable_keys=variable_keys, variable_names=variable_names, default_name=default_name, ) if targets is not None: targets = make_variable_array( targets, dataset_ids=dataset_ids, variable_keys=estimates.variable_keys, variable_names=estimates.variable_names, ) if priors is not None: priors = make_variable_array( priors, # priors are data independent so datasets_ids is not passed here variable_keys=estimates.variable_keys, variable_names=estimates.variable_names, ) return dict( estimates=estimates, targets=targets, priors=priors, )
[docs] def squeeze_inner_estimates_dict(estimates): """If a dictionary has only one key-value pair and the key is "value", return only its value. Otherwise, return the unchanged dictionary. This method helps to remove unnecessary nesting levels. """ if len(estimates.keys()) == 1 and "value" in estimates.keys(): return estimates["value"] else: return estimates