139 lines
4.7 KiB
Python
139 lines
4.7 KiB
Python
from __future__ import annotations
|
|
|
|
from importlib import import_module
|
|
from types import ModuleType
|
|
from typing import TYPE_CHECKING, Any, Literal
|
|
|
|
import numpy as np
|
|
from packaging.version import Version
|
|
|
|
from xarray.core.utils import is_scalar
|
|
from xarray.namedarray.utils import is_duck_array, is_duck_dask_array
|
|
|
|
integer_types = (int, np.integer)
|
|
|
|
if TYPE_CHECKING:
|
|
ModType = Literal["dask", "pint", "cupy", "sparse", "cubed", "numbagg"]
|
|
DuckArrayTypes = tuple[type[Any], ...] # TODO: improve this? maybe Generic
|
|
from xarray.namedarray._typing import _DType, _ShapeType, duckarray
|
|
|
|
|
|
class DuckArrayModule:
|
|
"""
|
|
Solely for internal isinstance and version checks.
|
|
|
|
Motivated by having to only import pint when required (as pint currently imports xarray)
|
|
https://github.com/pydata/xarray/pull/5561#discussion_r664815718
|
|
"""
|
|
|
|
module: ModuleType | None
|
|
version: Version
|
|
type: DuckArrayTypes
|
|
available: bool
|
|
|
|
def __init__(self, mod: ModType) -> None:
|
|
duck_array_module: ModuleType | None
|
|
duck_array_version: Version
|
|
duck_array_type: DuckArrayTypes
|
|
try:
|
|
duck_array_module = import_module(mod)
|
|
duck_array_version = Version(duck_array_module.__version__)
|
|
|
|
if mod == "dask":
|
|
duck_array_type = (import_module("dask.array").Array,)
|
|
elif mod == "pint":
|
|
duck_array_type = (duck_array_module.Quantity,)
|
|
elif mod == "cupy":
|
|
duck_array_type = (duck_array_module.ndarray,)
|
|
elif mod == "sparse":
|
|
duck_array_type = (duck_array_module.SparseArray,)
|
|
elif mod == "cubed":
|
|
duck_array_type = (duck_array_module.Array,)
|
|
# Not a duck array module, but using this system regardless, to get lazy imports
|
|
elif mod == "numbagg":
|
|
duck_array_type = ()
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
except (ImportError, AttributeError): # pragma: no cover
|
|
duck_array_module = None
|
|
duck_array_version = Version("0.0.0")
|
|
duck_array_type = ()
|
|
|
|
self.module = duck_array_module
|
|
self.version = duck_array_version
|
|
self.type = duck_array_type
|
|
self.available = duck_array_module is not None
|
|
|
|
|
|
_cached_duck_array_modules: dict[ModType, DuckArrayModule] = {}
|
|
|
|
|
|
def _get_cached_duck_array_module(mod: ModType) -> DuckArrayModule:
|
|
if mod not in _cached_duck_array_modules:
|
|
duckmod = DuckArrayModule(mod)
|
|
_cached_duck_array_modules[mod] = duckmod
|
|
return duckmod
|
|
else:
|
|
return _cached_duck_array_modules[mod]
|
|
|
|
|
|
def array_type(mod: ModType) -> DuckArrayTypes:
|
|
"""Quick wrapper to get the array class of the module."""
|
|
return _get_cached_duck_array_module(mod).type
|
|
|
|
|
|
def mod_version(mod: ModType) -> Version:
|
|
"""Quick wrapper to get the version of the module."""
|
|
return _get_cached_duck_array_module(mod).version
|
|
|
|
|
|
def is_chunked_array(x: duckarray[Any, Any]) -> bool:
|
|
return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks"))
|
|
|
|
|
|
def is_0d_dask_array(x: duckarray[Any, Any]) -> bool:
|
|
return is_duck_dask_array(x) and is_scalar(x)
|
|
|
|
|
|
def to_numpy(
|
|
data: duckarray[Any, Any], **kwargs: dict[str, Any]
|
|
) -> np.ndarray[Any, np.dtype[Any]]:
|
|
from xarray.core.indexing import ExplicitlyIndexed
|
|
from xarray.namedarray.parallelcompat import get_chunked_array_type
|
|
|
|
if isinstance(data, ExplicitlyIndexed):
|
|
data = data.get_duck_array() # type: ignore[no-untyped-call]
|
|
|
|
# TODO first attempt to call .to_numpy() once some libraries implement it
|
|
if is_chunked_array(data):
|
|
chunkmanager = get_chunked_array_type(data)
|
|
data, *_ = chunkmanager.compute(data, **kwargs)
|
|
if isinstance(data, array_type("cupy")):
|
|
data = data.get()
|
|
# pint has to be imported dynamically as pint imports xarray
|
|
if isinstance(data, array_type("pint")):
|
|
data = data.magnitude
|
|
if isinstance(data, array_type("sparse")):
|
|
data = data.todense()
|
|
data = np.asarray(data)
|
|
|
|
return data
|
|
|
|
|
|
def to_duck_array(data: Any, **kwargs: dict[str, Any]) -> duckarray[_ShapeType, _DType]:
|
|
from xarray.core.indexing import ExplicitlyIndexed
|
|
from xarray.namedarray.parallelcompat import get_chunked_array_type
|
|
|
|
if is_chunked_array(data):
|
|
chunkmanager = get_chunked_array_type(data)
|
|
loaded_data, *_ = chunkmanager.compute(data, **kwargs) # type: ignore[var-annotated]
|
|
return loaded_data
|
|
|
|
if isinstance(data, ExplicitlyIndexed):
|
|
return data.get_duck_array() # type: ignore[no-untyped-call, no-any-return]
|
|
elif is_duck_array(data):
|
|
return data
|
|
else:
|
|
return np.asarray(data) # type: ignore[return-value]
|