412 lines
13 KiB
Python
412 lines
13 KiB
Python
from sortedcontainers import SortedDict
|
|
from collections.abc import MutableMapping, Mapping
|
|
from .const import Bound
|
|
from .interval import Interval
|
|
|
|
|
|
def _sortkey(i):
|
|
# Sort by lower bound, closed first
|
|
return (i[0].lower, i[0].left is Bound.OPEN)
|
|
|
|
|
|
class IntervalDict(MutableMapping):
|
|
"""
|
|
An IntervalDict is a dict-like data structure that maps from intervals to data,
|
|
where keys can be single values or Interval instances.
|
|
|
|
When keys are Interval instances, its behaviour merely corresponds to
|
|
range queries and it returns IntervalDict instances corresponding to the
|
|
subset of values covered by the given interval. If no matching value is
|
|
found, an empty IntervalDict is returned.
|
|
When keys are "single values", its behaviour corresponds to the one of Python
|
|
built-in dict. When no matchin value is found, a KeyError is raised.
|
|
|
|
Note that this class does not aim to have the best performance, but is
|
|
provided mainly for convenience. Its performance mainly depends on the
|
|
number of distinct values (not keys) that are stored.
|
|
"""
|
|
|
|
__slots__ = ("_storage",)
|
|
|
|
# Class to use when creating Interval instances
|
|
_klass = Interval
|
|
|
|
def __init__(self, mapping_or_iterable=None):
|
|
"""
|
|
Return a new IntervalDict.
|
|
|
|
If no argument is given, an empty IntervalDict is created. If an argument
|
|
is given, and is a mapping object (e.g., another IntervalDict), an
|
|
new IntervalDict with the same key-value pairs is created. If an
|
|
iterable is provided, it has to be a list of (key, value) pairs.
|
|
|
|
:param mapping_or_iterable: optional mapping or iterable.
|
|
"""
|
|
self._storage = SortedDict(_sortkey) # Mapping from intervals to values
|
|
|
|
if mapping_or_iterable is not None:
|
|
self.update(mapping_or_iterable)
|
|
|
|
@classmethod
|
|
def _from_items(cls, items):
|
|
"""
|
|
Fast creation of an IntervalDict with the provided items.
|
|
|
|
The items have to satisfy the two following properties: (1) all keys
|
|
must be disjoint intervals and (2) all values must be distinct.
|
|
|
|
:param items: list of (key, value) pairs.
|
|
:return: an IntervalDict
|
|
"""
|
|
d = cls()
|
|
for key, value in items:
|
|
d._storage[key] = value
|
|
|
|
return d
|
|
|
|
def clear(self):
|
|
"""
|
|
Remove all items from the IntervalDict.
|
|
"""
|
|
self._storage.clear()
|
|
|
|
def copy(self):
|
|
"""
|
|
Return a shallow copy.
|
|
|
|
:return: a shallow copy.
|
|
"""
|
|
return self.__class__._from_items(self.items())
|
|
|
|
def get(self, key, default=None):
|
|
"""
|
|
Return the values associated to given key.
|
|
|
|
If the key is a single value, it returns a single value (if it exists) or
|
|
the default value. If the key is an Interval, it returns a new IntervalDict
|
|
restricted to given interval. In that case, the default value is used to
|
|
"fill the gaps" (if any) w.r.t. given key.
|
|
|
|
:param key: a single value or an Interval instance.
|
|
:param default: default value (default to None).
|
|
:return: an IntervalDict, or a single value if key is not an Interval.
|
|
"""
|
|
if isinstance(key, Interval):
|
|
d = self[key]
|
|
d[key - d.domain()] = default
|
|
return d
|
|
else:
|
|
try:
|
|
return self[key]
|
|
except KeyError:
|
|
return default
|
|
|
|
def find(self, value):
|
|
"""
|
|
Return a (possibly empty) Interval i such that self[i] = value, and
|
|
self[~i] != value.
|
|
|
|
:param value: value to look for.
|
|
:return: an Interval instance.
|
|
"""
|
|
return self._klass(*(i for i, v in self._storage.items() if v == value))
|
|
|
|
def items(self):
|
|
"""
|
|
Return a view object on the contained items sorted by their key
|
|
(see https://docs.python.org/3/library/stdtypes.html#dict-views).
|
|
|
|
:return: a view object.
|
|
"""
|
|
return self._storage.items()
|
|
|
|
def keys(self):
|
|
"""
|
|
Return a view object on the contained keys (sorted)
|
|
(see https://docs.python.org/3/library/stdtypes.html#dict-views).
|
|
|
|
:return: a view object.
|
|
"""
|
|
return self._storage.keys()
|
|
|
|
def values(self):
|
|
"""
|
|
Return a view object on the contained values sorted by their key
|
|
(see https://docs.python.org/3/library/stdtypes.html#dict-views).
|
|
|
|
:return: a view object.
|
|
"""
|
|
return self._storage.values()
|
|
|
|
def domain(self):
|
|
"""
|
|
Return an Interval corresponding to the domain of this IntervalDict.
|
|
|
|
:return: an Interval.
|
|
"""
|
|
return self._klass(*self._storage.keys())
|
|
|
|
def pop(self, key, default=None):
|
|
"""
|
|
Remove key and return the corresponding value if key is not an Interval.
|
|
If key is an interval, it returns an IntervalDict instance.
|
|
|
|
This method combines self[key] and del self[key]. If a default value
|
|
is provided and is not None, it uses self.get(key, default) instead of
|
|
self[key].
|
|
|
|
:param key: a single value or an Interval instance.
|
|
:param default: optional default value.
|
|
:return: an IntervalDict, or a single value if key is not an Interval.
|
|
"""
|
|
if default is None:
|
|
value = self[key]
|
|
del self[key]
|
|
return value
|
|
else:
|
|
value = self.get(key, default)
|
|
try:
|
|
del self[key]
|
|
except KeyError:
|
|
pass
|
|
return value
|
|
|
|
def popitem(self):
|
|
"""
|
|
Remove and return some (key, value) pair as a 2-tuple.
|
|
Raise KeyError if D is empty.
|
|
|
|
:return: a (key, value) pair.
|
|
"""
|
|
return self._storage.popitem()
|
|
|
|
def setdefault(self, key, default=None):
|
|
"""
|
|
Return given key. If it does not exist, set its value to given default
|
|
and return it.
|
|
|
|
:param key: a single value or an Interval instance.
|
|
:param default: default value (default to None).
|
|
:return: an IntervalDict, or a single value if key is not an Interval.
|
|
"""
|
|
if isinstance(key, Interval):
|
|
value = self.get(key, default)
|
|
self.update(value)
|
|
return value
|
|
else:
|
|
try:
|
|
return self[key]
|
|
except KeyError:
|
|
self[key] = default
|
|
return default
|
|
|
|
def update(self, mapping_or_iterable):
|
|
"""
|
|
Update current IntervalDict with provided values.
|
|
|
|
If a mapping is provided, it must map Interval instances to values (e.g.,
|
|
another IntervalDict). If an iterable is provided, it must consist of a
|
|
list of (key, value) pairs.
|
|
|
|
:param mapping_or_iterable: mapping or iterable.
|
|
"""
|
|
if isinstance(mapping_or_iterable, Mapping):
|
|
data = mapping_or_iterable.items()
|
|
else:
|
|
data = mapping_or_iterable
|
|
|
|
for i, v in data:
|
|
self[i] = v
|
|
|
|
def combine(self, other, how, *, missing=..., pass_interval=False):
|
|
"""
|
|
Return a new IntervalDict that combines the values from current and
|
|
provided IntervalDict.
|
|
|
|
If d = d1.combine(d2, f), then d contains (1) all values from d1 whose
|
|
keys do not intersect the ones of d2, (2) all values from d2 whose keys
|
|
do not intersect the ones of d1, and (3) f(x, y) for x in d1, y in d2 for
|
|
intersecting keys.
|
|
|
|
When missing is set, the how function is called even for non-intersecting
|
|
keys using the value of missing to replace the missing values. This is,
|
|
case (1) corresponds to f(x, missing) and case (2) to f(missing, y).
|
|
|
|
If pass_interval is set to True, the current interval will be passed to
|
|
the "how" function as third parameter.
|
|
|
|
:param other: another IntervalDict instance.
|
|
:param how: a function combining two values.
|
|
:param missing: if set, use this value for missing values when calling "how".
|
|
:param pass_interval: if set, provide the current interval to the "how" function.
|
|
:return: a new IntervalDict instance.
|
|
"""
|
|
new_items = []
|
|
|
|
if not pass_interval:
|
|
_how = lambda x, y, i: how(x, y)
|
|
else:
|
|
_how = how
|
|
|
|
dom1, dom2 = self.domain(), other.domain()
|
|
|
|
if missing is Ellipsis:
|
|
new_items.extend(self[dom1 - dom2].items())
|
|
new_items.extend(other[dom2 - dom1].items())
|
|
else:
|
|
for i, v in self[dom1 - dom2].items():
|
|
new_items.append((i, _how(v, missing, i)))
|
|
for i, v in other[dom2 - dom1].items():
|
|
new_items.append((i, _how(missing, v, i)))
|
|
|
|
intersection = dom1 & dom2
|
|
d1, d2 = self[intersection], other[intersection]
|
|
|
|
for i1, v1 in d1.items():
|
|
for i2, v2 in d2.items():
|
|
if i1.overlaps(i2):
|
|
i = i1 & i2
|
|
v = _how(v1, v2, i)
|
|
new_items.append((i, v))
|
|
|
|
return self.__class__(new_items)
|
|
|
|
def as_dict(self, atomic=False):
|
|
"""
|
|
Return the content as a classical Python dict.
|
|
|
|
:param atomic: whether keys are atomic intervals.
|
|
:return: a Python dict.
|
|
"""
|
|
if atomic:
|
|
d = dict()
|
|
for interval, v in self._storage.items():
|
|
for i in interval:
|
|
d[i] = v
|
|
return d
|
|
else:
|
|
return dict(self._storage)
|
|
|
|
def __getitem__(self, key):
|
|
if isinstance(key, Interval):
|
|
items = []
|
|
for i, v in self._storage.items():
|
|
# Early out
|
|
if key.upper < i.lower:
|
|
break
|
|
|
|
intersection = key & i
|
|
if not intersection.empty:
|
|
items.append((intersection, v))
|
|
return self.__class__._from_items(items)
|
|
else:
|
|
for i, v in self._storage.items():
|
|
# Early out
|
|
if key < i.lower:
|
|
break
|
|
if key in i:
|
|
return v
|
|
raise KeyError(key)
|
|
|
|
def __setitem__(self, key, value):
|
|
if isinstance(key, Interval):
|
|
interval = key
|
|
else:
|
|
interval = self._klass.from_atomic(Bound.CLOSED, key, key, Bound.CLOSED)
|
|
|
|
if interval.empty:
|
|
return
|
|
|
|
removed_keys = []
|
|
added_items = []
|
|
|
|
found = False
|
|
for i, v in self._storage.items():
|
|
if value == v:
|
|
found = True
|
|
# Extend existing key
|
|
removed_keys.append(i)
|
|
added_items.append((i | interval, v))
|
|
elif i.overlaps(interval):
|
|
# Reduce existing key
|
|
remaining = i - interval
|
|
removed_keys.append(i)
|
|
if not remaining.empty:
|
|
added_items.append((remaining, v))
|
|
|
|
if not found:
|
|
added_items.append((interval, value))
|
|
|
|
# Update storage accordingly
|
|
for key in removed_keys:
|
|
self._storage.pop(key)
|
|
|
|
for key, value in added_items:
|
|
self._storage[key] = value
|
|
|
|
def __delitem__(self, key):
|
|
if isinstance(key, Interval):
|
|
interval = key
|
|
else:
|
|
interval = self._klass.from_atomic(Bound.CLOSED, key, key, Bound.CLOSED)
|
|
|
|
if interval.empty:
|
|
return
|
|
|
|
removed_keys = []
|
|
added_items = []
|
|
|
|
found = False
|
|
for i, v in self._storage.items():
|
|
# Early out
|
|
if interval.upper < i.lower:
|
|
break
|
|
|
|
if i.overlaps(interval):
|
|
found = True
|
|
remaining = i - interval
|
|
removed_keys.append(i)
|
|
if not remaining.empty:
|
|
added_items.append((remaining, v))
|
|
|
|
if not found and not isinstance(key, Interval):
|
|
raise KeyError(key)
|
|
|
|
# Update storage accordingly
|
|
for key in removed_keys:
|
|
self._storage.pop(key)
|
|
|
|
for key, value in added_items:
|
|
self._storage[key] = value
|
|
|
|
def __or__(self, other):
|
|
d = self.copy()
|
|
d.update(other)
|
|
return d
|
|
|
|
def __ior__(self, other):
|
|
self.update(other)
|
|
return self
|
|
|
|
def __iter__(self):
|
|
return iter(self._storage)
|
|
|
|
def __len__(self):
|
|
return len(self._storage)
|
|
|
|
def __contains__(self, key):
|
|
return key in self.domain()
|
|
|
|
def __repr__(self):
|
|
return "{}{}{}".format(
|
|
"{",
|
|
", ".join("{!r}: {!r}".format(i, v) for i, v in self.items()),
|
|
"}",
|
|
)
|
|
|
|
def __eq__(self, other):
|
|
if isinstance(other, IntervalDict):
|
|
return self.as_dict() == other.as_dict()
|
|
else:
|
|
return NotImplemented
|