CCR/.venv/lib/python3.12/site-packages/xarray/namedarray/daskmanager.py

267 lines
7.8 KiB
Python

from __future__ import annotations
from collections.abc import Callable, Iterable, Sequence
from typing import TYPE_CHECKING, Any
import numpy as np
from xarray.core.indexing import ImplicitToExplicitIndexingAdapter
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint, T_ChunkedArray
from xarray.namedarray.utils import is_duck_dask_array, module_available
if TYPE_CHECKING:
from xarray.namedarray._typing import (
T_Chunks,
_DType_co,
_NormalizedChunks,
duckarray,
)
try:
from dask.array import Array as DaskArray
except ImportError:
DaskArray = np.ndarray[Any, Any]
dask_available = module_available("dask")
class DaskManager(ChunkManagerEntrypoint["DaskArray"]):
array_cls: type[DaskArray]
available: bool = dask_available
def __init__(self) -> None:
# TODO can we replace this with a class attribute instead?
from dask.array import Array
self.array_cls = Array
def is_chunked_array(self, data: duckarray[Any, Any]) -> bool:
return is_duck_dask_array(data)
def chunks(self, data: Any) -> _NormalizedChunks:
return data.chunks # type: ignore[no-any-return]
def normalize_chunks(
self,
chunks: T_Chunks | _NormalizedChunks,
shape: tuple[int, ...] | None = None,
limit: int | None = None,
dtype: _DType_co | None = None,
previous_chunks: _NormalizedChunks | None = None,
) -> Any:
"""Called by open_dataset"""
from dask.array.core import normalize_chunks
return normalize_chunks(
chunks,
shape=shape,
limit=limit,
dtype=dtype,
previous_chunks=previous_chunks,
) # type: ignore[no-untyped-call]
def from_array(
self, data: Any, chunks: T_Chunks | _NormalizedChunks, **kwargs: Any
) -> DaskArray | Any:
import dask.array as da
if isinstance(data, ImplicitToExplicitIndexingAdapter):
# lazily loaded backend array classes should use NumPy array operations.
kwargs["meta"] = np.ndarray
return da.from_array(
data,
chunks,
**kwargs,
) # type: ignore[no-untyped-call]
def compute(
self, *data: Any, **kwargs: Any
) -> tuple[np.ndarray[Any, _DType_co], ...]:
from dask.array import compute
return compute(*data, **kwargs) # type: ignore[no-untyped-call, no-any-return]
def persist(self, *data: Any, **kwargs: Any) -> tuple[DaskArray | Any, ...]:
from dask import persist
return persist(*data, **kwargs) # type: ignore[no-untyped-call, no-any-return]
@property
def array_api(self) -> Any:
from dask import array as da
return da
def reduction(
self,
arr: T_ChunkedArray,
func: Callable[..., Any],
combine_func: Callable[..., Any] | None = None,
aggregate_func: Callable[..., Any] | None = None,
axis: int | Sequence[int] | None = None,
dtype: _DType_co | None = None,
keepdims: bool = False,
) -> DaskArray | Any:
from dask.array import reduction
return reduction(
arr,
chunk=func,
combine=combine_func,
aggregate=aggregate_func,
axis=axis,
dtype=dtype,
keepdims=keepdims,
) # type: ignore[no-untyped-call]
def scan(
self,
func: Callable[..., Any],
binop: Callable[..., Any],
ident: float,
arr: T_ChunkedArray,
axis: int | None = None,
dtype: _DType_co | None = None,
**kwargs: Any,
) -> DaskArray | Any:
from dask.array.reductions import cumreduction
return cumreduction(
func,
binop,
ident,
arr,
axis=axis,
dtype=dtype,
**kwargs,
) # type: ignore[no-untyped-call]
def apply_gufunc(
self,
func: Callable[..., Any],
signature: str,
*args: Any,
axes: Sequence[tuple[int, ...]] | None = None,
axis: int | None = None,
keepdims: bool = False,
output_dtypes: Sequence[_DType_co] | None = None,
output_sizes: dict[str, int] | None = None,
vectorize: bool | None = None,
allow_rechunk: bool = False,
meta: tuple[np.ndarray[Any, _DType_co], ...] | None = None,
**kwargs: Any,
) -> Any:
from dask.array.gufunc import apply_gufunc
return apply_gufunc(
func,
signature,
*args,
axes=axes,
axis=axis,
keepdims=keepdims,
output_dtypes=output_dtypes,
output_sizes=output_sizes,
vectorize=vectorize,
allow_rechunk=allow_rechunk,
meta=meta,
**kwargs,
) # type: ignore[no-untyped-call]
def map_blocks(
self,
func: Callable[..., Any],
*args: Any,
dtype: _DType_co | None = None,
chunks: tuple[int, ...] | None = None,
drop_axis: int | Sequence[int] | None = None,
new_axis: int | Sequence[int] | None = None,
**kwargs: Any,
) -> Any:
from dask.array import map_blocks
# pass through name, meta, token as kwargs
return map_blocks(
func,
*args,
dtype=dtype,
chunks=chunks,
drop_axis=drop_axis,
new_axis=new_axis,
**kwargs,
) # type: ignore[no-untyped-call]
def blockwise(
self,
func: Callable[..., Any],
out_ind: Iterable[Any],
*args: Any,
# can't type this as mypy assumes args are all same type, but dask blockwise args alternate types
name: str | None = None,
token: Any | None = None,
dtype: _DType_co | None = None,
adjust_chunks: dict[Any, Callable[..., Any]] | None = None,
new_axes: dict[Any, int] | None = None,
align_arrays: bool = True,
concatenate: bool | None = None,
meta: tuple[np.ndarray[Any, _DType_co], ...] | None = None,
**kwargs: Any,
) -> DaskArray | Any:
from dask.array import blockwise
return blockwise(
func,
out_ind,
*args,
name=name,
token=token,
dtype=dtype,
adjust_chunks=adjust_chunks,
new_axes=new_axes,
align_arrays=align_arrays,
concatenate=concatenate,
meta=meta,
**kwargs,
) # type: ignore[no-untyped-call]
def unify_chunks(
self,
*args: Any, # can't type this as mypy assumes args are all same type, but dask unify_chunks args alternate types
**kwargs: Any,
) -> tuple[dict[str, _NormalizedChunks], list[DaskArray]]:
from dask.array.core import unify_chunks
return unify_chunks(*args, **kwargs) # type: ignore[no-any-return, no-untyped-call]
def store(
self,
sources: Any | Sequence[Any],
targets: Any,
**kwargs: Any,
) -> Any:
from dask.array import store
return store(
sources=sources,
targets=targets,
**kwargs,
)
def shuffle(
self, x: DaskArray, indexer: list[list[int]], axis: int, chunks: T_Chunks
) -> DaskArray:
import dask.array
if not module_available("dask", minversion="2024.08.1"):
raise ValueError(
"This method is very inefficient on dask<2024.08.1. Please upgrade."
)
if chunks is None:
chunks = "auto"
if chunks != "auto":
raise NotImplementedError("Only chunks='auto' is supported at present.")
return dask.array.shuffle(x, indexer, axis, chunks="auto")