Skip to content

Commit 7d4e420

Browse files
committed
index_mappings is working I think, lightly tested on p dim
1 parent 3731997 commit 7d4e420

2 files changed

Lines changed: 33 additions & 12 deletions

File tree

fastplotlib/widgets/nd_widget/_nd_positions.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ def get(self, indices: tuple[Any, ...]):
184184
Note that we do not use __getitem__ here since the index is a tuple specifying a single integer
185185
index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here.
186186
"""
187+
# apply any slider index mappings
188+
indices = tuple([m(i) for m, i in zip(self.index_mappings, indices)])
189+
187190
if len(indices) > 1:
188191
# there are dims in addition to the n_datapoints dim
189192
# apply window funcs
@@ -195,7 +198,8 @@ def get(self, indices: tuple[Any, ...]):
195198
# TODO: window function on the `p` n_datapoints dimension
196199

197200
if self.display_window is not None:
198-
dw = self.display_window
201+
# display window is interpreted using the index mapping for the `p` dim
202+
dw = self.index_mappings[-1](self.display_window)
199203

200204
if dw == 1:
201205
slices = [slice(indices[-1], indices[-1] + 1)]
@@ -244,7 +248,7 @@ def get(self, indices: tuple[Any, ...]):
244248
apply_dims = self.datapoints_window_func[1]
245249
ws = self.datapoints_window_size
246250

247-
# apply user's window func and return
251+
# apply user's window func
248252
# result will be of shape [n, p, 2 | 3]
249253
if apply_dims == "all":
250254
# windows will be of shape [n, p, 1 | 2 | 3, ws]
@@ -260,11 +264,11 @@ def get(self, indices: tuple[Any, ...]):
260264
).squeeze()
261265

262266
# this reshape is required to reshape wf outputs of shape [n, p] -> [n, p, 1] only when necessary
263-
graphic_data[..., : self.display_window, dims] = wf(
267+
graphic_data[..., : dw, dims] = wf(
264268
windows, axis=-1
265-
).reshape(graphic_data.shape[0], self.display_window, len(dims))
269+
).reshape(graphic_data.shape[0], dw, len(dims))
266270

267-
return graphic_data[..., : self.display_window, :]
271+
return graphic_data[..., : dw, :]
268272

269273
return graphic_data
270274

@@ -285,6 +289,7 @@ def __init__(
285289
display_window: int = 10,
286290
window_funcs: tuple[WindowFuncCallable | None] | None = None,
287291
window_sizes: tuple[int | None] | None = None,
292+
index_mappings: tuple[Callable[[Any], int] | None] | None = None,
288293
):
289294
if issubclass(graphic, LineCollection):
290295
multi = True
@@ -295,6 +300,7 @@ def __init__(
295300
display_window=display_window,
296301
window_funcs=window_funcs,
297302
window_sizes=window_sizes,
303+
index_mappings=index_mappings,
298304
)
299305

300306
self._indices = tuple([0] * self._processor.n_slider_dims)

fastplotlib/widgets/nd_widget/_processor_base.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike]
1212

1313

14+
def identity(index: int) -> int:
15+
return index
16+
17+
1418
class NDProcessor:
1519
def __init__(
1620
self,
@@ -23,7 +27,7 @@ def __init__(
2327
spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None,
2428
):
2529
self._data = self._validate_data(data)
26-
self._index_mappings = self._validate_index_mappings(index_mappings)
30+
self._index_mappings = tuple(self._validate_index_mappings(index_mappings))
2731

2832
self.window_funcs = window_funcs
2933
self.window_sizes = window_sizes
@@ -205,19 +209,30 @@ def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None:
205209
# pass
206210

207211
@property
208-
def index_mappings(self) -> tuple[Callable[[Any], int] | None, ...]:
212+
def index_mappings(self) -> tuple[Callable[[Any], int]]:
209213
return self._index_mappings
210214

211215
@index_mappings.setter
212-
def index_mappings(self, maps):
213-
self._index_mappings = self._validate_index_mappings(maps)
216+
def index_mappings(self, maps: tuple[Callable[[Any], int] | None] | None):
217+
self._index_mappings = tuple(self._validate_index_mappings(maps))
214218

215219
def _validate_index_mappings(self, maps):
216-
if maps is not None:
217-
if not all([callable(m) or m is None for m in maps]):
220+
if maps is None:
221+
return tuple([identity] * self.n_slider_dims)
222+
223+
if len(maps) != self.n_slider_dims:
224+
raise IndexError
225+
226+
_maps = list()
227+
for m in maps:
228+
if m is None:
229+
_maps.append(identity)
230+
elif callable(m):
231+
_maps.append(identity)
232+
else:
218233
raise TypeError
219234

220-
return maps
235+
return tuple(maps)
221236

222237
def __getitem__(self, item: tuple[Any, ...]) -> ArrayProtocol:
223238
pass

0 commit comments

Comments
 (0)