Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fastplotlib/graphics/selectors/_highlight_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ def selection(self) -> tuple[int | None, ...] | dict[str, tuple]:
return {k: tuple(v) for k, v in self._selection.items()}

@selection.setter
def selection(self, value: Iterable[int] | dict[Literal["rows", "cols", "pixels"], list]) -> None:
def selection(self, value: Iterable[int | None] | dict[Literal["rows", "cols", "pixels"], list] | None) -> None:
if self._selection_options is not None:
if value is None:
self._selected_indices = list()
Expand Down
2 changes: 1 addition & 1 deletion fastplotlib/graphics/selectors/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def selection(self) -> float:
return self._selection.value

@selection.setter
def selection(self, value: int):
def selection(self, value: float):
graphic = self._parent

if isinstance(graphic, GraphicCollection):
Expand Down
156 changes: 114 additions & 42 deletions fastplotlib/graphics/selectors/_selection_vector.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,161 @@
from collections.abc import Callable
from functools import partial
from typing import Any, Sequence
from typing import Any, Sequence, TypeAlias
from numbers import Integral

import numpy as np

from ._protocols import SelectorProtocol, MultiSelectorProtocol

Mapping: TypeAlias = np.ndarray | dict[int, int] | Callable

def identity(val: Any) -> Any:
return val

def array_map(arr: np.ndarray, index: Integral):
"""
Used to map local to global indices
"""
return None if np.isnan(arr[index]) else int(arr[index])

def inv_array_map(arr: np.ndarray,
value: int) -> None | Integral:
"""
arr[i] gives the global index
"""
x = np.flatnonzero(arr == value)
return None if x.size == 0 else int(x[0])

def dict_map(my_dict: dict, key: Integral):
if key is None:
return None
elif int(key) not in my_dict:
return None
else:
return my_dict[key]


class SelectionVector:
"""
A class for performing coordinated selections across multiple selectors.
For each selector in the selection vector, the user specifies how the global indices (shared across selectors)
maps to the local indices (each selector has its own local index space).

The SelectionVector coordinates across individual selectors, including the coordinated updating of indices whenever a selection changes
"""
def __init__(self, max_size: int = None):
# selector -> (map, map_inv)

## Key is a selector, value is a (1) local to global index map (2) global to local index map (3) list of event handlers
self._selectors: dict[
SelectorProtocol | MultiSelectorProtocol, tuple[Callable, Callable]
SelectorProtocol | MultiSelectorProtocol, tuple[Callable, Callable, list[Callable]]
] = dict()
self._selection: list[Any] = list()
self._block_reentrance = False

@property
def selection(self) -> tuple[Any]:
return tuple(self._selection)

@selection.setter
def selection(self, new: Sequence[Any]):
# iterate through each selector that operates in its own "local" space
for selector_local, (map_, map_inv) in self._selectors.items():
indices_local = map_(new)
selector_local.selection = indices_local
def selection(self, new: Integral | Sequence[Any]):
if self._block_reentrance:
return
else:
self._block_reentrance = True
if isinstance(new, Integral):
new = [new]
self._selection = [i for i in new]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't do anything, can just be self._selection = new

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kushalkolar thought about this -- this does something. if the input is any iterable other than a list, it makes the input a list.

# iterate through each selector that operates in its own "local" space
for selector_local, (map_, map_inv, handler) in self._selectors.items():
local_indices = []
for value in new:
curr_indices = map_(value)
local_indices.append(curr_indices)
selector_local.selection = local_indices
self._block_reentrance = False

def append(self, index):
self._selection.append(index)
for selector, (map_, map_inv) in self._selectors.items():
for selector, (map_, map_inv, handler_list) in self._selectors.items():
if not isinstance(selector, MultiSelectorProtocol):
continue

index_local = map_([index])
selector.append(index_local[0])

def clear(self):
self._selection.clear()
# TODO: clear selectors
index_local = map_(index)
selector.append(index_local)

def add_selector(
self,
new: (
SelectorProtocol
| tuple[SelectorProtocol, Callable]
| tuple[SelectorProtocol, Callable, Callable]
| tuple[SelectorProtocol, dict]
| tuple[SelectorProtocol, np.ndarray]
|tuple[SelectorProtocol, Callable, Callable]
),
):
selector: SelectorProtocol
map_: Callable
map_inv: Callable

"""
User specifies (1) the selector and (2) The master --> local index mapping. This

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they can also specify a mapping Callable directly

mapping is given either as:
- A 1D np.ndarray of integers. The array index is the global index, and the array value is the local index
- A dictionary where keys (master indices) and values (local indices) are both integers
- Two callables. The first callable defines the global index --> local index map, the second specifies the local index --> global index map.
All callables take as input nonnegative integers and output nonnegative integers.
"""
if isinstance(new, (tuple, list)):
if not isinstance(new[0], SelectorProtocol):
raise TypeError

if len(new) not in (2, 3):
raise TypeError

if not all(callable(c) for c in new[1:]):
raise TypeError
if len(new) == 3:
if isinstance(new[1], Callable) and isinstance(new[2], Callable):
master_to_local = new[1]
local_to_master = new[2]
else:
raise ValueError(f"Both index mappings must be Callables, you provided {type(new[1])} and {type(new[2])}")
elif len(new) == 2:
if isinstance(new[1], dict):
## Construct inverse mapping
inverse_dict = dict()
for key, val in new[1].items():
inverse_dict[int(val)] = int(key)
master_to_local = partial(dict_map, new[1])
local_to_master = partial(dict_map, inverse_dict)

elif isinstance(new[1], np.ndarray):
if not new[1].ndim == 1:
raise ValueError("If you pass in an array mapping, it must be 1-D")
master_to_local = partial(array_map, new[1])
local_to_master = partial(inv_array_map, new[1])
else:
raise ValueError(f"Must either provide a single dict or numpy array specifying the local to global index mapping, or two callables"
f"specifying the mapping in both directions")

selector = new[0]
map_ = new[1]
map_inv = new[2] if len(new) == 3 else identity

elif isinstance(new, SelectorProtocol):
selector, map_, map_inv = new, identity, identity
selector, master_to_local, local_to_master = new, identity, identity

else:
raise ValueError

selector.add_event_handler(partial(self._inv_handler, map_inv))

self._selectors[selector] = (map_, map_inv)

def _inv_handler(self, map_inv: Callable, local_selection):
return
# when a selectable changes its selection, set global index change using map inverse
# self._selection = map_inv(local_selection)

def remove(self):
pass

def clear_selectables(self):
self._selectors.clear()
handler = selector.add_event_handler(partial(self._inv_handler, local_to_master))
self._selectors[selector] = (master_to_local, local_to_master, [handler])

def _inv_handler(self, map_inv: Callable, local_selection: dict):
"""
HighlightSelector and VisibilitySelector emit a dictionary with keys selector and value
"""
input_to_map = local_selection['value']
self.selection = [map_inv(input_to_map[i]) for i in range(len(input_to_map))]

def remove_selector(self, selector: SelectorProtocol | MultiSelectorProtocol):
if selector in self._selectors:
map, map_inv, handler_list = self._selectors.pop(selector)
for handler in handler_list:
selector.remove_event_handler(handler)
if isinstance(selector, MultiSelectorProtocol):
selector.clear()

def clear_selectors(self):
for selector in self._selectors.keys():
if isinstance(selector, MultiSelectorProtocol):
selector.clear()