Skip to content

Commit 3a4dfde

Browse files
committed
Reworks the logic for adding selectors, improves some documentation, adds partial instead of lambda functions, improves typing in highlight selector
1 parent 174f861 commit 3a4dfde

2 files changed

Lines changed: 73 additions & 40 deletions

File tree

fastplotlib/graphics/selectors/_highlight_selector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def selection(self) -> tuple[int | None, ...] | dict[str, tuple]:
681681
return {k: tuple(v) for k, v in self._selection.items()}
682682

683683
@selection.setter
684-
def selection(self, value: Iterable[int] | dict[Literal["rows", "cols", "pixels"], list] | None) -> None:
684+
def selection(self, value: Iterable[int | None] | dict[Literal["rows", "cols", "pixels"], list] | None) -> None:
685685
if self._selection_options is not None:
686686
if value is None:
687687
self._selected_indices = list()

fastplotlib/graphics/selectors/_selection_vector.py

Lines changed: 72 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,54 @@
11
from collections.abc import Callable
22
from functools import partial
3-
from typing import Any, Sequence
3+
from typing import Any, Sequence, TypeAlias
44
from numbers import Integral
5+
56
import numpy as np
7+
68
from ._protocols import SelectorProtocol, MultiSelectorProtocol
7-
from fastplotlib.graphics.features._base import GraphicFeatureEvent
9+
10+
Mapping: TypeAlias = np.ndarray | dict[int, int] | Callable
811

912
def identity(val: Any) -> Any:
1013
return val
1114

15+
def array_map(arr: np.ndarray, index: Integral):
16+
"""
17+
Used to map local to global indices
18+
"""
19+
return None if np.isnan(arr[index]) else arr[index]
20+
21+
def inv_array_map(arr: np.ndarray,
22+
value: int) -> None | Integral:
23+
"""
24+
arr[i] gives the global index
25+
"""
26+
x = np.flatnonzero(arr == value)
27+
return None if x.size == 0 else x[0]
28+
29+
def dict_map(my_dict: dict, key: Integral):
30+
if key is None:
31+
return None
32+
elif int(key) not in my_dict:
33+
return None
34+
else:
35+
return my_dict[key]
36+
37+
1238
class SelectionVector:
1339
"""
1440
A class for performing coordinated selections across multiple selectors.
1541
For each selector in the selection vector, the user specifies how the global indices (shared across selectors)
16-
maps to the local indices.
42+
maps to the local indices (each selector has its own local index space).
1743
18-
The SelectionVector manages everything else, including the coordinated updating of indices whenever a selection changes
44+
The SelectionVector coordinates across individual selectors, including the coordinated updating of indices whenever a selection changes
1945
"""
2046
def __init__(self, max_size: int = None):
2147
# selector -> (map, map_inv)
48+
49+
## Key is a selector, value is a (1) local to global index map (2) global to local index map (3) list of event handlers
2250
self._selectors: dict[
23-
SelectorProtocol | MultiSelectorProtocol, tuple[Callable, Callable, list]
51+
SelectorProtocol | MultiSelectorProtocol, tuple[Callable, Callable, list[Callable]]
2452
] = dict()
2553
self._selection: list[Any] = list()
2654
self._block_reentrance = False
@@ -40,12 +68,11 @@ def selection(self, new: Integral | Sequence[Any]):
4068
self._selection = [i for i in new]
4169
# iterate through each selector that operates in its own "local" space
4270
for selector_local, (map_, map_inv, handler) in self._selectors.items():
43-
cumulated_output = []
71+
local_indices = []
4472
for value in new:
4573
curr_indices = map_(value)
46-
cumulated_output.append(curr_indices)
47-
# indices_local = map_(new)
48-
selector_local.selection = cumulated_output
74+
local_indices.append(curr_indices)
75+
selector_local.selection = local_indices
4976
self._block_reentrance = False
5077

5178
def append(self, index):
@@ -54,63 +81,69 @@ def append(self, index):
5481
if not isinstance(selector, MultiSelectorProtocol):
5582
continue
5683

57-
index_local = map_([index])
84+
index_local = map_(index)
5885
selector.append(index_local)
5986

