Source code for bayesflow.utils.workflow_utils
import bayesflow.networks
from bayesflow.networks import InferenceNetwork, PointInferenceNetwork, SummaryNetwork
[docs]
def find_inference_network(inference_network: InferenceNetwork | str, **kwargs) -> InferenceNetwork:
if isinstance(inference_network, InferenceNetwork) or isinstance(inference_network, PointInferenceNetwork):
return inference_network
if isinstance(inference_network, type):
return inference_network(**kwargs)
match inference_network.lower():
case "coupling_flow":
return bayesflow.networks.CouplingFlow(**kwargs)
case "flow_matching":
return bayesflow.networks.FlowMatching(**kwargs)
case "consistency_model":
return bayesflow.networks.ConsistencyModel(**kwargs)
case str() as unknown_network:
raise ValueError(f"Unknown inference network: '{unknown_network}'")
case other:
raise TypeError(f"Unknown transform type: {other}")
[docs]
def find_summary_network(summary_network: SummaryNetwork | str, **kwargs) -> SummaryNetwork:
if isinstance(summary_network, SummaryNetwork):
return summary_network
if isinstance(summary_network, type):
return summary_network(**kwargs)
match summary_network.lower():
case "deep_set":
return bayesflow.networks.DeepSet(**kwargs)
case "set_transformer":
return bayesflow.networks.SetTransformer(**kwargs)
case "fusion_transformer":
return bayesflow.networks.FusionTransformer(**kwargs)
case "time_series_transformer":
return bayesflow.networks.TimeSeriesTransformer(**kwargs)
case "time_series_network":
return bayesflow.networks.LSTNet(**kwargs)
case str() as unknown_network:
raise ValueError(f"Unknown summary network: '{unknown_network}'")
case other:
raise TypeError(f"Unknown transform type: {other}")