from __future__ import annotations from types import ModuleType from typing import Any import numpy as np from xarray.namedarray._typing import ( Default, _arrayapi, _Axes, _Axis, _default, _Dim, _DType, _ScalarType, _ShapeType, _SupportsImag, _SupportsReal, ) from xarray.namedarray.core import NamedArray def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: if isinstance(x._data, _arrayapi): return x._data.__array_namespace__() return np # %% Creation Functions def astype( x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True ) -> NamedArray[_ShapeType, _DType]: """ Copies an array to a specified data type irrespective of Type Promotion Rules rules. Parameters ---------- x : NamedArray Array to cast. dtype : _DType Desired data type. copy : bool, optional Specifies whether to copy an array when the specified dtype matches the data type of the input array x. If True, a newly allocated array must always be returned. If False and the specified dtype matches the data type of the input array, the input array must be returned; otherwise, a newly allocated array must be returned. Default: True. Returns ------- out : NamedArray An array having the specified data type. The returned array must have the same shape as x. Examples -------- >>> narr = NamedArray(("x",), np.asarray([1.5, 2.5])) >>> narr Size: 16B array([1.5, 2.5]) >>> astype(narr, np.dtype(np.int32)) Size: 8B array([1, 2], dtype=int32) """ if isinstance(x._data, _arrayapi): xp = x._data.__array_namespace__() return x._new(data=xp.astype(x._data, dtype, copy=copy)) # np.astype doesn't exist yet: return x._new(data=x._data.astype(dtype, copy=copy)) # type: ignore[attr-defined] # %% Elementwise Functions def imag( x: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], # type: ignore[type-var] /, ) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]: """ Returns the imaginary component of a complex number for each element x_i of the input array x. Parameters ---------- x : NamedArray Input array. Should have a complex floating-point data type. Returns ------- out : NamedArray An array containing the element-wise results. The returned array must have a floating-point data type with the same floating-point precision as x (e.g., if x is complex64, the returned array must have the floating-point data type float32). Examples -------- >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) >>> imag(narr) Size: 16B array([2., 4.]) """ xp = _get_data_namespace(x) out = x._new(data=xp.imag(x._data)) return out def real( x: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], # type: ignore[type-var] /, ) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]: """ Returns the real component of a complex number for each element x_i of the input array x. Parameters ---------- x : NamedArray Input array. Should have a complex floating-point data type. Returns ------- out : NamedArray An array containing the element-wise results. The returned array must have a floating-point data type with the same floating-point precision as x (e.g., if x is complex64, the returned array must have the floating-point data type float32). Examples -------- >>> narr = NamedArray(("x",), np.asarray([1.0 + 2j, 2 + 4j])) >>> real(narr) Size: 16B array([1., 2.]) """ xp = _get_data_namespace(x) out = x._new(data=xp.real(x._data)) return out # %% Manipulation functions def expand_dims( x: NamedArray[Any, _DType], /, *, dim: _Dim | Default = _default, axis: _Axis = 0, ) -> NamedArray[Any, _DType]: """ Expands the shape of an array by inserting a new dimension of size one at the position specified by dims. Parameters ---------- x : Array to expand. dim : Dimension name. New dimension will be stored in the axis position. axis : (Not recommended) Axis position (zero-based). Default is 0. Returns ------- out : An expanded output array having the same data type as x. Examples -------- >>> x = NamedArray(("x", "y"), np.asarray([[1.0, 2.0], [3.0, 4.0]])) >>> expand_dims(x) Size: 32B array([[[1., 2.], [3., 4.]]]) >>> expand_dims(x, dim="z") Size: 32B array([[[1., 2.], [3., 4.]]]) """ xp = _get_data_namespace(x) dims = x.dims if dim is _default: dim = f"dim_{len(dims)}" d = list(dims) d.insert(axis, dim) out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis)) return out def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DType]: """ Permutes the dimensions of an array. Parameters ---------- x : Array to permute. axes : Permutation of the dimensions of x. Returns ------- out : An array with permuted dimensions. The returned array must have the same data type as x. """ dims = x.dims new_dims = tuple(dims[i] for i in axes) if isinstance(x._data, _arrayapi): xp = _get_data_namespace(x) out = x._new(dims=new_dims, data=xp.permute_dims(x._data, axes)) else: out = x._new(dims=new_dims, data=x._data.transpose(axes)) # type: ignore[attr-defined] return out