264 lines
8.8 KiB
Python
264 lines
8.8 KiB
Python
"""Coders for strings."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
|
|
from xarray.coding.variables import (
|
|
VariableCoder,
|
|
lazy_elemwise_func,
|
|
pop_to,
|
|
safe_setitem,
|
|
unpack_for_decoding,
|
|
unpack_for_encoding,
|
|
)
|
|
from xarray.core import indexing
|
|
from xarray.core.utils import module_available
|
|
from xarray.core.variable import Variable
|
|
from xarray.namedarray.parallelcompat import get_chunked_array_type
|
|
from xarray.namedarray.pycompat import is_chunked_array
|
|
|
|
HAS_NUMPY_2_0 = module_available("numpy", minversion="2.0.0.dev0")
|
|
|
|
|
|
def create_vlen_dtype(element_type):
|
|
if element_type not in (str, bytes):
|
|
raise TypeError(f"unsupported type for vlen_dtype: {element_type!r}")
|
|
# based on h5py.special_dtype
|
|
return np.dtype("O", metadata={"element_type": element_type})
|
|
|
|
|
|
def check_vlen_dtype(dtype):
|
|
if dtype.kind != "O" or dtype.metadata is None:
|
|
return None
|
|
else:
|
|
# check xarray (element_type) as well as h5py (vlen)
|
|
return dtype.metadata.get("element_type", dtype.metadata.get("vlen"))
|
|
|
|
|
|
def is_unicode_dtype(dtype):
|
|
return dtype.kind == "U" or check_vlen_dtype(dtype) is str
|
|
|
|
|
|
def is_bytes_dtype(dtype):
|
|
return dtype.kind == "S" or check_vlen_dtype(dtype) is bytes
|
|
|
|
|
|
class EncodedStringCoder(VariableCoder):
|
|
"""Transforms between unicode strings and fixed-width UTF-8 bytes."""
|
|
|
|
def __init__(self, allows_unicode=True):
|
|
self.allows_unicode = allows_unicode
|
|
|
|
def encode(self, variable: Variable, name=None) -> Variable:
|
|
dims, data, attrs, encoding = unpack_for_encoding(variable)
|
|
|
|
contains_unicode = is_unicode_dtype(data.dtype)
|
|
encode_as_char = encoding.get("dtype") == "S1"
|
|
if encode_as_char:
|
|
del encoding["dtype"] # no longer relevant
|
|
|
|
if contains_unicode and (encode_as_char or not self.allows_unicode):
|
|
if "_FillValue" in attrs:
|
|
raise NotImplementedError(
|
|
f"variable {name!r} has a _FillValue specified, but "
|
|
"_FillValue is not yet supported on unicode strings: "
|
|
"https://github.com/pydata/xarray/issues/1647"
|
|
)
|
|
|
|
string_encoding = encoding.pop("_Encoding", "utf-8")
|
|
safe_setitem(attrs, "_Encoding", string_encoding, name=name)
|
|
# TODO: figure out how to handle this in a lazy way with dask
|
|
data = encode_string_array(data, string_encoding)
|
|
|
|
return Variable(dims, data, attrs, encoding)
|
|
else:
|
|
variable.encoding = encoding
|
|
return variable
|
|
|
|
def decode(self, variable: Variable, name=None) -> Variable:
|
|
dims, data, attrs, encoding = unpack_for_decoding(variable)
|
|
|
|
if "_Encoding" in attrs:
|
|
string_encoding = pop_to(attrs, encoding, "_Encoding")
|
|
func = partial(decode_bytes_array, encoding=string_encoding)
|
|
data = lazy_elemwise_func(data, func, np.dtype(object))
|
|
|
|
return Variable(dims, data, attrs, encoding)
|
|
|
|
|
|
def decode_bytes_array(bytes_array, encoding="utf-8"):
|
|
# This is faster than using np.char.decode() or np.vectorize()
|
|
bytes_array = np.asarray(bytes_array)
|
|
decoded = [x.decode(encoding) for x in bytes_array.ravel()]
|
|
return np.array(decoded, dtype=object).reshape(bytes_array.shape)
|
|
|
|
|
|
def encode_string_array(string_array, encoding="utf-8"):
|
|
string_array = np.asarray(string_array)
|
|
encoded = [x.encode(encoding) for x in string_array.ravel()]
|
|
return np.array(encoded, dtype=bytes).reshape(string_array.shape)
|
|
|
|
|
|
def ensure_fixed_length_bytes(var: Variable) -> Variable:
|
|
"""Ensure that a variable with vlen bytes is converted to fixed width."""
|
|
if check_vlen_dtype(var.dtype) is bytes:
|
|
dims, data, attrs, encoding = unpack_for_encoding(var)
|
|
# TODO: figure out how to handle this with dask
|
|
data = np.asarray(data, dtype=np.bytes_)
|
|
return Variable(dims, data, attrs, encoding)
|
|
else:
|
|
return var
|
|
|
|
|
|
class CharacterArrayCoder(VariableCoder):
|
|
"""Transforms between arrays containing bytes and character arrays."""
|
|
|
|
def encode(self, variable, name=None):
|
|
variable = ensure_fixed_length_bytes(variable)
|
|
|
|
dims, data, attrs, encoding = unpack_for_encoding(variable)
|
|
if data.dtype.kind == "S" and encoding.get("dtype") is not str:
|
|
data = bytes_to_char(data)
|
|
if "char_dim_name" in encoding.keys():
|
|
char_dim_name = encoding.pop("char_dim_name")
|
|
else:
|
|
char_dim_name = f"string{data.shape[-1]}"
|
|
dims = dims + (char_dim_name,)
|
|
return Variable(dims, data, attrs, encoding)
|
|
|
|
def decode(self, variable, name=None):
|
|
dims, data, attrs, encoding = unpack_for_decoding(variable)
|
|
|
|
if data.dtype == "S1" and dims:
|
|
encoding["char_dim_name"] = dims[-1]
|
|
dims = dims[:-1]
|
|
data = char_to_bytes(data)
|
|
return Variable(dims, data, attrs, encoding)
|
|
|
|
|
|
def bytes_to_char(arr):
|
|
"""Convert numpy/dask arrays from fixed width bytes to characters."""
|
|
if arr.dtype.kind != "S":
|
|
raise ValueError("argument must have a fixed-width bytes dtype")
|
|
|
|
if is_chunked_array(arr):
|
|
chunkmanager = get_chunked_array_type(arr)
|
|
|
|
return chunkmanager.map_blocks(
|
|
_numpy_bytes_to_char,
|
|
arr,
|
|
dtype="S1",
|
|
chunks=arr.chunks + ((arr.dtype.itemsize,)),
|
|
new_axis=[arr.ndim],
|
|
)
|
|
return _numpy_bytes_to_char(arr)
|
|
|
|
|
|
def _numpy_bytes_to_char(arr):
|
|
"""Like netCDF4.stringtochar, but faster and more flexible."""
|
|
# adapt handling of copy-kwarg to numpy 2.0
|
|
# see https://github.com/numpy/numpy/issues/25916
|
|
# and https://github.com/numpy/numpy/pull/25922
|
|
copy = None if HAS_NUMPY_2_0 else False
|
|
# ensure the array is contiguous
|
|
arr = np.array(arr, copy=copy, order="C", dtype=np.bytes_)
|
|
return arr.reshape(arr.shape + (1,)).view("S1")
|
|
|
|
|
|
def char_to_bytes(arr):
|
|
"""Convert numpy/dask arrays from characters to fixed width bytes."""
|
|
if arr.dtype != "S1":
|
|
raise ValueError("argument must have dtype='S1'")
|
|
|
|
if not arr.ndim:
|
|
# no dimension to concatenate along
|
|
return arr
|
|
|
|
size = arr.shape[-1]
|
|
|
|
if not size:
|
|
# can't make an S0 dtype
|
|
return np.zeros(arr.shape[:-1], dtype=np.bytes_)
|
|
|
|
if is_chunked_array(arr):
|
|
chunkmanager = get_chunked_array_type(arr)
|
|
|
|
if len(arr.chunks[-1]) > 1:
|
|
raise ValueError(
|
|
"cannot stacked dask character array with "
|
|
f"multiple chunks in the last dimension: {arr}"
|
|
)
|
|
|
|
dtype = np.dtype("S" + str(arr.shape[-1]))
|
|
return chunkmanager.map_blocks(
|
|
_numpy_char_to_bytes,
|
|
arr,
|
|
dtype=dtype,
|
|
chunks=arr.chunks[:-1],
|
|
drop_axis=[arr.ndim - 1],
|
|
)
|
|
else:
|
|
return StackedBytesArray(arr)
|
|
|
|
|
|
def _numpy_char_to_bytes(arr):
|
|
"""Like netCDF4.chartostring, but faster and more flexible."""
|
|
# adapt handling of copy-kwarg to numpy 2.0
|
|
# see https://github.com/numpy/numpy/issues/25916
|
|
# and https://github.com/numpy/numpy/pull/25922
|
|
copy = None if HAS_NUMPY_2_0 else False
|
|
# based on: https://stackoverflow.com/a/10984878/809705
|
|
arr = np.array(arr, copy=copy, order="C")
|
|
dtype = "S" + str(arr.shape[-1])
|
|
return arr.view(dtype).reshape(arr.shape[:-1])
|
|
|
|
|
|
class StackedBytesArray(indexing.ExplicitlyIndexedNDArrayMixin):
|
|
"""Wrapper around array-like objects to create a new indexable object where
|
|
values, when accessed, are automatically stacked along the last dimension.
|
|
|
|
>>> indexer = indexing.BasicIndexer((slice(None),))
|
|
>>> StackedBytesArray(np.array(["a", "b", "c"], dtype="S1"))[indexer]
|
|
array(b'abc', dtype='|S3')
|
|
"""
|
|
|
|
def __init__(self, array):
|
|
"""
|
|
Parameters
|
|
----------
|
|
array : array-like
|
|
Original array of values to wrap.
|
|
"""
|
|
if array.dtype != "S1":
|
|
raise ValueError(
|
|
"can only use StackedBytesArray if argument has dtype='S1'"
|
|
)
|
|
self.array = indexing.as_indexable(array)
|
|
|
|
@property
|
|
def dtype(self):
|
|
return np.dtype("S" + str(self.array.shape[-1]))
|
|
|
|
@property
|
|
def shape(self) -> tuple[int, ...]:
|
|
return self.array.shape[:-1]
|
|
|
|
def __repr__(self):
|
|
return f"{type(self).__name__}({self.array!r})"
|
|
|
|
def _vindex_get(self, key):
|
|
return _numpy_char_to_bytes(self.array.vindex[key])
|
|
|
|
def _oindex_get(self, key):
|
|
return _numpy_char_to_bytes(self.array.oindex[key])
|
|
|
|
def __getitem__(self, key):
|
|
# require slicing the last dimension completely
|
|
key = type(key)(indexing.expanded_indexer(key.tuple, self.array.ndim))
|
|
if key.tuple[-1] != slice(None):
|
|
raise IndexError("too many indices")
|
|
return _numpy_char_to_bytes(self.array[key])
|