11from collections .abc import Callable
22from functools import partial
3- from typing import Any , Sequence
3+ from typing import Any , Sequence , TypeAlias
44from numbers import Integral
5+
56import numpy as np
7+
68from ._protocols import SelectorProtocol , MultiSelectorProtocol
7- from fastplotlib .graphics .features ._base import GraphicFeatureEvent
9+
10+ Mapping : TypeAlias = np .ndarray | dict [int , int ] | Callable
811
912def identity (val : Any ) -> Any :
1013 return val
1114
15+ def array_map (arr : np .ndarray , index : Integral ):
16+ """
17+ Used to map local to global indices
18+ """
19+ return None if np .isnan (arr [index ]) else arr [index ]
20+
21+ def inv_array_map (arr : np .ndarray ,
22+ value : int ) -> None | Integral :
23+ """
24+ arr[i] gives the global index
25+ """
26+ x = np .flatnonzero (arr == value )
27+ return None if x .size == 0 else x [0 ]
28+
29+ def dict_map (my_dict : dict , key : Integral ):
30+ if key is None :
31+ return None
32+ elif int (key ) not in my_dict :
33+ return None
34+ else :
35+ return my_dict [key ]
36+
37+
1238class SelectionVector :
1339 """
1440 A class for performing coordinated selections across multiple selectors.
1541 For each selector in the selection vector, the user specifies how the global indices (shared across selectors)
16- maps to the local indices.
42+ maps to the local indices (each selector has its own local index space) .
1743
18- The SelectionVector manages everything else , including the coordinated updating of indices whenever a selection changes
44+ The SelectionVector coordinates across individual selectors , including the coordinated updating of indices whenever a selection changes
1945 """
2046 def __init__ (self , max_size : int = None ):
2147 # selector -> (map, map_inv)
48+
49+ ## Key is a selector, value is a (1) local to global index map (2) global to local index map (3) list of event handlers
2250 self ._selectors : dict [
23- SelectorProtocol | MultiSelectorProtocol , tuple [Callable , Callable , list ]
51+ SelectorProtocol | MultiSelectorProtocol , tuple [Callable , Callable , list [ Callable ] ]
2452 ] = dict ()
2553 self ._selection : list [Any ] = list ()
2654 self ._block_reentrance = False
@@ -40,12 +68,11 @@ def selection(self, new: Integral | Sequence[Any]):
4068 self ._selection = [i for i in new ]
4169 # iterate through each selector that operates in its own "local" space
4270 for selector_local , (map_ , map_inv , handler ) in self ._selectors .items ():
43- cumulated_output = []
71+ local_indices = []
4472 for value in new :
4573 curr_indices = map_ (value )
46- cumulated_output .append (curr_indices )
47- # indices_local = map_(new)
48- selector_local .selection = cumulated_output
74+ local_indices .append (curr_indices )
75+ selector_local .selection = local_indices
4976 self ._block_reentrance = False
5077
5178 def append (self , index ):
@@ -54,63 +81,69 @@ def append(self, index):
5481 if not isinstance (selector , MultiSelectorProtocol ):
5582 continue
5683
57- index_local = map_ ([ index ] )
84+ index_local = map_ (index )
5885 selector .append (index_local )
5986
6087 def add_selector (
6188 self ,
6289 new : (
6390 SelectorProtocol
64- | tuple [SelectorProtocol , np .ndarray | dict [int , int ]]
91+ | tuple [SelectorProtocol , dict ]
92+ | tuple [SelectorProtocol , np .ndarray ]
93+ | tuple [SelectorProtocol , Callable , Callable ]
6594 ),
6695 ):
6796 """
6897 User specifies (1) the selector and (2) The master --> local index mapping. This
6998 mapping is given either as:
7099 - A 1D np.ndarray of integers. The array index is the global index, and the array value is the local index
71100 - A dictionary where keys (master indices) and values (local indices) are both integers
101+ - Two callables. The first callable defines the global index --> local index map, the second specifies the local index --> global index map.
72102 """
73- selector : SelectorProtocol
74103 if isinstance (new , (tuple , list )):
75104 if not isinstance (new [0 ], SelectorProtocol ):
76105 raise TypeError
77106
78- if len (new ) != 2 :
79- raise TypeError
107+ if len (new ) == 3 :
108+ if isinstance (new [1 ], Callable ) and isinstance (new [2 ], Callable ):
109+ master_to_local = new [1 ]
110+ local_to_master = new [2 ]
111+ else :
112+ raise ValueError (f"Both index mappings must be Callables, you provided { type (new [1 ])} and { type (new [2 ])} " )
113+ elif len (new ) == 2 :
114+ if isinstance (new [1 ], dict ):
115+ ## Construct inverse mapping
116+ inverse_dict = dict ()
117+ for key , val in new [1 ].items ():
118+ inverse_dict [int (val )] = int (key )
119+ master_to_local = partial (dict_map , new [1 ])
120+ local_to_master = partial (dict_map , inverse_dict )
121+
122+ elif isinstance (new [1 ], np .ndarray ):
123+ if not new [1 ].ndim == 1 :
124+ raise ValueError ("If you pass in an array mapping, it must be 1-D" )
125+ master_to_local = partial (array_map , new [1 ])
126+ local_to_master = partial (inv_array_map , new [1 ])
127+ else :
128+ raise ValueError (f"Must either provide a single dict or numpy array specifying the local to global index mapping, or two callables"
129+ f"specifying the mapping in both directions" )
80130
81131 selector = new [0 ]
82- master_to_local = new [1 ]
83- if isinstance (master_to_local , np .ndarray ):
84- if not master_to_local .ndim == 1 :
85- raise ValueError ("If you pass in an array mapping, it must be 1-D" )
86- master_to_local = dict (enumerate (master_to_local ))
87-
88- ## Construct inverse mapping
89- inverse_dict = dict ()
90- for key , val in master_to_local .items ():
91- inverse_dict [int (val )] = int (key )
92-
93- ## Define the partial functions
94- master_to_local_map = lambda x :master_to_local [int (x )] if int (x ) in master_to_local else None
95- local_to_master_map = lambda x :inverse_dict [int (x )] if x in inverse_dict and x is not None else None
96132
97133 elif isinstance (new , SelectorProtocol ):
98- selector , master_to_local_map , local_to_master_map = new , identity , identity
134+ selector , master_to_local , local_to_master = new , identity , identity
99135
100136 else :
101137 raise ValueError
102138
103- handler = selector .add_event_handler (partial (self ._inv_handler , local_to_master_map ))
104- self ._selectors [selector ] = (master_to_local_map , local_to_master_map , [handler ])
139+ handler = selector .add_event_handler (partial (self ._inv_handler , local_to_master ))
140+ self ._selectors [selector ] = (master_to_local , local_to_master , [handler ])
105141
106- def _inv_handler (self , map_inv : Callable , local_selection : dict | GraphicFeatureEvent ):
107- if isinstance (local_selection , dict ):
108- input_to_map = local_selection ['value' ]
109- # local_selection = list(local_selection.items())[0][1]
110- elif isinstance (local_selection , GraphicFeatureEvent ):
111- input_to_map = local_selection .info ['value' ]
112- else :
113- raise ValueError ("Input to inverse handler should either be dictionary or GraphicFeatureEvent" )
142+ def _inv_handler (self , map_inv : Callable , local_selection : dict ):
143+ """
144+ HighlightSelector and VisibilitySelector emit a dictionary with keys selector and value
145+ """
146+ input_to_map = local_selection ['value' ]
114147 self .selection = [map_inv (input_to_map [i ]) for i in range (len (input_to_map ))]
115148
116149 def remove_selector (self , selector : SelectorProtocol | MultiSelectorProtocol ):
0 commit comments