"""isort:skip_file""" from __future__ import annotations import pickle from typing import TYPE_CHECKING, Any import numpy as np import pytest if TYPE_CHECKING: import dask import dask.array as da import distributed else: dask = pytest.importorskip("dask") da = pytest.importorskip("dask.array") distributed = pytest.importorskip("distributed") from dask.distributed import Client, Lock from distributed.client import futures_of from distributed.utils_test import ( # noqa: F401 cleanup, client, cluster, cluster_fixture, gen_cluster, loop, loop_in_thread, ) import xarray as xr from xarray.backends.locks import HDF5_LOCK, CombinedLock, SerializableLock from xarray.tests import ( assert_allclose, assert_identical, has_h5netcdf, has_netCDF4, has_scipy, requires_cftime, requires_netCDF4, requires_zarr, ) from xarray.tests.test_backends import ( ON_WINDOWS, create_tmp_file, ) from xarray.tests.test_dataset import create_test_data loop = loop # loop is an imported fixture, which flake8 has issues ack-ing client = client # client is an imported fixture, which flake8 has issues ack-ing @pytest.fixture def tmp_netcdf_filename(tmpdir): return str(tmpdir.join("testfile.nc")) ENGINES = [] if has_scipy: ENGINES.append("scipy") if has_netCDF4: ENGINES.append("netcdf4") if has_h5netcdf: ENGINES.append("h5netcdf") NC_FORMATS = { "netcdf4": [ "NETCDF3_CLASSIC", "NETCDF3_64BIT_OFFSET", "NETCDF3_64BIT_DATA", "NETCDF4_CLASSIC", "NETCDF4", ], "scipy": ["NETCDF3_CLASSIC", "NETCDF3_64BIT"], "h5netcdf": ["NETCDF4"], } ENGINES_AND_FORMATS = [ ("netcdf4", "NETCDF3_CLASSIC"), ("netcdf4", "NETCDF4_CLASSIC"), ("netcdf4", "NETCDF4"), ("h5netcdf", "NETCDF4"), ("scipy", "NETCDF3_64BIT"), ] @pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS) def test_dask_distributed_netcdf_roundtrip( loop, tmp_netcdf_filename, engine, nc_format ): if engine not in ENGINES: pytest.skip("engine not available") chunks = {"dim1": 4, "dim2": 3, "dim3": 6} with cluster() as (s, [a, b]): with Client(s["address"], loop=loop): original = create_test_data().chunk(chunks) if engine == "scipy": with pytest.raises(NotImplementedError): original.to_netcdf( tmp_netcdf_filename, engine=engine, format=nc_format ) return original.to_netcdf(tmp_netcdf_filename, engine=engine, format=nc_format) with xr.open_dataset( tmp_netcdf_filename, chunks=chunks, engine=engine ) as restored: assert isinstance(restored.var1.data, da.Array) computed = restored.compute() assert_allclose(original, computed) @requires_netCDF4 def test_dask_distributed_write_netcdf_with_dimensionless_variables( loop, tmp_netcdf_filename ): with cluster() as (s, [a, b]): with Client(s["address"], loop=loop): original = xr.Dataset({"x": da.zeros(())}) original.to_netcdf(tmp_netcdf_filename) with xr.open_dataset(tmp_netcdf_filename) as actual: assert actual.x.shape == () @requires_cftime @requires_netCDF4 @pytest.mark.parametrize("parallel", (True, False)) def test_open_mfdataset_can_open_files_with_cftime_index(parallel, tmp_path): T = xr.cftime_range("20010101", "20010501", calendar="360_day") Lon = np.arange(100) data = np.random.random((T.size, Lon.size)) da = xr.DataArray(data, coords={"time": T, "Lon": Lon}, name="test") file_path = tmp_path / "test.nc" da.to_netcdf(file_path) with cluster() as (s, [a, b]): with Client(s["address"]): with xr.open_mfdataset(file_path, parallel=parallel) as tf: assert_identical(tf["test"], da) @requires_cftime @requires_netCDF4 @pytest.mark.parametrize("parallel", (True, False)) def test_open_mfdataset_multiple_files_parallel_distributed(parallel, tmp_path): lon = np.arange(100) time = xr.cftime_range("20010101", periods=100, calendar="360_day") data = np.random.random((time.size, lon.size)) da = xr.DataArray(data, coords={"time": time, "lon": lon}, name="test") fnames = [] for i in range(0, 100, 10): fname = tmp_path / f"test_{i}.nc" da.isel(time=slice(i, i + 10)).to_netcdf(fname) fnames.append(fname) with cluster() as (s, [a, b]): with Client(s["address"]): with xr.open_mfdataset( fnames, parallel=parallel, concat_dim="time", combine="nested" ) as tf: assert_identical(tf["test"], da) # TODO: move this to test_backends.py @requires_cftime @requires_netCDF4 @pytest.mark.parametrize("parallel", (True, False)) def test_open_mfdataset_multiple_files_parallel(parallel, tmp_path): if parallel: pytest.skip( "Flaky in CI. Would be a welcome contribution to make a similar test reliable." ) lon = np.arange(100) time = xr.cftime_range("20010101", periods=100, calendar="360_day") data = np.random.random((time.size, lon.size)) da = xr.DataArray(data, coords={"time": time, "lon": lon}, name="test") fnames = [] for i in range(0, 100, 10): fname = tmp_path / f"test_{i}.nc" da.isel(time=slice(i, i + 10)).to_netcdf(fname) fnames.append(fname) for get in [dask.threaded.get, dask.multiprocessing.get, dask.local.get_sync, None]: with dask.config.set(scheduler=get): with xr.open_mfdataset( fnames, parallel=parallel, concat_dim="time", combine="nested" ) as tf: assert_identical(tf["test"], da) @pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS) def test_dask_distributed_read_netcdf_integration_test( loop, tmp_netcdf_filename, engine, nc_format ): if engine not in ENGINES: pytest.skip("engine not available") chunks = {"dim1": 4, "dim2": 3, "dim3": 6} with cluster() as (s, [a, b]): with Client(s["address"], loop=loop): original = create_test_data() original.to_netcdf(tmp_netcdf_filename, engine=engine, format=nc_format) with xr.open_dataset( tmp_netcdf_filename, chunks=chunks, engine=engine ) as restored: assert isinstance(restored.var1.data, da.Array) computed = restored.compute() assert_allclose(original, computed) # fixture vendored from dask # heads-up, this is using quite private zarr API # https://github.com/dask/dask/blob/e04734b4d8959ba259801f2e2a490cb4ee8d891f/dask/tests/test_distributed.py#L338-L358 @pytest.fixture(scope="function") def zarr(client): zarr_lib = pytest.importorskip("zarr") # Zarr-Python 3 lazily allocates a dedicated thread/IO loop # for to execute async tasks. To avoid having this thread # be picked up as a "leaked thread", we manually trigger it's # creation before using zarr try: _ = zarr_lib.core.sync._get_loop() _ = zarr_lib.core.sync._get_executor() yield zarr_lib except AttributeError: yield zarr_lib finally: # Zarr-Python 3 lazily allocates a IO thread, a thread pool executor, and # an IO loop. Here we clean up these resources to avoid leaking threads # In normal operations, this is done as by an atexit handler when Zarr # is shutting down. try: zarr_lib.core.sync.cleanup_resources() except AttributeError: pass @requires_zarr @pytest.mark.parametrize("consolidated", [True, False]) @pytest.mark.parametrize("compute", [True, False]) def test_dask_distributed_zarr_integration_test( client, zarr, consolidated: bool, compute: bool, ) -> None: if consolidated: write_kwargs: dict[str, Any] = {"consolidated": True} read_kwargs: dict[str, Any] = {"backend_kwargs": {"consolidated": True}} else: write_kwargs = read_kwargs = {} chunks = {"dim1": 4, "dim2": 3, "dim3": 5} original = create_test_data().chunk(chunks) with create_tmp_file(allow_cleanup_failure=ON_WINDOWS, suffix=".zarrc") as filename: maybe_futures = original.to_zarr( # type: ignore[call-overload] #mypy bug? filename, compute=compute, **write_kwargs ) if not compute: maybe_futures.compute() with xr.open_dataset( filename, chunks="auto", engine="zarr", **read_kwargs ) as restored: assert isinstance(restored.var1.data, da.Array) computed = restored.compute() assert_allclose(original, computed) @gen_cluster(client=True) async def test_async(c, s, a, b) -> None: x = create_test_data() assert not dask.is_dask_collection(x) y = x.chunk({"dim2": 4}) + 10 assert dask.is_dask_collection(y) assert dask.is_dask_collection(y.var1) assert dask.is_dask_collection(y.var2) z = y.persist() assert str(z) assert dask.is_dask_collection(z) assert dask.is_dask_collection(z.var1) assert dask.is_dask_collection(z.var2) assert len(y.__dask_graph__()) > len(z.__dask_graph__()) assert not futures_of(y) assert futures_of(z) future = c.compute(z) w = await future assert not dask.is_dask_collection(w) assert_allclose(x + 10, w) assert s.tasks def test_hdf5_lock() -> None: assert isinstance(HDF5_LOCK, SerializableLock) @gen_cluster(client=True) async def test_serializable_locks(c, s, a, b) -> None: def f(x, lock=None): with lock: return x + 1 # note, the creation of Lock needs to be done inside a cluster for lock in [ HDF5_LOCK, Lock(), Lock("filename.nc"), CombinedLock([HDF5_LOCK]), CombinedLock([HDF5_LOCK, Lock("filename.nc")]), ]: futures = c.map(f, list(range(10)), lock=lock) await c.gather(futures) lock2 = pickle.loads(pickle.dumps(lock)) assert type(lock) is type(lock2)