diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index de8c8e9a68..9b3da67261 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -48,6 +48,11 @@ ) from zarr.core.config import config from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata +from zarr.core.metadata.v3 import ( + AllowedExtraField, + check_allowed_extra_field, + parse_extra_fields, +) from zarr.core.metadata.io import save_metadata from zarr.core.sync import SyncMixin, sync from zarr.errors import ( @@ -354,6 +359,7 @@ class GroupMetadata(Metadata): zarr_format: ZarrFormat = 3 consolidated_metadata: ConsolidatedMetadata | None = None node_type: Literal["group"] = field(default="group", init=False) + extra_fields: dict[str, AllowedExtraField] = field(default_factory=dict) def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: json_indent = config.get("json_indent") @@ -408,13 +414,16 @@ def __init__( attributes: dict[str, Any] | None = None, zarr_format: ZarrFormat = 3, consolidated_metadata: ConsolidatedMetadata | None = None, + extra_fields: Mapping[str, AllowedExtraField] | None = None, ) -> None: attributes_parsed = parse_attributes(attributes) zarr_format_parsed = parse_zarr_format(zarr_format) + extra_fields_parsed = parse_extra_fields(extra_fields) object.__setattr__(self, "attributes", attributes_parsed) object.__setattr__(self, "zarr_format", zarr_format_parsed) object.__setattr__(self, "consolidated_metadata", consolidated_metadata) + object.__setattr__(self, "extra_fields", extra_fields_parsed) @classmethod def from_dict(cls, data: dict[str, Any]) -> GroupMetadata: @@ -431,11 +440,34 @@ def from_dict(cls, data: dict[str, Any]) -> GroupMetadata: # extra key in the metadata. expected = {x.name for x in fields(cls)} data = {k: v for k, v in data.items() if k in expected} + else: + # zarr v3: extra fields must have must_understand=False + expected = {x.name for x in fields(cls)} + extra_fields: dict[str, AllowedExtraField] = {} + invalid_extra_fields: dict[str, Any] = {} + for key in list(data.keys()): + if key not in expected: + val = data.pop(key) + if check_allowed_extra_field(val): + extra_fields[key] = val + else: + invalid_extra_fields[key] = val + if len(invalid_extra_fields) > 0: + msg = ( + "Got a Zarr V3 group metadata document with the following disallowed extra fields:" + f"{sorted(invalid_extra_fields.keys())}." + 'Extra fields are not allowed unless they are a dict with a "must_understand" key' + "which is assigned the value `False`." + ) + raise MetadataValidationError(msg) + data["extra_fields"] = extra_fields return cls(**data) def to_dict(self) -> dict[str, Any]: result = asdict(replace(self, consolidated_metadata=None)) + extra_fields = result.pop("extra_fields", {}) + result = result | extra_fields if self.consolidated_metadata is not None: result["consolidated_metadata"] = self.consolidated_metadata.to_dict() else: diff --git a/tests/test_group.py b/tests/test_group.py index e05df0dfcb..9def662447 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1601,6 +1601,51 @@ def test_from_dict_extra_fields(self): expected = GroupMetadata(attributes={"key": "value"}, zarr_format=2) assert result == expected + def test_from_dict_extra_fields_v3_allowed(self): + data = { + "attributes": {"key": "value"}, + "zarr_format": 3, + "node_type": "group", + "my_ext": {"must_understand": False, "data": [1, 2, 3]}, + } + result = GroupMetadata.from_dict(data) + expected = GroupMetadata( + attributes={"key": "value"}, + zarr_format=3, + extra_fields={"my_ext": {"must_understand": False, "data": [1, 2, 3]}}, + ) + assert result == expected + + def test_from_dict_extra_fields_v3_must_understand_true(self): + data = { + "attributes": {"key": "value"}, + "zarr_format": 3, + "node_type": "group", + "my_ext": {"must_understand": True}, + } + with pytest.raises(MetadataValidationError, match="disallowed extra fields"): + GroupMetadata.from_dict(data) + + def test_from_dict_extra_fields_v3_non_dict(self): + data = { + "attributes": {"key": "value"}, + "zarr_format": 3, + "node_type": "group", + "my_ext": 42, + } + with pytest.raises(MetadataValidationError, match="disallowed extra fields"): + GroupMetadata.from_dict(data) + + def test_to_dict_extra_fields_v3(self): + metadata = GroupMetadata( + attributes={"key": "value"}, + zarr_format=3, + extra_fields={"my_ext": {"must_understand": False, "data": [1, 2, 3]}}, + ) + result = metadata.to_dict() + assert result["my_ext"] == {"must_understand": False, "data": [1, 2, 3]} + assert "extra_fields" not in result + class TestInfo: def test_info(self):