forked from astropy/astropy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsoco.py
More file actions
179 lines (148 loc) · 5.06 KB
/
soco.py
File metadata and controls
179 lines (148 loc) · 5.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
The SCEngine class uses the ``sortedcontainers`` package to implement an
Index engine for Tables.
"""
from collections import OrderedDict
from itertools import starmap
from astropy.utils.compat.optional_deps import HAS_SORTEDCONTAINERS
if HAS_SORTEDCONTAINERS:
from sortedcontainers import SortedList
class Node:
__slots__ = ("key", "value")
def __init__(self, key, value):
self.key = key
self.value = value
def __lt__(self, other):
if other.__class__ is Node:
return (self.key, self.value) < (other.key, other.value)
return self.key < other
def __le__(self, other):
if other.__class__ is Node:
return (self.key, self.value) <= (other.key, other.value)
return self.key <= other
def __eq__(self, other):
if other.__class__ is Node:
return (self.key, self.value) == (other.key, other.value)
return self.key == other
def __ne__(self, other):
if other.__class__ is Node:
return (self.key, self.value) != (other.key, other.value)
return self.key != other
def __gt__(self, other):
if other.__class__ is Node:
return (self.key, self.value) > (other.key, other.value)
return self.key > other
def __ge__(self, other):
if other.__class__ is Node:
return (self.key, self.value) >= (other.key, other.value)
return self.key >= other
__hash__ = None
def __repr__(self):
return f"Node({self.key!r}, {self.value!r})"
class SCEngine:
"""
Fast tree-based implementation for indexing, using the
``sortedcontainers`` package.
Parameters
----------
data : Table
Sorted columns of the original table
row_index : Column object
Row numbers corresponding to data columns
unique : bool
Whether the values of the index must be unique.
Defaults to False.
"""
def __init__(self, data, row_index, unique=False):
if not HAS_SORTEDCONTAINERS:
raise ImportError("sortedcontainers is needed for using SCEngine")
node_keys = map(tuple, data)
self._nodes = SortedList(starmap(Node, zip(node_keys, row_index)))
self._unique = unique
def add(self, key, value):
"""
Add a key, value pair.
"""
if self._unique and (key in self._nodes):
message = f"duplicate {key!r} in unique index"
raise ValueError(message)
self._nodes.add(Node(key, value))
def find(self, key):
"""
Find rows corresponding to the given key.
"""
return [node.value for node in self._nodes.irange(key, key)]
def remove(self, key, data=None):
"""
Remove data from the given key.
"""
if data is not None:
item = Node(key, data)
try:
self._nodes.remove(item)
except ValueError:
return False
return True
items = list(self._nodes.irange(key, key))
for item in items:
self._nodes.remove(item)
return bool(items)
def shift_left(self, row):
"""
Decrement rows larger than the given row.
"""
for node in self._nodes:
if node.value > row:
node.value -= 1
def shift_right(self, row):
"""
Increment rows greater than or equal to the given row.
"""
for node in self._nodes:
if node.value >= row:
node.value += 1
def items(self):
"""
Return a list of key, data tuples.
"""
result = OrderedDict()
for node in self._nodes:
if node.key in result:
result[node.key].append(node.value)
else:
result[node.key] = [node.value]
return result.items()
def sort(self):
"""
Make row order align with key order.
"""
for index, node in enumerate(self._nodes):
node.value = index
def sorted_data(self):
"""
Return a list of rows in order sorted by key.
"""
return [node.value for node in self._nodes]
def range(self, lower, upper, bounds=(True, True)):
"""
Return row values in the given range.
"""
iterator = self._nodes.irange(lower, upper, bounds)
return [node.value for node in iterator]
def replace_rows(self, row_map):
"""
Replace rows with the values in row_map.
"""
nodes = [node for node in self._nodes if node.value in row_map]
for node in nodes:
node.value = row_map[node.value]
self._nodes.clear()
self._nodes.update(nodes)
def __repr__(self):
if len(self._nodes) > 6:
nodes = list(self._nodes[:3]) + ["..."] + list(self._nodes[-3:])
else:
nodes = self._nodes
nodes_str = ", ".join(str(node) for node in nodes)
return f"<{self.__class__.__name__} nodes={nodes_str}>"