Source code for bayesflow.adapters.transforms.random_subsample
import numpy as np
from bayesflow.utils.serialization import serializable, serialize
from .elementwise_transform import ElementwiseTransform
[docs]
@serializable(package="bayesflow.adapters")
class RandomSubsample(ElementwiseTransform):
"""
A transform that takes a random subsample of the data within an axis.
Examples
--------
>>> adapter = bf.Adapter().random_subsample("x", sample_size=3, axis=-1)
"""
def __init__(
self,
sample_size: int | float,
axis: int = -1,
):
super().__init__()
if isinstance(sample_size, float):
if sample_size <= 0 or sample_size >= 1:
raise ValueError("Sample size as a percentage must be a float between 0 and 1 exclusive. ")
self.sample_size = sample_size
self.axis = axis
[docs]
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
max_sample_size = data.shape[self.axis]
if isinstance(self.sample_size, int):
sample_size = self.sample_size
else:
sample_size = np.round(self.sample_size * max_sample_size)
# random sample without replacement
sample_indices = np.random.permutation(max_sample_size)[:sample_size]
return np.take(data, sample_indices, self.axis)
[docs]
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
# non invertible transform
return data
[docs]
def get_config(self) -> dict:
config = {"sample_size": self.sample_size, "axis": self.axis}
return serialize(config)