228 lines
6.1 KiB
Python
228 lines
6.1 KiB
Python
from __future__ import annotations
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pytest
|
|
|
|
import xarray as xr
|
|
from xarray import DataArray, Dataset, DataTree
|
|
from xarray.tests import create_test_data, requires_dask
|
|
|
|
|
|
@pytest.fixture(params=["numpy", pytest.param("dask", marks=requires_dask)])
|
|
def backend(request):
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture(params=["numbagg", "bottleneck", None])
|
|
def compute_backend(request):
|
|
if request.param is None:
|
|
options = dict(use_bottleneck=False, use_numbagg=False)
|
|
elif request.param == "bottleneck":
|
|
options = dict(use_bottleneck=True, use_numbagg=False)
|
|
elif request.param == "numbagg":
|
|
options = dict(use_bottleneck=False, use_numbagg=True)
|
|
else:
|
|
raise ValueError
|
|
|
|
with xr.set_options(**options):
|
|
yield request.param
|
|
|
|
|
|
@pytest.fixture(params=[1])
|
|
def ds(request, backend):
|
|
if request.param == 1:
|
|
ds = Dataset(
|
|
dict(
|
|
z1=(["y", "x"], np.random.randn(2, 8)),
|
|
z2=(["time", "y"], np.random.randn(10, 2)),
|
|
),
|
|
dict(
|
|
x=("x", np.linspace(0, 1.0, 8)),
|
|
time=("time", np.linspace(0, 1.0, 10)),
|
|
c=("y", ["a", "b"]),
|
|
y=range(2),
|
|
),
|
|
)
|
|
elif request.param == 2:
|
|
ds = Dataset(
|
|
dict(
|
|
z1=(["time", "y"], np.random.randn(10, 2)),
|
|
z2=(["time"], np.random.randn(10)),
|
|
z3=(["x", "time"], np.random.randn(8, 10)),
|
|
),
|
|
dict(
|
|
x=("x", np.linspace(0, 1.0, 8)),
|
|
time=("time", np.linspace(0, 1.0, 10)),
|
|
c=("y", ["a", "b"]),
|
|
y=range(2),
|
|
),
|
|
)
|
|
elif request.param == 3:
|
|
ds = create_test_data()
|
|
else:
|
|
raise ValueError
|
|
|
|
if backend == "dask":
|
|
return ds.chunk()
|
|
|
|
return ds
|
|
|
|
|
|
@pytest.fixture(params=[1])
|
|
def da(request, backend):
|
|
if request.param == 1:
|
|
times = pd.date_range("2000-01-01", freq="1D", periods=21)
|
|
da = DataArray(
|
|
np.random.random((3, 21, 4)),
|
|
dims=("a", "time", "x"),
|
|
coords=dict(time=times),
|
|
)
|
|
|
|
if request.param == 2:
|
|
da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time")
|
|
|
|
if request.param == "repeating_ints":
|
|
da = DataArray(
|
|
np.tile(np.arange(12), 5).reshape(5, 4, 3),
|
|
coords={"x": list("abc"), "y": list("defg")},
|
|
dims=list("zyx"),
|
|
)
|
|
|
|
if backend == "dask":
|
|
return da.chunk()
|
|
elif backend == "numpy":
|
|
return da
|
|
else:
|
|
raise ValueError
|
|
|
|
|
|
@pytest.fixture(params=[Dataset, DataArray])
|
|
def type(request):
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture(params=[1])
|
|
def d(request, backend, type) -> DataArray | Dataset:
|
|
"""
|
|
For tests which can test either a DataArray or a Dataset.
|
|
"""
|
|
result: DataArray | Dataset
|
|
if request.param == 1:
|
|
ds = Dataset(
|
|
dict(
|
|
a=(["x", "z"], np.arange(24).reshape(2, 12)),
|
|
b=(["y", "z"], np.arange(100, 136).reshape(3, 12).astype(np.float64)),
|
|
),
|
|
dict(
|
|
x=("x", np.linspace(0, 1.0, 2)),
|
|
y=range(3),
|
|
z=("z", pd.date_range("2000-01-01", periods=12)),
|
|
w=("x", ["a", "b"]),
|
|
),
|
|
)
|
|
if type == DataArray:
|
|
result = ds["a"].assign_coords(w=ds.coords["w"])
|
|
elif type == Dataset:
|
|
result = ds
|
|
else:
|
|
raise ValueError
|
|
else:
|
|
raise ValueError
|
|
|
|
if backend == "dask":
|
|
return result.chunk()
|
|
elif backend == "numpy":
|
|
return result
|
|
else:
|
|
raise ValueError
|
|
|
|
|
|
@pytest.fixture
|
|
def byte_attrs_dataset():
|
|
"""For testing issue #9407"""
|
|
null_byte = b"\x00"
|
|
other_bytes = bytes(range(1, 256))
|
|
ds = Dataset({"x": 1}, coords={"x_coord": [1]})
|
|
ds["x"].attrs["null_byte"] = null_byte
|
|
ds["x"].attrs["other_bytes"] = other_bytes
|
|
|
|
expected = ds.copy()
|
|
expected["x"].attrs["null_byte"] = ""
|
|
expected["x"].attrs["other_bytes"] = other_bytes.decode(errors="replace")
|
|
|
|
return {
|
|
"input": ds,
|
|
"expected": expected,
|
|
"h5netcdf_error": r"Invalid value provided for attribute .*: .*\. Null characters .*",
|
|
}
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def create_test_datatree():
|
|
"""
|
|
Create a test datatree with this structure:
|
|
|
|
<xarray.DataTree>
|
|
|-- set1
|
|
| |-- <xarray.Dataset>
|
|
| | Dimensions: ()
|
|
| | Data variables:
|
|
| | a int64 0
|
|
| | b int64 1
|
|
| |-- set1
|
|
| |-- set2
|
|
|-- set2
|
|
| |-- <xarray.Dataset>
|
|
| | Dimensions: (x: 2)
|
|
| | Data variables:
|
|
| | a (x) int64 2, 3
|
|
| | b (x) int64 0.1, 0.2
|
|
| |-- set1
|
|
|-- set3
|
|
|-- <xarray.Dataset>
|
|
| Dimensions: (x: 2, y: 3)
|
|
| Data variables:
|
|
| a (y) int64 6, 7, 8
|
|
| set0 (x) int64 9, 10
|
|
|
|
The structure has deliberately repeated names of tags, variables, and
|
|
dimensions in order to better check for bugs caused by name conflicts.
|
|
"""
|
|
|
|
def _create_test_datatree(modify=lambda ds: ds):
|
|
set1_data = modify(xr.Dataset({"a": 0, "b": 1}))
|
|
set2_data = modify(xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])}))
|
|
root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])}))
|
|
|
|
root = DataTree.from_dict(
|
|
{
|
|
"/": root_data,
|
|
"/set1": set1_data,
|
|
"/set1/set1": None,
|
|
"/set1/set2": None,
|
|
"/set2": set2_data,
|
|
"/set2/set1": None,
|
|
"/set3": None,
|
|
}
|
|
)
|
|
|
|
return root
|
|
|
|
return _create_test_datatree
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def simple_datatree(create_test_datatree):
|
|
"""
|
|
Invoke create_test_datatree fixture (callback).
|
|
|
|
Returns a DataTree.
|
|
"""
|
|
return create_test_datatree()
|
|
|
|
|
|
@pytest.fixture(scope="module", params=["s", "ms", "us", "ns"])
|
|
def time_unit(request):
|
|
return request.param
|