import inspect
from collections.abc import Callable, Mapping, Sequence
from typing import TypeVar, Any
import keras
import numpy as np
from bayesflow.types import Tensor
from . import logging
T = TypeVar("T")
[docs]
def convert_args(f: Callable, *args: any, **kwargs: any) -> tuple[any, ...]:
"""Convert positional and keyword arguments to just positional arguments for f"""
if not kwargs:
return args
signature = inspect.signature(f)
# convert to just kwargs first
kwargs = convert_kwargs(f, *args, **kwargs)
parameters = []
for name, param in signature.parameters.items():
if param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]:
continue
parameters.append(kwargs.get(name, param.default))
return tuple(parameters)
[docs]
def convert_kwargs(f: Callable, *args: any, **kwargs: any) -> dict[str, any]:
"""Convert positional and keyword arguments qto just keyword arguments for f"""
if not args:
return kwargs
signature = inspect.signature(f)
parameters = dict(zip(signature.parameters, args))
for name, value in kwargs.items():
if name in parameters:
raise TypeError(f"{f.__name__}() got multiple arguments for argument '{name}'")
parameters[name] = value
return parameters
[docs]
def filter_kwargs(kwargs: dict[str, T], f: Callable) -> dict[str, T]:
"""Filter keyword arguments for f"""
signature = inspect.signature(f)
for parameter in signature.parameters.values():
if parameter.kind == inspect.Parameter.VAR_KEYWORD:
# there is a **kwargs parameter, so anything is valid
return kwargs
kwargs = {key: value for key, value in kwargs.items() if key in signature.parameters}
return kwargs
[docs]
def keras_kwargs(kwargs: dict[str, T]) -> dict[str, T]:
"""Filter keyword arguments for keras.Layer"""
valid_keys = ["dtype", "name", "trainable"]
return {key: value for key, value in kwargs.items() if key in valid_keys}
# TODO: rename and streamline and make protected
def check_output(outputs: T) -> None:
# Warn if any NaNs present in output
for k, v in outputs.items():
nan_mask = keras.ops.isnan(v)
if keras.ops.any(nan_mask):
logging.warning("Found a total of {n:d} nan values for output {k}.", n=int(keras.ops.sum(nan_mask)), k=k)
# Warn if any inf present in output
for k, v in outputs.items():
inf_mask = keras.ops.isinf(v)
if keras.ops.any(inf_mask):
logging.warning("Found a total of {n:d} inf values for output {k}.", n=int(keras.ops.sum(inf_mask)), k=k)
[docs]
def split_tensors(data: Mapping[any, Tensor], axis: int = -1) -> Mapping[any, Tensor]:
"""Split tensors in the dictionary along the given axis."""
result = {}
for key, value in data.items():
if keras.ops.shape(value)[axis] == 1:
result[key] = keras.ops.squeeze(value, axis=axis)
continue
splits = keras.ops.split(value, keras.ops.shape(value)[axis], axis=axis)
splits = [keras.ops.squeeze(split, axis=axis) for split in splits]
for i, split in enumerate(splits):
result[f"{key}_{i + 1}"] = split
return result
[docs]
def split_arrays(data: Mapping[str, np.ndarray], axis: int = -1) -> Mapping[str, np.ndarray]:
"""Split tensors in the dictionary along the given axis."""
result = {}
for key, value in data.items():
if not hasattr(value, "shape"):
result[key] = np.array([value])
continue
if len(value.shape) == 1:
result[key] = value
continue
if value.shape[axis] == 1:
result[key] = np.squeeze(value, axis=axis)
continue
splits = np.split(value, value.shape[axis], axis=axis)
splits = [np.squeeze(split, axis=axis) for split in splits]
for i, split in enumerate(splits):
result[f"{key}_{i}"] = split
return result
class VariableArray(np.ndarray):
"""
An enriched numpy array with information on variable keys and names
to be used in post-processing, specifically the diagnostics module.
The current implemention is very basic and we may want to extend it
in the future should this general structure prove useful.
Design according to
https://numpy.org/doc/stable/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray
"""
def __new__(cls, input_array, variable_keys=None, variable_names=None):
obj = np.asarray(input_array).view(cls)
obj.variable_keys = variable_keys
obj.variable_names = variable_names
return obj
def __array_finalize__(self, obj):
if obj is None:
return
self.variable_keys = getattr(obj, "variable_keys", None)
self.variable_names = getattr(obj, "variable_names", None)
def make_variable_array(
x: Mapping[str, np.ndarray] | np.ndarray,
dataset_ids: Sequence[int] | int = None,
variable_keys: Sequence[str] | str = None,
variable_names: Sequence[str] | str = None,
default_name: str = "v",
) -> VariableArray:
"""
Helper function to validate arrays for use in the diagnostics module.
Parameters
----------
x : dict[str, ndarray] or ndarray. Dict of arrays or array to be validated.
See dicts_to_arrays
dataset_ids : Sequence of integers indexing the datasets to select (default = None).
By default, use all datasets.
variable_names : Sequence[str], optional (default = None)
Optional variable names to act as a filter if dicts provided or actual variable names in case of array
inputs.
default_name : str, optional (default = "v")
The default variable name to use if array arguments and no variable names are provided.
"""
if isinstance(variable_keys, str):
variable_keys = [variable_keys]
if isinstance(variable_names, str):
variable_names = [variable_names]
if isinstance(x, dict):
if variable_keys is not None:
x = {k: x[k] for k in variable_keys}
variable_keys = x.keys()
if dataset_ids is not None:
if isinstance(dataset_ids, int):
# dataset_ids needs to be a sequence so that np.stack works correctly
dataset_ids = [dataset_ids]
x = {k: v[dataset_ids] for k, v in x.items()}
x = split_arrays(x)
if variable_names is None:
variable_names = list(x.keys())
x = np.stack(list(x.values()), axis=-1)
# Case arrays provided
elif isinstance(x, np.ndarray):
if isinstance(x, VariableArray):
# reuse existing variable keys and names if contained in x
if variable_names is None:
variable_names = x.variable_names
if variable_keys in None:
variable_keys = x.variable_keys
# use default names if not otherwise specified
if variable_names is None:
variable_names = [f"{default_name}_{i}" for i in range(x.shape[-1])]
if dataset_ids is not None:
x = x[dataset_ids]
# Throw if unknown type
else:
raise TypeError(f"Only dicts and tensors are supported as arguments, but your estimates are of type {type(x)}")
if len(variable_names) is not x.shape[-1]:
raise ValueError("Length of 'variable_names' should be the same as the number of variables.")
if variable_keys is None:
# every variable will count as its own key if not otherwise specified
variable_keys = variable_names
x = VariableArray(x, variable_keys=variable_keys, variable_names=variable_names)
return x
def dicts_to_arrays(
estimates: Mapping[str, np.ndarray] | np.ndarray,
targets: Mapping[str, np.ndarray] | np.ndarray = None,
priors: Mapping[str, np.ndarray] | np.ndarray = None,
dataset_ids: Sequence[int] | int = None,
variable_keys: Sequence[str] | str = None,
variable_names: Sequence[str] | str = None,
default_name: str = "v",
) -> dict[str, Any]:
"""Helper function that prepares estimates and optional ground truths for diagnostics
(plotting or computation of metrics).
The function operates on both arrays and dictionaries and assumes either a dictionary
where each key contains a 1D or a 2D array (i.e., a univariate quantity or samples thereof)
or a 2D or 3D array where the last axis represents all quantities of interest.
If a `ground_truths` array is provided, it must correspond to estimates in terms of type
and structure of the first and last axis.
If a dictionary is provided, `variable_names` acts as a filter to select variables from
estimates. If an array is provided, `variable_names` can be used to override the `default_name`.
Parameters
----------
estimates : dict[str, ndarray] or ndarray
The model-generated predictions or estimates, which can take the following forms:
- ndarray of shape (num_datasets, num_variables)
Point estimates for each dataset, where `num_datasets` is the number of datasets
and `num_variables` is the number of variables per dataset.
- ndarray of shape (num_datasets, num_draws, num_variables)
Posterior samples for each dataset, where `num_datasets` is the number of datasets,
`num_draws` is the number of posterior draws, and `num_variables` is the number of variables.
targets : dict[str, ndarray] or ndarray, optional (default = None)
Ground-truth values corresponding to the estimates. Must match the structure and dimensionality
of `estimates` in terms of first and last axis.
dataset_ids : Sequence of integers indexing the datasets to select (default = None).
By default, use all datasets.
variable_keys : list or None, optional, default: None
Select keys from the dictionary provided in samples.
By default, select all keys.
variable_names : Sequence[str], optional (default = None)
Optional variable names to act as a filter if dicts provided or actual variable names in case of array
inputs.
default_name : str, optional (default = "v")
The default variable name to use if array arguments and no variable names are provided.
"""
# other to be validated arrays (see below) will take use
# the variable_keys and variable_names implied by estimates
estimates = make_variable_array(
estimates,
dataset_ids=dataset_ids,
variable_keys=variable_keys,
variable_names=variable_names,
default_name=default_name,
)
if targets is not None:
targets = make_variable_array(
targets,
dataset_ids=dataset_ids,
variable_keys=estimates.variable_keys,
variable_names=estimates.variable_names,
)
if priors is not None:
priors = make_variable_array(
priors,
# priors are data independent so datasets_ids is not passed here
variable_keys=estimates.variable_keys,
variable_names=estimates.variable_names,
)
return dict(
estimates=estimates,
targets=targets,
priors=priors,
)
[docs]
def squeeze_inner_estimates_dict(estimates):
"""If a dictionary has only one key-value pair and the key is "value", return only its value.
Otherwise, return the unchanged dictionary.
This method helps to remove unnecessary nesting levels.
"""
if len(estimates.keys()) == 1 and "value" in estimates.keys():
return estimates["value"]
else:
return estimates