186 lines
5.9 KiB
Python
186 lines
5.9 KiB
Python
from collections.abc import Callable, Iterable
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
from xarray.core import utils
|
|
from xarray.core.indexing import ExplicitlyIndexed
|
|
|
|
"""
|
|
This module contains various lazy array classes which can be wrapped and manipulated by xarray objects but will raise on data access.
|
|
"""
|
|
|
|
|
|
class UnexpectedDataAccess(Exception):
|
|
pass
|
|
|
|
|
|
class InaccessibleArray(utils.NDArrayMixin, ExplicitlyIndexed):
|
|
"""Disallows any loading."""
|
|
|
|
def __init__(self, array):
|
|
self.array = array
|
|
|
|
def get_duck_array(self):
|
|
raise UnexpectedDataAccess("Tried accessing data")
|
|
|
|
def __array__(
|
|
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
|
|
) -> np.ndarray:
|
|
raise UnexpectedDataAccess("Tried accessing data")
|
|
|
|
def __getitem__(self, key):
|
|
raise UnexpectedDataAccess("Tried accessing data.")
|
|
|
|
|
|
class FirstElementAccessibleArray(InaccessibleArray):
|
|
def __getitem__(self, key):
|
|
tuple_idxr = key.tuple
|
|
if len(tuple_idxr) > 1:
|
|
raise UnexpectedDataAccess("Tried accessing more than one element.")
|
|
return self.array[tuple_idxr]
|
|
|
|
|
|
class DuckArrayWrapper(utils.NDArrayMixin):
|
|
"""Array-like that prevents casting to array.
|
|
Modeled after cupy."""
|
|
|
|
def __init__(self, array: np.ndarray):
|
|
self.array = array
|
|
|
|
def __getitem__(self, key):
|
|
return type(self)(self.array[key])
|
|
|
|
def __array__(
|
|
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
|
|
) -> np.ndarray:
|
|
raise UnexpectedDataAccess("Tried accessing data")
|
|
|
|
def __array_namespace__(self):
|
|
"""Present to satisfy is_duck_array test."""
|
|
|
|
|
|
CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS: dict[str, Callable] = {}
|
|
|
|
|
|
def implements(numpy_function):
|
|
"""Register an __array_function__ implementation for ConcatenatableArray objects."""
|
|
|
|
def decorator(func):
|
|
CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS[numpy_function] = func
|
|
return func
|
|
|
|
return decorator
|
|
|
|
|
|
@implements(np.concatenate)
|
|
def concatenate(
|
|
arrays: Iterable["ConcatenatableArray"], /, *, axis=0
|
|
) -> "ConcatenatableArray":
|
|
if any(not isinstance(arr, ConcatenatableArray) for arr in arrays):
|
|
raise TypeError
|
|
|
|
result = np.concatenate([arr._array for arr in arrays], axis=axis)
|
|
return ConcatenatableArray(result)
|
|
|
|
|
|
@implements(np.stack)
|
|
def stack(
|
|
arrays: Iterable["ConcatenatableArray"], /, *, axis=0
|
|
) -> "ConcatenatableArray":
|
|
if any(not isinstance(arr, ConcatenatableArray) for arr in arrays):
|
|
raise TypeError
|
|
|
|
result = np.stack([arr._array for arr in arrays], axis=axis)
|
|
return ConcatenatableArray(result)
|
|
|
|
|
|
@implements(np.result_type)
|
|
def result_type(*arrays_and_dtypes) -> np.dtype:
|
|
"""Called by xarray to ensure all arguments to concat have the same dtype."""
|
|
first_dtype, *other_dtypes = (np.dtype(obj) for obj in arrays_and_dtypes)
|
|
for other_dtype in other_dtypes:
|
|
if other_dtype != first_dtype:
|
|
raise ValueError("dtypes not all consistent")
|
|
return first_dtype
|
|
|
|
|
|
@implements(np.broadcast_to)
|
|
def broadcast_to(
|
|
x: "ConcatenatableArray", /, shape: tuple[int, ...]
|
|
) -> "ConcatenatableArray":
|
|
"""
|
|
Broadcasts an array to a specified shape, by either manipulating chunk keys or copying chunk manifest entries.
|
|
"""
|
|
if not isinstance(x, ConcatenatableArray):
|
|
raise TypeError
|
|
|
|
result = np.broadcast_to(x._array, shape=shape)
|
|
return ConcatenatableArray(result)
|
|
|
|
|
|
class ConcatenatableArray:
|
|
"""Disallows loading or coercing to an index but does support concatenation / stacking."""
|
|
|
|
def __init__(self, array):
|
|
# use ._array instead of .array because we don't want this to be accessible even to xarray's internals (e.g. create_default_index_implicit)
|
|
self._array = array
|
|
|
|
@property
|
|
def dtype(self: Any) -> np.dtype:
|
|
return self._array.dtype
|
|
|
|
@property
|
|
def shape(self: Any) -> tuple[int, ...]:
|
|
return self._array.shape
|
|
|
|
@property
|
|
def ndim(self: Any) -> int:
|
|
return self._array.ndim
|
|
|
|
def __repr__(self: Any) -> str:
|
|
return f"{type(self).__name__}(array={self._array!r})"
|
|
|
|
def get_duck_array(self):
|
|
raise UnexpectedDataAccess("Tried accessing data")
|
|
|
|
def __array__(
|
|
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
|
|
) -> np.ndarray:
|
|
raise UnexpectedDataAccess("Tried accessing data")
|
|
|
|
def __getitem__(self, key) -> "ConcatenatableArray":
|
|
"""Some cases of concat require supporting expanding dims by dimensions of size 1"""
|
|
# see https://data-apis.org/array-api/2022.12/API_specification/indexing.html#multi-axis-indexing
|
|
arr = self._array
|
|
for axis, indexer_1d in enumerate(key):
|
|
if indexer_1d is None:
|
|
arr = np.expand_dims(arr, axis)
|
|
elif indexer_1d is Ellipsis:
|
|
pass
|
|
else:
|
|
raise UnexpectedDataAccess("Tried accessing data.")
|
|
return ConcatenatableArray(arr)
|
|
|
|
def __array_function__(self, func, types, args, kwargs) -> Any:
|
|
if func not in CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS:
|
|
return NotImplemented
|
|
|
|
# Note: this allows subclasses that don't override
|
|
# __array_function__ to handle ManifestArray objects
|
|
if not all(issubclass(t, ConcatenatableArray) for t in types):
|
|
return NotImplemented
|
|
|
|
return CONCATENATABLEARRAY_HANDLED_ARRAY_FUNCTIONS[func](*args, **kwargs)
|
|
|
|
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> Any:
|
|
"""We have to define this in order to convince xarray that this class is a duckarray, even though we will never support ufuncs."""
|
|
return NotImplemented
|
|
|
|
def astype(self, dtype: np.dtype, /, *, copy: bool = True) -> "ConcatenatableArray":
|
|
"""Needed because xarray will call this even when it's a no-op"""
|
|
if dtype != self.dtype:
|
|
raise NotImplementedError()
|
|
else:
|
|
return self
|