CCR/.venv/lib/python3.12/site-packages/xarray/util/generate_ops.py

317 lines
10 KiB
Python

"""Generate module and stub file for arithmetic operators of various xarray classes.
For internal xarray development use only. Requires that jinja2 is installed.
Usage:
python -m pip install jinja2
python xarray/util/generate_ops.py > xarray/core/_typed_ops.py
"""
# Note: the comments in https://github.com/pydata/xarray/pull/4904 provide some
# background to some of the design choices made here.
from __future__ import annotations
from collections.abc import Iterator, Sequence
from typing import Any
import jinja2
BINOPS_EQNE = (("__eq__", "nputils.array_eq"), ("__ne__", "nputils.array_ne"))
BINOPS_CMP = (
("__lt__", "operator.lt"),
("__le__", "operator.le"),
("__gt__", "operator.gt"),
("__ge__", "operator.ge"),
)
BINOPS_NUM = (
("__add__", "operator.add"),
("__sub__", "operator.sub"),
("__mul__", "operator.mul"),
("__pow__", "operator.pow"),
("__truediv__", "operator.truediv"),
("__floordiv__", "operator.floordiv"),
("__mod__", "operator.mod"),
("__and__", "operator.and_"),
("__xor__", "operator.xor"),
("__or__", "operator.or_"),
("__lshift__", "operator.lshift"),
("__rshift__", "operator.rshift"),
)
BINOPS_REFLEXIVE = (
("__radd__", "operator.add"),
("__rsub__", "operator.sub"),
("__rmul__", "operator.mul"),
("__rpow__", "operator.pow"),
("__rtruediv__", "operator.truediv"),
("__rfloordiv__", "operator.floordiv"),
("__rmod__", "operator.mod"),
("__rand__", "operator.and_"),
("__rxor__", "operator.xor"),
("__ror__", "operator.or_"),
)
BINOPS_INPLACE = (
("__iadd__", "operator.iadd"),
("__isub__", "operator.isub"),
("__imul__", "operator.imul"),
("__ipow__", "operator.ipow"),
("__itruediv__", "operator.itruediv"),
("__ifloordiv__", "operator.ifloordiv"),
("__imod__", "operator.imod"),
("__iand__", "operator.iand"),
("__ixor__", "operator.ixor"),
("__ior__", "operator.ior"),
("__ilshift__", "operator.ilshift"),
("__irshift__", "operator.irshift"),
)
UNARY_OPS = (
("__neg__", "operator.neg"),
("__pos__", "operator.pos"),
("__abs__", "operator.abs"),
("__invert__", "operator.invert"),
)
# round method and numpy/pandas unary methods which don't modify the data shape,
# so the result should still be wrapped in an Variable/DataArray/Dataset
OTHER_UNARY_METHODS = (
("round", "ops.round_"),
("argsort", "ops.argsort"),
("conj", "ops.conj"),
("conjugate", "ops.conjugate"),
)
required_method_binary = """
def _binary_op(
self, other: {{ other_type }}, f: Callable, reflexive: bool = False
) -> {{ return_type }}:
raise NotImplementedError"""
template_binop = """
def {{ method }}(self, other: {{ other_type }}) -> {{ return_type }}:{{ type_ignore }}
return self._binary_op(other, {{ func }})"""
template_binop_overload = """
{%- for overload_type in overload_types %}
@overload{{ overload_type_ignore if overload_type == overload_types[0] else "" }}
def {{ method }}(self, other: {{ overload_type }}) -> {{ overload_type }}: ...
{% endfor %}
@overload
def {{method}}(self, other: {{ other_type }}) -> {{ return_type }}: ...
def {{ method }}(self, other: {{ other_type }}) -> {{ return_type }} | {{ ' | '.join(overload_types) }}:{{ type_ignore }}
return self._binary_op(other, {{ func }})"""
template_reflexive = """
def {{ method }}(self, other: {{ other_type }}) -> {{ return_type }}:
return self._binary_op(other, {{ func }}, reflexive=True)"""
required_method_inplace = """
def _inplace_binary_op(self, other: {{ other_type }}, f: Callable) -> Self:
raise NotImplementedError"""
template_inplace = """
def {{ method }}(self, other: {{ other_type }}) -> Self:{{type_ignore}}
return self._inplace_binary_op(other, {{ func }})"""
required_method_unary = """
def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError"""
template_unary = """
def {{ method }}(self) -> Self:
return self._unary_op({{ func }})"""
template_other_unary = """
def {{ method }}(self, *args: Any, **kwargs: Any) -> Self:
return self._unary_op({{ func }}, *args, **kwargs)"""
unhashable = """
# When __eq__ is defined but __hash__ is not, then an object is unhashable,
# and it should be declared as follows:
__hash__: None # type:ignore[assignment]"""
# For some methods we override return type `bool` defined by base class `object`.
# We need to add "# type: ignore[override]"
# Keep an eye out for:
# https://discuss.python.org/t/make-type-hints-for-eq-of-primitives-less-strict/34240
# The type ignores might not be necessary anymore at some point.
#
# We require a "hack" to tell type checkers that e.g. Variable + DataArray = DataArray
# In reality this returns NotImplemented, but this is not a valid type in python 3.9.
# Therefore, we return DataArray. In reality this would call DataArray.__add__(Variable)
# TODO: change once python 3.10 is the minimum.
#
# Mypy seems to require that __iadd__ and __add__ have the same signature.
# This requires some extra type: ignores[misc] in the inplace methods :/
def _type_ignore(ignore: str) -> str:
return f" # type:ignore[{ignore}]" if ignore else ""
FuncType = Sequence[tuple[str | None, str | None]]
OpsType = tuple[FuncType, str, dict[str, Any]]
def binops(
other_type: str, return_type: str = "Self", type_ignore_eq: str = "override"
) -> list[OpsType]:
extras = {"other_type": other_type, "return_type": return_type}
return [
([(None, None)], required_method_binary, extras),
(BINOPS_NUM + BINOPS_CMP, template_binop, extras | {"type_ignore": ""}),
(
BINOPS_EQNE,
template_binop,
extras | {"type_ignore": _type_ignore(type_ignore_eq)},
),
([(None, None)], unhashable, extras),
(BINOPS_REFLEXIVE, template_reflexive, extras),
]
def binops_overload(
other_type: str,
overload_types: list[str],
return_type: str = "Self",
type_ignore_eq: str = "override",
) -> list[OpsType]:
extras = {"other_type": other_type, "return_type": return_type}
return [
([(None, None)], required_method_binary, extras),
(
BINOPS_NUM + BINOPS_CMP,
template_binop_overload,
extras
| {
"overload_types": overload_types,
"type_ignore": "",
"overload_type_ignore": "",
},
),
(
BINOPS_EQNE,
template_binop_overload,
extras
| {
"overload_types": overload_types,
"type_ignore": "",
"overload_type_ignore": _type_ignore(type_ignore_eq),
},
),
([(None, None)], unhashable, extras),
(BINOPS_REFLEXIVE, template_reflexive, extras),
]
def inplace(other_type: str, type_ignore: str = "") -> list[OpsType]:
extras = {"other_type": other_type}
return [
([(None, None)], required_method_inplace, extras),
(
BINOPS_INPLACE,
template_inplace,
extras | {"type_ignore": _type_ignore(type_ignore)},
),
]
def unops() -> list[OpsType]:
return [
([(None, None)], required_method_unary, {}),
(UNARY_OPS, template_unary, {}),
(OTHER_UNARY_METHODS, template_other_unary, {}),
]
# We use short names T_DA and T_DS to keep below 88 lines so
# ruff does not reformat everything. When reformatting, the
# type-ignores end up in the wrong line :/
ops_info = {}
# TODO add inplace ops for DataTree?
ops_info["DataTreeOpsMixin"] = binops(other_type="DtCompatible") + unops()
ops_info["DatasetOpsMixin"] = (
binops_overload(other_type="DsCompatible", overload_types=["DataTree"])
+ inplace(other_type="DsCompatible", type_ignore="misc")
+ unops()
)
ops_info["DataArrayOpsMixin"] = (
binops_overload(other_type="DaCompatible", overload_types=["Dataset", "DataTree"])
+ inplace(other_type="DaCompatible", type_ignore="misc")
+ unops()
)
ops_info["VariableOpsMixin"] = (
binops_overload(
other_type="VarCompatible", overload_types=["T_DA", "Dataset", "DataTree"]
)
+ inplace(other_type="VarCompatible", type_ignore="misc")
+ unops()
)
ops_info["DatasetGroupByOpsMixin"] = binops(
other_type="Dataset | DataArray", return_type="Dataset"
)
ops_info["DataArrayGroupByOpsMixin"] = binops(
other_type="T_Xarray", return_type="T_Xarray"
)
MODULE_PREAMBLE = '''\
"""Mixin classes with arithmetic operators."""
# This file was generated using xarray.util.generate_ops. Do not edit manually.
from __future__ import annotations
import operator
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, overload
from xarray.core import nputils, ops
from xarray.core.types import (
DaCompatible,
DsCompatible,
DtCompatible,
Self,
T_Xarray,
VarCompatible,
)
if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
from xarray.core.types import T_DataArray as T_DA'''
CLASS_PREAMBLE = """{newline}
class {cls_name}:
__slots__ = ()"""
COPY_DOCSTRING = """\
{method}.__doc__ = {func}.__doc__"""
def render(ops_info: dict[str, list[OpsType]]) -> Iterator[str]:
"""Render the module or stub file."""
yield MODULE_PREAMBLE
for cls_name, method_blocks in ops_info.items():
yield CLASS_PREAMBLE.format(cls_name=cls_name, newline="\n")
yield from _render_classbody(method_blocks)
def _render_classbody(method_blocks: list[OpsType]) -> Iterator[str]:
environment = jinja2.Environment()
for method_func_pairs, template, extra in method_blocks:
if template:
for method, func in method_func_pairs:
yield environment.from_string(template).render(
method=method, func=func, **extra
)
yield ""
for method_func_pairs, *_ in method_blocks:
for method, func in method_func_pairs:
if method and func:
yield COPY_DOCSTRING.format(method=method, func=func)
if __name__ == "__main__":
for line in render(ops_info):
print(line)