328 lines
10 KiB
Python
328 lines
10 KiB
Python
"""isort:skip_file"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pickle
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
if TYPE_CHECKING:
|
|
import dask
|
|
import dask.array as da
|
|
import distributed
|
|
else:
|
|
dask = pytest.importorskip("dask")
|
|
da = pytest.importorskip("dask.array")
|
|
distributed = pytest.importorskip("distributed")
|
|
|
|
from dask.distributed import Client, Lock
|
|
from distributed.client import futures_of
|
|
from distributed.utils_test import ( # noqa: F401
|
|
cleanup,
|
|
client,
|
|
cluster,
|
|
cluster_fixture,
|
|
gen_cluster,
|
|
loop,
|
|
loop_in_thread,
|
|
)
|
|
|
|
import xarray as xr
|
|
from xarray.backends.locks import HDF5_LOCK, CombinedLock, SerializableLock
|
|
from xarray.tests import (
|
|
assert_allclose,
|
|
assert_identical,
|
|
has_h5netcdf,
|
|
has_netCDF4,
|
|
has_scipy,
|
|
requires_cftime,
|
|
requires_netCDF4,
|
|
requires_zarr,
|
|
)
|
|
from xarray.tests.test_backends import (
|
|
ON_WINDOWS,
|
|
create_tmp_file,
|
|
)
|
|
from xarray.tests.test_dataset import create_test_data
|
|
|
|
loop = loop # loop is an imported fixture, which flake8 has issues ack-ing
|
|
client = client # client is an imported fixture, which flake8 has issues ack-ing
|
|
|
|
|
|
@pytest.fixture
|
|
def tmp_netcdf_filename(tmpdir):
|
|
return str(tmpdir.join("testfile.nc"))
|
|
|
|
|
|
ENGINES = []
|
|
if has_scipy:
|
|
ENGINES.append("scipy")
|
|
if has_netCDF4:
|
|
ENGINES.append("netcdf4")
|
|
if has_h5netcdf:
|
|
ENGINES.append("h5netcdf")
|
|
|
|
NC_FORMATS = {
|
|
"netcdf4": [
|
|
"NETCDF3_CLASSIC",
|
|
"NETCDF3_64BIT_OFFSET",
|
|
"NETCDF3_64BIT_DATA",
|
|
"NETCDF4_CLASSIC",
|
|
"NETCDF4",
|
|
],
|
|
"scipy": ["NETCDF3_CLASSIC", "NETCDF3_64BIT"],
|
|
"h5netcdf": ["NETCDF4"],
|
|
}
|
|
|
|
ENGINES_AND_FORMATS = [
|
|
("netcdf4", "NETCDF3_CLASSIC"),
|
|
("netcdf4", "NETCDF4_CLASSIC"),
|
|
("netcdf4", "NETCDF4"),
|
|
("h5netcdf", "NETCDF4"),
|
|
("scipy", "NETCDF3_64BIT"),
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS)
|
|
def test_dask_distributed_netcdf_roundtrip(
|
|
loop, tmp_netcdf_filename, engine, nc_format
|
|
):
|
|
if engine not in ENGINES:
|
|
pytest.skip("engine not available")
|
|
|
|
chunks = {"dim1": 4, "dim2": 3, "dim3": 6}
|
|
|
|
with cluster() as (s, [a, b]):
|
|
with Client(s["address"], loop=loop):
|
|
original = create_test_data().chunk(chunks)
|
|
|
|
if engine == "scipy":
|
|
with pytest.raises(NotImplementedError):
|
|
original.to_netcdf(
|
|
tmp_netcdf_filename, engine=engine, format=nc_format
|
|
)
|
|
return
|
|
|
|
original.to_netcdf(tmp_netcdf_filename, engine=engine, format=nc_format)
|
|
|
|
with xr.open_dataset(
|
|
tmp_netcdf_filename, chunks=chunks, engine=engine
|
|
) as restored:
|
|
assert isinstance(restored.var1.data, da.Array)
|
|
computed = restored.compute()
|
|
assert_allclose(original, computed)
|
|
|
|
|
|
@requires_netCDF4
|
|
def test_dask_distributed_write_netcdf_with_dimensionless_variables(
|
|
loop, tmp_netcdf_filename
|
|
):
|
|
with cluster() as (s, [a, b]):
|
|
with Client(s["address"], loop=loop):
|
|
original = xr.Dataset({"x": da.zeros(())})
|
|
original.to_netcdf(tmp_netcdf_filename)
|
|
|
|
with xr.open_dataset(tmp_netcdf_filename) as actual:
|
|
assert actual.x.shape == ()
|
|
|
|
|
|
@requires_cftime
|
|
@requires_netCDF4
|
|
@pytest.mark.parametrize("parallel", (True, False))
|
|
def test_open_mfdataset_can_open_files_with_cftime_index(parallel, tmp_path):
|
|
T = xr.cftime_range("20010101", "20010501", calendar="360_day")
|
|
Lon = np.arange(100)
|
|
data = np.random.random((T.size, Lon.size))
|
|
da = xr.DataArray(data, coords={"time": T, "Lon": Lon}, name="test")
|
|
file_path = tmp_path / "test.nc"
|
|
da.to_netcdf(file_path)
|
|
with cluster() as (s, [a, b]):
|
|
with Client(s["address"]):
|
|
with xr.open_mfdataset(file_path, parallel=parallel) as tf:
|
|
assert_identical(tf["test"], da)
|
|
|
|
|
|
@requires_cftime
|
|
@requires_netCDF4
|
|
@pytest.mark.parametrize("parallel", (True, False))
|
|
def test_open_mfdataset_multiple_files_parallel_distributed(parallel, tmp_path):
|
|
lon = np.arange(100)
|
|
time = xr.cftime_range("20010101", periods=100, calendar="360_day")
|
|
data = np.random.random((time.size, lon.size))
|
|
da = xr.DataArray(data, coords={"time": time, "lon": lon}, name="test")
|
|
|
|
fnames = []
|
|
for i in range(0, 100, 10):
|
|
fname = tmp_path / f"test_{i}.nc"
|
|
da.isel(time=slice(i, i + 10)).to_netcdf(fname)
|
|
fnames.append(fname)
|
|
|
|
with cluster() as (s, [a, b]):
|
|
with Client(s["address"]):
|
|
with xr.open_mfdataset(
|
|
fnames, parallel=parallel, concat_dim="time", combine="nested"
|
|
) as tf:
|
|
assert_identical(tf["test"], da)
|
|
|
|
|
|
# TODO: move this to test_backends.py
|
|
@requires_cftime
|
|
@requires_netCDF4
|
|
@pytest.mark.parametrize("parallel", (True, False))
|
|
def test_open_mfdataset_multiple_files_parallel(parallel, tmp_path):
|
|
if parallel:
|
|
pytest.skip(
|
|
"Flaky in CI. Would be a welcome contribution to make a similar test reliable."
|
|
)
|
|
lon = np.arange(100)
|
|
time = xr.cftime_range("20010101", periods=100, calendar="360_day")
|
|
data = np.random.random((time.size, lon.size))
|
|
da = xr.DataArray(data, coords={"time": time, "lon": lon}, name="test")
|
|
|
|
fnames = []
|
|
for i in range(0, 100, 10):
|
|
fname = tmp_path / f"test_{i}.nc"
|
|
da.isel(time=slice(i, i + 10)).to_netcdf(fname)
|
|
fnames.append(fname)
|
|
|
|
for get in [dask.threaded.get, dask.multiprocessing.get, dask.local.get_sync, None]:
|
|
with dask.config.set(scheduler=get):
|
|
with xr.open_mfdataset(
|
|
fnames, parallel=parallel, concat_dim="time", combine="nested"
|
|
) as tf:
|
|
assert_identical(tf["test"], da)
|
|
|
|
|
|
@pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS)
|
|
def test_dask_distributed_read_netcdf_integration_test(
|
|
loop, tmp_netcdf_filename, engine, nc_format
|
|
):
|
|
if engine not in ENGINES:
|
|
pytest.skip("engine not available")
|
|
|
|
chunks = {"dim1": 4, "dim2": 3, "dim3": 6}
|
|
|
|
with cluster() as (s, [a, b]):
|
|
with Client(s["address"], loop=loop):
|
|
original = create_test_data()
|
|
original.to_netcdf(tmp_netcdf_filename, engine=engine, format=nc_format)
|
|
|
|
with xr.open_dataset(
|
|
tmp_netcdf_filename, chunks=chunks, engine=engine
|
|
) as restored:
|
|
assert isinstance(restored.var1.data, da.Array)
|
|
computed = restored.compute()
|
|
assert_allclose(original, computed)
|
|
|
|
|
|
# fixture vendored from dask
|
|
# heads-up, this is using quite private zarr API
|
|
# https://github.com/dask/dask/blob/e04734b4d8959ba259801f2e2a490cb4ee8d891f/dask/tests/test_distributed.py#L338-L358
|
|
@pytest.fixture(scope="function")
|
|
def zarr(client):
|
|
zarr_lib = pytest.importorskip("zarr")
|
|
# Zarr-Python 3 lazily allocates a dedicated thread/IO loop
|
|
# for to execute async tasks. To avoid having this thread
|
|
# be picked up as a "leaked thread", we manually trigger it's
|
|
# creation before using zarr
|
|
try:
|
|
_ = zarr_lib.core.sync._get_loop()
|
|
_ = zarr_lib.core.sync._get_executor()
|
|
yield zarr_lib
|
|
except AttributeError:
|
|
yield zarr_lib
|
|
finally:
|
|
# Zarr-Python 3 lazily allocates a IO thread, a thread pool executor, and
|
|
# an IO loop. Here we clean up these resources to avoid leaking threads
|
|
# In normal operations, this is done as by an atexit handler when Zarr
|
|
# is shutting down.
|
|
try:
|
|
zarr_lib.core.sync.cleanup_resources()
|
|
except AttributeError:
|
|
pass
|
|
|
|
|
|
@requires_zarr
|
|
@pytest.mark.parametrize("consolidated", [True, False])
|
|
@pytest.mark.parametrize("compute", [True, False])
|
|
def test_dask_distributed_zarr_integration_test(
|
|
client,
|
|
zarr,
|
|
consolidated: bool,
|
|
compute: bool,
|
|
) -> None:
|
|
if consolidated:
|
|
write_kwargs: dict[str, Any] = {"consolidated": True}
|
|
read_kwargs: dict[str, Any] = {"backend_kwargs": {"consolidated": True}}
|
|
else:
|
|
write_kwargs = read_kwargs = {}
|
|
chunks = {"dim1": 4, "dim2": 3, "dim3": 5}
|
|
original = create_test_data().chunk(chunks)
|
|
with create_tmp_file(allow_cleanup_failure=ON_WINDOWS, suffix=".zarrc") as filename:
|
|
maybe_futures = original.to_zarr( # type: ignore[call-overload] #mypy bug?
|
|
filename, compute=compute, **write_kwargs
|
|
)
|
|
if not compute:
|
|
maybe_futures.compute()
|
|
with xr.open_dataset(
|
|
filename, chunks="auto", engine="zarr", **read_kwargs
|
|
) as restored:
|
|
assert isinstance(restored.var1.data, da.Array)
|
|
computed = restored.compute()
|
|
assert_allclose(original, computed)
|
|
|
|
|
|
@gen_cluster(client=True)
|
|
async def test_async(c, s, a, b) -> None:
|
|
x = create_test_data()
|
|
assert not dask.is_dask_collection(x)
|
|
y = x.chunk({"dim2": 4}) + 10
|
|
assert dask.is_dask_collection(y)
|
|
assert dask.is_dask_collection(y.var1)
|
|
assert dask.is_dask_collection(y.var2)
|
|
|
|
z = y.persist()
|
|
assert str(z)
|
|
|
|
assert dask.is_dask_collection(z)
|
|
assert dask.is_dask_collection(z.var1)
|
|
assert dask.is_dask_collection(z.var2)
|
|
assert len(y.__dask_graph__()) > len(z.__dask_graph__())
|
|
|
|
assert not futures_of(y)
|
|
assert futures_of(z)
|
|
|
|
future = c.compute(z)
|
|
w = await future
|
|
assert not dask.is_dask_collection(w)
|
|
assert_allclose(x + 10, w)
|
|
|
|
assert s.tasks
|
|
|
|
|
|
def test_hdf5_lock() -> None:
|
|
assert isinstance(HDF5_LOCK, SerializableLock)
|
|
|
|
|
|
@gen_cluster(client=True)
|
|
async def test_serializable_locks(c, s, a, b) -> None:
|
|
def f(x, lock=None):
|
|
with lock:
|
|
return x + 1
|
|
|
|
# note, the creation of Lock needs to be done inside a cluster
|
|
for lock in [
|
|
HDF5_LOCK,
|
|
Lock(),
|
|
Lock("filename.nc"),
|
|
CombinedLock([HDF5_LOCK]),
|
|
CombinedLock([HDF5_LOCK, Lock("filename.nc")]),
|
|
]:
|
|
futures = c.map(f, list(range(10)), lock=lock)
|
|
await c.gather(futures)
|
|
|
|
lock2 = pickle.loads(pickle.dumps(lock))
|
|
assert type(lock) is type(lock2)
|