diff --git a/fastplotlib/graphics/selectors/_highlight_selector.py b/fastplotlib/graphics/selectors/_highlight_selector.py index 3ba08a676..dc757c07a 100644 --- a/fastplotlib/graphics/selectors/_highlight_selector.py +++ b/fastplotlib/graphics/selectors/_highlight_selector.py @@ -681,7 +681,7 @@ def selection(self) -> tuple[int | None, ...] | dict[str, tuple]: return {k: tuple(v) for k, v in self._selection.items()} @selection.setter - def selection(self, value: Iterable[int] | dict[Literal["rows", "cols", "pixels"], list]) -> None: + def selection(self, value: Iterable[int | None] | dict[Literal["rows", "cols", "pixels"], list] | None) -> None: if self._selection_options is not None: if value is None: self._selected_indices = list() diff --git a/fastplotlib/graphics/selectors/_linear.py b/fastplotlib/graphics/selectors/_linear.py index 4ea454ee8..f652a3d9e 100644 --- a/fastplotlib/graphics/selectors/_linear.py +++ b/fastplotlib/graphics/selectors/_linear.py @@ -27,7 +27,7 @@ def selection(self) -> float: return self._selection.value @selection.setter - def selection(self, value: int): + def selection(self, value: float): graphic = self._parent if isinstance(graphic, GraphicCollection): diff --git a/fastplotlib/graphics/selectors/_selection_vector.py b/fastplotlib/graphics/selectors/_selection_vector.py index a1e0bed10..10b2885ea 100644 --- a/fastplotlib/graphics/selectors/_selection_vector.py +++ b/fastplotlib/graphics/selectors/_selection_vector.py @@ -1,89 +1,161 @@ from collections.abc import Callable from functools import partial -from typing import Any, Sequence +from typing import Any, Sequence, TypeAlias +from numbers import Integral + +import numpy as np from ._protocols import SelectorProtocol, MultiSelectorProtocol +Mapping: TypeAlias = np.ndarray | dict[int, int] | Callable def identity(val: Any) -> Any: return val +def array_map(arr: np.ndarray, index: Integral): + """ + Used to map local to global indices + """ + return None if np.isnan(arr[index]) else int(arr[index]) + +def inv_array_map(arr: np.ndarray, + value: int) -> None | Integral: + """ + arr[i] gives the global index + """ + x = np.flatnonzero(arr == value) + return None if x.size == 0 else int(x[0]) + +def dict_map(my_dict: dict, key: Integral): + if key is None: + return None + elif int(key) not in my_dict: + return None + else: + return my_dict[key] + class SelectionVector: + """ + A class for performing coordinated selections across multiple selectors. + For each selector in the selection vector, the user specifies how the global indices (shared across selectors) + maps to the local indices (each selector has its own local index space). + + The SelectionVector coordinates across individual selectors, including the coordinated updating of indices whenever a selection changes + """ def __init__(self, max_size: int = None): # selector -> (map, map_inv) + + ## Key is a selector, value is a (1) local to global index map (2) global to local index map (3) list of event handlers self._selectors: dict[ - SelectorProtocol | MultiSelectorProtocol, tuple[Callable, Callable] + SelectorProtocol | MultiSelectorProtocol, tuple[Callable, Callable, list[Callable]] ] = dict() self._selection: list[Any] = list() + self._block_reentrance = False @property def selection(self) -> tuple[Any]: return tuple(self._selection) @selection.setter - def selection(self, new: Sequence[Any]): - # iterate through each selector that operates in its own "local" space - for selector_local, (map_, map_inv) in self._selectors.items(): - indices_local = map_(new) - selector_local.selection = indices_local + def selection(self, new: Integral | Sequence[Any]): + if self._block_reentrance: + return + else: + self._block_reentrance = True + if isinstance(new, Integral): + new = [new] + self._selection = [i for i in new] + # iterate through each selector that operates in its own "local" space + for selector_local, (map_, map_inv, handler) in self._selectors.items(): + local_indices = [] + for value in new: + curr_indices = map_(value) + local_indices.append(curr_indices) + selector_local.selection = local_indices + self._block_reentrance = False def append(self, index): self._selection.append(index) - for selector, (map_, map_inv) in self._selectors.items(): + for selector, (map_, map_inv, handler_list) in self._selectors.items(): if not isinstance(selector, MultiSelectorProtocol): continue - index_local = map_([index]) - selector.append(index_local[0]) - - def clear(self): - self._selection.clear() - # TODO: clear selectors + index_local = map_(index) + selector.append(index_local) def add_selector( self, new: ( SelectorProtocol - | tuple[SelectorProtocol, Callable] - | tuple[SelectorProtocol, Callable, Callable] + | tuple[SelectorProtocol, dict] + | tuple[SelectorProtocol, np.ndarray] + |tuple[SelectorProtocol, Callable, Callable] ), ): - selector: SelectorProtocol - map_: Callable - map_inv: Callable - + """ + User specifies (1) the selector and (2) The master --> local index mapping. This + mapping is given either as: + - A 1D np.ndarray of integers. The array index is the global index, and the array value is the local index + - A dictionary where keys (master indices) and values (local indices) are both integers + - Two callables. The first callable defines the global index --> local index map, the second specifies the local index --> global index map. + All callables take as input nonnegative integers and output nonnegative integers. + """ if isinstance(new, (tuple, list)): if not isinstance(new[0], SelectorProtocol): raise TypeError - if len(new) not in (2, 3): - raise TypeError - - if not all(callable(c) for c in new[1:]): - raise TypeError + if len(new) == 3: + if isinstance(new[1], Callable) and isinstance(new[2], Callable): + master_to_local = new[1] + local_to_master = new[2] + else: + raise ValueError(f"Both index mappings must be Callables, you provided {type(new[1])} and {type(new[2])}") + elif len(new) == 2: + if isinstance(new[1], dict): + ## Construct inverse mapping + inverse_dict = dict() + for key, val in new[1].items(): + inverse_dict[int(val)] = int(key) + master_to_local = partial(dict_map, new[1]) + local_to_master = partial(dict_map, inverse_dict) + + elif isinstance(new[1], np.ndarray): + if not new[1].ndim == 1: + raise ValueError("If you pass in an array mapping, it must be 1-D") + master_to_local = partial(array_map, new[1]) + local_to_master = partial(inv_array_map, new[1]) + else: + raise ValueError(f"Must either provide a single dict or numpy array specifying the local to global index mapping, or two callables" + f"specifying the mapping in both directions") selector = new[0] - map_ = new[1] - map_inv = new[2] if len(new) == 3 else identity elif isinstance(new, SelectorProtocol): - selector, map_, map_inv = new, identity, identity + selector, master_to_local, local_to_master = new, identity, identity else: raise ValueError - selector.add_event_handler(partial(self._inv_handler, map_inv)) - - self._selectors[selector] = (map_, map_inv) - - def _inv_handler(self, map_inv: Callable, local_selection): - return - # when a selectable changes its selection, set global index change using map inverse - # self._selection = map_inv(local_selection) - - def remove(self): - pass - - def clear_selectables(self): - self._selectors.clear() + handler = selector.add_event_handler(partial(self._inv_handler, local_to_master)) + self._selectors[selector] = (master_to_local, local_to_master, [handler]) + + def _inv_handler(self, map_inv: Callable, local_selection: dict): + """ + HighlightSelector and VisibilitySelector emit a dictionary with keys selector and value + """ + input_to_map = local_selection['value'] + self.selection = [map_inv(input_to_map[i]) for i in range(len(input_to_map))] + + def remove_selector(self, selector: SelectorProtocol | MultiSelectorProtocol): + if selector in self._selectors: + map, map_inv, handler_list = self._selectors.pop(selector) + for handler in handler_list: + selector.remove_event_handler(handler) + if isinstance(selector, MultiSelectorProtocol): + selector.clear() + + def clear_selectors(self): + for selector in self._selectors.keys(): + if isinstance(selector, MultiSelectorProtocol): + selector.clear() \ No newline at end of file