258 lines
8.4 KiB
Python
258 lines
8.4 KiB
Python
from __future__ import annotations
|
|
|
|
from importlib.metadata import EntryPoint
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from xarray import set_options
|
|
from xarray.core.types import T_Chunks, T_DuckArray, T_NormalizedChunks
|
|
from xarray.namedarray._typing import _Chunks
|
|
from xarray.namedarray.daskmanager import DaskManager
|
|
from xarray.namedarray.parallelcompat import (
|
|
KNOWN_CHUNKMANAGERS,
|
|
ChunkManagerEntrypoint,
|
|
get_chunked_array_type,
|
|
guess_chunkmanager,
|
|
list_chunkmanagers,
|
|
load_chunkmanagers,
|
|
)
|
|
from xarray.tests import requires_dask
|
|
|
|
|
|
class DummyChunkedArray(np.ndarray):
|
|
"""
|
|
Mock-up of a chunked array class.
|
|
|
|
Adds a (non-functional) .chunks attribute by following this example in the numpy docs
|
|
https://numpy.org/doc/stable/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray
|
|
"""
|
|
|
|
chunks: T_NormalizedChunks
|
|
|
|
def __new__(
|
|
cls,
|
|
shape,
|
|
dtype=float,
|
|
buffer=None,
|
|
offset=0,
|
|
strides=None,
|
|
order=None,
|
|
chunks=None,
|
|
):
|
|
obj = super().__new__(cls, shape, dtype, buffer, offset, strides, order)
|
|
obj.chunks = chunks
|
|
return obj
|
|
|
|
def __array_finalize__(self, obj):
|
|
if obj is None:
|
|
return
|
|
self.chunks = getattr(obj, "chunks", None)
|
|
|
|
def rechunk(self, chunks, **kwargs):
|
|
copied = self.copy()
|
|
copied.chunks = chunks
|
|
return copied
|
|
|
|
|
|
class DummyChunkManager(ChunkManagerEntrypoint):
|
|
"""Mock-up of ChunkManager class for DummyChunkedArray"""
|
|
|
|
def __init__(self):
|
|
self.array_cls = DummyChunkedArray
|
|
|
|
def is_chunked_array(self, data: Any) -> bool:
|
|
return isinstance(data, DummyChunkedArray)
|
|
|
|
def chunks(self, data: DummyChunkedArray) -> T_NormalizedChunks:
|
|
return data.chunks
|
|
|
|
def normalize_chunks(
|
|
self,
|
|
chunks: T_Chunks | T_NormalizedChunks,
|
|
shape: tuple[int, ...] | None = None,
|
|
limit: int | None = None,
|
|
dtype: np.dtype | None = None,
|
|
previous_chunks: T_NormalizedChunks | None = None,
|
|
) -> T_NormalizedChunks:
|
|
from dask.array.core import normalize_chunks
|
|
|
|
return normalize_chunks(chunks, shape, limit, dtype, previous_chunks)
|
|
|
|
def from_array(
|
|
self, data: T_DuckArray | np.typing.ArrayLike, chunks: _Chunks, **kwargs
|
|
) -> DummyChunkedArray:
|
|
from dask import array as da
|
|
|
|
return da.from_array(data, chunks, **kwargs)
|
|
|
|
def rechunk(self, data: DummyChunkedArray, chunks, **kwargs) -> DummyChunkedArray:
|
|
return data.rechunk(chunks, **kwargs)
|
|
|
|
def compute(self, *data: DummyChunkedArray, **kwargs) -> tuple[np.ndarray, ...]:
|
|
from dask.array import compute
|
|
|
|
return compute(*data, **kwargs)
|
|
|
|
def apply_gufunc(
|
|
self,
|
|
func,
|
|
signature,
|
|
*args,
|
|
axes=None,
|
|
axis=None,
|
|
keepdims=False,
|
|
output_dtypes=None,
|
|
output_sizes=None,
|
|
vectorize=None,
|
|
allow_rechunk=False,
|
|
meta=None,
|
|
**kwargs,
|
|
):
|
|
from dask.array.gufunc import apply_gufunc
|
|
|
|
return apply_gufunc(
|
|
func,
|
|
signature,
|
|
*args,
|
|
axes=axes,
|
|
axis=axis,
|
|
keepdims=keepdims,
|
|
output_dtypes=output_dtypes,
|
|
output_sizes=output_sizes,
|
|
vectorize=vectorize,
|
|
allow_rechunk=allow_rechunk,
|
|
meta=meta,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def register_dummy_chunkmanager(monkeypatch):
|
|
"""
|
|
Mocks the registering of an additional ChunkManagerEntrypoint.
|
|
|
|
This preserves the presence of the existing DaskManager, so a test that relies on this and DaskManager both being
|
|
returned from list_chunkmanagers() at once would still work.
|
|
|
|
The monkeypatching changes the behavior of list_chunkmanagers when called inside xarray.namedarray.parallelcompat,
|
|
but not when called from this tests file.
|
|
"""
|
|
# Should include DaskManager iff dask is available to be imported
|
|
preregistered_chunkmanagers = list_chunkmanagers()
|
|
|
|
monkeypatch.setattr(
|
|
"xarray.namedarray.parallelcompat.list_chunkmanagers",
|
|
lambda: {"dummy": DummyChunkManager()} | preregistered_chunkmanagers,
|
|
)
|
|
yield
|
|
|
|
|
|
class TestGetChunkManager:
|
|
def test_get_chunkmanger(self, register_dummy_chunkmanager) -> None:
|
|
chunkmanager = guess_chunkmanager("dummy")
|
|
assert isinstance(chunkmanager, DummyChunkManager)
|
|
|
|
def test_get_chunkmanger_via_set_options(self, register_dummy_chunkmanager) -> None:
|
|
with set_options(chunk_manager="dummy"):
|
|
chunkmanager = guess_chunkmanager(None)
|
|
assert isinstance(chunkmanager, DummyChunkManager)
|
|
|
|
def test_fail_on_known_but_missing_chunkmanager(
|
|
self, register_dummy_chunkmanager, monkeypatch
|
|
) -> None:
|
|
monkeypatch.setitem(KNOWN_CHUNKMANAGERS, "test", "test-package")
|
|
with pytest.raises(
|
|
ImportError, match="chunk manager 'test' is not available.+test-package"
|
|
):
|
|
guess_chunkmanager("test")
|
|
|
|
def test_fail_on_nonexistent_chunkmanager(
|
|
self, register_dummy_chunkmanager
|
|
) -> None:
|
|
with pytest.raises(ValueError, match="unrecognized chunk manager 'foo'"):
|
|
guess_chunkmanager("foo")
|
|
|
|
@requires_dask
|
|
def test_get_dask_if_installed(self) -> None:
|
|
chunkmanager = guess_chunkmanager(None)
|
|
assert isinstance(chunkmanager, DaskManager)
|
|
|
|
def test_no_chunk_manager_available(self, monkeypatch) -> None:
|
|
monkeypatch.setattr("xarray.namedarray.parallelcompat.list_chunkmanagers", dict)
|
|
with pytest.raises(ImportError, match="no chunk managers available"):
|
|
guess_chunkmanager("foo")
|
|
|
|
def test_no_chunk_manager_available_but_known_manager_requested(
|
|
self, monkeypatch
|
|
) -> None:
|
|
monkeypatch.setattr("xarray.namedarray.parallelcompat.list_chunkmanagers", dict)
|
|
with pytest.raises(ImportError, match="chunk manager 'dask' is not available"):
|
|
guess_chunkmanager("dask")
|
|
|
|
@requires_dask
|
|
def test_choose_dask_over_other_chunkmanagers(
|
|
self, register_dummy_chunkmanager
|
|
) -> None:
|
|
chunk_manager = guess_chunkmanager(None)
|
|
assert isinstance(chunk_manager, DaskManager)
|
|
|
|
|
|
class TestGetChunkedArrayType:
|
|
def test_detect_chunked_arrays(self, register_dummy_chunkmanager) -> None:
|
|
dummy_arr = DummyChunkedArray([1, 2, 3])
|
|
|
|
chunk_manager = get_chunked_array_type(dummy_arr)
|
|
assert isinstance(chunk_manager, DummyChunkManager)
|
|
|
|
def test_ignore_inmemory_arrays(self, register_dummy_chunkmanager) -> None:
|
|
dummy_arr = DummyChunkedArray([1, 2, 3])
|
|
|
|
chunk_manager = get_chunked_array_type(*[dummy_arr, 1.0, np.array([5, 6])])
|
|
assert isinstance(chunk_manager, DummyChunkManager)
|
|
|
|
with pytest.raises(TypeError, match="Expected a chunked array"):
|
|
get_chunked_array_type(5.0)
|
|
|
|
def test_raise_if_no_arrays_chunked(self, register_dummy_chunkmanager) -> None:
|
|
with pytest.raises(TypeError, match="Expected a chunked array "):
|
|
get_chunked_array_type(*[1.0, np.array([5, 6])])
|
|
|
|
def test_raise_if_no_matching_chunkmanagers(self) -> None:
|
|
dummy_arr = DummyChunkedArray([1, 2, 3])
|
|
|
|
with pytest.raises(
|
|
TypeError, match="Could not find a Chunk Manager which recognises"
|
|
):
|
|
get_chunked_array_type(dummy_arr)
|
|
|
|
@requires_dask
|
|
def test_detect_dask_if_installed(self) -> None:
|
|
import dask.array as da
|
|
|
|
dask_arr = da.from_array([1, 2, 3], chunks=(1,))
|
|
|
|
chunk_manager = get_chunked_array_type(dask_arr)
|
|
assert isinstance(chunk_manager, DaskManager)
|
|
|
|
@requires_dask
|
|
def test_raise_on_mixed_array_types(self, register_dummy_chunkmanager) -> None:
|
|
import dask.array as da
|
|
|
|
dummy_arr = DummyChunkedArray([1, 2, 3])
|
|
dask_arr = da.from_array([1, 2, 3], chunks=(1,))
|
|
|
|
with pytest.raises(TypeError, match="received multiple types"):
|
|
get_chunked_array_type(*[dask_arr, dummy_arr])
|
|
|
|
|
|
def test_bogus_entrypoint() -> None:
|
|
# Create a bogus entry-point as if the user broke their setup.cfg
|
|
# or is actively developing their new chunk manager
|
|
entry_point = EntryPoint(
|
|
"bogus", "xarray.bogus.doesnotwork", "xarray.chunkmanagers"
|
|
)
|
|
with pytest.warns(UserWarning, match="Failed to load chunk manager"):
|
|
assert len(load_chunkmanagers([entry_point])) == 0
|