204 lines
7.1 KiB
Python
204 lines
7.1 KiB
Python
from __future__ import annotations
|
|
|
|
from numbers import Number
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
import xarray as xr
|
|
from xarray.backends.api import _get_default_engine
|
|
from xarray.tests import (
|
|
assert_identical,
|
|
assert_no_warnings,
|
|
requires_dask,
|
|
requires_netCDF4,
|
|
requires_scipy,
|
|
)
|
|
|
|
|
|
@requires_netCDF4
|
|
@requires_scipy
|
|
def test__get_default_engine() -> None:
|
|
engine_remote = _get_default_engine("http://example.org/test.nc", allow_remote=True)
|
|
assert engine_remote == "netcdf4"
|
|
|
|
engine_gz = _get_default_engine("/example.gz")
|
|
assert engine_gz == "scipy"
|
|
|
|
engine_default = _get_default_engine("/example")
|
|
assert engine_default == "netcdf4"
|
|
|
|
|
|
def test_custom_engine() -> None:
|
|
expected = xr.Dataset(
|
|
dict(a=2 * np.arange(5)), coords=dict(x=("x", np.arange(5), dict(units="s")))
|
|
)
|
|
|
|
class CustomBackend(xr.backends.BackendEntrypoint):
|
|
def open_dataset(
|
|
self,
|
|
filename_or_obj,
|
|
drop_variables=None,
|
|
**kwargs,
|
|
) -> xr.Dataset:
|
|
return expected.copy(deep=True)
|
|
|
|
actual = xr.open_dataset("fake_filename", engine=CustomBackend)
|
|
assert_identical(expected, actual)
|
|
|
|
|
|
def test_multiindex() -> None:
|
|
# GH7139
|
|
# Check that we properly handle backends that change index variables
|
|
dataset = xr.Dataset(coords={"coord1": ["A", "B"], "coord2": [1, 2]})
|
|
dataset = dataset.stack(z=["coord1", "coord2"])
|
|
|
|
class MultiindexBackend(xr.backends.BackendEntrypoint):
|
|
def open_dataset(
|
|
self,
|
|
filename_or_obj,
|
|
drop_variables=None,
|
|
**kwargs,
|
|
) -> xr.Dataset:
|
|
return dataset.copy(deep=True)
|
|
|
|
loaded = xr.open_dataset("fake_filename", engine=MultiindexBackend)
|
|
assert_identical(dataset, loaded)
|
|
|
|
|
|
class PassThroughBackendEntrypoint(xr.backends.BackendEntrypoint):
|
|
"""Access an object passed to the `open_dataset` method."""
|
|
|
|
def open_dataset(self, dataset, *, drop_variables=None):
|
|
"""Return the first argument."""
|
|
return dataset
|
|
|
|
|
|
def explicit_chunks(chunks, shape):
|
|
"""Return explicit chunks, expanding any integer member to a tuple of integers."""
|
|
# Emulate `dask.array.core.normalize_chunks` but for simpler inputs.
|
|
return tuple(
|
|
(
|
|
(
|
|
(size // chunk) * (chunk,)
|
|
+ ((size % chunk,) if size % chunk or size == 0 else ())
|
|
)
|
|
if isinstance(chunk, Number)
|
|
else chunk
|
|
)
|
|
for chunk, size in zip(chunks, shape, strict=True)
|
|
)
|
|
|
|
|
|
@requires_dask
|
|
class TestPreferredChunks:
|
|
"""Test behaviors related to the backend's preferred chunks."""
|
|
|
|
var_name = "data"
|
|
|
|
def create_dataset(self, shape, pref_chunks):
|
|
"""Return a dataset with a variable with the given shape and preferred chunks."""
|
|
dims = tuple(f"dim_{idx}" for idx in range(len(shape)))
|
|
return xr.Dataset(
|
|
{
|
|
self.var_name: xr.Variable(
|
|
dims,
|
|
np.empty(shape, dtype=np.dtype("V1")),
|
|
encoding={
|
|
"preferred_chunks": dict(zip(dims, pref_chunks, strict=True))
|
|
},
|
|
)
|
|
}
|
|
)
|
|
|
|
def check_dataset(self, initial, final, expected_chunks):
|
|
assert_identical(initial, final)
|
|
assert final[self.var_name].chunks == expected_chunks
|
|
|
|
@pytest.mark.parametrize(
|
|
"shape,pref_chunks",
|
|
[
|
|
# Represent preferred chunking with int.
|
|
((5,), (2,)),
|
|
# Represent preferred chunking with tuple.
|
|
((5,), ((2, 2, 1),)),
|
|
# Represent preferred chunking with int in two dims.
|
|
((5, 6), (4, 2)),
|
|
# Represent preferred chunking with tuple in second dim.
|
|
((5, 6), (4, (2, 2, 2))),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("request_with_empty_map", [False, True])
|
|
def test_honor_chunks(self, shape, pref_chunks, request_with_empty_map):
|
|
"""Honor the backend's preferred chunks when opening a dataset."""
|
|
initial = self.create_dataset(shape, pref_chunks)
|
|
# To keep the backend's preferred chunks, the `chunks` argument must be an
|
|
# empty mapping or map dimensions to `None`.
|
|
chunks = (
|
|
{}
|
|
if request_with_empty_map
|
|
else dict.fromkeys(initial[self.var_name].dims, None)
|
|
)
|
|
final = xr.open_dataset(
|
|
initial, engine=PassThroughBackendEntrypoint, chunks=chunks
|
|
)
|
|
self.check_dataset(initial, final, explicit_chunks(pref_chunks, shape))
|
|
|
|
@pytest.mark.parametrize(
|
|
"shape,pref_chunks,req_chunks",
|
|
[
|
|
# Preferred chunking is int; requested chunking is int.
|
|
((5,), (2,), (3,)),
|
|
# Preferred chunking is int; requested chunking is tuple.
|
|
((5,), (2,), ((2, 1, 1, 1),)),
|
|
# Preferred chunking is tuple; requested chunking is int.
|
|
((5,), ((2, 2, 1),), (3,)),
|
|
# Preferred chunking is tuple; requested chunking is tuple.
|
|
((5,), ((2, 2, 1),), ((2, 1, 1, 1),)),
|
|
# Split chunks along a dimension other than the first.
|
|
((1, 5), (1, 2), (1, 3)),
|
|
],
|
|
)
|
|
def test_split_chunks(self, shape, pref_chunks, req_chunks):
|
|
"""Warn when the requested chunks separate the backend's preferred chunks."""
|
|
initial = self.create_dataset(shape, pref_chunks)
|
|
with pytest.warns(UserWarning):
|
|
final = xr.open_dataset(
|
|
initial,
|
|
engine=PassThroughBackendEntrypoint,
|
|
chunks=dict(zip(initial[self.var_name].dims, req_chunks, strict=True)),
|
|
)
|
|
self.check_dataset(initial, final, explicit_chunks(req_chunks, shape))
|
|
|
|
@pytest.mark.parametrize(
|
|
"shape,pref_chunks,req_chunks",
|
|
[
|
|
# Keep preferred chunks using int representation.
|
|
((5,), (2,), (2,)),
|
|
# Keep preferred chunks using tuple representation.
|
|
((5,), (2,), ((2, 2, 1),)),
|
|
# Join chunks, leaving a final short chunk.
|
|
((5,), (2,), (4,)),
|
|
# Join all chunks with an int larger than the dimension size.
|
|
((5,), (2,), (6,)),
|
|
# Join one chunk using tuple representation.
|
|
((5,), (1,), ((1, 1, 2, 1),)),
|
|
# Join one chunk using int representation.
|
|
((5,), ((1, 1, 2, 1),), (2,)),
|
|
# Join multiple chunks using tuple representation.
|
|
((5,), ((1, 1, 2, 1),), ((2, 3),)),
|
|
# Join chunks in multiple dimensions.
|
|
((5, 5), (2, (1, 1, 2, 1)), (4, (2, 3))),
|
|
],
|
|
)
|
|
def test_join_chunks(self, shape, pref_chunks, req_chunks):
|
|
"""Don't warn when the requested chunks join or keep the preferred chunks."""
|
|
initial = self.create_dataset(shape, pref_chunks)
|
|
with assert_no_warnings():
|
|
final = xr.open_dataset(
|
|
initial,
|
|
engine=PassThroughBackendEntrypoint,
|
|
chunks=dict(zip(initial[self.var_name].dims, req_chunks, strict=True)),
|
|
)
|
|
self.check_dataset(initial, final, explicit_chunks(req_chunks, shape))
|