import os
import urllib.parse
import warnings
from enum import Enum, auto
from pathlib import Path
from typing import Annotated, Any, Dict, Iterable, List, Literal, Optional, TextIO
import numpy as np
import yaml
from netCDF4 import Dataset
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
UrlConstraints,
field_validator,
model_validator,
)
from simdb.imas.utils import SimDBUrl
def _expand_path(path: Path, base_path: Path) -> Path:
os.environ["MANIFEST_DIR"] = str(base_path)
path = Path(os.path.expandvars(str(path))).expanduser()
if not path.is_absolute():
if not base_path.is_absolute():
raise ValueError("base_path must be absolute")
return base_path / path
else:
path = path.resolve()
return path
ManifestUrl = Annotated[SimDBUrl, UrlConstraints(allowed_schemes=["file", "imas"])]
[docs]
class DataType(Enum):
UNKNOWN = auto()
UUID = auto()
FILE = auto()
IMAS = auto()
def _get_data_object_type(uri: SimDBUrl) -> "DataType":
if uri.scheme == "imas":
return DataType.IMAS
elif uri.scheme == "file":
if uri.path is None:
raise ValueError("no path provided")
if Path(uri.path).suffix == ".nc":
with Dataset(uri.path, "r") as ds:
if getattr(ds, "Conventions", None) == "IMAS":
return DataType.IMAS
return DataType.FILE
raise ValueError(f"URI scheme ({uri.scheme}:) not recognized")
[docs]
class DataObject(BaseModel):
model_config = ConfigDict(extra="forbid", populate_by_name=True)
uri: ManifestUrl = Field()
_type: DataType = PrivateAttr(default=DataType.UNKNOWN)
@model_validator(mode="after")
def _resolve_type(self) -> "DataObject":
self._type = _get_data_object_type(self.uri)
return self
@property
def name(self) -> str:
return self.uri.encoded_string()
[docs]
@field_validator("uri", mode="after")
@classmethod
def validate_uri(cls, v: ManifestUrl, info):
context = info.context or {}
base_path = context.get("base_path")
if not base_path:
base_path = Path.cwd()
if v.path is None:
raise ValueError("no uri path provided")
if v.scheme == "imas":
qs = dict(v.query_params())
if "path" not in qs and (
"shot" not in qs or "run" not in qs or "database" not in qs
):
raise ValueError(
"no path or (shot, run, database) provided in IMAS uri"
)
elif v.scheme == "file":
v = v.build(
scheme="file",
path=_expand_path(Path(v.path), base_path).as_posix(),
)
return v
@property
def type(self):
return self._type
[docs]
class Source(DataObject):
pass
[docs]
class Sink(DataObject):
pass
def _convert_v1_metadata(metadata: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Flatten version 1 ``values`` metadata entries into version 2 name/value pairs.
Version 1 stores metadata as a list of ``{"values": {name: value, ...}}``
entries, whereas version 2 expects a list of single ``{name: value}`` pairs.
"""
converted: List[Dict[str, Any]] = []
for entry in metadata:
values = entry.get("values", {})
for name, value in values.items():
converted.append({name: value})
return converted
[docs]
class ManifestV1(BaseModel):
"""Manifest schema for version 1.
Only kept around to load legacy manifests and convert them to the current
(version 2) format via :meth:`to_v2_data`. Inputs and outputs are carried
through unchanged (both versions use the ``uri`` form) and are validated by
:class:`Manifest`; only the metadata layout differs between versions.
"""
model_config = ConfigDict(extra="forbid", populate_by_name=True)
manifest_version: Literal[1] = Field(default=1)
alias: Optional[str] = None
responsible_name: Optional[str] = None
inputs: List[Dict[str, Any]] = Field(default_factory=list)
outputs: List[Dict[str, Any]] = Field(default_factory=list)
metadata: List[Dict[str, Any]] = Field(default_factory=list)
[docs]
def to_v2_data(self) -> Dict[str, Any]:
"""Return the manifest as raw version 2 data ready for ``Manifest``."""
data: Dict[str, Any] = {
"manifest_version": 2,
"inputs": self.inputs,
"outputs": self.outputs,
"metadata": _convert_v1_metadata(self.metadata),
}
if self.alias is not None:
data["alias"] = self.alias
if self.responsible_name is not None:
data["responsible_name"] = self.responsible_name
return data
[docs]
class Manifest(BaseModel):
model_config = ConfigDict(extra="forbid", populate_by_name=True)
manifest_version: Literal[2] = Field(default=2)
alias: Optional[str] = None
responsible_name: Optional[str] = None
inputs_raw: List[Source] = Field(default_factory=list, alias="inputs")
outputs_raw: List[Sink] = Field(default_factory=list, alias="outputs")
metadata_raw: List[Dict[str, Any]] = Field(default_factory=list, alias="metadata")
_path: Path = PrivateAttr(default_factory=Path)
_inputs: List[Source] = PrivateAttr(default_factory=list)
_outputs: List[Sink] = PrivateAttr(default_factory=list)
_metadata: Dict[str, Any] = PrivateAttr(default_factory=dict)
[docs]
@model_validator(mode="before")
@classmethod
def check_deprecated_version(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
if "version" in data:
warnings.warn(
"The 'version' field is deprecated and will be removed "
"in a future version. Please use 'manifest_version' instead.",
DeprecationWarning,
stacklevel=3,
)
if "manifest_version" not in data:
data["manifest_version"] = data.pop("version")
# Accept legacy version 1 manifests by upgrading them to version 2.
if data.get("manifest_version") == 1:
warnings.warn(
"Manifest version 1 is deprecated and will be removed in a "
"future version. It has been converted to version 2 in memory; "
"please update the manifest to version 2.",
DeprecationWarning,
stacklevel=3,
)
data = ManifestV1.model_validate(data).to_v2_data()
return data
[docs]
@field_validator("alias")
@classmethod
def validate_alias(cls, v: Optional[str]) -> Optional[str]:
if v is not None and urllib.parse.quote(v) != v:
raise ValueError(f"illegal characters in alias: {v}")
return v
[docs]
@field_validator("inputs_raw", "outputs_raw")
@classmethod
def validate_uris(cls, v: List[DataObject], info) -> List[DataObject]:
seen_uris = set()
for item in v:
uri_str = item.name
if uri_str in seen_uris:
raise ValueError(
"Duplicate URI found in "
f"{info.field_name.replace('_raw', '')}: {uri_str}"
)
seen_uris.add(uri_str)
return v
def _resolve_manifest_items(self, items, factory_cls, skip_glob_check):
resolved = []
for item in items:
if item.type == DataType.FILE and item.uri.path:
path_obj = Path(item.uri.path)
matches = list(path_obj.parent.glob(path_obj.name))
if not matches and skip_glob_check:
matches = [path_obj]
if not matches:
raise ValueError(f"No files found matching path {path_obj}")
for p in matches:
resolved.append(
factory_cls(
uri=SimDBUrl.build(scheme="file", path=p.as_posix())
)
)
else:
resolved.append(item)
return resolved
@classmethod
def _get_loader(cls):
def ndarray_constructor(
loader: yaml.SafeLoader, node: yaml.nodes.MappingNode
) -> np.ndarray:
mapping = loader.construct_mapping(node, deep=True)
return np.array(mapping["data"], mapping.get("dtype", None))
loader = yaml.SafeLoader
loader.add_constructor("!ndarray", ndarray_constructor)
return loader
[docs]
@classmethod
def from_template(cls) -> "Manifest":
dir_path = Path(__file__).resolve().parent
with (dir_path / "template.yaml").open() as file:
try:
raw_data = yaml.load(file, Loader=cls._get_loader())
except yaml.YAMLError as err:
raise ValueError("badly formatted manifest") from err
model = cls.model_validate(raw_data, context={"skip_glob_check": True})
model._path = dir_path / "template.yaml"
return model
[docs]
@classmethod
def load_from_file(
cls, file_path: Path, overrides: Optional[dict] = None
) -> "Manifest":
with file_path.open() as file:
try:
raw_data = yaml.load(file, Loader=cls._get_loader())
except yaml.YAMLError as err:
raise ValueError("badly formatted manifest") from err
if overrides:
raw_data.update(overrides)
model = cls.model_validate(
raw_data, context={"base_path": file_path.absolute().parent}
)
model._path = file_path
return model
[docs]
def save(self, out_file: TextIO) -> None:
yaml.dump(
self.model_dump(mode="json", by_alias=True, exclude_none=True),
out_file,
default_flow_style=False,
)
@property
def version(self) -> int:
return self.manifest_version
@property
def metadata(self) -> Dict[str, Any]:
return self._metadata
@property
def inputs(self) -> Iterable[Source]:
return self._inputs
@property
def outputs(self) -> Iterable[Sink]:
return self._outputs