CCR/.venv/lib/python3.12/site-packages/xarray/namedarray/_array_api.py

222 lines
5.8 KiB
Python

from __future__ import annotations
from types import ModuleType
from typing import Any
import numpy as np
from xarray.namedarray._typing import (
Default,
_arrayapi,
_Axes,
_Axis,
_default,
_Dim,
_DType,
_ScalarType,
_ShapeType,
_SupportsImag,
_SupportsReal,
)
from xarray.namedarray.core import NamedArray
def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType:
if isinstance(x._data, _arrayapi):
return x._data.__array_namespace__()
return np
# %% Creation Functions
def astype(
x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True
) -> NamedArray[_ShapeType, _DType]:
"""
Copies an array to a specified data type irrespective of Type Promotion Rules rules.
Parameters
----------
x : NamedArray
Array to cast.
dtype : _DType
Desired data type.
copy : bool, optional
Specifies whether to copy an array when the specified dtype matches the data
type of the input array x.
If True, a newly allocated array must always be returned.
If False and the specified dtype matches the data type of the input array,
the input array must be returned; otherwise, a newly allocated array must be
returned. Default: True.
Returns
-------
out : NamedArray
An array having the specified data type. The returned array must have the
same shape as x.
Examples
--------
>>> narr = NamedArray(("x",), np.asarray([1.5, 2.5]))
>>> narr
<xarray.NamedArray (x: 2)> Size: 16B
array([1.5, 2.5])
>>> astype(narr, np.dtype(np.int32))
<xarray.NamedArray (x: 2)> Size: 8B
array([1, 2], dtype=int32)
"""
if isinstance(x._data, _arrayapi):
xp = x._data.__array_namespace__()
return x._new(data=xp.astype(x._data, dtype, copy=copy))
# np.astype doesn't exist yet:
return x._new(data=x._data.astype(dtype, copy=copy)) # type: ignore[attr-defined]
# %% Elementwise Functions
def imag(
x: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], # type: ignore[type-var]
/,
) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]:
"""
Returns the imaginary component of a complex number for each element x_i of the
input array x.
Parameters
----------
x : NamedArray
Input array. Should have a complex floating-point data type.
Returns
-------
out : NamedArray
An array containing the element-wise results. The returned array must have a
floating-point data type with the same floating-point precision as x
(e.g., if x is complex64, the returned array must have the floating-point
data type float32).
Examples
--------
>>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j]))
>>> imag(narr)
<xarray.NamedArray (x: 2)> Size: 16B
array([2., 4.])
"""
xp = _get_data_namespace(x)
out = x._new(data=xp.imag(x._data))
return out
def real(
x: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], # type: ignore[type-var]
/,
) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]:
"""
Returns the real component of a complex number for each element x_i of the
input array x.
Parameters
----------
x : NamedArray
Input array. Should have a complex floating-point data type.
Returns
-------
out : NamedArray
An array containing the element-wise results. The returned array must have a
floating-point data type with the same floating-point precision as x
(e.g., if x is complex64, the returned array must have the floating-point
data type float32).
Examples
--------
>>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j]))
>>> real(narr)
<xarray.NamedArray (x: 2)> Size: 16B
array([1., 2.])
"""
xp = _get_data_namespace(x)
out = x._new(data=xp.real(x._data))
return out
# %% Manipulation functions
def expand_dims(
x: NamedArray[Any, _DType],
/,
*,
dim: _Dim | Default = _default,
axis: _Axis = 0,
) -> NamedArray[Any, _DType]:
"""
Expands the shape of an array by inserting a new dimension of size one at the
position specified by dims.
Parameters
----------
x :
Array to expand.
dim :
Dimension name. New dimension will be stored in the axis position.
axis :
(Not recommended) Axis position (zero-based). Default is 0.
Returns
-------
out :
An expanded output array having the same data type as x.
Examples
--------
>>> x = NamedArray(("x", "y"), np.asarray([[1.0, 2.0], [3.0, 4.0]]))
>>> expand_dims(x)
<xarray.NamedArray (dim_2: 1, x: 2, y: 2)> Size: 32B
array([[[1., 2.],
[3., 4.]]])
>>> expand_dims(x, dim="z")
<xarray.NamedArray (z: 1, x: 2, y: 2)> Size: 32B
array([[[1., 2.],
[3., 4.]]])
"""
xp = _get_data_namespace(x)
dims = x.dims
if dim is _default:
dim = f"dim_{len(dims)}"
d = list(dims)
d.insert(axis, dim)
out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis))
return out
def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DType]:
"""
Permutes the dimensions of an array.
Parameters
----------
x :
Array to permute.
axes :
Permutation of the dimensions of x.
Returns
-------
out :
An array with permuted dimensions. The returned array must have the same
data type as x.
"""
dims = x.dims
new_dims = tuple(dims[i] for i in axes)
if isinstance(x._data, _arrayapi):
xp = _get_data_namespace(x)
out = x._new(dims=new_dims, data=xp.permute_dims(x._data, axes))
else:
out = x._new(dims=new_dims, data=x._data.transpose(axes)) # type: ignore[attr-defined]
return out