Source code for bayesflow.networks.inference.scoring.point_network

from typing import Sequence

import keras


from bayesflow.utils.serialization import deserialize, serializable

from .scoring_rule_network import ScoringRuleNetwork
from .scoring_rules import ScoringRule, MeanScore, QuantileScore


[docs] @serializable("bayesflow.networks") class PointNetwork(ScoringRuleNetwork): """ (IN) Implements Bayesian estimation of point estimates like mean and quantiles using a shared feed-forward architecture. ``PointNetwork`` provides a subset of the functionality of :py:class:`ScoringRuleNetwork` with a simplified interface. It only supports a predefined set of scoring rules (currently mean and quantiles) and does not support custom scoring rules or parametric distribution scores. Examples -------- The following two are equivalent: >>> inference_network = bf.networks.PointNetwork( ... ["mean", "quantiles"], q=[0.1, 0.3, 0.5, 0.7, 0.9] ... ) # doctest: +SKIP >>> from bayesflow.scoring_rules import MeanScore, QuantileScore # doctest: +SKIP >>> inference_network = bf.networks.ScoringRuleNetwork( # doctest: +SKIP ... mean=MeanScore(), ... quantiles=QuantileScore([0.1, 0.3, 0.5, 0.7, 0.9]), ... # mvn=MvNormalScore(), # not supported by PointNetwork ... ) ... but the latter supports passing any subclass of :py:class:`ScoringRule`, e.g. parametric distributions. """ def __init__( self, points: str | Sequence[str], q: Sequence[float] | None = None, subnet: str | keras.Layer = "mlp", **kwargs ): scoring_rules = self._resolve_scoring_rules(points, q) super().__init__(scoring_rules=scoring_rules, subnet=subnet, **kwargs) def _resolve_scoring_rules(self, points: Sequence[str], q: Sequence[float]) -> dict[str, ScoringRule]: scoring_rules = {} if isinstance(points, str): points = [points] for p in points: match p: case "mean" as key: scoring_rules[key] = MeanScore() case "quantiles" as key: scoring_rules[key] = QuantileScore(q=q) case _ as key: raise ValueError(f"{key} must be either `mean` or `quantiles`") return scoring_rules
[docs] @classmethod def from_config(cls, config): # PointNetwork.__init__ expects `points`, but the serialized config # contains `scoring_rules` from the parent's get_config. Bypass # PointNetwork.__init__ and call ScoringRuleNetwork.__init__ directly. config = config.copy() config["scoring_rules"] = deserialize(config["scoring_rules"]) config["subnet"] = deserialize(config["subnet"]) instance = cls.__new__(cls) ScoringRuleNetwork.__init__(instance, **config) return instance