Source code for utopya.model_registry.utils
"""Utility functions that work on the already initialized model registry."""
import logging
from typing import Tuple, Union
from .._yaml import load_yml
from ..parameter import extract_validation_objects
from . import MODELS
from .info_bundle import ModelInfoBundle
log = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
[docs]def get_info_bundle(
*,
model_name: str = None,
info_bundle: ModelInfoBundle = None,
bundle_label: Union[str, int] = None,
) -> ModelInfoBundle:
"""Determine the model info bundle in cases where both a model name
and an info bundle are allowed as arguments.
Args:
model_name (str, optional): The model name.
info_bundle (ModelInfoBundle, optional): The info bundle object.
If given, will directly return this object again.
bundle_label (Union[str, int], optional): In cases where only the model
name is given, the bundle_label can be used for item access,
e.g. in cases where more than one bundle is available and access
would be ambiguous.
Returns:
ModelInfoBundle: The selected info bundle item
Raises:
ValueError: If neither or both model_name and info_bundle were None
"""
if model_name is None == info_bundle is None: # XOR-check existence
raise ValueError(
"Need either of the arguments model_name or "
"info_bundle, but got both or neither!"
)
if info_bundle:
return info_bundle
if bundle_label is None:
return MODELS[model_name].item()
return MODELS[model_name][bundle_label]
[docs]def load_model_cfg(**get_info_bundle_kwargs) -> Tuple[dict, str, dict]:
"""Loads the default model configuration file for the given model name,
using the path specified in the info bundle.
Furthermore, :py:func:`~utopya.parameter.extract_validation_objects` is
called to extract any Parameter objects that require validation, replace
them with their default values, and gather the Parameter class objects
into a separate dict.
Args:
**get_info_bundle_kwargs: Used to retrieve the info bundle via
:py:func:`~utopya.model_registry.utils.get_info_bundle`
Returns:
Tuple[dict, str, dict]: The corresponding model configuration, the path
to the model configuration file, and the Parameter class objects
requiring validation.
Will be ``{}, None, {}`` if there is no default configuration
available in the selected info bundle.
Raises:
FileNotFoundError: On missing file
"""
bundle = get_info_bundle(**get_info_bundle_kwargs)
if "default_cfg" not in bundle.paths:
log.debug(
"No default model configuration available for '%s'.",
bundle.model_name,
)
return {}, None, {}
path = bundle.paths["default_cfg"]
log.debug(
"Loading default model configuration for '%s' model ...\n %s",
bundle.model_name,
path,
)
try:
model_cfg = load_yml(path)
except FileNotFoundError as err:
raise FileNotFoundError(
"Could not locate default configuration for "
f"'{bundle.model_name}' model! Expected to find it at: {path}"
) from err
# Collect the validation objects from the model configuration and replace
# them with their default values
model_cfg, to_validate = extract_validation_objects(
model_cfg, model_name=bundle.model_name
)
return model_cfg, path, to_validate