import re import sys import typing from collections.abc import Mapping from copy import copy, deepcopy from textwrap import dedent import numpy as np import pytest import xarray as xr from xarray import DataArray, Dataset from xarray.core.coordinates import DataTreeCoordinates from xarray.core.datatree import DataTree from xarray.core.treenode import NotFoundInTreeError from xarray.testing import assert_equal, assert_identical from xarray.tests import ( assert_array_equal, create_test_data, requires_dask, source_ndarray, ) ON_WINDOWS = sys.platform == "win32" class TestTreeCreation: def test_empty(self) -> None: dt = DataTree(name="root") assert dt.name == "root" assert dt.parent is None assert dt.children == {} assert_identical(dt.to_dataset(), xr.Dataset()) def test_name(self) -> None: dt = DataTree() assert dt.name is None dt = DataTree(name="foo") assert dt.name == "foo" dt.name = "bar" assert dt.name == "bar" dt = DataTree(children={"foo": DataTree()}) assert dt["/foo"].name == "foo" with pytest.raises( ValueError, match="cannot set the name of a node which already has a parent" ): dt["/foo"].name = "bar" detached = dt["/foo"].copy() assert detached.name == "foo" detached.name = "bar" assert detached.name == "bar" def test_bad_names(self) -> None: with pytest.raises(TypeError): DataTree(name=5) # type: ignore[arg-type] with pytest.raises(ValueError): DataTree(name="folder/data") def test_data_arg(self) -> None: ds = xr.Dataset({"foo": 42}) tree: DataTree = DataTree(dataset=ds) assert_identical(tree.to_dataset(), ds) with pytest.raises(TypeError): DataTree(dataset=xr.DataArray(42, name="foo")) # type: ignore[arg-type] def test_child_data_not_copied(self) -> None: # regression test for https://github.com/pydata/xarray/issues/9683 class NoDeepCopy: def __deepcopy__(self, memo): raise TypeError("class can't be deepcopied") da = xr.DataArray(NoDeepCopy()) ds = xr.Dataset({"var": da}) dt1 = xr.DataTree(ds) dt2 = xr.DataTree(ds, children={"child": dt1}) dt3 = xr.DataTree.from_dict({"/": ds, "child": ds}) assert_identical(dt2, dt3) class TestFamilyTree: def test_dont_modify_children_inplace(self) -> None: # GH issue 9196 child = DataTree() DataTree(children={"child": child}) assert child.parent is None def test_create_two_children(self) -> None: root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}) set1_data = xr.Dataset({"a": 0, "b": 1}) root = DataTree.from_dict( {"/": root_data, "/set1": set1_data, "/set1/set2": None} ) assert root["/set1"].name == "set1" assert root["/set1/set2"].name == "set2" def test_create_full_tree(self, simple_datatree) -> None: d = simple_datatree.to_dict() d_keys = list(d.keys()) expected_keys = [ "/", "/set1", "/set2", "/set3", "/set1/set1", "/set1/set2", "/set2/set1", ] assert d_keys == expected_keys class TestNames: def test_child_gets_named_on_attach(self) -> None: sue = DataTree() mary = DataTree(children={"Sue": sue}) assert mary.children["Sue"].name == "Sue" def test_dataset_containing_slashes(self) -> None: xda: xr.DataArray = xr.DataArray( [[1, 2]], coords={"label": ["a"], "R30m/y": [30, 60]}, ) xds: xr.Dataset = xr.Dataset({"group/subgroup/my_variable": xda}) with pytest.raises( ValueError, match=re.escape( "Given variables have names containing the '/' character: " "['R30m/y', 'group/subgroup/my_variable']. " "Variables stored in DataTree objects cannot have names containing '/' characters, " "as this would make path-like access to variables ambiguous." ), ): DataTree(xds) class TestPaths: def test_path_property(self) -> None: john = DataTree.from_dict( { "/Mary/Sue": DataTree(), } ) assert john["/Mary/Sue"].path == "/Mary/Sue" assert john.path == "/" def test_path_roundtrip(self) -> None: john = DataTree.from_dict( { "/Mary/Sue": DataTree(), } ) assert john["/Mary/Sue"].name == "Sue" def test_same_tree(self) -> None: john = DataTree.from_dict( { "/Mary": DataTree(), "/Kate": DataTree(), } ) assert john["/Mary"].same_tree(john["/Kate"]) def test_relative_paths(self) -> None: john = DataTree.from_dict( { "/Mary/Sue": DataTree(), "/Annie": DataTree(), } ) sue_result = john["Mary/Sue"] if isinstance(sue_result, DataTree): sue: DataTree = sue_result annie_result = john["Annie"] if isinstance(annie_result, DataTree): annie: DataTree = annie_result assert sue.relative_to(john) == "Mary/Sue" assert john.relative_to(sue) == "../.." assert annie.relative_to(sue) == "../../Annie" assert sue.relative_to(annie) == "../Mary/Sue" assert sue.relative_to(sue) == "." evil_kate = DataTree() with pytest.raises( NotFoundInTreeError, match="nodes do not lie within the same tree" ): sue.relative_to(evil_kate) class TestStoreDatasets: def test_create_with_data(self) -> None: dat = xr.Dataset({"a": 0}) john = DataTree(name="john", dataset=dat) assert_identical(john.to_dataset(), dat) with pytest.raises(TypeError): DataTree(name="mary", dataset="junk") # type: ignore[arg-type] def test_set_data(self) -> None: john = DataTree(name="john") dat = xr.Dataset({"a": 0}) john.dataset = dat # type: ignore[assignment] assert_identical(john.to_dataset(), dat) with pytest.raises(TypeError): john.dataset = "junk" # type: ignore[assignment] def test_has_data(self) -> None: john = DataTree(name="john", dataset=xr.Dataset({"a": 0})) assert john.has_data john_no_data = DataTree(name="john", dataset=None) assert not john_no_data.has_data def test_is_hollow(self) -> None: john = DataTree(dataset=xr.Dataset({"a": 0})) assert john.is_hollow eve = DataTree(children={"john": john}) assert eve.is_hollow eve.dataset = xr.Dataset({"a": 1}) # type: ignore[assignment] assert not eve.is_hollow class TestToDataset: def test_to_dataset_inherited(self) -> None: base = xr.Dataset(coords={"a": [1], "b": 2}) sub = xr.Dataset(coords={"c": [3]}) tree = DataTree.from_dict({"/": base, "/sub": sub}) subtree = typing.cast(DataTree, tree["sub"]) assert_identical(tree.to_dataset(inherit=False), base) assert_identical(subtree.to_dataset(inherit=False), sub) sub_and_base = xr.Dataset(coords={"a": [1], "c": [3]}) # no "b" assert_identical(tree.to_dataset(inherit=True), base) assert_identical(subtree.to_dataset(inherit=True), sub_and_base) class TestVariablesChildrenNameCollisions: def test_parent_already_has_variable_with_childs_name(self) -> None: with pytest.raises(KeyError, match="already contains a variable named a"): DataTree.from_dict({"/": xr.Dataset({"a": [0], "b": 1}), "/a": None}) def test_parent_already_has_variable_with_childs_name_update(self) -> None: dt = DataTree(dataset=xr.Dataset({"a": [0], "b": 1})) with pytest.raises(ValueError, match="already contains a variable named a"): dt.update({"a": DataTree()}) def test_assign_when_already_child_with_variables_name(self) -> None: dt = DataTree.from_dict( { "/a": DataTree(), } ) with pytest.raises(ValueError, match="node already contains a variable"): dt.dataset = xr.Dataset({"a": 0}) # type: ignore[assignment] dt.dataset = xr.Dataset() # type: ignore[assignment] new_ds = dt.to_dataset().assign(a=xr.DataArray(0)) with pytest.raises(ValueError, match="node already contains a variable"): dt.dataset = new_ds # type: ignore[assignment] class TestGet: ... class TestGetItem: def test_getitem_node(self) -> None: folder1 = DataTree.from_dict( { "/results/highres": DataTree(), } ) assert folder1["results"].name == "results" assert folder1["results/highres"].name == "highres" def test_getitem_self(self) -> None: dt = DataTree() assert dt["."] is dt def test_getitem_single_data_variable(self) -> None: data = xr.Dataset({"temp": [0, 50]}) results = DataTree(name="results", dataset=data) assert_identical(results["temp"], data["temp"]) def test_getitem_single_data_variable_from_node(self) -> None: data = xr.Dataset({"temp": [0, 50]}) folder1 = DataTree.from_dict( { "/results/highres": data, } ) assert_identical(folder1["results/highres/temp"], data["temp"]) def test_getitem_nonexistent_node(self) -> None: folder1 = DataTree.from_dict({"/results": DataTree()}, name="folder1") with pytest.raises(KeyError): folder1["results/highres"] def test_getitem_nonexistent_variable(self) -> None: data = xr.Dataset({"temp": [0, 50]}) results = DataTree(name="results", dataset=data) with pytest.raises(KeyError): results["pressure"] @pytest.mark.xfail(reason="Should be deprecated in favour of .subset") def test_getitem_multiple_data_variables(self) -> None: data = xr.Dataset({"temp": [0, 50], "p": [5, 8, 7]}) results = DataTree(name="results", dataset=data) assert_identical(results[["temp", "p"]], data[["temp", "p"]]) # type: ignore[index] @pytest.mark.xfail( reason="Indexing needs to return whole tree (GH https://github.com/xarray-contrib/datatree/issues/77)" ) def test_getitem_dict_like_selection_access_to_dataset(self) -> None: data = xr.Dataset({"temp": [0, 50]}) results = DataTree(name="results", dataset=data) assert_identical(results[{"temp": 1}], data[{"temp": 1}]) # type: ignore[index] class TestUpdate: def test_update(self) -> None: dt = DataTree() dt.update({"foo": xr.DataArray(0), "a": DataTree()}) expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "a": None}) assert_equal(dt, expected) assert dt.groups == ("/", "/a") def test_update_new_named_dataarray(self) -> None: da = xr.DataArray(name="temp", data=[0, 50]) folder1 = DataTree(name="folder1") folder1.update({"results": da}) expected = da.rename("results") assert_equal(folder1["results"], expected) def test_update_doesnt_alter_child_name(self) -> None: dt = DataTree() dt.update({"foo": xr.DataArray(0), "a": DataTree(name="b")}) assert "a" in dt.children child = dt["a"] assert child.name == "a" def test_update_overwrite(self) -> None: actual = DataTree.from_dict({"a": DataTree(xr.Dataset({"x": 1}))}) actual.update({"a": DataTree(xr.Dataset({"x": 2}))}) expected = DataTree.from_dict({"a": DataTree(xr.Dataset({"x": 2}))}) assert_equal(actual, expected) def test_update_coordinates(self) -> None: expected = DataTree.from_dict({"/": xr.Dataset(coords={"a": 1})}) actual = DataTree.from_dict({"/": xr.Dataset()}) actual.update(xr.Dataset(coords={"a": 1})) assert_equal(actual, expected) def test_update_inherited_coords(self) -> None: expected = DataTree.from_dict( { "/": xr.Dataset(coords={"a": 1}), "/b": xr.Dataset(coords={"c": 1}), } ) actual = DataTree.from_dict( { "/": xr.Dataset(coords={"a": 1}), "/b": xr.Dataset(), } ) actual["/b"].update(xr.Dataset(coords={"c": 1})) assert_identical(actual, expected) # DataTree.identical() currently does not require that non-inherited # coordinates are defined identically, so we need to check this # explicitly actual_node = actual.children["b"].to_dataset(inherit=False) expected_node = expected.children["b"].to_dataset(inherit=False) assert_identical(actual_node, expected_node) class TestCopy: def test_copy(self, create_test_datatree) -> None: dt = create_test_datatree() for node in dt.root.subtree: node.attrs["Test"] = [1, 2, 3] for copied in [dt.copy(deep=False), copy(dt)]: assert_identical(dt, copied) for node, copied_node in zip( dt.root.subtree, copied.root.subtree, strict=True ): assert node.encoding == copied_node.encoding # Note: IndexVariable objects with string dtype are always # copied because of xarray.core.util.safe_cast_to_index. # Limiting the test to data variables. for k in node.data_vars: v0 = node.variables[k] v1 = copied_node.variables[k] assert source_ndarray(v0.data) is source_ndarray(v1.data) copied_node["foo"] = xr.DataArray(data=np.arange(5), dims="z") assert "foo" not in node copied_node.attrs["foo"] = "bar" assert "foo" not in node.attrs assert node.attrs["Test"] is copied_node.attrs["Test"] def test_copy_subtree(self) -> None: dt = DataTree.from_dict({"/level1/level2/level3": xr.Dataset()}) actual = dt["/level1/level2"].copy() expected = DataTree.from_dict({"/level3": xr.Dataset()}, name="level2") assert_identical(actual, expected) def test_copy_coord_inheritance(self) -> None: tree = DataTree.from_dict( {"/": xr.Dataset(coords={"x": [0, 1]}), "/c": DataTree()} ) actual = tree.copy() node_ds = actual.children["c"].to_dataset(inherit=False) assert_identical(node_ds, xr.Dataset()) actual = tree.children["c"].copy() expected = DataTree(Dataset(coords={"x": [0, 1]}), name="c") assert_identical(expected, actual) actual = tree.children["c"].copy(inherit=False) expected = DataTree(name="c") assert_identical(expected, actual) def test_deepcopy(self, create_test_datatree) -> None: dt = create_test_datatree() for node in dt.root.subtree: node.attrs["Test"] = [1, 2, 3] for copied in [dt.copy(deep=True), deepcopy(dt)]: assert_identical(dt, copied) for node, copied_node in zip( dt.root.subtree, copied.root.subtree, strict=True ): assert node.encoding == copied_node.encoding # Note: IndexVariable objects with string dtype are always # copied because of xarray.core.util.safe_cast_to_index. # Limiting the test to data variables. for k in node.data_vars: v0 = node.variables[k] v1 = copied_node.variables[k] assert source_ndarray(v0.data) is not source_ndarray(v1.data) copied_node["foo"] = xr.DataArray(data=np.arange(5), dims="z") assert "foo" not in node copied_node.attrs["foo"] = "bar" assert "foo" not in node.attrs assert node.attrs["Test"] is not copied_node.attrs["Test"] @pytest.mark.xfail(reason="data argument not yet implemented") def test_copy_with_data(self, create_test_datatree) -> None: orig = create_test_datatree() # TODO use .data_vars once that property is available data_vars = { k: v for k, v in orig.variables.items() if k not in orig._coord_names } new_data = {k: np.random.randn(*v.shape) for k, v in data_vars.items()} actual = orig.copy(data=new_data) expected = orig.copy() for k, v in new_data.items(): expected[k].data = v assert_identical(expected, actual) # TODO test parents and children? class TestSetItem: def test_setitem_new_child_node(self) -> None: john = DataTree(name="john") mary = DataTree(name="mary") john["mary"] = mary grafted_mary = john["mary"] assert grafted_mary.parent is john assert grafted_mary.name == "mary" def test_setitem_unnamed_child_node_becomes_named(self) -> None: john2 = DataTree(name="john2") john2["sonny"] = DataTree() assert john2["sonny"].name == "sonny" def test_setitem_new_grandchild_node(self) -> None: john = DataTree.from_dict({"/Mary/Rose": DataTree()}) new_rose = DataTree(dataset=xr.Dataset({"x": 0})) john["Mary/Rose"] = new_rose grafted_rose = john["Mary/Rose"] assert grafted_rose.parent is john["/Mary"] assert grafted_rose.name == "Rose" def test_grafted_subtree_retains_name(self) -> None: subtree = DataTree(name="original_subtree_name") root = DataTree(name="root") root["new_subtree_name"] = subtree assert subtree.name == "original_subtree_name" def test_setitem_new_empty_node(self) -> None: john = DataTree(name="john") john["mary"] = DataTree() mary = john["mary"] assert isinstance(mary, DataTree) assert_identical(mary.to_dataset(), xr.Dataset()) def test_setitem_overwrite_data_in_node_with_none(self) -> None: john = DataTree.from_dict({"/mary": xr.Dataset()}, name="john") john["mary"] = DataTree() assert_identical(john["mary"].to_dataset(), xr.Dataset()) john.dataset = xr.Dataset() # type: ignore[assignment] with pytest.raises(ValueError, match="has no name"): john["."] = DataTree() @pytest.mark.xfail(reason="assigning Datasets doesn't yet create new nodes") def test_setitem_dataset_on_this_node(self) -> None: data = xr.Dataset({"temp": [0, 50]}) results = DataTree(name="results") results["."] = data assert_identical(results.to_dataset(), data) def test_setitem_dataset_as_new_node(self) -> None: data = xr.Dataset({"temp": [0, 50]}) folder1 = DataTree(name="folder1") folder1["results"] = data assert_identical(folder1["results"].to_dataset(), data) def test_setitem_dataset_as_new_node_requiring_intermediate_nodes(self) -> None: data = xr.Dataset({"temp": [0, 50]}) folder1 = DataTree(name="folder1") folder1["results/highres"] = data assert_identical(folder1["results/highres"].to_dataset(), data) def test_setitem_named_dataarray(self) -> None: da = xr.DataArray(name="temp", data=[0, 50]) folder1 = DataTree(name="folder1") folder1["results"] = da expected = da.rename("results") assert_equal(folder1["results"], expected) def test_setitem_unnamed_dataarray(self) -> None: data = xr.DataArray([0, 50]) folder1 = DataTree(name="folder1") folder1["results"] = data assert_equal(folder1["results"], data) def test_setitem_variable(self) -> None: var = xr.Variable(data=[0, 50], dims="x") folder1 = DataTree(name="folder1") folder1["results"] = var assert_equal(folder1["results"], xr.DataArray(var)) def test_setitem_coerce_to_dataarray(self) -> None: folder1 = DataTree(name="folder1") folder1["results"] = 0 assert_equal(folder1["results"], xr.DataArray(0)) def test_setitem_add_new_variable_to_empty_node(self) -> None: results = DataTree(name="results") results["pressure"] = xr.DataArray(data=[2, 3]) assert "pressure" in results.dataset results["temp"] = xr.Variable(data=[10, 11], dims=["x"]) assert "temp" in results.dataset # What if there is a path to traverse first? results_with_path = DataTree(name="results") results_with_path["highres/pressure"] = xr.DataArray(data=[2, 3]) assert "pressure" in results_with_path["highres"].dataset results_with_path["highres/temp"] = xr.Variable(data=[10, 11], dims=["x"]) assert "temp" in results_with_path["highres"].dataset def test_setitem_dataarray_replace_existing_node(self) -> None: t = xr.Dataset({"temp": [0, 50]}) results = DataTree(name="results", dataset=t) p = xr.DataArray(data=[2, 3]) results["pressure"] = p expected = t.assign(pressure=p) assert_identical(results.to_dataset(), expected) class TestCoords: def test_properties(self) -> None: # use int64 for repr consistency on windows ds = Dataset( data_vars={ "foo": (["x", "y"], np.random.randn(2, 3)), }, coords={ "x": ("x", np.array([-1, -2], "int64")), "y": ("y", np.array([0, 1, 2], "int64")), "a": ("x", np.array([4, 5], "int64")), "b": np.int64(-10), }, ) dt = DataTree(dataset=ds) dt["child"] = DataTree() coords = dt.coords assert isinstance(coords, DataTreeCoordinates) # len assert len(coords) == 4 # iter assert list(coords) == ["x", "y", "a", "b"] assert_identical(coords["x"].variable, dt["x"].variable) assert_identical(coords["y"].variable, dt["y"].variable) assert "x" in coords assert "a" in coords assert 0 not in coords assert "foo" not in coords assert "child" not in coords with pytest.raises(KeyError): coords["foo"] # TODO this currently raises a ValueError instead of a KeyError # with pytest.raises(KeyError): # coords[0] # repr expected = dedent( """\ Coordinates: * x (x) int64 16B -1 -2 * y (y) int64 24B 0 1 2 a (x) int64 16B 4 5 b int64 8B -10""" ) actual = repr(coords) assert expected == actual # dims assert coords.sizes == {"x": 2, "y": 3} # dtypes assert coords.dtypes == { "x": np.dtype("int64"), "y": np.dtype("int64"), "a": np.dtype("int64"), "b": np.dtype("int64"), } def test_modify(self) -> None: ds = Dataset( data_vars={ "foo": (["x", "y"], np.random.randn(2, 3)), }, coords={ "x": ("x", np.array([-1, -2], "int64")), "y": ("y", np.array([0, 1, 2], "int64")), "a": ("x", np.array([4, 5], "int64")), "b": np.int64(-10), }, ) dt = DataTree(dataset=ds) dt["child"] = DataTree() actual = dt.copy(deep=True) actual.coords["x"] = ("x", ["a", "b"]) assert_array_equal(actual["x"], ["a", "b"]) actual = dt.copy(deep=True) actual.coords["z"] = ("z", ["a", "b"]) assert_array_equal(actual["z"], ["a", "b"]) actual = dt.copy(deep=True) with pytest.raises(ValueError, match=r"conflicting dimension sizes"): actual.coords["x"] = ("x", [-1]) assert_identical(actual, dt) # should not be modified # TODO: re-enable after implementing reset_coords() # actual = dt.copy() # del actual.coords["b"] # expected = dt.reset_coords("b", drop=True) # assert_identical(expected, actual) with pytest.raises(KeyError): del dt.coords["not_found"] with pytest.raises(KeyError): del dt.coords["foo"] # TODO: re-enable after implementing assign_coords() # actual = dt.copy(deep=True) # actual.coords.update({"c": 11}) # expected = dt.assign_coords({"c": 11}) # assert_identical(expected, actual) # # regression test for GH3746 # del actual.coords["x"] # assert "x" not in actual.xindexes # test that constructors can also handle the `DataTreeCoordinates` object ds2 = Dataset(coords=dt.coords) assert_identical(ds2.coords, dt.coords) da = DataArray(coords=dt.coords) assert_identical(da.coords, dt.coords) # DataTree constructor doesn't accept coords= but should still be able to handle DatasetCoordinates dt2 = DataTree(dataset=dt.coords) assert_identical(dt2.coords, dt.coords) def test_inherited(self) -> None: ds = Dataset( data_vars={ "foo": (["x", "y"], np.random.randn(2, 3)), }, coords={ "x": ("x", np.array([-1, -2], "int64")), "y": ("y", np.array([0, 1, 2], "int64")), "a": ("x", np.array([4, 5], "int64")), "b": np.int64(-10), }, ) dt = DataTree(dataset=ds) dt["child"] = DataTree() child = dt["child"] assert set(dt.coords) == {"x", "y", "a", "b"} assert set(child.coords) == {"x", "y"} actual = child.copy(deep=True) actual.coords["x"] = ("x", ["a", "b"]) assert_array_equal(actual["x"], ["a", "b"]) actual = child.copy(deep=True) actual.coords.update({"c": 11}) expected = child.copy(deep=True) expected.coords["c"] = 11 # check we have only altered the child node assert_identical(expected.root, actual.root) with pytest.raises(KeyError): # cannot delete inherited coordinate from child node del child["x"] # TODO requires a fix for #9472 # actual = child.copy(deep=True) # actual.coords.update({"c": 11}) # expected = child.assign_coords({"c": 11}) # assert_identical(expected, actual) def test_delitem() -> None: ds = Dataset({"a": 0}, coords={"x": ("x", [1, 2]), "z": "a"}) dt = DataTree(ds, children={"c": DataTree()}) with pytest.raises(KeyError): del dt["foo"] # test delete children del dt["c"] assert dt.children == {} assert set(dt.variables) == {"x", "z", "a"} with pytest.raises(KeyError): del dt["c"] # test delete variables del dt["a"] assert set(dt.coords) == {"x", "z"} with pytest.raises(KeyError): del dt["a"] # test delete coordinates del dt["z"] assert set(dt.coords) == {"x"} with pytest.raises(KeyError): del dt["z"] # test delete indexed coordinates del dt["x"] assert dt.variables == {} assert dt.coords == {} assert dt.indexes == {} with pytest.raises(KeyError): del dt["x"] class TestTreeFromDict: def test_data_in_root(self) -> None: dat = xr.Dataset() dt = DataTree.from_dict({"/": dat}) assert dt.name is None assert dt.parent is None assert dt.children == {} assert_identical(dt.to_dataset(), dat) def test_one_layer(self) -> None: dat1, dat2 = xr.Dataset({"a": 1}), xr.Dataset({"b": 2}) dt = DataTree.from_dict({"run1": dat1, "run2": dat2}) assert_identical(dt.to_dataset(), xr.Dataset()) assert dt.name is None assert_identical(dt["run1"].to_dataset(), dat1) assert dt["run1"].children == {} assert_identical(dt["run2"].to_dataset(), dat2) assert dt["run2"].children == {} def test_two_layers(self) -> None: dat1, dat2 = xr.Dataset({"a": 1}), xr.Dataset({"a": [1, 2]}) dt = DataTree.from_dict({"highres/run": dat1, "lowres/run": dat2}) assert "highres" in dt.children assert "lowres" in dt.children highres_run = dt["highres/run"] assert_identical(highres_run.to_dataset(), dat1) def test_nones(self) -> None: dt = DataTree.from_dict({"d": None, "d/e": None}) assert [node.name for node in dt.subtree] == [None, "d", "e"] assert [node.path for node in dt.subtree] == ["/", "/d", "/d/e"] assert_identical(dt["d/e"].to_dataset(), xr.Dataset()) def test_full(self, simple_datatree) -> None: dt = simple_datatree paths = list(node.path for node in dt.subtree) assert paths == [ "/", "/set1", "/set2", "/set3", "/set1/set1", "/set1/set2", "/set2/set1", ] def test_datatree_values(self) -> None: dat1 = DataTree(dataset=xr.Dataset({"a": 1})) expected = DataTree() expected["a"] = dat1 actual = DataTree.from_dict({"a": dat1}) assert_identical(actual, expected) def test_roundtrip_to_dict(self, simple_datatree) -> None: tree = simple_datatree roundtrip = DataTree.from_dict(tree.to_dict()) assert_identical(tree, roundtrip) def test_to_dict(self): tree = DataTree.from_dict({"/a/b/c": None}) roundtrip = DataTree.from_dict(tree.to_dict()) assert_identical(tree, roundtrip) roundtrip = DataTree.from_dict(tree.to_dict(relative=True)) assert_identical(tree, roundtrip) roundtrip = DataTree.from_dict(tree.children["a"].to_dict(relative=False)) assert_identical(tree, roundtrip) expected = DataTree.from_dict({"b/c": None}) actual = DataTree.from_dict(tree.children["a"].to_dict(relative=True)) assert_identical(expected, actual) def test_roundtrip_unnamed_root(self, simple_datatree) -> None: # See GH81 dt = simple_datatree dt.name = "root" roundtrip = DataTree.from_dict(dt.to_dict()) assert roundtrip.equals(dt) def test_insertion_order(self) -> None: # regression test for GH issue #9276 reversed = DataTree.from_dict( { "/Homer/Lisa": xr.Dataset({"age": 8}), "/Homer/Bart": xr.Dataset({"age": 10}), "/Homer": xr.Dataset({"age": 39}), "/": xr.Dataset({"age": 83}), } ) expected = DataTree.from_dict( { "/": xr.Dataset({"age": 83}), "/Homer": xr.Dataset({"age": 39}), "/Homer/Lisa": xr.Dataset({"age": 8}), "/Homer/Bart": xr.Dataset({"age": 10}), } ) assert reversed.equals(expected) # Check that Bart and Lisa's order is still preserved within the group, # despite 'Bart' coming before 'Lisa' when sorted alphabetically assert list(reversed["Homer"].children.keys()) == ["Lisa", "Bart"] def test_array_values(self) -> None: data = {"foo": xr.DataArray(1, name="bar")} with pytest.raises(TypeError): DataTree.from_dict(data) # type: ignore[arg-type] def test_relative_paths(self) -> None: tree = DataTree.from_dict({".": None, "foo": None, "./bar": None, "x/y": None}) paths = [node.path for node in tree.subtree] assert paths == [ "/", "/foo", "/bar", "/x", "/x/y", ] def test_root_keys(self): ds = Dataset({"x": 1}) expected = DataTree(dataset=ds) actual = DataTree.from_dict({"": ds}) assert_identical(actual, expected) actual = DataTree.from_dict({".": ds}) assert_identical(actual, expected) actual = DataTree.from_dict({"/": ds}) assert_identical(actual, expected) actual = DataTree.from_dict({"./": ds}) assert_identical(actual, expected) with pytest.raises( ValueError, match="multiple entries found corresponding to the root node" ): DataTree.from_dict({"": ds, "/": ds}) def test_name(self): tree = DataTree.from_dict({"/": None}, name="foo") assert tree.name == "foo" tree = DataTree.from_dict({"/": DataTree()}, name="foo") assert tree.name == "foo" tree = DataTree.from_dict({"/": DataTree(name="bar")}, name="foo") assert tree.name == "foo" class TestDatasetView: def test_view_contents(self) -> None: ds = create_test_data() dt = DataTree(dataset=ds) assert ds.identical( dt.dataset ) # this only works because Dataset.identical doesn't check types assert isinstance(dt.dataset, xr.Dataset) def test_immutability(self) -> None: # See issue https://github.com/xarray-contrib/datatree/issues/38 dt = DataTree.from_dict( { "/": None, "/a": None, }, name="root", ) with pytest.raises( AttributeError, match="Mutation of the DatasetView is not allowed" ): dt.dataset["a"] = xr.DataArray(0) with pytest.raises( AttributeError, match="Mutation of the DatasetView is not allowed" ): dt.dataset.update({"a": 0}) # TODO are there any other ways you can normally modify state (in-place)? # (not attribute-like assignment because that doesn't work on Dataset anyway) def test_methods(self) -> None: ds = create_test_data() dt = DataTree(dataset=ds) assert ds.mean().identical(dt.dataset.mean()) assert isinstance(dt.dataset.mean(), xr.Dataset) def test_arithmetic(self, create_test_datatree) -> None: dt = create_test_datatree() expected = create_test_datatree(modify=lambda ds: 10.0 * ds)[ "set1" ].to_dataset() result = 10.0 * dt["set1"].dataset assert result.identical(expected) def test_init_via_type(self) -> None: # from datatree GH issue https://github.com/xarray-contrib/datatree/issues/188 # xarray's .weighted is unusual because it uses type() to create a Dataset/DataArray a = xr.DataArray( np.random.rand(3, 4, 10), dims=["x", "y", "time"], coords={"area": (["x", "y"], np.random.rand(3, 4))}, ).to_dataset(name="data") dt = DataTree(dataset=a) def weighted_mean(ds): return ds.weighted(ds.area).mean(["x", "y"]) weighted_mean(dt.dataset) class TestAccess: def test_attribute_access(self, create_test_datatree) -> None: dt = create_test_datatree() # vars / coords for key in ["a", "set0"]: assert_equal(dt[key], getattr(dt, key)) assert key in dir(dt) # dims assert_equal(dt["a"]["y"], dt.a.y) assert "y" in dir(dt["a"]) # children for key in ["set1", "set2", "set3"]: assert_equal(dt[key], getattr(dt, key)) assert key in dir(dt) # attrs dt.attrs["meta"] = "NASA" assert dt.attrs["meta"] == "NASA" assert "meta" in dir(dt) def test_ipython_key_completions_complex(self, create_test_datatree) -> None: dt = create_test_datatree() key_completions = dt._ipython_key_completions_() node_keys = [node.path[1:] for node in dt.descendants] assert all(node_key in key_completions for node_key in node_keys) var_keys = list(dt.variables.keys()) assert all(var_key in key_completions for var_key in var_keys) def test_ipython_key_completitions_subnode(self) -> None: tree = xr.DataTree.from_dict({"/": None, "/a": None, "/a/b/": None}) expected = ["b"] actual = tree["a"]._ipython_key_completions_() assert expected == actual def test_operation_with_attrs_but_no_data(self) -> None: # tests bug from xarray-datatree GH262 xs = xr.Dataset({"testvar": xr.DataArray(np.ones((2, 3)))}) dt = DataTree.from_dict({"node1": xs, "node2": xs}) dt.attrs["test_key"] = 1 # sel works fine without this line dt.sel(dim_0=0) class TestRepr: def test_repr_four_nodes(self) -> None: dt = DataTree.from_dict( { "/": xr.Dataset( {"e": (("x",), [1.0, 2.0])}, coords={"x": [2.0, 3.0]}, ), "/b": xr.Dataset({"f": (("y",), [3.0])}), "/b/c": xr.Dataset(), "/b/d": xr.Dataset({"g": 4.0}), } ) result = repr(dt) expected = dedent( """ Group: / │ Dimensions: (x: 2) │ Coordinates: │ * x (x) float64 16B 2.0 3.0 │ Data variables: │ e (x) float64 16B 1.0 2.0 └── Group: /b │ Dimensions: (y: 1) │ Dimensions without coordinates: y │ Data variables: │ f (y) float64 8B 3.0 ├── Group: /b/c └── Group: /b/d Dimensions: () Data variables: g float64 8B 4.0 """ ).strip() assert result == expected result = repr(dt.b) expected = dedent( """ Group: /b │ Dimensions: (x: 2, y: 1) │ Inherited coordinates: │ * x (x) float64 16B 2.0 3.0 │ Dimensions without coordinates: y │ Data variables: │ f (y) float64 8B 3.0 ├── Group: /b/c └── Group: /b/d Dimensions: () Data variables: g float64 8B 4.0 """ ).strip() assert result == expected result = repr(dt.b.d) expected = dedent( """ Group: /b/d Dimensions: (x: 2, y: 1) Inherited coordinates: * x (x) float64 16B 2.0 3.0 Dimensions without coordinates: y Data variables: g float64 8B 4.0 """ ).strip() assert result == expected def test_repr_two_children(self) -> None: tree = DataTree.from_dict( { "/": Dataset(coords={"x": [1.0]}), "/first_child": None, "/second_child": Dataset({"foo": ("x", [0.0])}, coords={"z": 1.0}), } ) result = repr(tree) expected = dedent( """ Group: / │ Dimensions: (x: 1) │ Coordinates: │ * x (x) float64 8B 1.0 ├── Group: /first_child └── Group: /second_child Dimensions: (x: 1) Coordinates: z float64 8B 1.0 Data variables: foo (x) float64 8B 0.0 """ ).strip() assert result == expected result = repr(tree["first_child"]) expected = dedent( """ Group: /first_child Dimensions: (x: 1) Inherited coordinates: * x (x) float64 8B 1.0 """ ).strip() assert result == expected result = repr(tree["second_child"]) expected = dedent( """ Group: /second_child Dimensions: (x: 1) Coordinates: z float64 8B 1.0 Inherited coordinates: * x (x) float64 8B 1.0 Data variables: foo (x) float64 8B 0.0 """ ).strip() assert result == expected def test_repr_inherited_dims(self) -> None: tree = DataTree.from_dict( { "/": Dataset({"foo": ("x", [1.0])}), "/child": Dataset({"bar": ("y", [2.0])}), } ) result = repr(tree) expected = dedent( """ Group: / │ Dimensions: (x: 1) │ Dimensions without coordinates: x │ Data variables: │ foo (x) float64 8B 1.0 └── Group: /child Dimensions: (y: 1) Dimensions without coordinates: y Data variables: bar (y) float64 8B 2.0 """ ).strip() assert result == expected result = repr(tree["child"]) expected = dedent( """ Group: /child Dimensions: (x: 1, y: 1) Dimensions without coordinates: x, y Data variables: bar (y) float64 8B 2.0 """ ).strip() assert result == expected @pytest.mark.skipif( ON_WINDOWS, reason="windows (pre NumPy2) uses int32 instead of int64" ) def test_doc_example(self) -> None: # regression test for https://github.com/pydata/xarray/issues/9499 time = xr.DataArray( data=np.array(["2022-01", "2023-01"], dtype=" Group: / │ Dimensions: (time: 2) │ Coordinates: │ * time (time) Group: /weather │ Dimensions: (time: 2, station: 6) │ Coordinates: │ * station (station) str: return re.escape(dedent(message).strip()) return "^" + re.escape(dedent(message.rstrip())) + "$" class TestInheritance: def test_inherited_dims(self) -> None: dt = DataTree.from_dict( { "/": xr.Dataset({"d": (("x",), [1, 2])}), "/b": xr.Dataset({"e": (("y",), [3])}), "/c": xr.Dataset({"f": (("y",), [3, 4, 5])}), } ) assert dt.sizes == {"x": 2} # nodes should include inherited dimensions assert dt.b.sizes == {"x": 2, "y": 1} assert dt.c.sizes == {"x": 2, "y": 3} # dataset objects created from nodes should not assert dt.b.dataset.sizes == {"y": 1} assert dt.b.to_dataset(inherit=True).sizes == {"y": 1} assert dt.b.to_dataset(inherit=False).sizes == {"y": 1} def test_inherited_coords_index(self) -> None: dt = DataTree.from_dict( { "/": xr.Dataset({"d": (("x",), [1, 2])}, coords={"x": [2, 3]}), "/b": xr.Dataset({"e": (("y",), [3])}), } ) assert "x" in dt["/b"].indexes assert "x" in dt["/b"].coords xr.testing.assert_identical(dt["/x"], dt["/b/x"]) def test_inherit_only_index_coords(self) -> None: dt = DataTree.from_dict( { "/": xr.Dataset(coords={"x": [1], "y": 2}), "/b": xr.Dataset(coords={"z": 3}), } ) assert dt.coords.keys() == {"x", "y"} xr.testing.assert_equal( dt["/x"], xr.DataArray([1], dims=["x"], coords={"x": [1], "y": 2}) ) xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords={"y": 2})) assert dt["/b"].coords.keys() == {"x", "z"} xr.testing.assert_equal( dt["/b/x"], xr.DataArray([1], dims=["x"], coords={"x": [1], "z": 3}) ) xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords={"z": 3})) def test_inherited_coords_with_index_are_deduplicated(self) -> None: dt = DataTree.from_dict( { "/": xr.Dataset(coords={"x": [1, 2]}), "/b": xr.Dataset(coords={"x": [1, 2]}), } ) child_dataset = dt.children["b"].to_dataset(inherit=False) expected = xr.Dataset() assert_identical(child_dataset, expected) dt["/c"] = xr.Dataset({"foo": ("x", [4, 5])}, coords={"x": [1, 2]}) child_dataset = dt.children["c"].to_dataset(inherit=False) expected = xr.Dataset({"foo": ("x", [4, 5])}) assert_identical(child_dataset, expected) def test_deduplicated_after_setitem(self) -> None: # regression test for GH #9601 dt = DataTree.from_dict( { "/": xr.Dataset(coords={"x": [1, 2]}), "/b": None, } ) dt["b/x"] = dt["x"] child_dataset = dt.children["b"].to_dataset(inherit=False) expected = xr.Dataset() assert_identical(child_dataset, expected) def test_inconsistent_dims(self) -> None: expected_msg = _exact_match( """ group '/b' is not aligned with its parents: Group: Dimensions: (x: 1) Dimensions without coordinates: x Data variables: c (x) float64 8B 3.0 From parents: Dimensions: (x: 2) Dimensions without coordinates: x """ ) with pytest.raises(ValueError, match=expected_msg): DataTree.from_dict( { "/": xr.Dataset({"a": (("x",), [1.0, 2.0])}), "/b": xr.Dataset({"c": (("x",), [3.0])}), } ) dt = DataTree() dt["/a"] = xr.DataArray([1.0, 2.0], dims=["x"]) with pytest.raises(ValueError, match=expected_msg): dt["/b/c"] = xr.DataArray([3.0], dims=["x"]) b = DataTree(dataset=xr.Dataset({"c": (("x",), [3.0])})) with pytest.raises(ValueError, match=expected_msg): DataTree( dataset=xr.Dataset({"a": (("x",), [1.0, 2.0])}), children={"b": b}, ) def test_inconsistent_child_indexes(self) -> None: expected_msg = _exact_match( """ group '/b' is not aligned with its parents: Group: Dimensions: (x: 1) Coordinates: * x (x) float64 8B 2.0 Data variables: *empty* From parents: Dimensions: (x: 1) Coordinates: * x (x) float64 8B 1.0 """ ) with pytest.raises(ValueError, match=expected_msg): DataTree.from_dict( { "/": xr.Dataset(coords={"x": [1.0]}), "/b": xr.Dataset(coords={"x": [2.0]}), } ) dt = DataTree() dt.dataset = xr.Dataset(coords={"x": [1.0]}) # type: ignore[assignment] dt["/b"] = DataTree() with pytest.raises(ValueError, match=expected_msg): dt["/b"].dataset = xr.Dataset(coords={"x": [2.0]}) b = DataTree(xr.Dataset(coords={"x": [2.0]})) with pytest.raises(ValueError, match=expected_msg): DataTree(dataset=xr.Dataset(coords={"x": [1.0]}), children={"b": b}) def test_inconsistent_grandchild_indexes(self) -> None: expected_msg = _exact_match( """ group '/b/c' is not aligned with its parents: Group: Dimensions: (x: 1) Coordinates: * x (x) float64 8B 2.0 Data variables: *empty* From parents: Dimensions: (x: 1) Coordinates: * x (x) float64 8B 1.0 """ ) with pytest.raises(ValueError, match=expected_msg): DataTree.from_dict( { "/": xr.Dataset(coords={"x": [1.0]}), "/b/c": xr.Dataset(coords={"x": [2.0]}), } ) dt = DataTree() dt.dataset = xr.Dataset(coords={"x": [1.0]}) # type: ignore[assignment] dt["/b/c"] = DataTree() with pytest.raises(ValueError, match=expected_msg): dt["/b/c"].dataset = xr.Dataset(coords={"x": [2.0]}) c = DataTree(xr.Dataset(coords={"x": [2.0]})) b = DataTree(children={"c": c}) with pytest.raises(ValueError, match=expected_msg): DataTree(dataset=xr.Dataset(coords={"x": [1.0]}), children={"b": b}) def test_inconsistent_grandchild_dims(self) -> None: expected_msg = _exact_match( """ group '/b/c' is not aligned with its parents: Group: Dimensions: (x: 1) Dimensions without coordinates: x Data variables: d (x) float64 8B 3.0 From parents: Dimensions: (x: 2) Dimensions without coordinates: x """ ) with pytest.raises(ValueError, match=expected_msg): DataTree.from_dict( { "/": xr.Dataset({"a": (("x",), [1.0, 2.0])}), "/b/c": xr.Dataset({"d": (("x",), [3.0])}), } ) dt = DataTree() dt["/a"] = xr.DataArray([1.0, 2.0], dims=["x"]) with pytest.raises(ValueError, match=expected_msg): dt["/b/c/d"] = xr.DataArray([3.0], dims=["x"]) class TestRestructuring: def test_drop_nodes(self) -> None: sue = DataTree.from_dict({"Mary": None, "Kate": None, "Ashley": None}) # test drop just one node dropped_one = sue.drop_nodes(names="Mary") assert "Mary" not in dropped_one.children # test drop multiple nodes dropped = sue.drop_nodes(names=["Mary", "Kate"]) assert not set(["Mary", "Kate"]).intersection(set(dropped.children)) assert "Ashley" in dropped.children # test raise with pytest.raises(KeyError, match="nodes {'Mary'} not present"): dropped.drop_nodes(names=["Mary", "Ashley"]) # test ignore childless = dropped.drop_nodes(names=["Mary", "Ashley"], errors="ignore") assert childless.children == {} def test_assign(self) -> None: dt = DataTree() expected = DataTree.from_dict({"/": xr.Dataset({"foo": 0}), "/a": None}) # kwargs form result = dt.assign(foo=xr.DataArray(0), a=DataTree()) assert_equal(result, expected) # dict form result = dt.assign({"foo": xr.DataArray(0), "a": DataTree()}) assert_equal(result, expected) class TestPipe: def test_noop(self, create_test_datatree) -> None: dt = create_test_datatree() actual = dt.pipe(lambda tree: tree) assert actual.identical(dt) def test_params(self, create_test_datatree) -> None: dt = create_test_datatree() def f(tree, **attrs): return tree.assign(arr_with_attrs=xr.Variable("dim0", [], attrs=attrs)) attrs = {"x": 1, "y": 2, "z": 3} actual = dt.pipe(f, **attrs) assert actual["arr_with_attrs"].attrs == attrs def test_named_self(self, create_test_datatree) -> None: dt = create_test_datatree() def f(x, tree, y): tree.attrs.update({"x": x, "y": y}) return tree attrs = {"x": 1, "y": 2} actual = dt.pipe((f, "tree"), **attrs) assert actual is dt and actual.attrs == attrs class TestIsomorphicEqualsAndIdentical: def test_isomorphic(self): tree = DataTree.from_dict({"/a": None, "/a/b": None, "/c": None}) diff_data = DataTree.from_dict( {"/a": None, "/a/b": None, "/c": xr.Dataset({"foo": 1})} ) assert tree.isomorphic(diff_data) diff_order = DataTree.from_dict({"/c": None, "/a": None, "/a/b": None}) assert tree.isomorphic(diff_order) diff_nodes = DataTree.from_dict({"/a": None, "/a/b": None, "/d": None}) assert not tree.isomorphic(diff_nodes) more_nodes = DataTree.from_dict( {"/a": None, "/a/b": None, "/c": None, "/d": None} ) assert not tree.isomorphic(more_nodes) def test_minimal_variations(self): tree = DataTree.from_dict( { "/": Dataset({"x": 1}), "/child": Dataset({"x": 2}), } ) assert tree.equals(tree) assert tree.identical(tree) child = tree.children["child"] assert child.equals(child) assert child.identical(child) new_child = DataTree(dataset=Dataset({"x": 2}), name="child") assert child.equals(new_child) assert child.identical(new_child) anonymous_child = DataTree(dataset=Dataset({"x": 2})) # TODO: re-enable this after fixing .equals() not to require matching # names on the root node (i.e., after switching to use zip_subtrees) # assert child.equals(anonymous_child) assert not child.identical(anonymous_child) different_variables = DataTree.from_dict( { "/": Dataset(), "/other": Dataset({"x": 2}), } ) assert not tree.equals(different_variables) assert not tree.identical(different_variables) different_root_data = DataTree.from_dict( { "/": Dataset({"x": 4}), "/child": Dataset({"x": 2}), } ) assert not tree.equals(different_root_data) assert not tree.identical(different_root_data) different_child_data = DataTree.from_dict( { "/": Dataset({"x": 1}), "/child": Dataset({"x": 3}), } ) assert not tree.equals(different_child_data) assert not tree.identical(different_child_data) different_child_node_attrs = DataTree.from_dict( { "/": Dataset({"x": 1}), "/child": Dataset({"x": 2}, attrs={"foo": "bar"}), } ) assert tree.equals(different_child_node_attrs) assert not tree.identical(different_child_node_attrs) different_child_variable_attrs = DataTree.from_dict( { "/": Dataset({"x": 1}), "/child": Dataset({"x": ((), 2, {"foo": "bar"})}), } ) assert tree.equals(different_child_variable_attrs) assert not tree.identical(different_child_variable_attrs) different_name = DataTree.from_dict( { "/": Dataset({"x": 1}), "/child": Dataset({"x": 2}), }, name="different", ) # TODO: re-enable this after fixing .equals() not to require matching # names on the root node (i.e., after switching to use zip_subtrees) # assert tree.equals(different_name) assert not tree.identical(different_name) def test_differently_inherited_coordinates(self): root = DataTree.from_dict( { "/": Dataset(coords={"x": [1, 2]}), "/child": Dataset(), } ) child = root.children["child"] assert child.equals(child) assert child.identical(child) new_child = DataTree(dataset=Dataset(coords={"x": [1, 2]}), name="child") assert child.equals(new_child) assert not child.identical(new_child) deeper_root = DataTree(children={"root": root}) grandchild = deeper_root["/root/child"] assert child.equals(grandchild) assert child.identical(grandchild) class TestSubset: def test_match(self) -> None: # TODO is this example going to cause problems with case sensitivity? dt = DataTree.from_dict( { "/a/A": None, "/a/B": None, "/b/A": None, "/b/B": None, } ) result = dt.match("*/B") expected = DataTree.from_dict( { "/a/B": None, "/b/B": None, } ) assert_identical(result, expected) result = dt.children["a"].match("B") expected = DataTree.from_dict({"/B": None}, name="a") assert_identical(result, expected) def test_filter(self) -> None: simpsons = DataTree.from_dict( { "/": xr.Dataset({"age": 83}), "/Herbert": xr.Dataset({"age": 40}), "/Homer": xr.Dataset({"age": 39}), "/Homer/Bart": xr.Dataset({"age": 10}), "/Homer/Lisa": xr.Dataset({"age": 8}), "/Homer/Maggie": xr.Dataset({"age": 1}), }, name="Abe", ) expected = DataTree.from_dict( { "/": xr.Dataset({"age": 83}), "/Herbert": xr.Dataset({"age": 40}), "/Homer": xr.Dataset({"age": 39}), }, name="Abe", ) elders = simpsons.filter(lambda node: node["age"].item() > 18) assert_identical(elders, expected) expected = DataTree.from_dict({"/Bart": xr.Dataset({"age": 10})}, name="Homer") actual = simpsons.children["Homer"].filter( lambda node: node["age"].item() == 10 ) assert_identical(actual, expected) class TestIndexing: def test_isel_siblings(self) -> None: tree = DataTree.from_dict( { "/first": xr.Dataset({"a": ("x", [1, 2])}), "/second": xr.Dataset({"b": ("x", [1, 2, 3])}), } ) expected = DataTree.from_dict( { "/first": xr.Dataset({"a": 2}), "/second": xr.Dataset({"b": 3}), } ) actual = tree.isel(x=-1) assert_identical(actual, expected) expected = DataTree.from_dict( { "/first": xr.Dataset({"a": ("x", [1])}), "/second": xr.Dataset({"b": ("x", [1])}), } ) actual = tree.isel(x=slice(1)) assert_identical(actual, expected) actual = tree.isel(x=[0]) assert_identical(actual, expected) actual = tree.isel(x=slice(None)) assert_identical(actual, tree) def test_isel_inherited(self) -> None: tree = DataTree.from_dict( { "/": xr.Dataset(coords={"x": [1, 2]}), "/child": xr.Dataset({"foo": ("x", [3, 4])}), } ) expected = DataTree.from_dict( { "/": xr.Dataset(coords={"x": 2}), "/child": xr.Dataset({"foo": 4}), } ) actual = tree.isel(x=-1) assert_identical(actual, expected) expected = DataTree.from_dict( { "/child": xr.Dataset({"foo": 4}), } ) actual = tree.isel(x=-1, drop=True) assert_identical(actual, expected) expected = DataTree.from_dict( { "/": xr.Dataset(coords={"x": [1]}), "/child": xr.Dataset({"foo": ("x", [3])}), } ) actual = tree.isel(x=[0]) assert_identical(actual, expected) actual = tree.isel(x=slice(None)) # TODO: re-enable after the fix to copy() from #9628 is submitted # actual = tree.children["child"].isel(x=slice(None)) # expected = tree.children["child"].copy() # assert_identical(actual, expected) actual = tree.children["child"].isel(x=0) expected = DataTree( dataset=xr.Dataset({"foo": 3}, coords={"x": 1}), name="child", ) assert_identical(actual, expected) def test_sel(self) -> None: tree = DataTree.from_dict( { "/first": xr.Dataset({"a": ("x", [1, 2, 3])}, coords={"x": [1, 2, 3]}), "/second": xr.Dataset({"b": ("x", [4, 5])}, coords={"x": [2, 3]}), } ) expected = DataTree.from_dict( { "/first": xr.Dataset({"a": 2}, coords={"x": 2}), "/second": xr.Dataset({"b": 4}, coords={"x": 2}), } ) actual = tree.sel(x=2) assert_identical(actual, expected) actual = tree.children["first"].sel(x=2) expected = DataTree( dataset=xr.Dataset({"a": 2}, coords={"x": 2}), name="first", ) assert_identical(actual, expected) class TestAggregations: def test_reduce_method(self) -> None: ds = xr.Dataset({"a": ("x", [False, True, False])}) dt = DataTree.from_dict({"/": ds, "/results": ds}) expected = DataTree.from_dict({"/": ds.any(), "/results": ds.any()}) result = dt.any() assert_equal(result, expected) def test_nan_reduce_method(self) -> None: ds = xr.Dataset({"a": ("x", [1, 2, 3])}) dt = DataTree.from_dict({"/": ds, "/results": ds}) expected = DataTree.from_dict({"/": ds.mean(), "/results": ds.mean()}) result = dt.mean() assert_equal(result, expected) def test_cum_method(self) -> None: ds = xr.Dataset({"a": ("x", [1, 2, 3])}) dt = DataTree.from_dict({"/": ds, "/results": ds}) expected = DataTree.from_dict( { "/": ds.cumsum(), "/results": ds.cumsum(), } ) result = dt.cumsum() assert_equal(result, expected) def test_dim_argument(self) -> None: dt = DataTree.from_dict( { "/a": xr.Dataset({"A": ("x", [1, 2])}), "/b": xr.Dataset({"B": ("y", [1, 2])}), } ) expected = DataTree.from_dict( { "/a": xr.Dataset({"A": 1.5}), "/b": xr.Dataset({"B": 1.5}), } ) actual = dt.mean() assert_equal(expected, actual) actual = dt.mean(dim=...) assert_equal(expected, actual) expected = DataTree.from_dict( { "/a": xr.Dataset({"A": 1.5}), "/b": xr.Dataset({"B": ("y", [1.0, 2.0])}), } ) actual = dt.mean("x") assert_equal(expected, actual) with pytest.raises( ValueError, match=re.escape("Dimension(s) 'invalid' do not exist."), ): dt.mean("invalid") def test_subtree(self) -> None: tree = DataTree.from_dict( { "/child": Dataset({"a": ("x", [1, 2])}), } ) expected = DataTree(dataset=Dataset({"a": 1.5}), name="child") actual = tree.children["child"].mean() assert_identical(expected, actual) class TestOps: def test_unary_op(self) -> None: ds1 = xr.Dataset({"a": [5], "b": [3]}) ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) dt = DataTree.from_dict({"/": ds1, "/subnode": ds2}) expected = DataTree.from_dict({"/": (-ds1), "/subnode": (-ds2)}) result = -dt assert_equal(result, expected) def test_unary_op_inherited_coords(self) -> None: tree = DataTree(xr.Dataset(coords={"x": [1, 2, 3]})) tree["/foo"] = DataTree(xr.Dataset({"bar": ("x", [4, 5, 6])})) actual = -tree actual_dataset = actual.children["foo"].to_dataset(inherit=False) assert "x" not in actual_dataset.coords expected = tree.copy() # unary ops are not applied to coordinate variables, only data variables expected["/foo/bar"].data = np.array([-4, -5, -6]) assert_identical(actual, expected) def test_binary_op_on_int(self) -> None: ds1 = xr.Dataset({"a": [5], "b": [3]}) ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) dt = DataTree.from_dict({"/": ds1, "/subnode": ds2}) expected = DataTree.from_dict({"/": ds1 * 5, "/subnode": ds2 * 5}) result = dt * 5 assert_equal(result, expected) def test_binary_op_on_dataarray(self) -> None: ds1 = xr.Dataset({"a": [5], "b": [3]}) ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) dt = DataTree.from_dict( { "/": ds1, "/subnode": ds2, } ) other_da = xr.DataArray(name="z", data=[0.1, 0.2], dims="z") expected = DataTree.from_dict( { "/": ds1 * other_da, "/subnode": ds2 * other_da, } ) result = dt * other_da assert_equal(result, expected) def test_binary_op_on_dataset(self) -> None: ds1 = xr.Dataset({"a": [5], "b": [3]}) ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) dt = DataTree.from_dict( { "/": ds1, "/subnode": ds2, } ) other_ds = xr.Dataset({"z": ("z", [0.1, 0.2])}) expected = DataTree.from_dict( { "/": ds1 * other_ds, "/subnode": ds2 * other_ds, } ) result = dt * other_ds assert_equal(result, expected) def test_binary_op_on_datatree(self) -> None: ds1 = xr.Dataset({"a": [5], "b": [3]}) ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) dt = DataTree.from_dict({"/": ds1, "/subnode": ds2}) expected = DataTree.from_dict({"/": ds1 * ds1, "/subnode": ds2 * ds2}) result = dt * dt assert_equal(result, expected) def test_binary_op_order_invariant(self) -> None: tree_ab = DataTree.from_dict({"/a": Dataset({"a": 1}), "/b": Dataset({"b": 2})}) tree_ba = DataTree.from_dict({"/b": Dataset({"b": 2}), "/a": Dataset({"a": 1})}) expected = DataTree.from_dict( {"/a": Dataset({"a": 2}), "/b": Dataset({"b": 4})} ) actual = tree_ab + tree_ba assert_identical(expected, actual) def test_arithmetic_inherited_coords(self) -> None: tree = DataTree(xr.Dataset(coords={"x": [1, 2, 3]})) tree["/foo"] = DataTree(xr.Dataset({"bar": ("x", [4, 5, 6])})) actual = 2 * tree actual_dataset = actual.children["foo"].to_dataset(inherit=False) assert "x" not in actual_dataset.coords expected = tree.copy() expected["/foo/bar"].data = np.array([8, 10, 12]) assert_identical(actual, expected) def test_binary_op_commutativity_with_dataset(self) -> None: # regression test for #9365 ds1 = xr.Dataset({"a": [5], "b": [3]}) ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) dt = DataTree.from_dict( { "/": ds1, "/subnode": ds2, } ) other_ds = xr.Dataset({"z": ("z", [0.1, 0.2])}) expected = DataTree.from_dict( { "/": ds1 * other_ds, "/subnode": ds2 * other_ds, } ) result = other_ds * dt assert_equal(result, expected) def test_inplace_binary_op(self) -> None: ds1 = xr.Dataset({"a": [5], "b": [3]}) ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) dt = DataTree.from_dict({"/": ds1, "/subnode": ds2}) expected = DataTree.from_dict({"/": ds1 + 1, "/subnode": ds2 + 1}) dt += 1 assert_equal(dt, expected) def test_dont_broadcast_single_node_tree(self) -> None: # regression test for https://github.com/pydata/xarray/issues/9365#issuecomment-2291622577 ds1 = xr.Dataset({"a": [5], "b": [3]}) ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) dt = DataTree.from_dict({"/": ds1, "/subnode": ds2}) node = dt["/subnode"] with pytest.raises( xr.TreeIsomorphismError, match=re.escape(r"children at root node do not match: ['subnode'] vs []"), ): dt * node class TestUFuncs: @pytest.mark.xfail(reason="__array_ufunc__ not implemented yet") def test_tree(self, create_test_datatree): dt = create_test_datatree() expected = create_test_datatree(modify=lambda ds: np.sin(ds)) result_tree = np.sin(dt) assert_equal(result_tree, expected) class Closer: def __init__(self): self.closed = False def close(self): if self.closed: raise RuntimeError("already closed") self.closed = True @pytest.fixture() def tree_and_closers(): tree = DataTree.from_dict({"/child/grandchild": None}) closers = { "/": Closer(), "/child": Closer(), "/child/grandchild": Closer(), } for path, closer in closers.items(): tree[path].set_close(closer.close) return tree, closers class TestClose: def test_close(self, tree_and_closers): tree, closers = tree_and_closers assert not any(closer.closed for closer in closers.values()) tree.close() assert all(closer.closed for closer in closers.values()) tree.close() # should not error def test_context_manager(self, tree_and_closers): tree, closers = tree_and_closers assert not any(closer.closed for closer in closers.values()) with tree: pass assert all(closer.closed for closer in closers.values()) def test_close_child(self, tree_and_closers): tree, closers = tree_and_closers assert not any(closer.closed for closer in closers.values()) tree["child"].close() # should only close descendants assert not closers["/"].closed assert closers["/child"].closed assert closers["/child/grandchild"].closed def test_close_datasetview(self, tree_and_closers): tree, _ = tree_and_closers with pytest.raises( AttributeError, match=re.escape( r"cannot close a DatasetView(). Close the associated DataTree node instead" ), ): tree.dataset.close() with pytest.raises( AttributeError, match=re.escape(r"cannot modify a DatasetView()") ): tree.dataset.set_close(None) def test_close_dataset(self, tree_and_closers): tree, closers = tree_and_closers ds = tree.to_dataset() # should discard closers ds.close() assert not closers["/"].closed # with tree: # pass @requires_dask class TestDask: def test_chunksizes(self): ds1 = xr.Dataset({"a": ("x", np.arange(10))}) ds2 = xr.Dataset({"b": ("y", np.arange(5))}) ds3 = xr.Dataset({"c": ("z", np.arange(4))}) ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))}) groups = { "/": ds1.chunk({"x": 5}), "/group1": ds2.chunk({"y": 3}), "/group2": ds3.chunk({"z": 2}), "/group1/subgroup1": ds4.chunk({"x": 5}), } tree = xr.DataTree.from_dict(groups) expected_chunksizes = {path: node.chunksizes for path, node in groups.items()} assert tree.chunksizes == expected_chunksizes def test_load(self): ds1 = xr.Dataset({"a": ("x", np.arange(10))}) ds2 = xr.Dataset({"b": ("y", np.arange(5))}) ds3 = xr.Dataset({"c": ("z", np.arange(4))}) ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))}) groups = {"/": ds1, "/group1": ds2, "/group2": ds3, "/group1/subgroup1": ds4} expected = xr.DataTree.from_dict(groups) tree = xr.DataTree.from_dict( { "/": ds1.chunk({"x": 5}), "/group1": ds2.chunk({"y": 3}), "/group2": ds3.chunk({"z": 2}), "/group1/subgroup1": ds4.chunk({"x": 5}), } ) expected_chunksizes: Mapping[str, Mapping] expected_chunksizes = {node.path: {} for node in tree.subtree} actual = tree.load() assert_identical(actual, expected) assert tree.chunksizes == expected_chunksizes assert actual.chunksizes == expected_chunksizes tree = xr.DataTree.from_dict(groups) actual = tree.load() assert_identical(actual, expected) assert actual.chunksizes == expected_chunksizes def test_compute(self): ds1 = xr.Dataset({"a": ("x", np.arange(10))}) ds2 = xr.Dataset({"b": ("y", np.arange(5))}) ds3 = xr.Dataset({"c": ("z", np.arange(4))}) ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))}) expected = xr.DataTree.from_dict( {"/": ds1, "/group1": ds2, "/group2": ds3, "/group1/subgroup1": ds4} ) tree = xr.DataTree.from_dict( { "/": ds1.chunk({"x": 5}), "/group1": ds2.chunk({"y": 3}), "/group2": ds3.chunk({"z": 2}), "/group1/subgroup1": ds4.chunk({"x": 5}), } ) original_chunksizes = tree.chunksizes expected_chunksizes: Mapping[str, Mapping] expected_chunksizes = {node.path: {} for node in tree.subtree} actual = tree.compute() assert_identical(actual, expected) assert actual.chunksizes == expected_chunksizes, "mismatching chunksizes" assert tree.chunksizes == original_chunksizes, "original tree was modified" def test_persist(self): ds1 = xr.Dataset({"a": ("x", np.arange(10))}) ds2 = xr.Dataset({"b": ("y", np.arange(5))}) ds3 = xr.Dataset({"c": ("z", np.arange(4))}) ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))}) def fn(x): return 2 * x expected = xr.DataTree.from_dict( { "/": fn(ds1).chunk({"x": 5}), "/group1": fn(ds2).chunk({"y": 3}), "/group2": fn(ds3).chunk({"z": 2}), "/group1/subgroup1": fn(ds4).chunk({"x": 5}), } ) # Add trivial second layer to the task graph, persist should reduce to one tree = xr.DataTree.from_dict( { "/": fn(ds1.chunk({"x": 5})), "/group1": fn(ds2.chunk({"y": 3})), "/group2": fn(ds3.chunk({"z": 2})), "/group1/subgroup1": fn(ds4.chunk({"x": 5})), } ) original_chunksizes = tree.chunksizes original_hlg_depths = { node.path: len(node.dataset.__dask_graph__().layers) for node in tree.subtree } actual = tree.persist() actual_hlg_depths = { node.path: len(node.dataset.__dask_graph__().layers) for node in actual.subtree } assert_identical(actual, expected) assert actual.chunksizes == original_chunksizes, "chunksizes were modified" assert ( tree.chunksizes == original_chunksizes ), "original chunksizes were modified" assert all( d == 1 for d in actual_hlg_depths.values() ), "unexpected dask graph depth" assert all( d == 2 for d in original_hlg_depths.values() ), "original dask graph was modified" def test_chunk(self): ds1 = xr.Dataset({"a": ("x", np.arange(10))}) ds2 = xr.Dataset({"b": ("y", np.arange(5))}) ds3 = xr.Dataset({"c": ("z", np.arange(4))}) ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))}) expected = xr.DataTree.from_dict( { "/": ds1.chunk({"x": 5}), "/group1": ds2.chunk({"y": 3}), "/group2": ds3.chunk({"z": 2}), "/group1/subgroup1": ds4.chunk({"x": 5}), } ) tree = xr.DataTree.from_dict( {"/": ds1, "/group1": ds2, "/group2": ds3, "/group1/subgroup1": ds4} ) actual = tree.chunk({"x": 5, "y": 3, "z": 2}) assert_identical(actual, expected) assert actual.chunksizes == expected.chunksizes with pytest.raises(TypeError, match="invalid type"): tree.chunk(None) with pytest.raises(TypeError, match="invalid type"): tree.chunk((1, 2)) with pytest.raises(ValueError, match="not found in data dimensions"): tree.chunk({"u": 2})