6087
def add_selector(
6188
self,
6289
new: (
6390
SelectorProtocol
64-
| tuple[SelectorProtocol, np.ndarray | dict[int, int]]
91+
| tuple[SelectorProtocol, dict]
92+
| tuple[SelectorProtocol, np.ndarray]
93+
|tuple[SelectorProtocol, Callable, Callable]
6594
),
6695
):
6796
"""
6897
User specifies (1) the selector and (2) The master --> local index mapping. This
6998
mapping is given either as:
7099
- A 1D np.ndarray of integers. The array index is the global index, and the array value is the local index
71100
- A dictionary where keys (master indices) and values (local indices) are both integers
101+
- Two callables. The first callable defines the global index --> local index map, the second specifies the local index --> global index map.
72102
"""
73-
selector: SelectorProtocol
74103
if isinstance(new, (tuple, list)):
75104
if not isinstance(new[0], SelectorProtocol):
76105
raise TypeError
77106

78-
if len(new) != 2:
79-
raise TypeError
107+
if len(new) == 3:
108+
if isinstance(new[1], Callable) and isinstance(new[2], Callable):
109+
master_to_local = new[1]
110+
local_to_master = new[2]
111+
else:
112+
raise ValueError(f"Both index mappings must be Callables, you provided {type(new[1])} and {type(new[2])}")
113+
elif len(new) == 2:
114+
if isinstance(new[1], dict):
115+
## Construct inverse mapping
116+
inverse_dict = dict()
117+
for key, val in new[1].items():
118+
inverse_dict[int(val)] = int(key)
119+
master_to_local = partial(dict_map, new[1])
120+
local_to_master = partial(dict_map, inverse_dict)
121+
122+
elif isinstance(new[1], np.ndarray):
123+
if not new[1].ndim == 1:
124+
raise ValueError("If you pass in an array mapping, it must be 1-D")
125+
master_to_local = partial(array_map, new[1])
126+
local_to_master = partial(inv_array_map, new[1])
127+
else:
128+
raise ValueError(f"Must either provide a single dict or numpy array specifying the local to global index mapping, or two callables"
129+
f"specifying the mapping in both directions")
80130

81131
selector = new[0]
82-
master_to_local = new[1]
83-
if isinstance(master_to_local, np.ndarray):
84-
if not master_to_local.ndim == 1:
85-
raise ValueError("If you pass in an array mapping, it must be 1-D")
86-
master_to_local = dict(enumerate(master_to_local))
87-
88-
## Construct inverse mapping
89-
inverse_dict = dict()
90-
for key, val in master_to_local.items():
91-
inverse_dict[int(val)] = int(key)
92-
93-
## Define the partial functions
94-
master_to_local_map = lambda x:master_to_local[int(x)] if int(x) in master_to_local else None
95-
local_to_master_map = lambda x:inverse_dict[int(x)] if x in inverse_dict and x is not None else None
96132

97133
elif isinstance(new, SelectorProtocol):
98-
selector, master_to_local_map, local_to_master_map = new, identity, identity
134+
selector, master_to_local, local_to_master = new, identity, identity
99135

100136
else:
101137
raise ValueError
102138

103-
handler = selector.add_event_handler(partial(self._inv_handler, local_to_master_map))
104-
self._selectors[selector] = (master_to_local_map, local_to_master_map, [handler])
139+
handler = selector.add_event_handler(partial(self._inv_handler, local_to_master))
140+
self._selectors[selector] = (master_to_local, local_to_master, [handler])
105141

106-
def _inv_handler(self, map_inv: Callable, local_selection: dict | GraphicFeatureEvent):
107-
if isinstance(local_selection, dict):
108-
input_to_map = local_selection['value']
109-
# local_selection = list(local_selection.items())[0][1]
110-
elif isinstance(local_selection, GraphicFeatureEvent):
111-
input_to_map = local_selection.info['value']
112-
else:
113-
raise ValueError("Input to inverse handler should either be dictionary or GraphicFeatureEvent")
142+
def _inv_handler(self, map_inv: Callable, local_selection: dict):
143+
"""
144+
HighlightSelector and VisibilitySelector emit a dictionary with keys selector and value
145+
"""
146+
input_to_map = local_selection['value']
114147
self.selection = [map_inv(input_to_map[i]) for i in range(len(input_to_map))]
115148

116149
def remove_selector(self, selector: SelectorProtocol | MultiSelectorProtocol):

0 commit comments

Comments
 (0)