import re import numpy as np import pytest import xarray as xr from xarray.core.datatree_mapping import map_over_datasets from xarray.core.treenode import TreeIsomorphismError from xarray.testing import assert_equal, assert_identical empty = xr.Dataset() class TestMapOverSubTree: def test_no_trees_passed(self): with pytest.raises(TypeError, match="must pass at least one tree object"): map_over_datasets(lambda x: x, "dt") def test_not_isomorphic(self, create_test_datatree): dt1 = create_test_datatree() dt2 = create_test_datatree() dt2["set1/set2/extra"] = xr.DataTree(name="extra") with pytest.raises( TreeIsomorphismError, match=re.escape( r"children at node 'set1/set2' do not match: [] vs ['extra']" ), ): map_over_datasets(lambda x, y: None, dt1, dt2) def test_no_trees_returned(self, create_test_datatree): dt1 = create_test_datatree() dt2 = create_test_datatree() expected = xr.DataTree.from_dict({k: None for k in dt1.to_dict()}) actual = map_over_datasets(lambda x, y: None, dt1, dt2) assert_equal(expected, actual) def test_single_tree_arg(self, create_test_datatree): dt = create_test_datatree() expected = create_test_datatree(modify=lambda x: 10.0 * x) result_tree = map_over_datasets(lambda x: 10 * x, dt) assert_equal(result_tree, expected) def test_single_tree_arg_plus_arg(self, create_test_datatree): dt = create_test_datatree() expected = create_test_datatree(modify=lambda ds: (10.0 * ds)) result_tree = map_over_datasets(lambda x, y: x * y, dt, 10.0) assert_equal(result_tree, expected) result_tree = map_over_datasets(lambda x, y: x * y, 10.0, dt) assert_equal(result_tree, expected) def test_multiple_tree_args(self, create_test_datatree): dt1 = create_test_datatree() dt2 = create_test_datatree() expected = create_test_datatree(modify=lambda ds: 2.0 * ds) result = map_over_datasets(lambda x, y: x + y, dt1, dt2) assert_equal(result, expected) def test_return_multiple_trees(self, create_test_datatree): dt = create_test_datatree() dt_min, dt_max = map_over_datasets(lambda x: (x.min(), x.max()), dt) expected_min = create_test_datatree(modify=lambda ds: ds.min()) assert_equal(dt_min, expected_min) expected_max = create_test_datatree(modify=lambda ds: ds.max()) assert_equal(dt_max, expected_max) def test_return_wrong_type(self, simple_datatree): dt1 = simple_datatree with pytest.raises( TypeError, match=re.escape( "the result of calling func on the node at position is not a " "Dataset or None or a tuple of such types" ), ): map_over_datasets(lambda x: "string", dt1) # type: ignore[arg-type,return-value] def test_return_tuple_of_wrong_types(self, simple_datatree): dt1 = simple_datatree with pytest.raises( TypeError, match=re.escape( "the result of calling func on the node at position is not a " "Dataset or None or a tuple of such types" ), ): map_over_datasets(lambda x: (x, "string"), dt1) # type: ignore[arg-type,return-value] def test_return_inconsistent_number_of_results(self, simple_datatree): dt1 = simple_datatree with pytest.raises( TypeError, match=re.escape( r"Calling func on the nodes at position set1 returns a tuple " "of 0 datasets, whereas calling func on the nodes at position " ". instead returns a tuple of 2 datasets." ), ): # Datasets in simple_datatree have different numbers of dims map_over_datasets(lambda ds: tuple((None,) * len(ds.dims)), dt1) def test_wrong_number_of_arguments_for_func(self, simple_datatree): dt = simple_datatree with pytest.raises( TypeError, match="takes 1 positional argument but 2 were given" ): map_over_datasets(lambda x: 10 * x, dt, dt) def test_map_single_dataset_against_whole_tree(self, create_test_datatree): dt = create_test_datatree() def nodewise_merge(node_ds, fixed_ds): return xr.merge([node_ds, fixed_ds]) other_ds = xr.Dataset({"z": ("z", [0])}) expected = create_test_datatree(modify=lambda ds: xr.merge([ds, other_ds])) result_tree = map_over_datasets(nodewise_merge, dt, other_ds) assert_equal(result_tree, expected) @pytest.mark.xfail def test_trees_with_different_node_names(self): # TODO test this after I've got good tests for renaming nodes raise NotImplementedError def test_tree_method(self, create_test_datatree): dt = create_test_datatree() def multiply(ds, times): return times * ds expected = create_test_datatree(modify=lambda ds: 10.0 * ds) result_tree = dt.map_over_datasets(multiply, 10.0) assert_equal(result_tree, expected) def test_discard_ancestry(self, create_test_datatree): # Check for datatree GH issue https://github.com/xarray-contrib/datatree/issues/48 dt = create_test_datatree() subtree = dt["set1"] expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"] result_tree = map_over_datasets(lambda x: 10.0 * x, subtree) assert_equal(result_tree, expected) def test_keep_attrs_on_empty_nodes(self, create_test_datatree): # GH278 dt = create_test_datatree() dt["set1/set2"].attrs["foo"] = "bar" def empty_func(ds): return ds result = dt.map_over_datasets(empty_func) assert result["set1/set2"].attrs == dt["set1/set2"].attrs def test_error_contains_path_of_offending_node(self, create_test_datatree): dt = create_test_datatree() dt["set1"]["bad_var"] = 0 print(dt) def fail_on_specific_node(ds): if "bad_var" in ds: raise ValueError("Failed because 'bar_var' present in dataset") with pytest.raises( ValueError, match=re.escape( r"Raised whilst mapping function over node with path 'set1'" ), ): dt.map_over_datasets(fail_on_specific_node) def test_inherited_coordinates_with_index(self): root = xr.Dataset(coords={"x": [1, 2]}) child = xr.Dataset({"foo": ("x", [0, 1])}) # no coordinates tree = xr.DataTree.from_dict({"/": root, "/child": child}) actual = tree.map_over_datasets(lambda ds: ds) # identity assert isinstance(actual, xr.DataTree) assert_identical(tree, actual) actual_child = actual.children["child"].to_dataset(inherit=False) assert_identical(actual_child, child) class TestMutableOperations: def test_construct_using_type(self): # from datatree GH issue https://github.com/xarray-contrib/datatree/issues/188 # xarray's .weighted is unusual because it uses type() to create a Dataset/DataArray a = xr.DataArray( np.random.rand(3, 4, 10), dims=["x", "y", "time"], coords={"area": (["x", "y"], np.random.rand(3, 4))}, ).to_dataset(name="data") b = xr.DataArray( np.random.rand(2, 6, 14), dims=["x", "y", "time"], coords={"area": (["x", "y"], np.random.rand(2, 6))}, ).to_dataset(name="data") dt = xr.DataTree.from_dict({"a": a, "b": b}) def weighted_mean(ds): if "area" not in ds.coords: return None return ds.weighted(ds.area).mean(["x", "y"]) dt.map_over_datasets(weighted_mean) def test_alter_inplace_forbidden(self): simpsons = xr.DataTree.from_dict( { "/": xr.Dataset({"age": 83}), "/Herbert": xr.Dataset({"age": 40}), "/Homer": xr.Dataset({"age": 39}), "/Homer/Bart": xr.Dataset({"age": 10}), "/Homer/Lisa": xr.Dataset({"age": 8}), "/Homer/Maggie": xr.Dataset({"age": 1}), }, name="Abe", ) def fast_forward(ds: xr.Dataset, years: float) -> xr.Dataset: """Add some years to the age, but by altering the given dataset""" ds["age"] = ds["age"] + years return ds with pytest.raises(AttributeError): simpsons.map_over_datasets(fast_forward, 10)