randomly_mask_along_axis#

bayesflow.utils.randomly_mask_along_axis(x: Tensor, drop_prob: float, axis: int = 0, seed_generator: SeedGenerator = None) Tensor[source]#

Randomly zero out entire slices of a tensor along an axis.

Each slice along axis is independently zeroed with probability drop_prob. With axis=0 (default) this drops entire batch samples, which is the standard approach for classifier-free guidance.

Parameters:
xTensor

Input tensor.

drop_probfloat

Probability of dropping each slice. Must be in [0, 1].

axisint, optional

Axis along which to mask. Default is 0 (batch axis).

seed_generatorkeras.random.SeedGenerator, optional

Seed generator used for randomness.

Returns:
Tensor

Tensor with the same shape as x, with some slices zeroed out.