Source code for bayesflow.utils.serialization

from copy import copy

import builtins
import inspect
import keras
import numpy as np
import sys
from warnings import warn

# this import needs to be exactly like this to work with monkey patching
from keras.saving import deserialize_keras_object

from .context_managers import monkey_patch
from .decorators import allow_args


PREFIX = "_bayesflow_"

_type_prefix = "__bayesflow_type__"


[docs] def serialize_value_or_type(config, name, obj): """This function is deprecated.""" warn( "This method is deprecated. It was replaced by bayesflow.utils.serialization.serialize.", DeprecationWarning, stacklevel=2, )
[docs] def deserialize_value_or_type(config, name): """This function is deprecated.""" warn( "This method is deprecated. It was replaced by bayesflow.utils.serialization.deserialize.", DeprecationWarning, stacklevel=2, )
[docs] def deserialize(config: dict, custom_objects=None, safe_mode=True, **kwargs): """Deserialize an object serialized with :py:func:`serialize`. Wrapper function around `keras.saving.deserialize_keras_object` to enable deserialization of classes. Parameters ---------- config : dict Python dict describing the object. custom_objects : dict, optional Python dict containing a mapping between custom object names and the corresponding classes or functions. Forwarded to `keras.saving.deserialize_keras_object`. safe_mode : bool, optional Boolean, whether to disallow unsafe lambda deserialization. When safe_mode=False, loading an object has the potential to trigger arbitrary code execution. This argument is only applicable to the Keras v3 model format. Defaults to True. Forwarded to `keras.saving.deserialize_keras_object`. Returns ------- obj : The object described by the config dictionary. Raises ------ ValueError If a type in the config can not be deserialized. See Also -------- serialize """ with monkey_patch(deserialize_keras_object, deserialize) as original_deserialize: if isinstance(config, str) and config.startswith(_type_prefix): # we marked this as a type during serialization config = config[len(_type_prefix) :] tp = keras.saving.get_registered_object( # TODO: can we pass module objects without overwriting numpy's dict with builtins? config, custom_objects=custom_objects, module_objects=np.__dict__ | builtins.__dict__, ) if tp is None: raise ValueError( f"Could not deserialize type {config!r}. Make sure it is registered with " f"`keras.saving.register_keras_serializable` or pass it in `custom_objects`." ) return tp if inspect.isclass(config): # add this base case since keras does not cover it return config obj = original_deserialize(config, custom_objects=custom_objects, safe_mode=safe_mode, **kwargs) return obj
[docs] @allow_args def serializable(cls, package: str, name: str | None = None, disable_module_check: bool = False): """Register class as Keras serializable. Wrapper function around `keras.saving.register_keras_serializable` to automatically check consistency of the supplied `package` argument with the module a class resides in. The `package` name should generally be the module the class resides in, truncated at depth two. Valid examples would be "bayesflow.networks" or "bayesflow.adapters". The check can be disabled if necessary by setting `disable_module_check` to True. This should only be done in exceptional cases, and accompanied by a comment why it is necessary for a given class. Parameters ---------- cls : type The class to register. package : str `package` argument forwarded to `keras.saving.register_keras_serializable`. Should generally correspond to the module of the class, truncated at depth two (e.g., "bayesflow.networks"). name : str, optional `name` argument forwarded to `keras.saving.register_keras_serializable`. If None is provided, the classe's __name__ attribute is used. disable_module_check : bool, optional Disable check that the provided `package` is consistent with the location of the class within the library. Raises ------ ValueError If the supplied `package` does not correspond to the module of the class, truncated at depth two, and `disable_module_check` is False. No error is thrown when a class is not part of the bayesflow module. """ if not disable_module_check: frame = sys._getframe(2) g = frame.f_globals module_name = g.get("__name__", "") # only apply this check if the class is inside the bayesflow module is_bayesflow = module_name.split(".")[0] == "bayesflow" auto_package = ".".join(module_name.split(".")[:2]) if is_bayesflow and package != auto_package: raise ValueError( "'package' should be the first two levels of the module the class resides in (e.g., bayesflow.networks)" f'. In this case it should be \'package="{auto_package}"\' (was "{package}"). If this is not possible' " (e.g., because a class was moved to a different module, and serializability should be preserved)," " please set 'disable_module_check=True' and add a comment why it is necessary for this class." ) if name is None: name = copy(cls.__name__) # register subclasses as keras serializable return keras.saving.register_keras_serializable(package=package, name=name)(cls)
[docs] def serialize(obj): """Serialize an object using Keras. Wrapper function around `keras.saving.serialize_keras_object`, which adds the ability to serialize classes. Parameters ---------- object : Keras serializable object, or class The object to serialize Returns ------- config : dict A python dict that represents the object. The python dict can be deserialized via :py:func:`deserialize`. See Also -------- deserialize """ if isinstance(obj, (tuple, list, dict)): return keras.tree.map_structure(serialize, obj) elif inspect.isclass(obj): return _type_prefix + keras.saving.get_registered_name(obj) return keras.saving.serialize_keras_object(obj)