CCR/.venv/lib/python3.12/site-packages/xarray/coding/strings.py

264 lines
8.8 KiB
Python

"""Coders for strings."""
from __future__ import annotations
from functools import partial
import numpy as np
from xarray.coding.variables import (
VariableCoder,
lazy_elemwise_func,
pop_to,
safe_setitem,
unpack_for_decoding,
unpack_for_encoding,
)
from xarray.core import indexing
from xarray.core.utils import module_available
from xarray.core.variable import Variable
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
HAS_NUMPY_2_0 = module_available("numpy", minversion="2.0.0.dev0")
def create_vlen_dtype(element_type):
if element_type not in (str, bytes):
raise TypeError(f"unsupported type for vlen_dtype: {element_type!r}")
# based on h5py.special_dtype
return np.dtype("O", metadata={"element_type": element_type})
def check_vlen_dtype(dtype):
if dtype.kind != "O" or dtype.metadata is None:
return None
else:
# check xarray (element_type) as well as h5py (vlen)
return dtype.metadata.get("element_type", dtype.metadata.get("vlen"))
def is_unicode_dtype(dtype):
return dtype.kind == "U" or check_vlen_dtype(dtype) is str
def is_bytes_dtype(dtype):
return dtype.kind == "S" or check_vlen_dtype(dtype) is bytes
class EncodedStringCoder(VariableCoder):
"""Transforms between unicode strings and fixed-width UTF-8 bytes."""
def __init__(self, allows_unicode=True):
self.allows_unicode = allows_unicode
def encode(self, variable: Variable, name=None) -> Variable:
dims, data, attrs, encoding = unpack_for_encoding(variable)
contains_unicode = is_unicode_dtype(data.dtype)
encode_as_char = encoding.get("dtype") == "S1"
if encode_as_char:
del encoding["dtype"] # no longer relevant
if contains_unicode and (encode_as_char or not self.allows_unicode):
if "_FillValue" in attrs:
raise NotImplementedError(
f"variable {name!r} has a _FillValue specified, but "
"_FillValue is not yet supported on unicode strings: "
"https://github.com/pydata/xarray/issues/1647"
)
string_encoding = encoding.pop("_Encoding", "utf-8")
safe_setitem(attrs, "_Encoding", string_encoding, name=name)
# TODO: figure out how to handle this in a lazy way with dask
data = encode_string_array(data, string_encoding)
return Variable(dims, data, attrs, encoding)
else:
variable.encoding = encoding
return variable
def decode(self, variable: Variable, name=None) -> Variable:
dims, data, attrs, encoding = unpack_for_decoding(variable)
if "_Encoding" in attrs:
string_encoding = pop_to(attrs, encoding, "_Encoding")
func = partial(decode_bytes_array, encoding=string_encoding)
data = lazy_elemwise_func(data, func, np.dtype(object))
return Variable(dims, data, attrs, encoding)
def decode_bytes_array(bytes_array, encoding="utf-8"):
# This is faster than using np.char.decode() or np.vectorize()
bytes_array = np.asarray(bytes_array)
decoded = [x.decode(encoding) for x in bytes_array.ravel()]
return np.array(decoded, dtype=object).reshape(bytes_array.shape)
def encode_string_array(string_array, encoding="utf-8"):
string_array = np.asarray(string_array)
encoded = [x.encode(encoding) for x in string_array.ravel()]
return np.array(encoded, dtype=bytes).reshape(string_array.shape)
def ensure_fixed_length_bytes(var: Variable) -> Variable:
"""Ensure that a variable with vlen bytes is converted to fixed width."""
if check_vlen_dtype(var.dtype) is bytes:
dims, data, attrs, encoding = unpack_for_encoding(var)
# TODO: figure out how to handle this with dask
data = np.asarray(data, dtype=np.bytes_)
return Variable(dims, data, attrs, encoding)
else:
return var
class CharacterArrayCoder(VariableCoder):
"""Transforms between arrays containing bytes and character arrays."""
def encode(self, variable, name=None):
variable = ensure_fixed_length_bytes(variable)
dims, data, attrs, encoding = unpack_for_encoding(variable)
if data.dtype.kind == "S" and encoding.get("dtype") is not str:
data = bytes_to_char(data)
if "char_dim_name" in encoding.keys():
char_dim_name = encoding.pop("char_dim_name")
else:
char_dim_name = f"string{data.shape[-1]}"
dims = dims + (char_dim_name,)
return Variable(dims, data, attrs, encoding)
def decode(self, variable, name=None):
dims, data, attrs, encoding = unpack_for_decoding(variable)
if data.dtype == "S1" and dims:
encoding["char_dim_name"] = dims[-1]
dims = dims[:-1]
data = char_to_bytes(data)
return Variable(dims, data, attrs, encoding)
def bytes_to_char(arr):
"""Convert numpy/dask arrays from fixed width bytes to characters."""
if arr.dtype.kind != "S":
raise ValueError("argument must have a fixed-width bytes dtype")
if is_chunked_array(arr):
chunkmanager = get_chunked_array_type(arr)
return chunkmanager.map_blocks(
_numpy_bytes_to_char,
arr,
dtype="S1",
chunks=arr.chunks + ((arr.dtype.itemsize,)),
new_axis=[arr.ndim],
)
return _numpy_bytes_to_char(arr)
def _numpy_bytes_to_char(arr):
"""Like netCDF4.stringtochar, but faster and more flexible."""
# adapt handling of copy-kwarg to numpy 2.0
# see https://github.com/numpy/numpy/issues/25916
# and https://github.com/numpy/numpy/pull/25922
copy = None if HAS_NUMPY_2_0 else False
# ensure the array is contiguous
arr = np.array(arr, copy=copy, order="C", dtype=np.bytes_)
return arr.reshape(arr.shape + (1,)).view("S1")
def char_to_bytes(arr):
"""Convert numpy/dask arrays from characters to fixed width bytes."""
if arr.dtype != "S1":
raise ValueError("argument must have dtype='S1'")
if not arr.ndim:
# no dimension to concatenate along
return arr
size = arr.shape[-1]
if not size:
# can't make an S0 dtype
return np.zeros(arr.shape[:-1], dtype=np.bytes_)
if is_chunked_array(arr):
chunkmanager = get_chunked_array_type(arr)
if len(arr.chunks[-1]) > 1:
raise ValueError(
"cannot stacked dask character array with "
f"multiple chunks in the last dimension: {arr}"
)
dtype = np.dtype("S" + str(arr.shape[-1]))
return chunkmanager.map_blocks(
_numpy_char_to_bytes,
arr,
dtype=dtype,
chunks=arr.chunks[:-1],
drop_axis=[arr.ndim - 1],
)
else:
return StackedBytesArray(arr)
def _numpy_char_to_bytes(arr):
"""Like netCDF4.chartostring, but faster and more flexible."""
# adapt handling of copy-kwarg to numpy 2.0
# see https://github.com/numpy/numpy/issues/25916
# and https://github.com/numpy/numpy/pull/25922
copy = None if HAS_NUMPY_2_0 else False
# based on: https://stackoverflow.com/a/10984878/809705
arr = np.array(arr, copy=copy, order="C")
dtype = "S" + str(arr.shape[-1])
return arr.view(dtype).reshape(arr.shape[:-1])
class StackedBytesArray(indexing.ExplicitlyIndexedNDArrayMixin):
"""Wrapper around array-like objects to create a new indexable object where
values, when accessed, are automatically stacked along the last dimension.
>>> indexer = indexing.BasicIndexer((slice(None),))
>>> StackedBytesArray(np.array(["a", "b", "c"], dtype="S1"))[indexer]
array(b'abc', dtype='|S3')
"""
def __init__(self, array):
"""
Parameters
----------
array : array-like
Original array of values to wrap.
"""
if array.dtype != "S1":
raise ValueError(
"can only use StackedBytesArray if argument has dtype='S1'"
)
self.array = indexing.as_indexable(array)
@property
def dtype(self):
return np.dtype("S" + str(self.array.shape[-1]))
@property
def shape(self) -> tuple[int, ...]:
return self.array.shape[:-1]
def __repr__(self):
return f"{type(self).__name__}({self.array!r})"
def _vindex_get(self, key):
return _numpy_char_to_bytes(self.array.vindex[key])
def _oindex_get(self, key):
return _numpy_char_to_bytes(self.array.oindex[key])
def __getitem__(self, key):
# require slicing the last dimension completely
key = type(key)(indexing.expanded_indexer(key.tuple, self.array.ndim))
if key.tuple[-1] != slice(None):
raise IndexError("too many indices")
return _numpy_char_to_bytes(self.array[key])