369 lines
12 KiB
Python
369 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import importlib
|
|
import platform
|
|
import string
|
|
import warnings
|
|
from contextlib import contextmanager, nullcontext
|
|
from unittest import mock # noqa: F401
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pytest
|
|
from numpy.testing import assert_array_equal # noqa: F401
|
|
from packaging.version import Version
|
|
from pandas.testing import assert_frame_equal # noqa: F401
|
|
|
|
import xarray.testing
|
|
from xarray import Dataset
|
|
from xarray.core.duck_array_ops import allclose_or_equiv # noqa: F401
|
|
from xarray.core.extension_array import PandasExtensionArray
|
|
from xarray.core.options import set_options
|
|
from xarray.core.variable import IndexVariable
|
|
from xarray.testing import ( # noqa: F401
|
|
assert_chunks_equal,
|
|
assert_duckarray_allclose,
|
|
assert_duckarray_equal,
|
|
)
|
|
from xarray.tests.arrays import ( # noqa: F401
|
|
ConcatenatableArray,
|
|
DuckArrayWrapper,
|
|
FirstElementAccessibleArray,
|
|
InaccessibleArray,
|
|
UnexpectedDataAccess,
|
|
)
|
|
|
|
# import mpl and change the backend before other mpl imports
|
|
try:
|
|
import matplotlib as mpl
|
|
|
|
# Order of imports is important here.
|
|
# Using a different backend makes Travis CI work
|
|
mpl.use("Agg")
|
|
except ImportError:
|
|
pass
|
|
|
|
# https://github.com/pydata/xarray/issues/7322
|
|
warnings.filterwarnings("ignore", "'urllib3.contrib.pyopenssl' module is deprecated")
|
|
warnings.filterwarnings("ignore", "Deprecated call to `pkg_resources.declare_namespace")
|
|
warnings.filterwarnings("ignore", "pkg_resources is deprecated as an API")
|
|
|
|
arm_xfail = pytest.mark.xfail(
|
|
platform.machine() == "aarch64" or "arm" in platform.machine(),
|
|
reason="expected failure on ARM",
|
|
)
|
|
|
|
|
|
def assert_writeable(ds):
|
|
readonly = [
|
|
name
|
|
for name, var in ds.variables.items()
|
|
if not isinstance(var, IndexVariable)
|
|
and not isinstance(var.data, PandasExtensionArray)
|
|
and not var.data.flags.writeable
|
|
]
|
|
assert not readonly, readonly
|
|
|
|
|
|
def _importorskip(
|
|
modname: str, minversion: str | None = None
|
|
) -> tuple[bool, pytest.MarkDecorator]:
|
|
try:
|
|
mod = importlib.import_module(modname)
|
|
has = True
|
|
if minversion is not None:
|
|
v = getattr(mod, "__version__", "999")
|
|
if Version(v) < Version(minversion):
|
|
raise ImportError("Minimum version not satisfied")
|
|
except ImportError:
|
|
has = False
|
|
|
|
reason = f"requires {modname}"
|
|
if minversion is not None:
|
|
reason += f">={minversion}"
|
|
func = pytest.mark.skipif(not has, reason=reason)
|
|
return has, func
|
|
|
|
|
|
has_matplotlib, requires_matplotlib = _importorskip("matplotlib")
|
|
has_scipy, requires_scipy = _importorskip("scipy")
|
|
has_scipy_ge_1_13, requires_scipy_ge_1_13 = _importorskip("scipy", "1.13")
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
message="'cgi' is deprecated and slated for removal in Python 3.13",
|
|
category=DeprecationWarning,
|
|
)
|
|
has_pydap, requires_pydap = _importorskip("pydap.client")
|
|
has_netCDF4, requires_netCDF4 = _importorskip("netCDF4")
|
|
with warnings.catch_warnings():
|
|
# see https://github.com/pydata/xarray/issues/8537
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
message="h5py is running against HDF5 1.14.3",
|
|
category=UserWarning,
|
|
)
|
|
|
|
has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf")
|
|
has_cftime, requires_cftime = _importorskip("cftime")
|
|
has_dask, requires_dask = _importorskip("dask")
|
|
has_dask_ge_2024_08_1, requires_dask_ge_2024_08_1 = _importorskip(
|
|
"dask", minversion="2024.08.1"
|
|
)
|
|
has_dask_ge_2024_11_0, requires_dask_ge_2024_11_0 = _importorskip("dask", "2024.11.0")
|
|
has_dask_ge_2025_1_0, requires_dask_ge_2025_1_0 = _importorskip("dask", "2025.1.0")
|
|
if has_dask_ge_2025_1_0:
|
|
has_dask_expr = True
|
|
requires_dask_expr = pytest.mark.skipif(not has_dask_expr, reason="should not skip")
|
|
else:
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
message="The current Dask DataFrame implementation is deprecated.",
|
|
category=DeprecationWarning,
|
|
)
|
|
has_dask_expr, requires_dask_expr = _importorskip("dask_expr")
|
|
has_bottleneck, requires_bottleneck = _importorskip("bottleneck")
|
|
has_rasterio, requires_rasterio = _importorskip("rasterio")
|
|
has_zarr, requires_zarr = _importorskip("zarr")
|
|
# TODO: switch to "3" once Zarr V3 is released
|
|
has_zarr_v3, requires_zarr_v3 = _importorskip("zarr", "2.99")
|
|
has_fsspec, requires_fsspec = _importorskip("fsspec")
|
|
has_iris, requires_iris = _importorskip("iris")
|
|
has_numbagg, requires_numbagg = _importorskip("numbagg")
|
|
has_pyarrow, requires_pyarrow = _importorskip("pyarrow")
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
message="is_categorical_dtype is deprecated and will be removed in a future version.",
|
|
category=DeprecationWarning,
|
|
)
|
|
# seaborn uses the deprecated `pandas.is_categorical_dtype`
|
|
has_seaborn, requires_seaborn = _importorskip("seaborn")
|
|
has_sparse, requires_sparse = _importorskip("sparse")
|
|
has_cupy, requires_cupy = _importorskip("cupy")
|
|
has_cartopy, requires_cartopy = _importorskip("cartopy")
|
|
has_pint, requires_pint = _importorskip("pint")
|
|
has_numexpr, requires_numexpr = _importorskip("numexpr")
|
|
has_flox, requires_flox = _importorskip("flox")
|
|
has_netcdf, requires_netcdf = _importorskip("netcdf")
|
|
has_pandas_ge_2_2, requires_pandas_ge_2_2 = _importorskip("pandas", "2.2")
|
|
has_pandas_3, requires_pandas_3 = _importorskip("pandas", "3.0.0.dev0")
|
|
|
|
|
|
# some special cases
|
|
has_scipy_or_netCDF4 = has_scipy or has_netCDF4
|
|
requires_scipy_or_netCDF4 = pytest.mark.skipif(
|
|
not has_scipy_or_netCDF4, reason="requires scipy or netCDF4"
|
|
)
|
|
has_numbagg_or_bottleneck = has_numbagg or has_bottleneck
|
|
requires_numbagg_or_bottleneck = pytest.mark.skipif(
|
|
not has_numbagg_or_bottleneck, reason="requires numbagg or bottleneck"
|
|
)
|
|
has_numpy_2, requires_numpy_2 = _importorskip("numpy", "2.0.0")
|
|
has_flox_0_9_12, requires_flox_0_9_12 = _importorskip("flox", "0.9.12")
|
|
|
|
has_array_api_strict, requires_array_api_strict = _importorskip("array_api_strict")
|
|
|
|
|
|
def _importorskip_h5netcdf_ros3(has_h5netcdf: bool):
|
|
if not has_h5netcdf:
|
|
return has_h5netcdf, pytest.mark.skipif(
|
|
not has_h5netcdf, reason="requires h5netcdf"
|
|
)
|
|
|
|
import h5py
|
|
|
|
h5py_with_ros3 = h5py.get_config().ros3
|
|
|
|
return h5py_with_ros3, pytest.mark.skipif(
|
|
not h5py_with_ros3,
|
|
reason="requires h5netcdf>=1.3.0 and h5py with ros3 support",
|
|
)
|
|
|
|
|
|
has_h5netcdf_ros3, requires_h5netcdf_ros3 = _importorskip_h5netcdf_ros3(has_h5netcdf)
|
|
has_netCDF4_1_6_2_or_above, requires_netCDF4_1_6_2_or_above = _importorskip(
|
|
"netCDF4", "1.6.2"
|
|
)
|
|
|
|
has_h5netcdf_1_4_0_or_above, requires_h5netcdf_1_4_0_or_above = _importorskip(
|
|
"h5netcdf", "1.4.0.dev"
|
|
)
|
|
|
|
has_netCDF4_1_7_0_or_above, requires_netCDF4_1_7_0_or_above = _importorskip(
|
|
"netCDF4", "1.7.0"
|
|
)
|
|
|
|
# change some global options for tests
|
|
set_options(warn_for_unclosed_files=True)
|
|
|
|
if has_dask:
|
|
import dask
|
|
|
|
|
|
class CountingScheduler:
|
|
"""Simple dask scheduler counting the number of computes.
|
|
|
|
Reference: https://stackoverflow.com/questions/53289286/"""
|
|
|
|
def __init__(self, max_computes=0):
|
|
self.total_computes = 0
|
|
self.max_computes = max_computes
|
|
|
|
def __call__(self, dsk, keys, **kwargs):
|
|
self.total_computes += 1
|
|
if self.total_computes > self.max_computes:
|
|
raise RuntimeError(
|
|
f"Too many computes. Total: {self.total_computes} > max: {self.max_computes}."
|
|
)
|
|
return dask.get(dsk, keys, **kwargs)
|
|
|
|
|
|
def raise_if_dask_computes(max_computes=0):
|
|
# return a dummy context manager so that this can be used for non-dask objects
|
|
if not has_dask:
|
|
return nullcontext()
|
|
scheduler = CountingScheduler(max_computes)
|
|
return dask.config.set(scheduler=scheduler)
|
|
|
|
|
|
flaky = pytest.mark.flaky
|
|
network = pytest.mark.network
|
|
|
|
|
|
class ReturnItem:
|
|
def __getitem__(self, key):
|
|
return key
|
|
|
|
|
|
class IndexerMaker:
|
|
def __init__(self, indexer_cls):
|
|
self._indexer_cls = indexer_cls
|
|
|
|
def __getitem__(self, key):
|
|
if not isinstance(key, tuple):
|
|
key = (key,)
|
|
return self._indexer_cls(key)
|
|
|
|
|
|
def source_ndarray(array):
|
|
"""Given an ndarray, return the base object which holds its memory, or the
|
|
object itself.
|
|
"""
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings("ignore", "DatetimeIndex.base")
|
|
warnings.filterwarnings("ignore", "TimedeltaIndex.base")
|
|
base = getattr(array, "base", np.asarray(array).base)
|
|
if base is None:
|
|
base = array
|
|
return base
|
|
|
|
|
|
def format_record(record) -> str:
|
|
"""Format warning record like `FutureWarning('Function will be deprecated...')`"""
|
|
return f"{str(record.category)[8:-2]}('{record.message}'))"
|
|
|
|
|
|
@contextmanager
|
|
def assert_no_warnings():
|
|
with warnings.catch_warnings(record=True) as record:
|
|
yield record
|
|
assert (
|
|
len(record) == 0
|
|
), f"Got {len(record)} unexpected warning(s): {[format_record(r) for r in record]}"
|
|
|
|
|
|
# Internal versions of xarray's test functions that validate additional
|
|
# invariants
|
|
|
|
|
|
def assert_equal(a, b, check_default_indexes=True):
|
|
__tracebackhide__ = True
|
|
xarray.testing.assert_equal(a, b)
|
|
xarray.testing._assert_internal_invariants(a, check_default_indexes)
|
|
xarray.testing._assert_internal_invariants(b, check_default_indexes)
|
|
|
|
|
|
def assert_identical(a, b, check_default_indexes=True):
|
|
__tracebackhide__ = True
|
|
xarray.testing.assert_identical(a, b)
|
|
xarray.testing._assert_internal_invariants(a, check_default_indexes)
|
|
xarray.testing._assert_internal_invariants(b, check_default_indexes)
|
|
|
|
|
|
def assert_allclose(a, b, check_default_indexes=True, **kwargs):
|
|
__tracebackhide__ = True
|
|
xarray.testing.assert_allclose(a, b, **kwargs)
|
|
xarray.testing._assert_internal_invariants(a, check_default_indexes)
|
|
xarray.testing._assert_internal_invariants(b, check_default_indexes)
|
|
|
|
|
|
_DEFAULT_TEST_DIM_SIZES = (8, 9, 10)
|
|
|
|
|
|
def create_test_data(
|
|
seed: int = 12345,
|
|
add_attrs: bool = True,
|
|
dim_sizes: tuple[int, int, int] = _DEFAULT_TEST_DIM_SIZES,
|
|
use_extension_array: bool = False,
|
|
) -> Dataset:
|
|
rs = np.random.default_rng(seed)
|
|
_vars = {
|
|
"var1": ["dim1", "dim2"],
|
|
"var2": ["dim1", "dim2"],
|
|
"var3": ["dim3", "dim1"],
|
|
}
|
|
_dims = {"dim1": dim_sizes[0], "dim2": dim_sizes[1], "dim3": dim_sizes[2]}
|
|
|
|
obj = Dataset()
|
|
obj["dim2"] = ("dim2", 0.5 * np.arange(_dims["dim2"]))
|
|
if _dims["dim3"] > 26:
|
|
raise RuntimeError(
|
|
f'Not enough letters for filling this dimension size ({_dims["dim3"]})'
|
|
)
|
|
obj["dim3"] = ("dim3", list(string.ascii_lowercase[0 : _dims["dim3"]]))
|
|
obj["time"] = (
|
|
"time",
|
|
pd.date_range(
|
|
"2000-01-01",
|
|
periods=20,
|
|
unit="ns",
|
|
),
|
|
)
|
|
for v, dims in sorted(_vars.items()):
|
|
data = rs.normal(size=tuple(_dims[d] for d in dims))
|
|
obj[v] = (dims, data)
|
|
if add_attrs:
|
|
obj[v].attrs = {"foo": "variable"}
|
|
if use_extension_array:
|
|
obj["var4"] = (
|
|
"dim1",
|
|
pd.Categorical(
|
|
rs.choice(
|
|
list(string.ascii_lowercase[: rs.integers(1, 5)]),
|
|
size=dim_sizes[0],
|
|
)
|
|
),
|
|
)
|
|
if dim_sizes == _DEFAULT_TEST_DIM_SIZES:
|
|
numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64")
|
|
else:
|
|
numbers_values = rs.integers(0, 3, _dims["dim3"], dtype="int64")
|
|
obj.coords["numbers"] = ("dim3", numbers_values)
|
|
obj.encoding = {"foo": "bar"}
|
|
assert_writeable(obj)
|
|
return obj
|
|
|
|
|
|
_CFTIME_CALENDARS = [
|
|
"365_day",
|
|
"360_day",
|
|
"julian",
|
|
"all_leap",
|
|
"366_day",
|
|
"gregorian",
|
|
"proleptic_gregorian",
|
|
"standard",
|
|
]
|