CCR/.venv/lib/python3.12/site-packages/xarray/tests/test_backends_datatree.py

652 lines
25 KiB
Python

from __future__ import annotations
import re
from collections.abc import Hashable
from typing import TYPE_CHECKING, cast
import numpy as np
import pytest
import xarray as xr
from xarray.backends.api import open_datatree, open_groups
from xarray.core.datatree import DataTree
from xarray.testing import assert_equal, assert_identical
from xarray.tests import (
requires_dask,
requires_h5netcdf,
requires_netCDF4,
requires_zarr,
)
if TYPE_CHECKING:
from xarray.core.datatree_io import T_DataTreeNetcdfEngine
try:
import netCDF4 as nc4
except ImportError:
pass
have_zarr_v3 = xr.backends.zarr._zarr_v3()
def diff_chunks(
comparison: dict[tuple[str, Hashable], bool], tree1: DataTree, tree2: DataTree
) -> str:
mismatching_variables = [loc for loc, equals in comparison.items() if not equals]
variable_messages = [
"\n".join(
[
f"L {path}:{name}: {tree1[path].variables[name].chunksizes}",
f"R {path}:{name}: {tree2[path].variables[name].chunksizes}",
]
)
for path, name in mismatching_variables
]
return "\n".join(["Differing chunk sizes:"] + variable_messages)
def assert_chunks_equal(
actual: DataTree, expected: DataTree, enforce_dask: bool = False
) -> None:
__tracebackhide__ = True
from xarray.namedarray.pycompat import array_type
dask_array_type = array_type("dask")
comparison = {
(path, name): (
(
not enforce_dask
or isinstance(node1.variables[name].data, dask_array_type)
)
and node1.variables[name].chunksizes == node2.variables[name].chunksizes
)
for path, (node1, node2) in xr.group_subtrees(actual, expected)
for name in node1.variables.keys()
}
assert all(comparison.values()), diff_chunks(comparison, actual, expected)
@pytest.fixture(scope="module")
def unaligned_datatree_nc(tmp_path_factory):
"""Creates a test netCDF4 file with the following unaligned structure, writes it to a /tmp directory
and returns the file path of the netCDF4 file.
Group: /
│ Dimensions: (lat: 1, lon: 2)
│ Dimensions without coordinates: lat, lon
│ Data variables:
│ root_variable (lat, lon) float64 16B ...
└── Group: /Group1
│ Dimensions: (lat: 1, lon: 2)
│ Dimensions without coordinates: lat, lon
│ Data variables:
│ group_1_var (lat, lon) float64 16B ...
└── Group: /Group1/subgroup1
Dimensions: (lat: 2, lon: 2)
Dimensions without coordinates: lat, lon
Data variables:
subgroup1_var (lat, lon) float64 32B ...
"""
filepath = tmp_path_factory.mktemp("data") / "unaligned_subgroups.nc"
with nc4.Dataset(filepath, "w", format="NETCDF4") as root_group:
group_1 = root_group.createGroup("/Group1")
subgroup_1 = group_1.createGroup("/subgroup1")
root_group.createDimension("lat", 1)
root_group.createDimension("lon", 2)
root_group.createVariable("root_variable", np.float64, ("lat", "lon"))
group_1_var = group_1.createVariable("group_1_var", np.float64, ("lat", "lon"))
group_1_var[:] = np.array([[0.1, 0.2]])
group_1_var.units = "K"
group_1_var.long_name = "air_temperature"
subgroup_1.createDimension("lat", 2)
subgroup1_var = subgroup_1.createVariable(
"subgroup1_var", np.float64, ("lat", "lon")
)
subgroup1_var[:] = np.array([[0.1, 0.2]])
yield filepath
@pytest.fixture(scope="module")
def unaligned_datatree_zarr(tmp_path_factory):
"""Creates a zarr store with the following unaligned group hierarchy:
Group: /
│ Dimensions: (y: 3, x: 2)
│ Dimensions without coordinates: y, x
│ Data variables:
│ a (y) int64 24B ...
│ set0 (x) int64 16B ...
└── Group: /Group1
│ │ Dimensions: ()
│ │ Data variables:
│ │ a int64 8B ...
│ │ b int64 8B ...
│ └── /Group1/subgroup1
│ Dimensions: ()
│ Data variables:
│ a int64 8B ...
│ b int64 8B ...
└── Group: /Group2
Dimensions: (y: 2, x: 2)
Dimensions without coordinates: y, x
Data variables:
a (y) int64 16B ...
b (x) float64 16B ...
"""
filepath = tmp_path_factory.mktemp("data") / "unaligned_simple_datatree.zarr"
root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
set1_data = xr.Dataset({"a": 0, "b": 1})
set2_data = xr.Dataset({"a": ("y", [2, 3]), "b": ("x", [0.1, 0.2])})
root_data.to_zarr(filepath)
set1_data.to_zarr(filepath, group="/Group1", mode="a")
set2_data.to_zarr(filepath, group="/Group2", mode="a")
set1_data.to_zarr(filepath, group="/Group1/subgroup1", mode="a")
yield filepath
class DatatreeIOBase:
engine: T_DataTreeNetcdfEngine | None = None
def test_to_netcdf(self, tmpdir, simple_datatree):
filepath = tmpdir / "test.nc"
original_dt = simple_datatree
original_dt.to_netcdf(filepath, engine=self.engine)
with open_datatree(filepath, engine=self.engine) as roundtrip_dt:
assert roundtrip_dt._close is not None
assert_equal(original_dt, roundtrip_dt)
def test_to_netcdf_inherited_coords(self, tmpdir):
filepath = tmpdir / "test.nc"
original_dt = DataTree.from_dict(
{
"/": xr.Dataset({"a": (("x",), [1, 2])}, coords={"x": [3, 4]}),
"/sub": xr.Dataset({"b": (("x",), [5, 6])}),
}
)
original_dt.to_netcdf(filepath, engine=self.engine)
with open_datatree(filepath, engine=self.engine) as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)
subtree = cast(DataTree, roundtrip_dt["/sub"])
assert "x" not in subtree.to_dataset(inherit=False).coords
def test_netcdf_encoding(self, tmpdir, simple_datatree):
filepath = tmpdir / "test.nc"
original_dt = simple_datatree
# add compression
comp = dict(zlib=True, complevel=9)
enc = {"/set2": {var: comp for var in original_dt["/set2"].dataset.data_vars}}
original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine)
with open_datatree(filepath, engine=self.engine) as roundtrip_dt:
assert roundtrip_dt["/set2/a"].encoding["zlib"] == comp["zlib"]
assert roundtrip_dt["/set2/a"].encoding["complevel"] == comp["complevel"]
enc["/not/a/group"] = {"foo": "bar"} # type: ignore[dict-item]
with pytest.raises(ValueError, match="unexpected encoding group.*"):
original_dt.to_netcdf(filepath, encoding=enc, engine=self.engine)
def test_write_subgroup(self, tmpdir):
original_dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2, 3]}),
"/child": xr.Dataset({"foo": ("x", [4, 5, 6])}),
}
).children["child"]
expected_dt = original_dt.copy()
expected_dt.name = None
filepath = tmpdir / "test.zarr"
original_dt.to_netcdf(filepath, engine=self.engine)
with open_datatree(filepath, engine=self.engine) as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)
assert_identical(expected_dt, roundtrip_dt)
@requires_netCDF4
class TestNetCDF4DatatreeIO(DatatreeIOBase):
engine: T_DataTreeNetcdfEngine | None = "netcdf4"
def test_open_datatree(self, unaligned_datatree_nc) -> None:
"""Test if `open_datatree` fails to open a netCDF4 with an unaligned group hierarchy."""
with pytest.raises(
ValueError,
match=(
re.escape(
"group '/Group1/subgroup1' is not aligned with its parents:\nGroup:\n"
)
+ ".*"
),
):
open_datatree(unaligned_datatree_nc)
@requires_dask
def test_open_datatree_chunks(self, tmpdir, simple_datatree) -> None:
filepath = tmpdir / "test.nc"
chunks = {"x": 2, "y": 1}
root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
set1_data = xr.Dataset({"a": ("y", [-1, 0, 1]), "b": ("x", [-10, 6])})
set2_data = xr.Dataset({"a": ("y", [1, 2, 3]), "b": ("x", [0.1, 0.2])})
original_tree = DataTree.from_dict(
{
"/": root_data.chunk(chunks),
"/group1": set1_data.chunk(chunks),
"/group2": set2_data.chunk(chunks),
}
)
original_tree.to_netcdf(filepath, engine="netcdf4")
with open_datatree(filepath, engine="netcdf4", chunks=chunks) as tree:
xr.testing.assert_identical(tree, original_tree)
assert_chunks_equal(tree, original_tree, enforce_dask=True)
def test_open_groups(self, unaligned_datatree_nc) -> None:
"""Test `open_groups` with a netCDF4 file with an unaligned group hierarchy."""
unaligned_dict_of_datasets = open_groups(unaligned_datatree_nc)
# Check that group names are keys in the dictionary of `xr.Datasets`
assert "/" in unaligned_dict_of_datasets.keys()
assert "/Group1" in unaligned_dict_of_datasets.keys()
assert "/Group1/subgroup1" in unaligned_dict_of_datasets.keys()
# Check that group name returns the correct datasets
with xr.open_dataset(unaligned_datatree_nc, group="/") as expected:
assert_identical(unaligned_dict_of_datasets["/"], expected)
with xr.open_dataset(unaligned_datatree_nc, group="Group1") as expected:
assert_identical(unaligned_dict_of_datasets["/Group1"], expected)
with xr.open_dataset(
unaligned_datatree_nc, group="/Group1/subgroup1"
) as expected:
assert_identical(unaligned_dict_of_datasets["/Group1/subgroup1"], expected)
for ds in unaligned_dict_of_datasets.values():
ds.close()
@requires_dask
def test_open_groups_chunks(self, tmpdir) -> None:
"""Test `open_groups` with chunks on a netcdf4 file."""
chunks = {"x": 2, "y": 1}
filepath = tmpdir / "test.nc"
chunks = {"x": 2, "y": 1}
root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
set1_data = xr.Dataset({"a": ("y", [-1, 0, 1]), "b": ("x", [-10, 6])})
set2_data = xr.Dataset({"a": ("y", [1, 2, 3]), "b": ("x", [0.1, 0.2])})
original_tree = DataTree.from_dict(
{
"/": root_data.chunk(chunks),
"/group1": set1_data.chunk(chunks),
"/group2": set2_data.chunk(chunks),
}
)
original_tree.to_netcdf(filepath, mode="w")
dict_of_datasets = open_groups(filepath, engine="netcdf4", chunks=chunks)
for path, ds in dict_of_datasets.items():
assert {
k: max(vs) for k, vs in ds.chunksizes.items()
} == chunks, f"unexpected chunking for {path}"
for ds in dict_of_datasets.values():
ds.close()
def test_open_groups_to_dict(self, tmpdir) -> None:
"""Create an aligned netCDF4 with the following structure to test `open_groups`
and `DataTree.from_dict`.
Group: /
│ Dimensions: (lat: 1, lon: 2)
│ Dimensions without coordinates: lat, lon
│ Data variables:
│ root_variable (lat, lon) float64 16B ...
└── Group: /Group1
│ Dimensions: (lat: 1, lon: 2)
│ Dimensions without coordinates: lat, lon
│ Data variables:
│ group_1_var (lat, lon) float64 16B ...
└── Group: /Group1/subgroup1
Dimensions: (lat: 1, lon: 2)
Dimensions without coordinates: lat, lon
Data variables:
subgroup1_var (lat, lon) float64 16B ...
"""
filepath = tmpdir + "/all_aligned_child_nodes.nc"
with nc4.Dataset(filepath, "w", format="NETCDF4") as root_group:
group_1 = root_group.createGroup("/Group1")
subgroup_1 = group_1.createGroup("/subgroup1")
root_group.createDimension("lat", 1)
root_group.createDimension("lon", 2)
root_group.createVariable("root_variable", np.float64, ("lat", "lon"))
group_1_var = group_1.createVariable(
"group_1_var", np.float64, ("lat", "lon")
)
group_1_var[:] = np.array([[0.1, 0.2]])
group_1_var.units = "K"
group_1_var.long_name = "air_temperature"
subgroup1_var = subgroup_1.createVariable(
"subgroup1_var", np.float64, ("lat", "lon")
)
subgroup1_var[:] = np.array([[0.1, 0.2]])
aligned_dict_of_datasets = open_groups(filepath)
aligned_dt = DataTree.from_dict(aligned_dict_of_datasets)
with open_datatree(filepath) as opened_tree:
assert opened_tree.identical(aligned_dt)
for ds in aligned_dict_of_datasets.values():
ds.close()
def test_open_datatree_specific_group(self, tmpdir, simple_datatree) -> None:
"""Test opening a specific group within a NetCDF file using `open_datatree`."""
filepath = tmpdir / "test.nc"
group = "/set1"
original_dt = simple_datatree
original_dt.to_netcdf(filepath)
expected_subtree = original_dt[group].copy()
expected_subtree.orphan()
with open_datatree(filepath, group=group, engine=self.engine) as subgroup_tree:
assert subgroup_tree.root.parent is None
assert_equal(subgroup_tree, expected_subtree)
@requires_h5netcdf
class TestH5NetCDFDatatreeIO(DatatreeIOBase):
engine: T_DataTreeNetcdfEngine | None = "h5netcdf"
@pytest.mark.skipif(
have_zarr_v3, reason="datatree support for zarr 3 is not implemented yet"
)
@requires_zarr
class TestZarrDatatreeIO:
engine = "zarr"
def test_to_zarr(self, tmpdir, simple_datatree):
filepath = tmpdir / "test.zarr"
original_dt = simple_datatree
original_dt.to_zarr(filepath)
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)
def test_zarr_encoding(self, tmpdir, simple_datatree):
from numcodecs.blosc import Blosc
filepath = tmpdir / "test.zarr"
original_dt = simple_datatree
comp = {"compressor": Blosc(cname="zstd", clevel=3, shuffle=2)}
enc = {"/set2": {var: comp for var in original_dt["/set2"].dataset.data_vars}}
original_dt.to_zarr(filepath, encoding=enc)
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
print(roundtrip_dt["/set2/a"].encoding)
assert roundtrip_dt["/set2/a"].encoding["compressor"] == comp["compressor"]
enc["/not/a/group"] = {"foo": "bar"} # type: ignore[dict-item]
with pytest.raises(ValueError, match="unexpected encoding group.*"):
original_dt.to_zarr(filepath, encoding=enc, engine="zarr")
def test_to_zarr_zip_store(self, tmpdir, simple_datatree):
from zarr.storage import ZipStore
filepath = tmpdir / "test.zarr.zip"
original_dt = simple_datatree
store = ZipStore(filepath)
original_dt.to_zarr(store)
with open_datatree(store, engine="zarr") as roundtrip_dt: # type: ignore[arg-type, unused-ignore]
assert_equal(original_dt, roundtrip_dt)
def test_to_zarr_not_consolidated(self, tmpdir, simple_datatree):
filepath = tmpdir / "test.zarr"
zmetadata = filepath / ".zmetadata"
s1zmetadata = filepath / "set1" / ".zmetadata"
filepath = str(filepath) # casting to str avoids a pathlib bug in xarray
original_dt = simple_datatree
original_dt.to_zarr(filepath, consolidated=False)
assert not zmetadata.exists()
assert not s1zmetadata.exists()
with pytest.warns(RuntimeWarning, match="consolidated"):
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)
def test_to_zarr_default_write_mode(self, tmpdir, simple_datatree):
import zarr
simple_datatree.to_zarr(tmpdir)
# with default settings, to_zarr should not overwrite an existing dir
with pytest.raises(zarr.errors.ContainsGroupError):
simple_datatree.to_zarr(tmpdir)
@requires_dask
def test_to_zarr_compute_false(self, tmpdir, simple_datatree):
import dask.array as da
filepath = tmpdir / "test.zarr"
original_dt = simple_datatree.chunk()
original_dt.to_zarr(filepath, compute=False)
for node in original_dt.subtree:
for name, variable in node.dataset.variables.items():
var_dir = filepath / node.path / name
var_files = var_dir.listdir()
assert var_dir / ".zarray" in var_files
assert var_dir / ".zattrs" in var_files
if isinstance(variable.data, da.Array):
assert var_dir / "0" not in var_files
else:
assert var_dir / "0" in var_files
def test_to_zarr_inherited_coords(self, tmpdir):
original_dt = DataTree.from_dict(
{
"/": xr.Dataset({"a": (("x",), [1, 2])}, coords={"x": [3, 4]}),
"/sub": xr.Dataset({"b": (("x",), [5, 6])}),
}
)
filepath = tmpdir / "test.zarr"
original_dt.to_zarr(filepath)
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)
subtree = cast(DataTree, roundtrip_dt["/sub"])
assert "x" not in subtree.to_dataset(inherit=False).coords
def test_open_groups_round_trip(self, tmpdir, simple_datatree) -> None:
"""Test `open_groups` opens a zarr store with the `simple_datatree` structure."""
filepath = tmpdir / "test.zarr"
original_dt = simple_datatree
original_dt.to_zarr(filepath)
roundtrip_dict = open_groups(filepath, engine="zarr")
roundtrip_dt = DataTree.from_dict(roundtrip_dict)
with open_datatree(filepath, engine="zarr") as opened_tree:
assert opened_tree.identical(roundtrip_dt)
for ds in roundtrip_dict.values():
ds.close()
def test_open_datatree(self, unaligned_datatree_zarr) -> None:
"""Test if `open_datatree` fails to open a zarr store with an unaligned group hierarchy."""
with pytest.raises(
ValueError,
match=(
re.escape("group '/Group2' is not aligned with its parents:") + ".*"
),
):
open_datatree(unaligned_datatree_zarr, engine="zarr")
@requires_dask
def test_open_datatree_chunks(self, tmpdir, simple_datatree) -> None:
filepath = tmpdir / "test.zarr"
chunks = {"x": 2, "y": 1}
root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
set1_data = xr.Dataset({"a": ("y", [-1, 0, 1]), "b": ("x", [-10, 6])})
set2_data = xr.Dataset({"a": ("y", [1, 2, 3]), "b": ("x", [0.1, 0.2])})
original_tree = DataTree.from_dict(
{
"/": root_data.chunk(chunks),
"/group1": set1_data.chunk(chunks),
"/group2": set2_data.chunk(chunks),
}
)
original_tree.to_zarr(filepath)
with open_datatree(filepath, engine="zarr", chunks=chunks) as tree:
xr.testing.assert_identical(tree, original_tree)
assert_chunks_equal(tree, original_tree, enforce_dask=True)
def test_open_groups(self, unaligned_datatree_zarr) -> None:
"""Test `open_groups` with a zarr store of an unaligned group hierarchy."""
unaligned_dict_of_datasets = open_groups(unaligned_datatree_zarr, engine="zarr")
assert "/" in unaligned_dict_of_datasets.keys()
assert "/Group1" in unaligned_dict_of_datasets.keys()
assert "/Group1/subgroup1" in unaligned_dict_of_datasets.keys()
assert "/Group2" in unaligned_dict_of_datasets.keys()
# Check that group name returns the correct datasets
with xr.open_dataset(
unaligned_datatree_zarr, group="/", engine="zarr"
) as expected:
assert_identical(unaligned_dict_of_datasets["/"], expected)
with xr.open_dataset(
unaligned_datatree_zarr, group="Group1", engine="zarr"
) as expected:
assert_identical(unaligned_dict_of_datasets["/Group1"], expected)
with xr.open_dataset(
unaligned_datatree_zarr, group="/Group1/subgroup1", engine="zarr"
) as expected:
assert_identical(unaligned_dict_of_datasets["/Group1/subgroup1"], expected)
with xr.open_dataset(
unaligned_datatree_zarr, group="/Group2", engine="zarr"
) as expected:
assert_identical(unaligned_dict_of_datasets["/Group2"], expected)
for ds in unaligned_dict_of_datasets.values():
ds.close()
def test_open_datatree_specific_group(self, tmpdir, simple_datatree) -> None:
"""Test opening a specific group within a Zarr store using `open_datatree`."""
filepath = tmpdir / "test.zarr"
group = "/set2"
original_dt = simple_datatree
original_dt.to_zarr(filepath)
expected_subtree = original_dt[group].copy()
expected_subtree.orphan()
with open_datatree(filepath, group=group, engine=self.engine) as subgroup_tree:
assert subgroup_tree.root.parent is None
assert_equal(subgroup_tree, expected_subtree)
@requires_dask
def test_open_groups_chunks(self, tmpdir) -> None:
"""Test `open_groups` with chunks on a zarr store."""
chunks = {"x": 2, "y": 1}
filepath = tmpdir / "test.zarr"
chunks = {"x": 2, "y": 1}
root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
set1_data = xr.Dataset({"a": ("y", [-1, 0, 1]), "b": ("x", [-10, 6])})
set2_data = xr.Dataset({"a": ("y", [1, 2, 3]), "b": ("x", [0.1, 0.2])})
original_tree = DataTree.from_dict(
{
"/": root_data.chunk(chunks),
"/group1": set1_data.chunk(chunks),
"/group2": set2_data.chunk(chunks),
}
)
original_tree.to_zarr(filepath, mode="w")
dict_of_datasets = open_groups(filepath, engine="zarr", chunks=chunks)
for path, ds in dict_of_datasets.items():
assert {
k: max(vs) for k, vs in ds.chunksizes.items()
} == chunks, f"unexpected chunking for {path}"
for ds in dict_of_datasets.values():
ds.close()
def test_write_subgroup(self, tmpdir):
original_dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2, 3]}),
"/child": xr.Dataset({"foo": ("x", [4, 5, 6])}),
}
).children["child"]
expected_dt = original_dt.copy()
expected_dt.name = None
filepath = tmpdir / "test.zarr"
original_dt.to_zarr(filepath)
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert_equal(original_dt, roundtrip_dt)
assert_identical(expected_dt, roundtrip_dt)
def test_write_inherited_coords_false(self, tmpdir):
original_dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2, 3]}),
"/child": xr.Dataset({"foo": ("x", [4, 5, 6])}),
}
)
filepath = tmpdir / "test.zarr"
original_dt.to_zarr(filepath, write_inherited_coords=False)
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert_identical(original_dt, roundtrip_dt)
expected_child = original_dt.children["child"].copy(inherit=False)
expected_child.name = None
with open_datatree(filepath, group="child", engine="zarr") as roundtrip_child:
assert_identical(expected_child, roundtrip_child)
def test_write_inherited_coords_true(self, tmpdir):
original_dt = DataTree.from_dict(
{
"/": xr.Dataset(coords={"x": [1, 2, 3]}),
"/child": xr.Dataset({"foo": ("x", [4, 5, 6])}),
}
)
filepath = tmpdir / "test.zarr"
original_dt.to_zarr(filepath, write_inherited_coords=True)
with open_datatree(filepath, engine="zarr") as roundtrip_dt:
assert_identical(original_dt, roundtrip_dt)
expected_child = original_dt.children["child"].copy(inherit=True)
expected_child.name = None
with open_datatree(filepath, group="child", engine="zarr") as roundtrip_child:
assert_identical(expected_child, roundtrip_child)