222 lines
6.6 KiB
Python
222 lines
6.6 KiB
Python
from __future__ import annotations
|
|
|
|
import importlib
|
|
import warnings
|
|
from collections.abc import Hashable, Iterable, Iterator, Mapping
|
|
from functools import lru_cache
|
|
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
|
|
|
import numpy as np
|
|
from packaging.version import Version
|
|
|
|
from xarray.namedarray._typing import ErrorOptionsWithWarn, _DimsLike
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import TypeGuard
|
|
|
|
from numpy.typing import NDArray
|
|
|
|
try:
|
|
from dask.array.core import Array as DaskArray
|
|
from dask.typing import DaskCollection
|
|
except ImportError:
|
|
DaskArray = NDArray # type: ignore[assignment, misc]
|
|
DaskCollection: Any = NDArray # type: ignore[no-redef]
|
|
|
|
from xarray.namedarray._typing import _Dim, duckarray
|
|
|
|
|
|
K = TypeVar("K")
|
|
V = TypeVar("V")
|
|
T = TypeVar("T")
|
|
|
|
|
|
@lru_cache
|
|
def module_available(module: str, minversion: str | None = None) -> bool:
|
|
"""Checks whether a module is installed without importing it.
|
|
|
|
Use this for a lightweight check and lazy imports.
|
|
|
|
Parameters
|
|
----------
|
|
module : str
|
|
Name of the module.
|
|
minversion : str, optional
|
|
Minimum version of the module
|
|
|
|
Returns
|
|
-------
|
|
available : bool
|
|
Whether the module is installed.
|
|
"""
|
|
if importlib.util.find_spec(module) is None:
|
|
return False
|
|
|
|
if minversion is not None:
|
|
version = importlib.metadata.version(module)
|
|
|
|
return Version(version) >= Version(minversion)
|
|
|
|
return True
|
|
|
|
|
|
def is_dask_collection(x: object) -> TypeGuard[DaskCollection]:
|
|
if module_available("dask"):
|
|
from dask.base import is_dask_collection
|
|
|
|
# use is_dask_collection function instead of dask.typing.DaskCollection
|
|
# see https://github.com/pydata/xarray/pull/8241#discussion_r1476276023
|
|
return is_dask_collection(x)
|
|
return False
|
|
|
|
|
|
def is_duck_array(value: Any) -> TypeGuard[duckarray[Any, Any]]:
|
|
# TODO: replace is_duck_array with runtime checks via _arrayfunction_or_api protocol on
|
|
# python 3.12 and higher (see https://github.com/pydata/xarray/issues/8696#issuecomment-1924588981)
|
|
if isinstance(value, np.ndarray):
|
|
return True
|
|
return (
|
|
hasattr(value, "ndim")
|
|
and hasattr(value, "shape")
|
|
and hasattr(value, "dtype")
|
|
and (
|
|
(hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__"))
|
|
or hasattr(value, "__array_namespace__")
|
|
)
|
|
)
|
|
|
|
|
|
def is_duck_dask_array(x: duckarray[Any, Any]) -> TypeGuard[DaskArray]:
|
|
return is_duck_array(x) and is_dask_collection(x)
|
|
|
|
|
|
def to_0d_object_array(
|
|
value: object,
|
|
) -> NDArray[np.object_]:
|
|
"""Given a value, wrap it in a 0-D numpy.ndarray with dtype=object."""
|
|
result = np.empty((), dtype=object)
|
|
result[()] = value
|
|
return result
|
|
|
|
|
|
def is_dict_like(value: Any) -> TypeGuard[Mapping[Any, Any]]:
|
|
return hasattr(value, "keys") and hasattr(value, "__getitem__")
|
|
|
|
|
|
def drop_missing_dims(
|
|
supplied_dims: Iterable[_Dim],
|
|
dims: Iterable[_Dim],
|
|
missing_dims: ErrorOptionsWithWarn,
|
|
) -> _DimsLike:
|
|
"""Depending on the setting of missing_dims, drop any dimensions from supplied_dims that
|
|
are not present in dims.
|
|
|
|
Parameters
|
|
----------
|
|
supplied_dims : Iterable of Hashable
|
|
dims : Iterable of Hashable
|
|
missing_dims : {"raise", "warn", "ignore"}
|
|
"""
|
|
|
|
if missing_dims == "raise":
|
|
supplied_dims_set = {val for val in supplied_dims if val is not ...}
|
|
if invalid := supplied_dims_set - set(dims):
|
|
raise ValueError(
|
|
f"Dimensions {invalid} do not exist. Expected one or more of {dims}"
|
|
)
|
|
|
|
return supplied_dims
|
|
|
|
elif missing_dims == "warn":
|
|
if invalid := set(supplied_dims) - set(dims):
|
|
warnings.warn(
|
|
f"Dimensions {invalid} do not exist. Expected one or more of {dims}",
|
|
stacklevel=2,
|
|
)
|
|
|
|
return [val for val in supplied_dims if val in dims or val is ...]
|
|
|
|
elif missing_dims == "ignore":
|
|
return [val for val in supplied_dims if val in dims or val is ...]
|
|
|
|
else:
|
|
raise ValueError(
|
|
f"Unrecognised option {missing_dims} for missing_dims argument"
|
|
)
|
|
|
|
|
|
def infix_dims(
|
|
dims_supplied: Iterable[_Dim],
|
|
dims_all: Iterable[_Dim],
|
|
missing_dims: ErrorOptionsWithWarn = "raise",
|
|
) -> Iterator[_Dim]:
|
|
"""
|
|
Resolves a supplied list containing an ellipsis representing other items, to
|
|
a generator with the 'realized' list of all items
|
|
"""
|
|
if ... in dims_supplied:
|
|
dims_all_list = list(dims_all)
|
|
if len(set(dims_all)) != len(dims_all_list):
|
|
raise ValueError("Cannot use ellipsis with repeated dims")
|
|
if list(dims_supplied).count(...) > 1:
|
|
raise ValueError("More than one ellipsis supplied")
|
|
other_dims = [d for d in dims_all if d not in dims_supplied]
|
|
existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims)
|
|
for d in existing_dims:
|
|
if d is ...:
|
|
yield from other_dims
|
|
else:
|
|
yield d
|
|
else:
|
|
existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims)
|
|
if set(existing_dims) ^ set(dims_all):
|
|
raise ValueError(
|
|
f"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included"
|
|
)
|
|
yield from existing_dims
|
|
|
|
|
|
def either_dict_or_kwargs(
|
|
pos_kwargs: Mapping[Any, T] | None,
|
|
kw_kwargs: Mapping[str, T],
|
|
func_name: str,
|
|
) -> Mapping[Hashable, T]:
|
|
if pos_kwargs is None or pos_kwargs == {}:
|
|
# Need an explicit cast to appease mypy due to invariance; see
|
|
# https://github.com/python/mypy/issues/6228
|
|
return cast(Mapping[Hashable, T], kw_kwargs)
|
|
|
|
if not is_dict_like(pos_kwargs):
|
|
raise ValueError(f"the first argument to .{func_name} must be a dictionary")
|
|
if kw_kwargs:
|
|
raise ValueError(
|
|
f"cannot specify both keyword and positional arguments to .{func_name}"
|
|
)
|
|
return pos_kwargs
|
|
|
|
|
|
class ReprObject:
|
|
"""Object that prints as the given value, for use with sentinel values."""
|
|
|
|
__slots__ = ("_value",)
|
|
|
|
_value: str
|
|
|
|
def __init__(self, value: str):
|
|
self._value = value
|
|
|
|
def __repr__(self) -> str:
|
|
return self._value
|
|
|
|
def __eq__(self, other: ReprObject | Any) -> bool:
|
|
# TODO: What type can other be? ArrayLike?
|
|
return self._value == other._value if isinstance(other, ReprObject) else False
|
|
|
|
def __hash__(self) -> int:
|
|
return hash((type(self), self._value))
|
|
|
|
def __dask_tokenize__(self) -> object:
|
|
from dask.base import normalize_token
|
|
|
|
return normalize_token((type(self), self._value))
|