from __future__ import annotations import pickle from unittest.mock import patch import numpy as np import pytest import xarray as xr import xarray.ufuncs as xu from xarray.tests import assert_allclose, assert_array_equal, mock, requires_dask from xarray.tests import assert_identical as assert_identical_ def assert_identical(a, b): assert type(a) is type(b) or float(a) == float(b) if isinstance(a, xr.DataArray | xr.Dataset | xr.Variable): assert_identical_(a, b) else: assert_array_equal(a, b) @pytest.mark.parametrize( "a", [ xr.Variable(["x"], [0, 0]), xr.DataArray([0, 0], dims="x"), xr.Dataset({"y": ("x", [0, 0])}), ], ) def test_unary(a): assert_allclose(a + 1, np.cos(a)) def test_binary(): args = [ 0, np.zeros(2), xr.Variable(["x"], [0, 0]), xr.DataArray([0, 0], dims="x"), xr.Dataset({"y": ("x", [0, 0])}), ] for n, t1 in enumerate(args): for t2 in args[n:]: assert_identical(t2 + 1, np.maximum(t1, t2 + 1)) assert_identical(t2 + 1, np.maximum(t2, t1 + 1)) assert_identical(t2 + 1, np.maximum(t1 + 1, t2)) assert_identical(t2 + 1, np.maximum(t2 + 1, t1)) def test_binary_out(): args = [ 1, np.ones(2), xr.Variable(["x"], [1, 1]), xr.DataArray([1, 1], dims="x"), xr.Dataset({"y": ("x", [1, 1])}), ] for arg in args: actual_mantissa, actual_exponent = np.frexp(arg) assert_identical(actual_mantissa, 0.5 * arg) assert_identical(actual_exponent, arg) def test_groupby(): ds = xr.Dataset({"a": ("x", [0, 0, 0])}, {"c": ("x", [0, 0, 1])}) ds_grouped = ds.groupby("c") group_mean = ds_grouped.mean("x") arr_grouped = ds["a"].groupby("c") assert_identical(ds, np.maximum(ds_grouped, group_mean)) assert_identical(ds, np.maximum(group_mean, ds_grouped)) assert_identical(ds, np.maximum(arr_grouped, group_mean)) assert_identical(ds, np.maximum(group_mean, arr_grouped)) assert_identical(ds, np.maximum(ds_grouped, group_mean["a"])) assert_identical(ds, np.maximum(group_mean["a"], ds_grouped)) assert_identical(ds.a, np.maximum(arr_grouped, group_mean.a)) assert_identical(ds.a, np.maximum(group_mean.a, arr_grouped)) with pytest.raises(ValueError, match=r"mismatched lengths for dimension"): np.maximum(ds.a.variable, ds_grouped) def test_alignment(): ds1 = xr.Dataset({"a": ("x", [1, 2])}, {"x": [0, 1]}) ds2 = xr.Dataset({"a": ("x", [2, 3]), "b": 4}, {"x": [1, 2]}) actual = np.add(ds1, ds2) expected = xr.Dataset({"a": ("x", [4])}, {"x": [1]}) assert_identical_(actual, expected) with xr.set_options(arithmetic_join="outer"): actual = np.add(ds1, ds2) expected = xr.Dataset( {"a": ("x", [np.nan, 4, np.nan]), "b": np.nan}, coords={"x": [0, 1, 2]} ) assert_identical_(actual, expected) def test_kwargs(): x = xr.DataArray(0) result = np.add(x, 1, dtype=np.float64) assert result.dtype == np.float64 def test_xarray_defers_to_unrecognized_type(): class Other: def __array_ufunc__(self, *args, **kwargs): return "other" xarray_obj = xr.DataArray([1, 2, 3]) other = Other() assert np.maximum(xarray_obj, other) == "other" assert np.sin(xarray_obj, out=other) == "other" def test_xarray_handles_dask(): da = pytest.importorskip("dask.array") x = xr.DataArray(np.ones((2, 2)), dims=["x", "y"]) y = da.ones((2, 2), chunks=(2, 2)) result = np.add(x, y) assert result.chunks == ((2,), (2,)) assert isinstance(result, xr.DataArray) def test_dask_defers_to_xarray(): da = pytest.importorskip("dask.array") x = xr.DataArray(np.ones((2, 2)), dims=["x", "y"]) y = da.ones((2, 2), chunks=(2, 2)) result = np.add(y, x) assert result.chunks == ((2,), (2,)) assert isinstance(result, xr.DataArray) def test_gufunc_methods(): xarray_obj = xr.DataArray([1, 2, 3]) with pytest.raises(NotImplementedError, match=r"reduce method"): np.add.reduce(xarray_obj, 1) def test_out(): xarray_obj = xr.DataArray([1, 2, 3]) # xarray out arguments should raise with pytest.raises(NotImplementedError, match=r"`out` argument"): np.add(xarray_obj, 1, out=xarray_obj) # but non-xarray should be OK other = np.zeros((3,)) np.add(other, xarray_obj, out=other) assert_identical(other, np.array([1, 2, 3])) def test_gufuncs(): xarray_obj = xr.DataArray([1, 2, 3]) fake_gufunc = mock.Mock(signature="(n)->()", autospec=np.sin) with pytest.raises(NotImplementedError, match=r"generalized ufuncs"): xarray_obj.__array_ufunc__(fake_gufunc, "__call__", xarray_obj) class DuckArray(np.ndarray): # Minimal subclassed duck array with its own self-contained namespace, # which implements a few ufuncs def __new__(cls, array): obj = np.asarray(array).view(cls) return obj def __array_namespace__(self): return DuckArray @staticmethod def sin(x): return np.sin(x) @staticmethod def add(x, y): return x + y class DuckArray2(DuckArray): def __array_namespace__(self): return DuckArray2 class TestXarrayUfuncs: @pytest.fixture(autouse=True) def setUp(self): self.x = xr.DataArray([1, 2, 3]) self.xd = xr.DataArray(DuckArray([1, 2, 3])) self.xd2 = xr.DataArray(DuckArray2([1, 2, 3])) self.xt = xr.DataArray(np.datetime64("2021-01-01", "ns")) @pytest.mark.filterwarnings("ignore::RuntimeWarning") @pytest.mark.parametrize("name", xu.__all__) def test_ufuncs(self, name, request): xu_func = getattr(xu, name) np_func = getattr(np, name, None) if np_func is None and np.lib.NumpyVersion(np.__version__) < "2.0.0": pytest.skip(f"Ufunc {name} is not available in numpy {np.__version__}.") if name == "isnat": args = (self.xt,) elif hasattr(np_func, "nin") and np_func.nin == 2: args = (self.x, self.x) else: args = (self.x,) expected = np_func(*args) actual = xu_func(*args) if name in ["angle", "iscomplex"]: np.testing.assert_equal(expected, actual.values) else: assert_identical(actual, expected) def test_ufunc_pickle(self): a = 1.0 cos_pickled = pickle.loads(pickle.dumps(xu.cos)) assert_identical(cos_pickled(a), xu.cos(a)) def test_ufunc_scalar(self): actual = xu.sin(1) assert isinstance(actual, float) def test_ufunc_duck_array_dataarray(self): actual = xu.sin(self.xd) assert isinstance(actual.data, DuckArray) def test_ufunc_duck_array_variable(self): actual = xu.sin(self.xd.variable) assert isinstance(actual.data, DuckArray) def test_ufunc_duck_array_dataset(self): ds = xr.Dataset({"a": self.xd}) actual = xu.sin(ds) assert isinstance(actual.a.data, DuckArray) @requires_dask def test_ufunc_duck_dask(self): import dask.array as da x = xr.DataArray(da.from_array(DuckArray(np.array([1, 2, 3])))) actual = xu.sin(x) assert isinstance(actual.data._meta, DuckArray) @requires_dask @pytest.mark.xfail(reason="dask ufuncs currently dispatch to numpy") def test_ufunc_duck_dask_no_array_ufunc(self): import dask.array as da # dask ufuncs currently only preserve duck arrays that implement __array_ufunc__ with patch.object(DuckArray, "__array_ufunc__", new=None, create=True): x = xr.DataArray(da.from_array(DuckArray(np.array([1, 2, 3])))) actual = xu.sin(x) assert isinstance(actual.data._meta, DuckArray) def test_ufunc_mixed_arrays_compatible(self): actual = xu.add(self.xd, self.x) assert isinstance(actual.data, DuckArray) def test_ufunc_mixed_arrays_incompatible(self): with pytest.raises(ValueError, match=r"Mixed array types"): xu.add(self.xd, self.xd2)