CCR/.venv/lib/python3.12/site-packages/xarray/tests/test_weighted.py

794 lines
24 KiB
Python

from __future__ import annotations
from collections.abc import Iterable
from typing import Any
import numpy as np
import pytest
import xarray as xr
from xarray import DataArray, Dataset
from xarray.tests import (
assert_allclose,
assert_equal,
raise_if_dask_computes,
requires_cftime,
requires_dask,
)
@pytest.mark.parametrize("as_dataset", (True, False))
def test_weighted_non_DataArray_weights(as_dataset: bool) -> None:
data: DataArray | Dataset = DataArray([1, 2])
if as_dataset:
data = data.to_dataset(name="data")
with pytest.raises(ValueError, match=r"`weights` must be a DataArray"):
data.weighted([1, 2]) # type: ignore[arg-type]
@pytest.mark.parametrize("as_dataset", (True, False))
@pytest.mark.parametrize("weights", ([np.nan, 2], [np.nan, np.nan]))
def test_weighted_weights_nan_raises(as_dataset: bool, weights: list[float]) -> None:
data: DataArray | Dataset = DataArray([1, 2])
if as_dataset:
data = data.to_dataset(name="data")
with pytest.raises(ValueError, match="`weights` cannot contain missing values."):
data.weighted(DataArray(weights))
@requires_dask
@pytest.mark.parametrize("as_dataset", (True, False))
@pytest.mark.parametrize("weights", ([np.nan, 2], [np.nan, np.nan]))
def test_weighted_weights_nan_raises_dask(as_dataset, weights):
data = DataArray([1, 2]).chunk({"dim_0": -1})
if as_dataset:
data = data.to_dataset(name="data")
weights = DataArray(weights).chunk({"dim_0": -1})
with raise_if_dask_computes():
weighted = data.weighted(weights)
with pytest.raises(ValueError, match="`weights` cannot contain missing values."):
weighted.sum().load()
@requires_cftime
@requires_dask
@pytest.mark.parametrize("time_chunks", (1, 5))
@pytest.mark.parametrize("resample_spec", ("1YS", "5YS", "10YS"))
def test_weighted_lazy_resample(time_chunks, resample_spec):
# https://github.com/pydata/xarray/issues/4625
# simple customized weighted mean function
def mean_func(ds):
return ds.weighted(ds.weights).mean("time")
# example dataset
t = xr.cftime_range(start="2000", periods=20, freq="1YS")
weights = xr.DataArray(np.random.rand(len(t)), dims=["time"], coords={"time": t})
data = xr.DataArray(
np.random.rand(len(t)), dims=["time"], coords={"time": t, "weights": weights}
)
ds = xr.Dataset({"data": data}).chunk({"time": time_chunks})
with raise_if_dask_computes():
ds.resample(time=resample_spec).map(mean_func)
@pytest.mark.parametrize(
("weights", "expected"),
(([1, 2], 3), ([2, 0], 2), ([0, 0], np.nan), ([-1, 1], np.nan)),
)
def test_weighted_sum_of_weights_no_nan(weights, expected):
da = DataArray([1, 2])
weights = DataArray(weights)
result = da.weighted(weights).sum_of_weights()
expected = DataArray(expected)
assert_equal(expected, result)
@pytest.mark.parametrize(
("weights", "expected"),
(([1, 2], 2), ([2, 0], np.nan), ([0, 0], np.nan), ([-1, 1], 1)),
)
def test_weighted_sum_of_weights_nan(weights, expected):
da = DataArray([np.nan, 2])
weights = DataArray(weights)
result = da.weighted(weights).sum_of_weights()
expected = DataArray(expected)
assert_equal(expected, result)
def test_weighted_sum_of_weights_bool():
# https://github.com/pydata/xarray/issues/4074
da = DataArray([1, 2])
weights = DataArray([True, True])
result = da.weighted(weights).sum_of_weights()
expected = DataArray(2)
assert_equal(expected, result)
@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan]))
@pytest.mark.parametrize("factor", [0, 1, 3.14])
@pytest.mark.parametrize("skipna", (True, False))
def test_weighted_sum_equal_weights(da, factor, skipna):
# if all weights are 'f'; weighted sum is f times the ordinary sum
da = DataArray(da)
weights = xr.full_like(da, factor)
expected = da.sum(skipna=skipna) * factor
result = da.weighted(weights).sum(skipna=skipna)
assert_equal(expected, result)
@pytest.mark.parametrize(
("weights", "expected"), (([1, 2], 5), ([0, 2], 4), ([0, 0], 0))
)
def test_weighted_sum_no_nan(weights, expected):
da = DataArray([1, 2])
weights = DataArray(weights)
result = da.weighted(weights).sum()
expected = DataArray(expected)
assert_equal(expected, result)
@pytest.mark.parametrize(
("weights", "expected"), (([1, 2], 4), ([0, 2], 4), ([1, 0], 0), ([0, 0], 0))
)
@pytest.mark.parametrize("skipna", (True, False))
def test_weighted_sum_nan(weights, expected, skipna):
da = DataArray([np.nan, 2])
weights = DataArray(weights)
result = da.weighted(weights).sum(skipna=skipna)
if skipna:
expected = DataArray(expected)
else:
expected = DataArray(np.nan)
assert_equal(expected, result)
@pytest.mark.filterwarnings("error")
@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan]))
@pytest.mark.parametrize("skipna", (True, False))
@pytest.mark.parametrize("factor", [1, 2, 3.14])
def test_weighted_mean_equal_weights(da, skipna, factor):
# if all weights are equal (!= 0), should yield the same result as mean
da = DataArray(da)
# all weights as 1.
weights = xr.full_like(da, factor)
expected = da.mean(skipna=skipna)
result = da.weighted(weights).mean(skipna=skipna)
assert_equal(expected, result)
@pytest.mark.parametrize(
("weights", "expected"), (([4, 6], 1.6), ([1, 0], 1.0), ([0, 0], np.nan))
)
def test_weighted_mean_no_nan(weights, expected):
da = DataArray([1, 2])
weights = DataArray(weights)
expected = DataArray(expected)
result = da.weighted(weights).mean()
assert_equal(expected, result)
@pytest.mark.parametrize(
("weights", "expected"),
(
(
[0.25, 0.05, 0.15, 0.25, 0.15, 0.1, 0.05],
[1.554595, 2.463784, 3.000000, 3.518378],
),
(
[0.05, 0.05, 0.1, 0.15, 0.15, 0.25, 0.25],
[2.840000, 3.632973, 4.076216, 4.523243],
),
),
)
def test_weighted_quantile_no_nan(weights, expected):
# Expected values were calculated by running the reference implementation
# proposed in https://aakinshin.net/posts/weighted-quantiles/
da = DataArray([1, 1.9, 2.2, 3, 3.7, 4.1, 5])
q = [0.2, 0.4, 0.6, 0.8]
weights = DataArray(weights)
expected = DataArray(expected, coords={"quantile": q})
result = da.weighted(weights).quantile(q)
assert_allclose(expected, result)
def test_weighted_quantile_zero_weights():
da = DataArray([0, 1, 2, 3])
weights = DataArray([1, 0, 1, 0])
q = 0.75
result = da.weighted(weights).quantile(q)
expected = DataArray([0, 2]).quantile(0.75)
assert_allclose(expected, result)
def test_weighted_quantile_simple():
# Check that weighted quantiles return the same value as numpy quantiles
da = DataArray([0, 1, 2, 3])
w = DataArray([1, 0, 1, 0])
w_eps = DataArray([1, 0.0001, 1, 0.0001])
q = 0.75
expected = DataArray(np.quantile([0, 2], q), coords={"quantile": q}) # 1.5
assert_equal(expected, da.weighted(w).quantile(q))
assert_allclose(expected, da.weighted(w_eps).quantile(q), rtol=0.001)
@pytest.mark.parametrize("skipna", (True, False))
def test_weighted_quantile_nan(skipna):
# Check skipna behavior
da = DataArray([0, 1, 2, 3, np.nan])
w = DataArray([1, 0, 1, 0, 1])
q = [0.5, 0.75]
result = da.weighted(w).quantile(q, skipna=skipna)
if skipna:
expected = DataArray(np.quantile([0, 2], q), coords={"quantile": q})
else:
expected = DataArray(np.full(len(q), np.nan), coords={"quantile": q})
assert_allclose(expected, result)
@pytest.mark.parametrize(
"da",
(
pytest.param([1, 1.9, 2.2, 3, 3.7, 4.1, 5], id="nonan"),
pytest.param([1, 1.9, 2.2, 3, 3.7, 4.1, np.nan], id="singlenan"),
pytest.param(
[np.nan, np.nan, np.nan],
id="allnan",
marks=pytest.mark.filterwarnings(
"ignore:All-NaN slice encountered:RuntimeWarning"
),
),
),
)
@pytest.mark.parametrize("q", (0.5, (0.2, 0.8)))
@pytest.mark.parametrize("skipna", (True, False))
@pytest.mark.parametrize("factor", [1, 3.14])
def test_weighted_quantile_equal_weights(
da: list[float], q: float | tuple[float, ...], skipna: bool, factor: float
) -> None:
# if all weights are equal (!= 0), should yield the same result as quantile
data = DataArray(da)
weights = xr.full_like(data, factor)
expected = data.quantile(q, skipna=skipna)
result = data.weighted(weights).quantile(q, skipna=skipna)
assert_allclose(expected, result)
@pytest.mark.skip(reason="`method` argument is not currently exposed")
@pytest.mark.parametrize(
"da",
(
[1, 1.9, 2.2, 3, 3.7, 4.1, 5],
[1, 1.9, 2.2, 3, 3.7, 4.1, np.nan],
[np.nan, np.nan, np.nan],
),
)
@pytest.mark.parametrize("q", (0.5, (0.2, 0.8)))
@pytest.mark.parametrize("skipna", (True, False))
@pytest.mark.parametrize(
"method",
[
"linear",
"interpolated_inverted_cdf",
"hazen",
"weibull",
"median_unbiased",
"normal_unbiased2",
],
)
def test_weighted_quantile_equal_weights_all_methods(da, q, skipna, factor, method):
# If all weights are equal (!= 0), should yield the same result as numpy quantile
da = DataArray(da)
weights = xr.full_like(da, 3.14)
expected = da.quantile(q, skipna=skipna, method=method)
result = da.weighted(weights).quantile(q, skipna=skipna, method=method)
assert_allclose(expected, result)
def test_weighted_quantile_bool():
# https://github.com/pydata/xarray/issues/4074
da = DataArray([1, 1])
weights = DataArray([True, True])
q = 0.5
expected = DataArray([1], coords={"quantile": [q]}).squeeze()
result = da.weighted(weights).quantile(q)
assert_equal(expected, result)
@pytest.mark.parametrize("q", (-1, 1.1, (0.5, 1.1), ((0.2, 0.4), (0.6, 0.8))))
def test_weighted_quantile_with_invalid_q(q):
da = DataArray([1, 1.9, 2.2, 3, 3.7, 4.1, 5])
q = np.asarray(q)
weights = xr.ones_like(da)
if q.ndim <= 1:
with pytest.raises(ValueError, match="q values must be between 0 and 1"):
da.weighted(weights).quantile(q)
else:
with pytest.raises(ValueError, match="q must be a scalar or 1d"):
da.weighted(weights).quantile(q)
@pytest.mark.parametrize(
("weights", "expected"), (([4, 6], 2.0), ([1, 0], np.nan), ([0, 0], np.nan))
)
@pytest.mark.parametrize("skipna", (True, False))
def test_weighted_mean_nan(weights, expected, skipna):
da = DataArray([np.nan, 2])
weights = DataArray(weights)
if skipna:
expected = DataArray(expected)
else:
expected = DataArray(np.nan)
result = da.weighted(weights).mean(skipna=skipna)
assert_equal(expected, result)
def test_weighted_mean_bool():
# https://github.com/pydata/xarray/issues/4074
da = DataArray([1, 1])
weights = DataArray([True, True])
expected = DataArray(1)
result = da.weighted(weights).mean()
assert_equal(expected, result)
@pytest.mark.parametrize(
("weights", "expected"),
(([1, 2], 2 / 3), ([2, 0], 0), ([0, 0], 0), ([-1, 1], 0)),
)
def test_weighted_sum_of_squares_no_nan(weights, expected):
da = DataArray([1, 2])
weights = DataArray(weights)
result = da.weighted(weights).sum_of_squares()
expected = DataArray(expected)
assert_equal(expected, result)
@pytest.mark.parametrize(
("weights", "expected"),
(([1, 2], 0), ([2, 0], 0), ([0, 0], 0), ([-1, 1], 0)),
)
def test_weighted_sum_of_squares_nan(weights, expected):
da = DataArray([np.nan, 2])
weights = DataArray(weights)
result = da.weighted(weights).sum_of_squares()
expected = DataArray(expected)
assert_equal(expected, result)
@pytest.mark.filterwarnings("error")
@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan]))
@pytest.mark.parametrize("skipna", (True, False))
@pytest.mark.parametrize("factor", [1, 2, 3.14])
def test_weighted_var_equal_weights(da, skipna, factor):
# if all weights are equal (!= 0), should yield the same result as var
da = DataArray(da)
# all weights as 1.
weights = xr.full_like(da, factor)
expected = da.var(skipna=skipna)
result = da.weighted(weights).var(skipna=skipna)
assert_equal(expected, result)
@pytest.mark.parametrize(
("weights", "expected"), (([4, 6], 0.24), ([1, 0], 0.0), ([0, 0], np.nan))
)
def test_weighted_var_no_nan(weights, expected):
da = DataArray([1, 2])
weights = DataArray(weights)
expected = DataArray(expected)
result = da.weighted(weights).var()
assert_equal(expected, result)
@pytest.mark.parametrize(
("weights", "expected"), (([4, 6], 0), ([1, 0], np.nan), ([0, 0], np.nan))
)
def test_weighted_var_nan(weights, expected):
da = DataArray([np.nan, 2])
weights = DataArray(weights)
expected = DataArray(expected)
result = da.weighted(weights).var()
assert_equal(expected, result)
def test_weighted_var_bool():
# https://github.com/pydata/xarray/issues/4074
da = DataArray([1, 1])
weights = DataArray([True, True])
expected = DataArray(0)
result = da.weighted(weights).var()
assert_equal(expected, result)
@pytest.mark.filterwarnings("error")
@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan]))
@pytest.mark.parametrize("skipna", (True, False))
@pytest.mark.parametrize("factor", [1, 2, 3.14])
def test_weighted_std_equal_weights(da, skipna, factor):
# if all weights are equal (!= 0), should yield the same result as std
da = DataArray(da)
# all weights as 1.
weights = xr.full_like(da, factor)
expected = da.std(skipna=skipna)
result = da.weighted(weights).std(skipna=skipna)
assert_equal(expected, result)
@pytest.mark.parametrize(
("weights", "expected"), (([4, 6], np.sqrt(0.24)), ([1, 0], 0.0), ([0, 0], np.nan))
)
def test_weighted_std_no_nan(weights, expected):
da = DataArray([1, 2])
weights = DataArray(weights)
expected = DataArray(expected)
result = da.weighted(weights).std()
assert_equal(expected, result)
@pytest.mark.parametrize(
("weights", "expected"), (([4, 6], 0), ([1, 0], np.nan), ([0, 0], np.nan))
)
def test_weighted_std_nan(weights, expected):
da = DataArray([np.nan, 2])
weights = DataArray(weights)
expected = DataArray(expected)
result = da.weighted(weights).std()
assert_equal(expected, result)
def test_weighted_std_bool():
# https://github.com/pydata/xarray/issues/4074
da = DataArray([1, 1])
weights = DataArray([True, True])
expected = DataArray(0)
result = da.weighted(weights).std()
assert_equal(expected, result)
def expected_weighted(da, weights, dim, skipna, operation):
"""
Generate expected result using ``*`` and ``sum``. This is checked against
the result of da.weighted which uses ``dot``
"""
weighted_sum = (da * weights).sum(dim=dim, skipna=skipna)
if operation == "sum":
return weighted_sum
masked_weights = weights.where(da.notnull())
sum_of_weights = masked_weights.sum(dim=dim, skipna=True)
valid_weights = sum_of_weights != 0
sum_of_weights = sum_of_weights.where(valid_weights)
if operation == "sum_of_weights":
return sum_of_weights
weighted_mean = weighted_sum / sum_of_weights
if operation == "mean":
return weighted_mean
demeaned = da - weighted_mean
sum_of_squares = ((demeaned**2) * weights).sum(dim=dim, skipna=skipna)
if operation == "sum_of_squares":
return sum_of_squares
var = sum_of_squares / sum_of_weights
if operation == "var":
return var
if operation == "std":
return np.sqrt(var)
def check_weighted_operations(data, weights, dim, skipna):
# check sum of weights
result = data.weighted(weights).sum_of_weights(dim)
expected = expected_weighted(data, weights, dim, skipna, "sum_of_weights")
assert_allclose(expected, result)
# check weighted sum
result = data.weighted(weights).sum(dim, skipna=skipna)
expected = expected_weighted(data, weights, dim, skipna, "sum")
assert_allclose(expected, result)
# check weighted mean
result = data.weighted(weights).mean(dim, skipna=skipna)
expected = expected_weighted(data, weights, dim, skipna, "mean")
assert_allclose(expected, result)
# check weighted sum of squares
result = data.weighted(weights).sum_of_squares(dim, skipna=skipna)
expected = expected_weighted(data, weights, dim, skipna, "sum_of_squares")
assert_allclose(expected, result)
# check weighted var
result = data.weighted(weights).var(dim, skipna=skipna)
expected = expected_weighted(data, weights, dim, skipna, "var")
assert_allclose(expected, result)
# check weighted std
result = data.weighted(weights).std(dim, skipna=skipna)
expected = expected_weighted(data, weights, dim, skipna, "std")
assert_allclose(expected, result)
@pytest.mark.parametrize("dim", ("a", "b", "c", ("a", "b"), ("a", "b", "c"), None))
@pytest.mark.parametrize("add_nans", (True, False))
@pytest.mark.parametrize("skipna", (None, True, False))
@pytest.mark.filterwarnings("ignore:invalid value encountered in sqrt")
def test_weighted_operations_3D(dim, add_nans, skipna):
dims = ("a", "b", "c")
coords = dict(a=[0, 1, 2, 3], b=[0, 1, 2, 3], c=[0, 1, 2, 3])
weights = DataArray(np.random.randn(4, 4, 4), dims=dims, coords=coords)
data = np.random.randn(4, 4, 4)
# add approximately 25 % NaNs (https://stackoverflow.com/a/32182680/3010700)
if add_nans:
c = int(data.size * 0.25)
data.ravel()[np.random.choice(data.size, c, replace=False)] = np.nan
data = DataArray(data, dims=dims, coords=coords)
check_weighted_operations(data, weights, dim, skipna)
data = data.to_dataset(name="data")
check_weighted_operations(data, weights, dim, skipna)
@pytest.mark.parametrize("dim", ("a", "b", "c", ("a", "b"), ("a", "b", "c"), None))
@pytest.mark.parametrize("q", (0.5, (0.1, 0.9), (0.2, 0.4, 0.6, 0.8)))
@pytest.mark.parametrize("add_nans", (True, False))
@pytest.mark.parametrize("skipna", (None, True, False))
def test_weighted_quantile_3D(dim, q, add_nans, skipna):
dims = ("a", "b", "c")
coords = dict(a=[0, 1, 2], b=[0, 1, 2, 3], c=[0, 1, 2, 3, 4])
data = np.arange(60).reshape(3, 4, 5).astype(float)
# add approximately 25 % NaNs (https://stackoverflow.com/a/32182680/3010700)
if add_nans:
c = int(data.size * 0.25)
data.ravel()[np.random.choice(data.size, c, replace=False)] = np.nan
da = DataArray(data, dims=dims, coords=coords)
# Weights are all ones, because we will compare against DataArray.quantile (non-weighted)
weights = xr.ones_like(da)
result = da.weighted(weights).quantile(q, dim=dim, skipna=skipna)
expected = da.quantile(q, dim=dim, skipna=skipna)
assert_allclose(expected, result)
ds = da.to_dataset(name="data")
result2 = ds.weighted(weights).quantile(q, dim=dim, skipna=skipna)
assert_allclose(expected, result2.data)
@pytest.mark.parametrize(
"coords_weights, coords_data, expected_value_at_weighted_quantile",
[
([0, 1, 2, 3], [1, 2, 3, 4], 2.5), # no weights for coord a == 4
([0, 1, 2, 3], [2, 3, 4, 5], 1.8), # no weights for coord a == 4 or 5
([2, 3, 4, 5], [0, 1, 2, 3], 3.8), # no weights for coord a == 0 or 1
],
)
def test_weighted_operations_nonequal_coords(
coords_weights: Iterable[Any],
coords_data: Iterable[Any],
expected_value_at_weighted_quantile: float,
) -> None:
"""Check that weighted operations work with unequal coords.
Parameters
----------
coords_weights : Iterable[Any]
The coords for the weights.
coords_data : Iterable[Any]
The coords for the data.
expected_value_at_weighted_quantile : float
The expected value for the quantile of the weighted data.
"""
da_weights = DataArray(
[0.5, 1.0, 1.0, 2.0], dims=("a",), coords=dict(a=coords_weights)
)
da_data = DataArray([1, 2, 3, 4], dims=("a",), coords=dict(a=coords_data))
check_weighted_operations(da_data, da_weights, dim="a", skipna=None)
quantile = 0.5
da_actual = da_data.weighted(da_weights).quantile(quantile, dim="a")
da_expected = DataArray(
[expected_value_at_weighted_quantile], coords={"quantile": [quantile]}
).squeeze()
assert_allclose(da_actual, da_expected)
ds_data = da_data.to_dataset(name="data")
check_weighted_operations(ds_data, da_weights, dim="a", skipna=None)
ds_actual = ds_data.weighted(da_weights).quantile(quantile, dim="a")
assert_allclose(ds_actual, da_expected.to_dataset(name="data"))
@pytest.mark.parametrize("shape_data", ((4,), (4, 4), (4, 4, 4)))
@pytest.mark.parametrize("shape_weights", ((4,), (4, 4), (4, 4, 4)))
@pytest.mark.parametrize("add_nans", (True, False))
@pytest.mark.parametrize("skipna", (None, True, False))
@pytest.mark.filterwarnings("ignore:invalid value encountered in sqrt")
def test_weighted_operations_different_shapes(
shape_data, shape_weights, add_nans, skipna
):
weights = DataArray(np.random.randn(*shape_weights))
data = np.random.randn(*shape_data)
# add approximately 25 % NaNs
if add_nans:
c = int(data.size * 0.25)
data.ravel()[np.random.choice(data.size, c, replace=False)] = np.nan
data = DataArray(data)
check_weighted_operations(data, weights, "dim_0", skipna)
check_weighted_operations(data, weights, None, skipna)
data = data.to_dataset(name="data")
check_weighted_operations(data, weights, "dim_0", skipna)
check_weighted_operations(data, weights, None, skipna)
@pytest.mark.parametrize(
"operation",
("sum_of_weights", "sum", "mean", "sum_of_squares", "var", "std", "quantile"),
)
@pytest.mark.parametrize("as_dataset", (True, False))
@pytest.mark.parametrize("keep_attrs", (True, False, None))
def test_weighted_operations_keep_attr(operation, as_dataset, keep_attrs):
weights = DataArray(np.random.randn(2, 2), attrs=dict(attr="weights"))
data = DataArray(np.random.randn(2, 2))
if as_dataset:
data = data.to_dataset(name="data")
data.attrs = dict(attr="weights")
kwargs = {"keep_attrs": keep_attrs}
if operation == "quantile":
kwargs["q"] = 0.5
result = getattr(data.weighted(weights), operation)(**kwargs)
if operation == "sum_of_weights":
assert result.attrs == (weights.attrs if keep_attrs else {})
assert result.attrs == (weights.attrs if keep_attrs else {})
else:
assert result.attrs == (weights.attrs if keep_attrs else {})
assert result.attrs == (data.attrs if keep_attrs else {})
@pytest.mark.parametrize(
"operation",
("sum_of_weights", "sum", "mean", "sum_of_squares", "var", "std", "quantile"),
)
def test_weighted_operations_keep_attr_da_in_ds(operation):
# GH #3595
weights = DataArray(np.random.randn(2, 2))
data = DataArray(np.random.randn(2, 2), attrs=dict(attr="data"))
data = data.to_dataset(name="a")
kwargs = {"keep_attrs": True}
if operation == "quantile":
kwargs["q"] = 0.5
result = getattr(data.weighted(weights), operation)(**kwargs)
assert data.a.attrs == result.a.attrs
@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean", "quantile"))
@pytest.mark.parametrize("as_dataset", (True, False))
def test_weighted_bad_dim(operation, as_dataset):
data = DataArray(np.random.randn(2, 2))
weights = xr.ones_like(data)
if as_dataset:
data = data.to_dataset(name="data")
kwargs = {"dim": "bad_dim"}
if operation == "quantile":
kwargs["q"] = 0.5
with pytest.raises(
ValueError,
match=(
f"Dimensions \\('bad_dim',\\) not found in {data.__class__.__name__}Weighted "
# the order of (dim_0, dim_1) varies
"dimensions \\(('dim_0', 'dim_1'|'dim_1', 'dim_0')\\)"
),
):
getattr(data.weighted(weights), operation)(**kwargs)