Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@
)
from zarr.core.config import config
from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata
from zarr.core.metadata.v3 import (
AllowedExtraField,
check_allowed_extra_field,
parse_extra_fields,
)
from zarr.core.metadata.io import save_metadata
from zarr.core.sync import SyncMixin, sync
from zarr.errors import (
Expand Down Expand Up @@ -354,6 +359,7 @@ class GroupMetadata(Metadata):
zarr_format: ZarrFormat = 3
consolidated_metadata: ConsolidatedMetadata | None = None
node_type: Literal["group"] = field(default="group", init=False)
extra_fields: dict[str, AllowedExtraField] = field(default_factory=dict)

def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]:
json_indent = config.get("json_indent")
Expand Down Expand Up @@ -408,13 +414,16 @@ def __init__(
attributes: dict[str, Any] | None = None,
zarr_format: ZarrFormat = 3,
consolidated_metadata: ConsolidatedMetadata | None = None,
extra_fields: Mapping[str, AllowedExtraField] | None = None,
) -> None:
attributes_parsed = parse_attributes(attributes)
zarr_format_parsed = parse_zarr_format(zarr_format)
extra_fields_parsed = parse_extra_fields(extra_fields)

object.__setattr__(self, "attributes", attributes_parsed)
object.__setattr__(self, "zarr_format", zarr_format_parsed)
object.__setattr__(self, "consolidated_metadata", consolidated_metadata)
object.__setattr__(self, "extra_fields", extra_fields_parsed)

@classmethod
def from_dict(cls, data: dict[str, Any]) -> GroupMetadata:
Expand All @@ -431,11 +440,34 @@ def from_dict(cls, data: dict[str, Any]) -> GroupMetadata:
# extra key in the metadata.
expected = {x.name for x in fields(cls)}
data = {k: v for k, v in data.items() if k in expected}
else:
# zarr v3: extra fields must have must_understand=False
expected = {x.name for x in fields(cls)}
extra_fields: dict[str, AllowedExtraField] = {}
invalid_extra_fields: dict[str, Any] = {}
for key in list(data.keys()):
if key not in expected:
val = data.pop(key)
if check_allowed_extra_field(val):
extra_fields[key] = val
else:
invalid_extra_fields[key] = val
if len(invalid_extra_fields) > 0:
msg = (
"Got a Zarr V3 group metadata document with the following disallowed extra fields:"
f"{sorted(invalid_extra_fields.keys())}."
'Extra fields are not allowed unless they are a dict with a "must_understand" key'
"which is assigned the value `False`."
)
raise MetadataValidationError(msg)
data["extra_fields"] = extra_fields

return cls(**data)

def to_dict(self) -> dict[str, Any]:
result = asdict(replace(self, consolidated_metadata=None))
extra_fields = result.pop("extra_fields", {})
result = result | extra_fields
if self.consolidated_metadata is not None:
result["consolidated_metadata"] = self.consolidated_metadata.to_dict()
else:
Expand Down
45 changes: 45 additions & 0 deletions tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,6 +1601,51 @@ def test_from_dict_extra_fields(self):
expected = GroupMetadata(attributes={"key": "value"}, zarr_format=2)
assert result == expected

def test_from_dict_extra_fields_v3_allowed(self):
data = {
"attributes": {"key": "value"},
"zarr_format": 3,
"node_type": "group",
"my_ext": {"must_understand": False, "data": [1, 2, 3]},
}
result = GroupMetadata.from_dict(data)
expected = GroupMetadata(
attributes={"key": "value"},
zarr_format=3,
extra_fields={"my_ext": {"must_understand": False, "data": [1, 2, 3]}},
)
assert result == expected

def test_from_dict_extra_fields_v3_must_understand_true(self):
data = {
"attributes": {"key": "value"},
"zarr_format": 3,
"node_type": "group",
"my_ext": {"must_understand": True},
}
with pytest.raises(MetadataValidationError, match="disallowed extra fields"):
GroupMetadata.from_dict(data)

def test_from_dict_extra_fields_v3_non_dict(self):
data = {
"attributes": {"key": "value"},
"zarr_format": 3,
"node_type": "group",
"my_ext": 42,
}
with pytest.raises(MetadataValidationError, match="disallowed extra fields"):
GroupMetadata.from_dict(data)

def test_to_dict_extra_fields_v3(self):
metadata = GroupMetadata(
attributes={"key": "value"},
zarr_format=3,
extra_fields={"my_ext": {"must_understand": False, "data": [1, 2, 3]}},
)
result = metadata.to_dict()
assert result["my_ext"] == {"must_understand": False, "data": [1, 2, 3]}
assert "extra_fields" not in result


class TestInfo:
def test_info(self):
Expand Down
Loading