CCR/.venv/lib/python3.12/site-packages/xarray/tests/test_datatree_mapping.py

233 lines
8.6 KiB
Python

import re
import numpy as np
import pytest
import xarray as xr
from xarray.core.datatree_mapping import map_over_datasets
from xarray.core.treenode import TreeIsomorphismError
from xarray.testing import assert_equal, assert_identical
empty = xr.Dataset()
class TestMapOverSubTree:
def test_no_trees_passed(self):
with pytest.raises(TypeError, match="must pass at least one tree object"):
map_over_datasets(lambda x: x, "dt")
def test_not_isomorphic(self, create_test_datatree):
dt1 = create_test_datatree()
dt2 = create_test_datatree()
dt2["set1/set2/extra"] = xr.DataTree(name="extra")
with pytest.raises(
TreeIsomorphismError,
match=re.escape(
r"children at node 'set1/set2' do not match: [] vs ['extra']"
),
):
map_over_datasets(lambda x, y: None, dt1, dt2)
def test_no_trees_returned(self, create_test_datatree):
dt1 = create_test_datatree()
dt2 = create_test_datatree()
expected = xr.DataTree.from_dict({k: None for k in dt1.to_dict()})
actual = map_over_datasets(lambda x, y: None, dt1, dt2)
assert_equal(expected, actual)
def test_single_tree_arg(self, create_test_datatree):
dt = create_test_datatree()
expected = create_test_datatree(modify=lambda x: 10.0 * x)
result_tree = map_over_datasets(lambda x: 10 * x, dt)
assert_equal(result_tree, expected)
def test_single_tree_arg_plus_arg(self, create_test_datatree):
dt = create_test_datatree()
expected = create_test_datatree(modify=lambda ds: (10.0 * ds))
result_tree = map_over_datasets(lambda x, y: x * y, dt, 10.0)
assert_equal(result_tree, expected)
result_tree = map_over_datasets(lambda x, y: x * y, 10.0, dt)
assert_equal(result_tree, expected)
def test_multiple_tree_args(self, create_test_datatree):
dt1 = create_test_datatree()
dt2 = create_test_datatree()
expected = create_test_datatree(modify=lambda ds: 2.0 * ds)
result = map_over_datasets(lambda x, y: x + y, dt1, dt2)
assert_equal(result, expected)
def test_return_multiple_trees(self, create_test_datatree):
dt = create_test_datatree()
dt_min, dt_max = map_over_datasets(lambda x: (x.min(), x.max()), dt)
expected_min = create_test_datatree(modify=lambda ds: ds.min())
assert_equal(dt_min, expected_min)
expected_max = create_test_datatree(modify=lambda ds: ds.max())
assert_equal(dt_max, expected_max)
def test_return_wrong_type(self, simple_datatree):
dt1 = simple_datatree
with pytest.raises(
TypeError,
match=re.escape(
"the result of calling func on the node at position is not a "
"Dataset or None or a tuple of such types"
),
):
map_over_datasets(lambda x: "string", dt1) # type: ignore[arg-type,return-value]
def test_return_tuple_of_wrong_types(self, simple_datatree):
dt1 = simple_datatree
with pytest.raises(
TypeError,
match=re.escape(
"the result of calling func on the node at position is not a "
"Dataset or None or a tuple of such types"
),
):
map_over_datasets(lambda x: (x, "string"), dt1) # type: ignore[arg-type,return-value]
def test_return_inconsistent_number_of_results(self, simple_datatree):
dt1 = simple_datatree
with pytest.raises(
TypeError,
match=re.escape(
r"Calling func on the nodes at position set1 returns a tuple "
"of 0 datasets, whereas calling func on the nodes at position "
". instead returns a tuple of 2 datasets."
),
):
# Datasets in simple_datatree have different numbers of dims
map_over_datasets(lambda ds: tuple((None,) * len(ds.dims)), dt1)
def test_wrong_number_of_arguments_for_func(self, simple_datatree):
dt = simple_datatree
with pytest.raises(
TypeError, match="takes 1 positional argument but 2 were given"
):
map_over_datasets(lambda x: 10 * x, dt, dt)
def test_map_single_dataset_against_whole_tree(self, create_test_datatree):
dt = create_test_datatree()
def nodewise_merge(node_ds, fixed_ds):
return xr.merge([node_ds, fixed_ds])
other_ds = xr.Dataset({"z": ("z", [0])})
expected = create_test_datatree(modify=lambda ds: xr.merge([ds, other_ds]))
result_tree = map_over_datasets(nodewise_merge, dt, other_ds)
assert_equal(result_tree, expected)
@pytest.mark.xfail
def test_trees_with_different_node_names(self):
# TODO test this after I've got good tests for renaming nodes
raise NotImplementedError
def test_tree_method(self, create_test_datatree):
dt = create_test_datatree()
def multiply(ds, times):
return times * ds
expected = create_test_datatree(modify=lambda ds: 10.0 * ds)
result_tree = dt.map_over_datasets(multiply, 10.0)
assert_equal(result_tree, expected)
def test_discard_ancestry(self, create_test_datatree):
# Check for datatree GH issue https://github.com/xarray-contrib/datatree/issues/48
dt = create_test_datatree()
subtree = dt["set1"]
expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"]
result_tree = map_over_datasets(lambda x: 10.0 * x, subtree)
assert_equal(result_tree, expected)
def test_keep_attrs_on_empty_nodes(self, create_test_datatree):
# GH278
dt = create_test_datatree()
dt["set1/set2"].attrs["foo"] = "bar"
def empty_func(ds):
return ds
result = dt.map_over_datasets(empty_func)
assert result["set1/set2"].attrs == dt["set1/set2"].attrs
def test_error_contains_path_of_offending_node(self, create_test_datatree):
dt = create_test_datatree()
dt["set1"]["bad_var"] = 0
print(dt)
def fail_on_specific_node(ds):
if "bad_var" in ds:
raise ValueError("Failed because 'bar_var' present in dataset")
with pytest.raises(
ValueError,
match=re.escape(
r"Raised whilst mapping function over node with path 'set1'"
),
):
dt.map_over_datasets(fail_on_specific_node)
def test_inherited_coordinates_with_index(self):
root = xr.Dataset(coords={"x": [1, 2]})
child = xr.Dataset({"foo": ("x", [0, 1])}) # no coordinates
tree = xr.DataTree.from_dict({"/": root, "/child": child})
actual = tree.map_over_datasets(lambda ds: ds) # identity
assert isinstance(actual, xr.DataTree)
assert_identical(tree, actual)
actual_child = actual.children["child"].to_dataset(inherit=False)
assert_identical(actual_child, child)
class TestMutableOperations:
def test_construct_using_type(self):
# from datatree GH issue https://github.com/xarray-contrib/datatree/issues/188
# xarray's .weighted is unusual because it uses type() to create a Dataset/DataArray
a = xr.DataArray(
np.random.rand(3, 4, 10),
dims=["x", "y", "time"],
coords={"area": (["x", "y"], np.random.rand(3, 4))},
).to_dataset(name="data")
b = xr.DataArray(
np.random.rand(2, 6, 14),
dims=["x", "y", "time"],
coords={"area": (["x", "y"], np.random.rand(2, 6))},
).to_dataset(name="data")
dt = xr.DataTree.from_dict({"a": a, "b": b})
def weighted_mean(ds):
if "area" not in ds.coords:
return None
return ds.weighted(ds.area).mean(["x", "y"])
dt.map_over_datasets(weighted_mean)
def test_alter_inplace_forbidden(self):
simpsons = xr.DataTree.from_dict(
{
"/": xr.Dataset({"age": 83}),
"/Herbert": xr.Dataset({"age": 40}),
"/Homer": xr.Dataset({"age": 39}),
"/Homer/Bart": xr.Dataset({"age": 10}),
"/Homer/Lisa": xr.Dataset({"age": 8}),
"/Homer/Maggie": xr.Dataset({"age": 1}),
},
name="Abe",
)
def fast_forward(ds: xr.Dataset, years: float) -> xr.Dataset:
"""Add some years to the age, but by altering the given dataset"""
ds["age"] = ds["age"] + years
return ds
with pytest.raises(AttributeError):
simpsons.map_over_datasets(fast_forward, 10)