Source code for bayesflow.utils.logging

import keras
import logging
from functools import lru_cache


logger = logging.getLogger("bayesflow")


def _log(msg, *args, callback_fn: callable = print, **kwargs):
    if keras.backend.backend() == "jax":
        import jax

        def __log(*a, **k):
            callback_fn(msg.format(*a, **k))

        jax.debug.callback(__log, *args, **kwargs)
    else:
        callback_fn(msg.format(*args, **kwargs))


[docs] def critical(msg, *args, **kwargs): _log(msg, *args, callback_fn=logger.critical, **kwargs)
[docs] def debug(msg, *args, **kwargs): _log(msg, *args, callback_fn=logger.debug, **kwargs)
[docs] def error(msg, *args, **kwargs): _log(msg, *args, callback_fn=logger.error, **kwargs)
[docs] def exception(msg, *args, **kwargs): _log(msg, *args, callback_fn=logger.exception, **kwargs)
[docs] def info(msg, *args, **kwargs): _log(msg, *args, callback_fn=logger.info, **kwargs)
[docs] def log(msg, *args, **kwargs): _log(msg, *args, callback_fn=logger.log, **kwargs)
[docs] def warning(msg, *args, **kwargs): _log(msg, *args, callback_fn=logger.warning, **kwargs)
[docs] @lru_cache(None) def warn_once(msg, *args, **kwargs): warning(msg, *args, **kwargs)