CCR/.venv/lib/python3.12/site-packages/xarray/tests/arrays.py

186 lines
5.9 KiB
Python

from collections.abc import Callable, Iterable
from typing import Any
import numpy as np
from xarray.core import utils
from xarray.core.indexing import ExplicitlyIndexed
"""
This module contains various lazy array classes which can be wrapped and manipulated by xarray objects but will raise on data access.
"""
class UnexpectedDataAccess(Exception):
pass
class InaccessibleArray(utils.NDArrayMixin, ExplicitlyIndexed):
"""Disallows any loading."""
def __init__(self, array):
self.array = array
def get_duck_array(self):
raise UnexpectedDataAccess("Tried accessing data")
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
raise UnexpectedDataAccess("Tried accessing data")
def __getitem__(self, key):
raise UnexpectedDataAccess("Tried accessing data.")
class FirstElementAccessibleArray(InaccessibleArray):
def __getitem__(self, key):
tuple_idxr = key.tuple
if len(tuple_idxr) > 1:
raise UnexpectedDataAccess("Tried accessing more than one element.")
return self.array[tuple_idxr]
class DuckArrayWrapper(utils.NDArrayMixin):
"""Array-like that prevents casting to array.
Modeled after cupy."""
def __init__(self, array: np.ndarray):
self.array = array
def __getitem__(self, key):
return type(self)(self.array[key])
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
raise UnexpectedDataAccess("Tried accessing data")
def __array_namespace__(self):
"""Present to satisfy is_duck_array test."""
CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS: dict[str, Callable] = {}
def implements(numpy_function):
"""Register an __array_function__ implementation for ConcatenatableArray objects."""
def decorator(func):
CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS[numpy_function] = func
return func
return decorator
@implements(np.concatenate)
def concatenate(
arrays: Iterable["ConcatenatableArray"], /, *, axis=0
) -> "ConcatenatableArray":
if any(not isinstance(arr, ConcatenatableArray) for arr in arrays):
raise TypeError
result = np.concatenate([arr._array for arr in arrays], axis=axis)
return ConcatenatableArray(result)
@implements(np.stack)
def stack(
arrays: Iterable["ConcatenatableArray"], /, *, axis=0
) -> "ConcatenatableArray":
if any(not isinstance(arr, ConcatenatableArray) for arr in arrays):
raise TypeError
result = np.stack([arr._array for arr in arrays], axis=axis)
return ConcatenatableArray(result)
@implements(np.result_type)
def result_type(*arrays_and_dtypes) -> np.dtype:
"""Called by xarray to ensure all arguments to concat have the same dtype."""
first_dtype, *other_dtypes = (np.dtype(obj) for obj in arrays_and_dtypes)
for other_dtype in other_dtypes:
if other_dtype != first_dtype:
raise ValueError("dtypes not all consistent")
return first_dtype
@implements(np.broadcast_to)
def broadcast_to(
x: "ConcatenatableArray", /, shape: tuple[int, ...]
) -> "ConcatenatableArray":
"""
Broadcasts an array to a specified shape, by either manipulating chunk keys or copying chunk manifest entries.
"""
if not isinstance(x, ConcatenatableArray):
raise TypeError
result = np.broadcast_to(x._array, shape=shape)
return ConcatenatableArray(result)
class ConcatenatableArray:
"""Disallows loading or coercing to an index but does support concatenation / stacking."""
def __init__(self, array):
# use ._array instead of .array because we don't want this to be accessible even to xarray's internals (e.g. create_default_index_implicit)
self._array = array
@property
def dtype(self: Any) -> np.dtype:
return self._array.dtype
@property
def shape(self: Any) -> tuple[int, ...]:
return self._array.shape
@property
def ndim(self: Any) -> int:
return self._array.ndim
def __repr__(self: Any) -> str:
return f"{type(self).__name__}(array={self._array!r})"
def get_duck_array(self):
raise UnexpectedDataAccess("Tried accessing data")
def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
) -> np.ndarray:
raise UnexpectedDataAccess("Tried accessing data")
def __getitem__(self, key) -> "ConcatenatableArray":
"""Some cases of concat require supporting expanding dims by dimensions of size 1"""
# see https://data-apis.org/array-api/2022.12/API_specification/indexing.html#multi-axis-indexing
arr = self._array
for axis, indexer_1d in enumerate(key):
if indexer_1d is None:
arr = np.expand_dims(arr, axis)
elif indexer_1d is Ellipsis:
pass
else:
raise UnexpectedDataAccess("Tried accessing data.")
return ConcatenatableArray(arr)
def __array_function__(self, func, types, args, kwargs) -> Any:
if func not in CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS:
return NotImplemented
# Note: this allows subclasses that don't override
# __array_function__ to handle ManifestArray objects
if not all(issubclass(t, ConcatenatableArray) for t in types):
return NotImplemented
return CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS[func](*args, **kwargs)
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> Any:
"""We have to define this in order to convince xarray that this class is a duckarray, even though we will never support ufuncs."""
return NotImplemented
def astype(self, dtype: np.dtype, /, *, copy: bool = True) -> "ConcatenatableArray":
"""Needed because xarray will call this even when it's a no-op"""
if dtype != self.dtype:
raise NotImplementedError()
else:
return self