From 9a01cd5ee7fda8e4dd923670f0466d1233bc6de0 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Thu, 25 Dec 2025 01:04:06 -0800 Subject: [PATCH 001/101] start ndprocessors --- fastplotlib/utils/__init__.py | 1 + fastplotlib/utils/_protocols.py | 12 ++ fastplotlib/widgets/nd_widget/_processor.py | 141 ++++++++++++++++++++ 3 files changed, 154 insertions(+) create mode 100644 fastplotlib/utils/_protocols.py create mode 100644 fastplotlib/widgets/nd_widget/_processor.py diff --git a/fastplotlib/utils/__init__.py b/fastplotlib/utils/__init__.py index dd527ca67..8001ae375 100644 --- a/fastplotlib/utils/__init__.py +++ b/fastplotlib/utils/__init__.py @@ -6,6 +6,7 @@ from .gpu import enumerate_adapters, select_adapter, print_wgpu_report from ._plot_helpers import * from .enums import * +from ._protocols import ArrayProtocol @dataclass diff --git a/fastplotlib/utils/_protocols.py b/fastplotlib/utils/_protocols.py new file mode 100644 index 000000000..c168ecfa4 --- /dev/null +++ b/fastplotlib/utils/_protocols.py @@ -0,0 +1,12 @@ +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class ArrayProtocol(Protocol): + @property + def ndim(self) -> int: ... + + @property + def shape(self) -> tuple[int, ...]: ... + + def __getitem__(self, key): ... diff --git a/fastplotlib/widgets/nd_widget/_processor.py b/fastplotlib/widgets/nd_widget/_processor.py new file mode 100644 index 000000000..9e5299118 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_processor.py @@ -0,0 +1,141 @@ +import inspect +from typing import Literal, Callable, Any +from warnings import warn + +import numpy as np +from numpy.typing import ArrayLike + +from ...utils import subsample_array, ArrayProtocol + + +# must take arguments: array-like, `axis`: int, `keepdims`: bool +WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] + + +class NDProcessor: + def __init__( + self, + data: ArrayProtocol, + n_display_dims: Literal[2, 3] = 2, + slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, + window_funcs: tuple[WindowFuncCallable | None] | None = None, + window_sizes: tuple[int | None] | None = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, + ): + self._data = self._validate_data(data) + self._slider_index_maps = self._validate_slider_index_maps(slider_index_maps) + + @property + def data(self) -> ArrayProtocol: + return self._data + + @data.setter + def data(self, data: ArrayProtocol): + self._data = self._validate_data(data) + + def _validate_data(self, data: ArrayProtocol): + if not isinstance(data, ArrayProtocol): + raise TypeError("`data` must implement the ArrayProtocol") + + return data + + @property + def window_funcs(self) -> tuple[WindowFuncCallable | None] | None: + pass + + @property + def window_sizes(self) -> tuple[int | None] | None: + pass + + @property + def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None: + pass + + @property + def slider_dims(self) -> tuple[int, ...] | None: + pass + + @property + def slider_index_maps(self) -> tuple[Callable[[Any], int] | None, ...]: + return self._slider_index_maps + + @slider_index_maps.setter + def slider_index_maps(self, maps): + self._maps = self._validate_slider_index_maps(maps) + + def _validate_slider_index_maps(self, maps): + if maps is not None: + if not all([callable(m) or m is None for m in maps]): + raise TypeError + + return maps + + def __getitem__(self, item: tuple[Any, ...]) -> ArrayProtocol: + pass + + +class NDImageProcessor(NDProcessor): + @property + def n_display_dims(self) -> Literal[2, 3]: + pass + + def _validate_n_display_dims(self, n_display_dims): + if n_display_dims not in (2, 3): + raise ValueError("`n_display_dims` must be") + + +class NDTimeSeriesProcessor(NDProcessor): + def __init__( + self, + data: ArrayProtocol, + graphic: Literal["line", "heatmap"] = "line", + n_display_dims: Literal[2, 3] = 2, + slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, + display_window: int | float | None = None, + window_funcs: tuple[WindowFuncCallable | None] | None = None, + window_sizes: tuple[int | None] | None = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, + ): + super().__init__( + data=data, + n_display_dims=n_display_dims, + slider_index_maps=slider_index_maps, + ) + + self._display_window = display_window + + def _validate_data(self, data: ArrayProtocol): + data = super()._validate_data(data) + + # need to make shape be [n_lines, n_datapoints, 2] + # this will work for displaying a linestack and heatmap + # for heatmap just slice: [..., 1] + # TODO: Think about how to allow n-dimensional lines, + # maybe [d1, d2, ..., d(n - 1), n_lines, n_datapoint, 2] + # and dn is the x-axis values?? + if data.ndim == 1: + pass + + @property + def display_window(self) -> int | float | None: + """display window in the reference units along the x-axis""" + return self._display_window + + def __getitem__(self, indices: tuple[Any, ...]) -> ArrayProtocol: + if self.display_window is not None: + # map reference units -> array int indices if necessary + if self.slider_index_maps is not None: + indices_window = self.slider_index_maps(self.display_window) + else: + indices_window = self.display_window + + # half window size + hw = indices_window // 2 + + # for now assume just a single index provided that indicates x axis value + start = max(indices - hw, 0) + stop = indices + hw + + # slice dim would be ndim - 1 + + return self.data[start:stop] From c46455ff71e460772148bf629dd906beffaf3cca Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 27 Dec 2025 03:52:32 -0800 Subject: [PATCH 002/101] basic timeseries --- fastplotlib/widgets/nd_widget/_processor.py | 187 +++++++++++++++++--- 1 file changed, 159 insertions(+), 28 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_processor.py b/fastplotlib/widgets/nd_widget/_processor.py index 9e5299118..d0a8e66ab 100644 --- a/fastplotlib/widgets/nd_widget/_processor.py +++ b/fastplotlib/widgets/nd_widget/_processor.py @@ -5,6 +5,7 @@ import numpy as np from numpy.typing import ArrayLike +from ...graphics import ImageGraphic, LineStack, LineCollection, ScatterGraphic from ...utils import subsample_array, ArrayProtocol @@ -14,13 +15,13 @@ class NDProcessor: def __init__( - self, - data: ArrayProtocol, - n_display_dims: Literal[2, 3] = 2, - slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, - window_funcs: tuple[WindowFuncCallable | None] | None = None, - window_sizes: tuple[int | None] | None = None, - spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, + self, + data, + n_display_dims: Literal[2, 3] = 2, + slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, + window_funcs: tuple[WindowFuncCallable | None] | None = None, + window_sizes: tuple[int | None] | None = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, ): self._data = self._validate_data(data) self._slider_index_maps = self._validate_slider_index_maps(slider_index_maps) @@ -84,17 +85,30 @@ def _validate_n_display_dims(self, n_display_dims): raise ValueError("`n_display_dims` must be") +VALID_TIMESERIES_Y_DATA_SHAPES = ( + "[n_datapoints] for 1D array of y-values, [n_datapoints, 2] " + "for a 1D array of y and z-values, [n_lines, n_datapoints] for a 2D stack of lines with y-values, " + "or [n_lines, n_datapoints, 2] for a stack of lines with y and z-values." +) + + +# Limitation, no heatmap if z-values present, I don't think you can visualize that class NDTimeSeriesProcessor(NDProcessor): def __init__( - self, - data: ArrayProtocol, - graphic: Literal["line", "heatmap"] = "line", - n_display_dims: Literal[2, 3] = 2, - slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, - display_window: int | float | None = None, - window_funcs: tuple[WindowFuncCallable | None] | None = None, - window_sizes: tuple[int | None] | None = None, - spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, + self, + data: list[ + ArrayProtocol, ArrayProtocol + ], # list: [x_vals_array, y_vals_and_z_vals_array] + x_values: ArrayProtocol = None, + cmap: str = None, + cmap_transform: ArrayProtocol = None, + display_graphic: Literal["line", "heatmap"] = "line", + n_display_dims: Literal[2, 3] = 2, + slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, + display_window: int | float | None = 100, + window_funcs: tuple[WindowFuncCallable | None] | None = None, + window_sizes: tuple[int | None] | None = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, ): super().__init__( data=data, @@ -104,23 +118,73 @@ def __init__( self._display_window = display_window - def _validate_data(self, data: ArrayProtocol): - data = super()._validate_data(data) + self._display_graphic = None + self.display_graphic = display_graphic - # need to make shape be [n_lines, n_datapoints, 2] - # this will work for displaying a linestack and heatmap - # for heatmap just slice: [..., 1] - # TODO: Think about how to allow n-dimensional lines, - # maybe [d1, d2, ..., d(n - 1), n_lines, n_datapoint, 2] - # and dn is the x-axis values?? - if data.ndim == 1: - pass + self._uniform_x_values: ArrayProtocol | None = None + self._interp_yz: ArrayProtocol | None = None + + @property + def data(self) -> list[ArrayProtocol, ArrayProtocol]: + return self._data + + @data.setter + def data(self, data: list[ArrayProtocol, ArrayProtocol]): + self._data = self._validate_data(data) + + def _validate_data(self, data: list[ArrayProtocol, ArrayProtocol]): + x_vals, yz_vals = data + + if x_vals.ndim != 1: + raise ("data x values must be 1D") + + if data[1].ndim > 3: + raise ValueError( + f"data yz values must be of shape: {VALID_TIMESERIES_Y_DATA_SHAPES}. You passed data of shape: {yz_vals.shape}" + ) + + return data + + @property + def display_graphic(self) -> Literal["line", "heatmap"]: + return self._display_graphic + + @display_graphic.setter + def display_graphic(self, dg: Literal["line", "heatmap"]): + dg = self._validate_display_graphic(dg) + + if dg == "heatmap": + # check if x-vals uniformly spaced + norm = np.linalg.norm(np.diff(np.diff(self.x_values))) / len(self.x_values) + if norm > 10 ** -12: + # need to create evenly spaced x-values + x0 = self.data[0][0] + xn = self.data[0][-1] + self._uniform_x_values = np.linspace(x0, xn, num=len(self.data[0])) + + # TODO: interpolate yz values on the fly only when within the display window + + def _validate_display_graphic(self, dg): + if dg not in ("line", "heatmap"): + raise ValueError + + return dg @property def display_window(self) -> int | float | None: """display window in the reference units along the x-axis""" return self._display_window + @display_window.setter + def display_window(self, dw: int | float | None): + if dw is None: + self._display_window = None + + elif not isinstance(dw, (int, float)): + raise TypeError + + self._display_window = dw + def __getitem__(self, indices: tuple[Any, ...]) -> ArrayProtocol: if self.display_window is not None: # map reference units -> array int indices if necessary @@ -134,8 +198,75 @@ def __getitem__(self, indices: tuple[Any, ...]) -> ArrayProtocol: # for now assume just a single index provided that indicates x axis value start = max(indices - hw, 0) - stop = indices + hw + stop = start + indices_window # slice dim would be ndim - 1 + return self.data[0][start:stop], self.data[1][:, start:stop] + + +class NDTimeSeries: + def __init__(self, processor: NDTimeSeriesProcessor, display_graphic): + self._processor = processor + + self._indices = 0 + + if display_graphic == "line": + self._create_line_stack() + + @property + def processor(self) -> NDTimeSeriesProcessor: + return self._processor + + @property + def graphic(self) -> LineStack | ImageGraphic: + """LineStack or ImageGraphic for heatmaps""" + return self._graphic + + @property + def display_window(self) -> int | float | None: + return self.processor.display_window + + @display_window.setter + def display_window(self, dw: int | float | None): + # create new graphic if it changed + if dw != self.display_window: + create_new_graphic = True + else: + create_new_graphic = False + + self.processor.display_window = dw + + if create_new_graphic: + if isinstance(self.graphic, LineStack): + self.set_index(self._indices) + + def set_index(self, indices: tuple[Any, ...]): + # set the graphic at the given data indices + data_slice = self.processor[indices] + + if isinstance(self.graphic, LineStack): + line_stack_data = self._create_line_stack_data(data_slice) + + for g, line_data in zip(self.graphic.graphics, line_stack_data): + if line_data.shape[1] == 2: + # only x and y values + g.data[:, :-1] = line_data + else: + # has z values too + g.data[:] = line_data + + self._indices = indices + + def _create_line_stack_data(self, data_slice): + xs = data_slice[0] # 1D + yz = data_slice[1] # [n_lines, n_datapoints] for y-vals or [n_lines, n_datapoints, 2] for yz-vals + + # need to go from x_vals and yz_vals arrays to an array of shape: [n_lines, n_datapoints, 2 | 3] + return np.dstack([np.repeat(xs[None], repeats=yz.shape[0], axis=0), yz]) + + def _create_line_stack(self): + data_slice = self.processor[self._indices] + + ls_data = self._create_line_stack_data(data_slice) - return self.data[start:stop] + self._graphic = LineStack(ls_data) From d93fa5d5fdc685b8d7f2b7bc38a95abb100f31da Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 27 Dec 2025 03:52:52 -0800 Subject: [PATCH 003/101] add __init__ --- fastplotlib/widgets/nd_widget/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 fastplotlib/widgets/nd_widget/__init__.py diff --git a/fastplotlib/widgets/nd_widget/__init__.py b/fastplotlib/widgets/nd_widget/__init__.py new file mode 100644 index 000000000..e69de29bb From fddefb826f44f443c2504557c6d5e76b2e50c05f Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 27 Dec 2025 17:46:54 -0800 Subject: [PATCH 004/101] heatmap for timeseries works! --- fastplotlib/widgets/nd_widget/_processor.py | 55 ++++++++++++++++++++- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_processor.py b/fastplotlib/widgets/nd_widget/_processor.py index d0a8e66ab..0add36594 100644 --- a/fastplotlib/widgets/nd_widget/_processor.py +++ b/fastplotlib/widgets/nd_widget/_processor.py @@ -155,6 +155,7 @@ def display_graphic(self, dg: Literal["line", "heatmap"]): if dg == "heatmap": # check if x-vals uniformly spaced + # this is very fast to do on the fly, especially for typical small display windows norm = np.linalg.norm(np.diff(np.diff(self.x_values))) / len(self.x_values) if norm > 10 ** -12: # need to create evenly spaced x-values @@ -205,13 +206,17 @@ def __getitem__(self, indices: tuple[Any, ...]) -> ArrayProtocol: class NDTimeSeries: - def __init__(self, processor: NDTimeSeriesProcessor, display_graphic): + def __init__(self, processor: NDTimeSeriesProcessor, graphic): self._processor = processor self._indices = 0 - if display_graphic == "line": + if graphic == "line": self._create_line_stack() + elif graphic == "heatmap": + self._create_heatmap() + else: + raise ValueError @property def processor(self) -> NDTimeSeriesProcessor: @@ -222,6 +227,19 @@ def graphic(self) -> LineStack | ImageGraphic: """LineStack or ImageGraphic for heatmaps""" return self._graphic + @graphic.setter + def graphic(self, g: Literal["line", "heatmap"]): + if g == "line": + # TODO: remove existing graphic + self._create_line_stack() + + elif g == "heatmap": + # make sure "yz" data is only ys and no z values + # can't represent y and z vals in a heatmap + if self.processor.data[1].ndim > 2: + raise ValueError("Only y-values are supported for heatmaps, not yz-values") + self._create_heatmap() + @property def display_window(self) -> int | float | None: return self.processor.display_window @@ -255,6 +273,10 @@ def set_index(self, indices: tuple[Any, ...]): # has z values too g.data[:] = line_data + elif isinstance(self.graphic, ImageGraphic): + hm_data, scale = self._create_heatmap_data(data_slice) + self.graphic.data = hm_data + self._indices = indices def _create_line_stack_data(self, data_slice): @@ -270,3 +292,32 @@ def _create_line_stack(self): ls_data = self._create_line_stack_data(data_slice) self._graphic = LineStack(ls_data) + + def _create_heatmap_data(self, data_slice) -> tuple[ArrayProtocol, float]: + """Returns [n_lines, y_values] array and scale factor for x dimension""" + # check if x-vals uniformly spaced + # this is very fast to do on the fly, especially for typical small display windows + x, y = data_slice + norm = np.linalg.norm(np.diff(np.diff(x))) / x.size + if norm > 10 ** -12: + # need to create evenly spaced x-values + x_uniform = np.linspace(x[0], x[-1], num=x.size) + # yz is [n_lines, n_datapoints] + y_interp = np.zeros(shape=y.shape, dtype=np.float32) + for i in range(y.shape[0]): + y_interp[i] = np.interp(x_uniform, x, y[i]) + + else: + y_interp = y + + x_scale = x[-1] / x.size + + return y_interp, x_scale + + def _create_heatmap(self): + data_slice = self.processor[self._indices] + + hm_data, x_scale = self._create_heatmap_data(data_slice) + + self._graphic = ImageGraphic(hm_data) + self._graphic.world_object.world.scale_x = x_scale \ No newline at end of file From d5e4c7d45901b1f5f2de89e68ef4d416d0ea7dde Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Mon, 29 Dec 2025 01:44:11 -0800 Subject: [PATCH 005/101] NDPositions, basics work, reorganize, increase default scatter size --- fastplotlib/graphics/scatter.py | 2 +- fastplotlib/widgets/nd_widget/_nd_image.py | 13 ++ .../widgets/nd_widget/_nd_positions.py | 137 ++++++++++++++++++ .../{_processor.py => _nd_timeseries.py} | 104 +------------ .../widgets/nd_widget/_processor_base.py | 74 ++++++++++ 5 files changed, 227 insertions(+), 103 deletions(-) create mode 100644 fastplotlib/widgets/nd_widget/_nd_image.py create mode 100644 fastplotlib/widgets/nd_widget/_nd_positions.py rename fastplotlib/widgets/nd_widget/{_processor.py => _nd_timeseries.py} (70%) create mode 100644 fastplotlib/widgets/nd_widget/_processor_base.py diff --git a/fastplotlib/graphics/scatter.py b/fastplotlib/graphics/scatter.py index a2e696a82..5268dcc51 100644 --- a/fastplotlib/graphics/scatter.py +++ b/fastplotlib/graphics/scatter.py @@ -53,7 +53,7 @@ def __init__( image: np.ndarray = None, point_rotations: float | np.ndarray = 0, point_rotation_mode: Literal["uniform", "vertex", "curve"] = "uniform", - sizes: float | np.ndarray | Sequence[float] = 1, + sizes: float | np.ndarray | Sequence[float] = 5, uniform_size: bool = False, size_space: str = "screen", isolated_buffer: bool = True, diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py new file mode 100644 index 000000000..f115e146e --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -0,0 +1,13 @@ +from typing import Literal + +from ._processor_base import NDProcessor + + +class NDImageProcessor(NDProcessor): + @property + def n_display_dims(self) -> Literal[2, 3]: + pass + + def _validate_n_display_dims(self, n_display_dims): + if n_display_dims not in (2, 3): + raise ValueError("`n_display_dims` must be") diff --git a/fastplotlib/widgets/nd_widget/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions.py new file mode 100644 index 000000000..db8c80e72 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_positions.py @@ -0,0 +1,137 @@ +import inspect +from typing import Literal, Callable, Any, Type +from warnings import warn + +import numpy as np +from numpy.typing import ArrayLike + +from ...utils import subsample_array, ArrayProtocol + +from ...graphics import ImageGraphic, LineGraphic, LineStack, LineCollection, ScatterGraphic +from ._processor_base import NDProcessor + +# TODO: Maybe get rid of n_display_dims in NDProcessor, +# we will know the display dims automatically here from the last dim +# so maybe we only need it for images? +class NDPositionsProcessor(NDProcessor): + def __init__( + self, + data: ArrayProtocol, + multi: bool = False, # TODO: interpret [n - 2] dimension as n_lines or n_points + display_window: int | float | None = 100, # window for n_datapoints dim only + ): + super().__init__(data=data) + + self._display_window = display_window + + self.multi = multi + + def _validate_data(self, data: ArrayProtocol): + # TODO: determine right validation shape etc. + return data + + @property + def display_window(self) -> int | float | None: + """display window in the reference units for the n_datapoints dim""" + return self._display_window + + @display_window.setter + def display_window(self, dw: int | float | None): + if dw is None: + self._display_window = None + + elif not isinstance(dw, (int, float)): + raise TypeError + + self._display_window = dw + + @property + def multi(self) -> bool: + return self._multi + + @multi.setter + def multi(self, m: bool): + if m and self.data.ndim < 3: + # p is p-datapoints, n is how many lines/scatter to show simultaneously + raise ValueError("ndim must be >= 3 for multi, shape must be [s1..., sn, n, p, 2 | 3]") + + self._multi = m + + def __getitem__(self, indices: tuple[Any, ...]): + """sliders through all slider dims and outputs an array that can be used to set graphic data""" + if self.display_window is not None: + indices_window = self.display_window + + # half window size + hw = indices_window // 2 + + # for now assume just a single index provided that indicates x axis value + start = max(indices - hw, 0) + stop = start + indices_window + + slices = [slice(start, stop)] + + # TODO: implement slicing for multiple slider dims, i.e. [s1, s2, ... n_datapoints, 2 | 3] + # this currently assumes the shape is: [n_datapoints, 2 | 3] + if self.multi: + # n - 2 dim is n_lines or n_scatters + slices.insert(0, slice(None)) + + return self.data[tuple(slices)] + + +class NDPositions: + def __init__(self, data, graphic: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic], multi: bool = False): + self._indices = 0 + + if issubclass(graphic, LineCollection): + multi = True + + self._processor = NDPositionsProcessor(data, multi=multi) + + self._create_graphic(graphic) + + @property + def processor(self) -> NDPositionsProcessor: + return self._processor + + @property + def graphic(self) -> LineGraphic | LineCollection | LineStack | ScatterGraphic | list[ScatterGraphic]: + """LineStack or ImageGraphic for heatmaps""" + return self._graphic + + @property + def indices(self) -> tuple: + return self._indices + + @indices.setter + def indices(self, indices): + data_slice = self.processor[indices] + + if isinstance(self.graphic, list): + # list of scatter + for i in range(len(self.graphic)): + # data_slice shape is [n_scatters, n_datapoints, 2 | 3] + # by using data_slice.shape[-1] it will auto-select if the data is only xy or has xyz + self.graphic[i].data[:, :data_slice.shape[-1]] = data_slice[i] + + elif isinstance(self.graphic, (LineGraphic, ScatterGraphic)): + self.graphic.data[:, :data_slice.shape[-1]] = data_slice + + elif isinstance(self.graphic, LineCollection): + for i in range(len(self.graphic)): + # data_slice shape is [n_lines, n_datapoints, 2 | 3] + self.graphic[i].data[:, :data_slice.shape[-1]] = data_slice[i] + + def _create_graphic(self, graphic_cls: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic]): + if self.processor.multi and issubclass(graphic_cls, ScatterGraphic): + # make list of scatters + self._graphic = list() + data_slice = self.processor[self.indices] + for d in data_slice: + scatter = graphic_cls(d) + self._graphic.append(scatter) + + else: + data_slice = self.processor[self.indices] + self._graphic = graphic_cls(data_slice) diff --git a/fastplotlib/widgets/nd_widget/_processor.py b/fastplotlib/widgets/nd_widget/_nd_timeseries.py similarity index 70% rename from fastplotlib/widgets/nd_widget/_processor.py rename to fastplotlib/widgets/nd_widget/_nd_timeseries.py index 0add36594..8630044cf 100644 --- a/fastplotlib/widgets/nd_widget/_processor.py +++ b/fastplotlib/widgets/nd_widget/_nd_timeseries.py @@ -5,84 +5,10 @@ import numpy as np from numpy.typing import ArrayLike -from ...graphics import ImageGraphic, LineStack, LineCollection, ScatterGraphic from ...utils import subsample_array, ArrayProtocol - -# must take arguments: array-like, `axis`: int, `keepdims`: bool -WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] - - -class NDProcessor: - def __init__( - self, - data, - n_display_dims: Literal[2, 3] = 2, - slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, - window_funcs: tuple[WindowFuncCallable | None] | None = None, - window_sizes: tuple[int | None] | None = None, - spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, - ): - self._data = self._validate_data(data) - self._slider_index_maps = self._validate_slider_index_maps(slider_index_maps) - - @property - def data(self) -> ArrayProtocol: - return self._data - - @data.setter - def data(self, data: ArrayProtocol): - self._data = self._validate_data(data) - - def _validate_data(self, data: ArrayProtocol): - if not isinstance(data, ArrayProtocol): - raise TypeError("`data` must implement the ArrayProtocol") - - return data - - @property - def window_funcs(self) -> tuple[WindowFuncCallable | None] | None: - pass - - @property - def window_sizes(self) -> tuple[int | None] | None: - pass - - @property - def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None: - pass - - @property - def slider_dims(self) -> tuple[int, ...] | None: - pass - - @property - def slider_index_maps(self) -> tuple[Callable[[Any], int] | None, ...]: - return self._slider_index_maps - - @slider_index_maps.setter - def slider_index_maps(self, maps): - self._maps = self._validate_slider_index_maps(maps) - - def _validate_slider_index_maps(self, maps): - if maps is not None: - if not all([callable(m) or m is None for m in maps]): - raise TypeError - - return maps - - def __getitem__(self, item: tuple[Any, ...]) -> ArrayProtocol: - pass - - -class NDImageProcessor(NDProcessor): - @property - def n_display_dims(self) -> Literal[2, 3]: - pass - - def _validate_n_display_dims(self, n_display_dims): - if n_display_dims not in (2, 3): - raise ValueError("`n_display_dims` must be") +from ...graphics import ImageGraphic, LineStack, LineCollection, ScatterGraphic +from ._processor_base import NDProcessor, WindowFuncCallable VALID_TIMESERIES_Y_DATA_SHAPES = ( @@ -145,32 +71,6 @@ def _validate_data(self, data: list[ArrayProtocol, ArrayProtocol]): return data - @property - def display_graphic(self) -> Literal["line", "heatmap"]: - return self._display_graphic - - @display_graphic.setter - def display_graphic(self, dg: Literal["line", "heatmap"]): - dg = self._validate_display_graphic(dg) - - if dg == "heatmap": - # check if x-vals uniformly spaced - # this is very fast to do on the fly, especially for typical small display windows - norm = np.linalg.norm(np.diff(np.diff(self.x_values))) / len(self.x_values) - if norm > 10 ** -12: - # need to create evenly spaced x-values - x0 = self.data[0][0] - xn = self.data[0][-1] - self._uniform_x_values = np.linspace(x0, xn, num=len(self.data[0])) - - # TODO: interpolate yz values on the fly only when within the display window - - def _validate_display_graphic(self, dg): - if dg not in ("line", "heatmap"): - raise ValueError - - return dg - @property def display_window(self) -> int | float | None: """display window in the reference units along the x-axis""" diff --git a/fastplotlib/widgets/nd_widget/_processor_base.py b/fastplotlib/widgets/nd_widget/_processor_base.py new file mode 100644 index 000000000..fa56e4b52 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_processor_base.py @@ -0,0 +1,74 @@ +import inspect +from typing import Literal, Callable, Any +from warnings import warn + +import numpy as np +from numpy.typing import ArrayLike + +from ...utils import subsample_array, ArrayProtocol + + +# must take arguments: array-like, `axis`: int, `keepdims`: bool +WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] + + +class NDProcessor: + def __init__( + self, + data, + n_display_dims: Literal[2, 3] = 2, + slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, + window_funcs: tuple[WindowFuncCallable | None] | None = None, + window_sizes: tuple[int | None] | None = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, + ): + self._data = self._validate_data(data) + self._slider_index_maps = self._validate_slider_index_maps(slider_index_maps) + + @property + def data(self) -> ArrayProtocol: + return self._data + + @data.setter + def data(self, data: ArrayProtocol): + self._data = self._validate_data(data) + + def _validate_data(self, data: ArrayProtocol): + if not isinstance(data, ArrayProtocol): + raise TypeError("`data` must implement the ArrayProtocol") + + return data + + @property + def window_funcs(self) -> tuple[WindowFuncCallable | None] | None: + pass + + @property + def window_sizes(self) -> tuple[int | None] | None: + pass + + @property + def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None: + pass + + @property + def slider_dims(self) -> tuple[int, ...] | None: + pass + + @property + def slider_index_maps(self) -> tuple[Callable[[Any], int] | None, ...]: + return self._slider_index_maps + + @slider_index_maps.setter + def slider_index_maps(self, maps): + self._maps = self._validate_slider_index_maps(maps) + + def _validate_slider_index_maps(self, maps): + if maps is not None: + if not all([callable(m) or m is None for m in maps]): + raise TypeError + + return maps + + def __getitem__(self, item: tuple[Any, ...]) -> ArrayProtocol: + pass From 074669b084068784bed7a5f54e6cfe014ea0abf5 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sun, 25 Jan 2026 23:58:14 -0500 Subject: [PATCH 006/101] black --- .../widgets/nd_widget/_nd_positions.py | 47 ++++++++++++++----- .../widgets/nd_widget/_nd_timeseries.py | 16 ++++--- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions.py index db8c80e72..10215d351 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions.py @@ -7,18 +7,25 @@ from ...utils import subsample_array, ArrayProtocol -from ...graphics import ImageGraphic, LineGraphic, LineStack, LineCollection, ScatterGraphic +from ...graphics import ( + ImageGraphic, + LineGraphic, + LineStack, + LineCollection, + ScatterGraphic, +) from ._processor_base import NDProcessor + # TODO: Maybe get rid of n_display_dims in NDProcessor, # we will know the display dims automatically here from the last dim # so maybe we only need it for images? class NDPositionsProcessor(NDProcessor): def __init__( - self, - data: ArrayProtocol, - multi: bool = False, # TODO: interpret [n - 2] dimension as n_lines or n_points - display_window: int | float | None = 100, # window for n_datapoints dim only + self, + data: ArrayProtocol, + multi: bool = False, # TODO: interpret [n - 2] dimension as n_lines or n_points + display_window: int | float | None = 100, # window for n_datapoints dim only ): super().__init__(data=data) @@ -52,8 +59,10 @@ def multi(self) -> bool: @multi.setter def multi(self, m: bool): if m and self.data.ndim < 3: - # p is p-datapoints, n is how many lines/scatter to show simultaneously - raise ValueError("ndim must be >= 3 for multi, shape must be [s1..., sn, n, p, 2 | 3]") + # p is p-datapoints, n is how many lines to show simultaneously (for line collection/stack) + raise ValueError( + "ndim must be >= 3 for multi, shape must be [s1..., sn, n, p, 2 | 3]" + ) self._multi = m @@ -81,7 +90,12 @@ def __getitem__(self, indices: tuple[Any, ...]): class NDPositions: - def __init__(self, data, graphic: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic], multi: bool = False): + def __init__( + self, + data, + graphic: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic], + multi: bool = False, + ): self._indices = 0 if issubclass(graphic, LineCollection): @@ -96,7 +110,11 @@ def processor(self) -> NDPositionsProcessor: return self._processor @property - def graphic(self) -> LineGraphic | LineCollection | LineStack | ScatterGraphic | list[ScatterGraphic]: + def graphic( + self, + ) -> ( + LineGraphic | LineCollection | LineStack | ScatterGraphic + ): """LineStack or ImageGraphic for heatmaps""" return self._graphic @@ -113,17 +131,20 @@ def indices(self, indices): for i in range(len(self.graphic)): # data_slice shape is [n_scatters, n_datapoints, 2 | 3] # by using data_slice.shape[-1] it will auto-select if the data is only xy or has xyz - self.graphic[i].data[:, :data_slice.shape[-1]] = data_slice[i] + self.graphic[i].data[:, : data_slice.shape[-1]] = data_slice[i] elif isinstance(self.graphic, (LineGraphic, ScatterGraphic)): - self.graphic.data[:, :data_slice.shape[-1]] = data_slice + self.graphic.data[:, : data_slice.shape[-1]] = data_slice elif isinstance(self.graphic, LineCollection): for i in range(len(self.graphic)): # data_slice shape is [n_lines, n_datapoints, 2 | 3] - self.graphic[i].data[:, :data_slice.shape[-1]] = data_slice[i] + self.graphic[i].data[:, : data_slice.shape[-1]] = data_slice[i] - def _create_graphic(self, graphic_cls: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic]): + def _create_graphic( + self, + graphic_cls: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic], + ): if self.processor.multi and issubclass(graphic_cls, ScatterGraphic): # make list of scatters self._graphic = list() diff --git a/fastplotlib/widgets/nd_widget/_nd_timeseries.py b/fastplotlib/widgets/nd_widget/_nd_timeseries.py index 8630044cf..49b9231c3 100644 --- a/fastplotlib/widgets/nd_widget/_nd_timeseries.py +++ b/fastplotlib/widgets/nd_widget/_nd_timeseries.py @@ -137,15 +137,17 @@ def graphic(self, g: Literal["line", "heatmap"]): # make sure "yz" data is only ys and no z values # can't represent y and z vals in a heatmap if self.processor.data[1].ndim > 2: - raise ValueError("Only y-values are supported for heatmaps, not yz-values") + raise ValueError( + "Only y-values are supported for heatmaps, not yz-values" + ) self._create_heatmap() @property - def display_window(self) -> int | float | None: + def display_window(self) -> int | float | None: return self.processor.display_window @display_window.setter - def display_window(self, dw: int | float | None): + def display_window(self, dw: int | float | None): # create new graphic if it changed if dw != self.display_window: create_new_graphic = True @@ -181,7 +183,9 @@ def set_index(self, indices: tuple[Any, ...]): def _create_line_stack_data(self, data_slice): xs = data_slice[0] # 1D - yz = data_slice[1] # [n_lines, n_datapoints] for y-vals or [n_lines, n_datapoints, 2] for yz-vals + yz = data_slice[ + 1 + ] # [n_lines, n_datapoints] for y-vals or [n_lines, n_datapoints, 2] for yz-vals # need to go from x_vals and yz_vals arrays to an array of shape: [n_lines, n_datapoints, 2 | 3] return np.dstack([np.repeat(xs[None], repeats=yz.shape[0], axis=0), yz]) @@ -199,7 +203,7 @@ def _create_heatmap_data(self, data_slice) -> tuple[ArrayProtocol, float]: # this is very fast to do on the fly, especially for typical small display windows x, y = data_slice norm = np.linalg.norm(np.diff(np.diff(x))) / x.size - if norm > 10 ** -12: + if norm > 10**-12: # need to create evenly spaced x-values x_uniform = np.linspace(x[0], x[-1], num=x.size) # yz is [n_lines, n_datapoints] @@ -220,4 +224,4 @@ def _create_heatmap(self): hm_data, x_scale = self._create_heatmap_data(data_slice) self._graphic = ImageGraphic(hm_data) - self._graphic.world_object.world.scale_x = x_scale \ No newline at end of file + self._graphic.world_object.world.scale_x = x_scale From ff5c5783c235376a049732f889cc477cdfc5ca9a Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Thu, 29 Jan 2026 20:27:59 -0500 Subject: [PATCH 007/101] NDPositions working with multi-dim stack of lines, need to test window funcs --- .../widgets/nd_widget/_nd_positions.py | 113 +++++++++++-- .../widgets/nd_widget/_processor_base.py | 157 +++++++++++++++++- 2 files changed, 247 insertions(+), 23 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions.py index 10215d351..dfcb263c5 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions.py @@ -26,6 +26,7 @@ def __init__( data: ArrayProtocol, multi: bool = False, # TODO: interpret [n - 2] dimension as n_lines or n_points display_window: int | float | None = 100, # window for n_datapoints dim only + n_slider_dims: int = 0, ): super().__init__(data=data) @@ -33,6 +34,8 @@ def __init__( self.multi = multi + self.n_slider_dims = n_slider_dims + def _validate_data(self, data: ArrayProtocol): # TODO: determine right validation shape etc. return data @@ -66,27 +69,108 @@ def multi(self, m: bool): self._multi = m - def __getitem__(self, indices: tuple[Any, ...]): - """sliders through all slider dims and outputs an array that can be used to set graphic data""" + def _apply_window_functions(self, indices: tuple[int, ...]): + """applies the window functions for each dimension specified""" + # window size for each dim + winds = self._window_sizes + # window function for each dim + funcs = self._window_funcs + + if winds is None or funcs is None: + # no window funcs or window sizes, just slice data and return + # clamp to max bounds + indexer = list() + for dim, i in enumerate(indices): + i = min(self.shape[dim] - 1, i) + indexer.append(i) + + return self.data[tuple(indexer)] + + # order in which window funcs are applied + order = self._window_order + + if order is not None: + # remove any entries in `window_order` where the specified dim + # has a window function or window size specified as `None` + # example: + # window_sizes = (3, 2) + # window_funcs = (np.mean, None) + # order = (0, 1) + # `1` is removed from the order since that window_func is `None` + order = tuple( + d for d in order if winds[d] is not None and funcs[d] is not None + ) + else: + # sequential order + order = list() + for d in range(self.n_slider_dims): + if winds[d] is not None and funcs[d] is not None: + order.append(d) + + # the final indexer which will be used on the data array + indexer = list() + + for dim_index, (i, w, f) in enumerate(zip(indices, winds, funcs)): + # clamp i within the max bounds + i = min(self.shape[dim_index] - 1, i) + + if (w is not None) and (f is not None): + # specify slice window if both window size and function for this dim are not None + hw = int((w - 1) / 2) # half window + + # start index cannot be less than 0 + start = max(0, i - hw) + + # stop index cannot exceed the bounds of this dimension + stop = min(self.shape[dim_index] - 1, i + hw) + + s = slice(start, stop, 1) + else: + s = slice(i, i + 1, 1) + + indexer.append(s) + + # apply indexer to slice data with the specified windows + data_sliced = self.data[tuple(indexer)] + + # finally apply the window functions in the specified order + for dim in order: + f = funcs[dim] + + data_sliced = f(data_sliced, axis=dim, keepdims=True) + + return data_sliced + + def get(self, indices: tuple[Any, ...]): + """ + slices through all slider dims and outputs an array that can be used to set graphic data + + Note that we do not use __getitem__ here since the index is a tuple specifying a single integer + index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. + """ + # apply window funcs + # this array should be of shape [n_datapoints, 2 | 3] + window_output = self._apply_window_functions(indices[:-1]).squeeze() + + # TODO: window function on the `p` n_datapoints dimension + if self.display_window is not None: - indices_window = self.display_window + dw = self.display_window # half window size - hw = indices_window // 2 + hw = dw // 2 # for now assume just a single index provided that indicates x axis value - start = max(indices - hw, 0) - stop = start + indices_window + start = max(indices[-1] - hw, 0) + stop = start + dw slices = [slice(start, stop)] - # TODO: implement slicing for multiple slider dims, i.e. [s1, s2, ... n_datapoints, 2 | 3] - # this currently assumes the shape is: [n_datapoints, 2 | 3] if self.multi: # n - 2 dim is n_lines or n_scatters slices.insert(0, slice(None)) - return self.data[tuple(slices)] + return window_output[tuple(slices)] class NDPositions: @@ -96,12 +180,11 @@ def __init__( graphic: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic], multi: bool = False, ): - self._indices = 0 - if issubclass(graphic, LineCollection): multi = True - self._processor = NDPositionsProcessor(data, multi=multi) + self._processor = NDPositionsProcessor(data, multi=multi, display_window=100, n_slider_dims=2) + self._indices = tuple([0] * (2 + 1)) self._create_graphic(graphic) @@ -124,7 +207,7 @@ def indices(self) -> tuple: @indices.setter def indices(self, indices): - data_slice = self.processor[indices] + data_slice = self.processor.get(indices) if isinstance(self.graphic, list): # list of scatter @@ -148,11 +231,11 @@ def _create_graphic( if self.processor.multi and issubclass(graphic_cls, ScatterGraphic): # make list of scatters self._graphic = list() - data_slice = self.processor[self.indices] + data_slice = self.processor.get(self.indices) for d in data_slice: scatter = graphic_cls(d) self._graphic.append(scatter) else: - data_slice = self.processor[self.indices] + data_slice = self.processor.get(self.indices) self._graphic = graphic_cls(data_slice) diff --git a/fastplotlib/widgets/nd_widget/_processor_base.py b/fastplotlib/widgets/nd_widget/_processor_base.py index fa56e4b52..3350fff8f 100644 --- a/fastplotlib/widgets/nd_widget/_processor_base.py +++ b/fastplotlib/widgets/nd_widget/_processor_base.py @@ -7,7 +7,6 @@ from ...utils import subsample_array, ArrayProtocol - # must take arguments: array-like, `axis`: int, `keepdims`: bool WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] @@ -20,11 +19,16 @@ def __init__( slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, window_funcs: tuple[WindowFuncCallable | None] | None = None, window_sizes: tuple[int | None] | None = None, + window_order: tuple[int, ...] = None, spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, ): self._data = self._validate_data(data) self._slider_index_maps = self._validate_slider_index_maps(slider_index_maps) + self.window_funcs = window_funcs + self.window_sizes = window_sizes + self.window_order = window_order + @property def data(self) -> ArrayProtocol: return self._data @@ -33,6 +37,14 @@ def data(self) -> ArrayProtocol: def data(self, data: ArrayProtocol): self._data = self._validate_data(data) + @property + def shape(self) -> tuple[int, ...]: + return self.data.shape + + @property + def ndim(self) -> int: + return int(np.prod(self.shape)) + def _validate_data(self, data: ArrayProtocol): if not isinstance(data, ArrayProtocol): raise TypeError("`data` must implement the ArrayProtocol") @@ -40,21 +52,150 @@ def _validate_data(self, data: ArrayProtocol): return data @property - def window_funcs(self) -> tuple[WindowFuncCallable | None] | None: - pass + def window_funcs( + self, + ) -> tuple[WindowFuncCallable | None, ...] | None: + """get or set window functions, see docstring for details""" + return self._window_funcs + + @window_funcs.setter + def window_funcs( + self, + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable | None, + ): + if window_funcs is None: + self._window_funcs = None + return + + if callable(window_funcs): + window_funcs = (window_funcs,) + + # if all are None + if all([f is None for f in window_funcs]): + self._window_funcs = None + return + + self._validate_window_func(window_funcs) + + self._window_funcs = tuple(window_funcs) + self._recompute_histogram() + + def _validate_window_func(self, funcs): + if isinstance(funcs, (tuple, list)): + for f in funcs: + if f is None: + pass + elif callable(f): + sig = inspect.signature(f) + + if "axis" not in sig.parameters or "keepdims" not in sig.parameters: + raise TypeError( + f"Each window function must take an `axis` and `keepdims` argument, " + f"you passed: {f} with the following function signature: {sig}" + ) + else: + raise TypeError( + f"`window_funcs` must be of type: tuple[Callable | None, ...], you have passed: {funcs}" + ) + + if not (len(funcs) == self.n_slider_dims or self.n_slider_dims == 0): + raise IndexError( + f"number of `window_funcs` must be the same as the number of slider dims: {self.n_slider_dims}, " + f"and you passed {len(funcs)} `window_funcs`: {funcs}" + ) @property - def window_sizes(self) -> tuple[int | None] | None: - pass + def window_sizes(self) -> tuple[int | None, ...] | None: + """get or set window sizes used for the corresponding window functions, see docstring for details""" + return self._window_sizes + + @window_sizes.setter + def window_sizes(self, window_sizes: tuple[int | None, ...] | int | None): + if window_sizes is None: + self._window_sizes = None + return + + if isinstance(window_sizes, int): + window_sizes = (window_sizes,) + + # if all are None + if all([w is None for w in window_sizes]): + self._window_sizes = None + return + + if not all([isinstance(w, (int)) or w is None for w in window_sizes]): + raise TypeError( + f"`window_sizes` must be of type: tuple[int | None, ...] | int | None, you have passed: {window_sizes}" + ) + + # if not (len(window_sizes) == self.n_slider_dims or self.n_slider_dims == 0): + # raise IndexError( + # f"number of `window_sizes` must be the same as the number of slider dims, " + # f"i.e. `data.ndim` - n_display_dims, your data array has {self.ndim} dimensions " + # f"and you passed {len(window_sizes)} `window_sizes`: {window_sizes}" + # ) + + # make all window sizes are valid numbers + _window_sizes = list() + for i, w in enumerate(window_sizes): + if w is None: + _window_sizes.append(None) + continue + + if w < 0: + raise ValueError( + f"negative window size passed, all `window_sizes` must be positive " + f"integers or `None`, you passed: {_window_sizes}" + ) + + if w == 0 or w == 1: + # this is not a real window, set as None + w = None + + elif w % 2 == 0: + # odd window sizes makes most sense + warn( + f"provided even window size: {w} in dim: {i}, adding `1` to make it odd" + ) + w += 1 + + _window_sizes.append(w) + + self._window_sizes = tuple(_window_sizes) @property - def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None: - pass + def window_order(self) -> tuple[int, ...] | None: + """get or set dimension order in which window functions are applied""" + return self._window_order + + @window_order.setter + def window_order(self, order: tuple[int] | None): + if order is None: + self._window_order = None + return + + if order is not None: + if not all([d <= self.n_slider_dims for d in order]): + raise IndexError( + f"all `window_order` entries must be <= n_slider_dims\n" + f"`n_slider_dims` is: {self.n_slider_dims}, you have passed `window_order`: {order}" + ) + + if not all([d >= 0 for d in order]): + raise IndexError( + f"all `window_order` entires must be >= 0, you have passed: {order}" + ) + + self._window_order = tuple(order) @property - def slider_dims(self) -> tuple[int, ...] | None: + def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None: pass + # @property + # def slider_dims(self) -> tuple[int, ...] | None: + # pass + @property def slider_index_maps(self) -> tuple[Callable[[Any], int] | None, ...]: return self._slider_index_maps From 3f412c514e204279b70dad1bbb0c7c2b06796405 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Thu, 29 Jan 2026 20:48:25 -0500 Subject: [PATCH 008/101] scatter collection --- fastplotlib/graphics/__init__.py | 3 +- fastplotlib/graphics/scatter_collection.py | 517 ++++++++++++++++++ fastplotlib/layouts/_graphic_methods_mixin.py | 84 ++- 3 files changed, 602 insertions(+), 2 deletions(-) create mode 100644 fastplotlib/graphics/scatter_collection.py diff --git a/fastplotlib/graphics/__init__.py b/fastplotlib/graphics/__init__.py index 3d01e4a35..8734a5e72 100644 --- a/fastplotlib/graphics/__init__.py +++ b/fastplotlib/graphics/__init__.py @@ -7,7 +7,7 @@ from .mesh import MeshGraphic, SurfaceGraphic, PolygonGraphic from .text import TextGraphic from .line_collection import LineCollection, LineStack - +from .scatter_collection import ScatterCollection __all__ = [ "Graphic", @@ -22,4 +22,5 @@ "TextGraphic", "LineCollection", "LineStack", + "ScatterCollection", ] diff --git a/fastplotlib/graphics/scatter_collection.py b/fastplotlib/graphics/scatter_collection.py new file mode 100644 index 000000000..4d671b0ac --- /dev/null +++ b/fastplotlib/graphics/scatter_collection.py @@ -0,0 +1,517 @@ +from typing import * + +import numpy as np + +import pygfx + +from ..utils import parse_cmap_values +from ._collection_base import CollectionIndexer, GraphicCollection, CollectionFeature +from .scatter import ScatterGraphic +from .selectors import ( + LinearRegionSelector, + LinearSelector, + RectangleSelector, + PolygonSelector, +) + + +class _ScatterCollectionProperties: + """Mix-in class for ScatterCollection properties""" + + @property + def colors(self) -> CollectionFeature: + """get or set colors of scatters in the collection""" + return CollectionFeature(self.graphics, "colors") + + @colors.setter + def colors(self, values: str | np.ndarray | tuple[float] | list[float] | list[str]): + if isinstance(values, str): + # set colors of all scatter to one str color + for g in self: + g.colors = values + return + + elif all(isinstance(v, str) for v in values): + # individual str colors for each scatter + if not len(values) == len(self): + raise IndexError + + for g, v in zip(self.graphics, values): + g.colors = v + + return + + if isinstance(values, np.ndarray): + if values.ndim == 2: + # assume individual colors for each + for g, v in zip(self, values): + g.colors = v + return + + elif len(values) == 4: + # assume RGBA + self.colors[:] = values + + else: + # assume individual colors for each + for g, v in zip(self, values): + g.colors = v + + @property + def data(self) -> CollectionFeature: + """get or set data of lines in the collection""" + return CollectionFeature(self.graphics, "data") + + @data.setter + def data(self, values): + for g, v in zip(self, values): + g.data = v + + @property + def cmap(self) -> CollectionFeature: + """ + Get or set a cmap along the scatter collection. + + Optionally set using a tuple ("cmap", ) to set the transform. + Example: + + scatter_collection.cmap = ("jet", sine_transform_vals, 0.7) + + """ + return CollectionFeature(self.graphics, "cmap") + + @cmap.setter + def cmap(self, args): + if isinstance(args, str): + name = args + transform = None + elif len(args) == 1: + name = args[0] + transform = None + elif len(args) == 2: + name, transform = args + else: + raise ValueError( + "Too many values for cmap (note that alpha is deprecated, set alpha on the graphic instead)" + ) + + self.colors = parse_cmap_values( + n_colors=len(self), cmap_name=name, transform=transform + ) + + +class ScatterCollectionIndexer(CollectionIndexer, _ScatterCollectionProperties): + """Indexer for scatter collections""" + + pass + + +class ScatterCollection(GraphicCollection, _ScatterCollectionProperties): + _child_type = ScatterGraphic + _indexer = ScatterCollectionIndexer + + def __init__( + self, + data: np.ndarray | List[np.ndarray], + colors: str | Sequence[str] | np.ndarray | Sequence[np.ndarray] = "w", + uniform_colors: bool = False, + cmap: Sequence[str] | str = None, + cmap_transform: np.ndarray | List = None, + sizes: float | Sequence[float] = 2.0, + name: str = None, + names: list[str] = None, + metadata: Any = None, + metadatas: Sequence[Any] | np.ndarray = None, + isolated_buffer: bool = True, + kwargs_lines: list[dict] = None, + **kwargs, + ): + """ + Create a collection of :class:`.ScatterGraphic` + + Parameters + ---------- + data: list of array-like + List or array-like of multiple line data to plot + + | if ``list`` each item in the list must be a 1D, 2D, or 3D numpy array + | if array-like, must be of shape [n_lines, n_points_line, y | xy | xyz] + + colors: str, RGBA array, Iterable of RGBA array, or Iterable of str, default "w" + | if single ``str`` such as "w", "r", "b", etc, represents a single color for all lines + | if single ``RGBA array`` (tuple or list of size 4), represents a single color for all lines + | if ``list`` of ``str``, represents color for each individual line, example ["w", "b", "r",...] + | if ``RGBA array`` of shape [data_size, 4], represents a single RGBA array for each line + + cmap: Iterable of str or str, optional + | if ``str``, single cmap will be used for all lines + | if ``list`` of ``str``, each cmap will apply to the individual lines + + .. note:: + ``cmap`` overrides any arguments passed to ``colors`` + + cmap_transform: 1D array-like of numerical values, optional + if provided, these values are used to map the colors from the cmap + + name: str, optional + name of the line collection as a whole + + names: list[str], optional + names of the individual lines in the collection, ``len(names)`` must equal ``len(data)`` + + metadata: Any + meatadata associated with the collection as a whole + + metadatas: Iterable or array + metadata for each individual line associated with this collection, this is for the user to manage. + ``len(metadata)`` must be same as ``len(data)`` + + kwargs_lines: list[dict], optional + list of kwargs passed to the individual lines, ``len(kwargs_lines)`` must equal ``len(data)`` + + kwargs_collection + kwargs for the collection, passed to GraphicCollection + + """ + + super().__init__(name=name, metadata=metadata, **kwargs) + + if names is not None: + if len(names) != len(data): + raise ValueError( + f"len(names) != len(data)\n{len(names)} != {len(data)}" + ) + + if metadatas is not None: + if len(metadatas) != len(data): + raise ValueError( + f"len(metadata) != len(data)\n{len(metadatas)} != {len(data)}" + ) + + if kwargs_lines is not None: + if len(kwargs_lines) != len(data): + raise ValueError( + f"len(kwargs_lines) != len(data)\n" + f"{len(kwargs_lines)} != {len(data)}" + ) + + self._cmap_transform = cmap_transform + self._cmap_str = cmap + + # cmap takes priority over colors + if cmap is not None: + # cmap across lines + if isinstance(cmap, str): + colors = parse_cmap_values( + n_colors=len(data), cmap_name=cmap, transform=cmap_transform + ) + single_color = False + cmap = None + + elif isinstance(cmap, (tuple, list)): + if len(cmap) != len(data): + raise ValueError( + "cmap argument must be a single cmap or a list of cmaps " + "with the same length as the data" + ) + single_color = False + else: + raise ValueError( + "cmap argument must be a single cmap or a list of cmaps " + "with the same length as the data" + ) + else: + if isinstance(colors, np.ndarray): + # single color for all lines in the collection as RGBA + if colors.shape in [(3,), (4,)]: + single_color = True + + # colors specified for each line as array of shape [n_lines, RGBA] + elif colors.shape == (len(data), 4): + single_color = False + + else: + raise ValueError( + f"numpy array colors argument must be of shape (4,) or (n_lines, 4)." + f"You have pass the following shape: {colors.shape}" + ) + + elif isinstance(colors, str): + if colors == "random": + colors = np.random.rand(len(data), 3) + single_color = False + else: + # parse string color + single_color = True + colors = pygfx.Color(colors) + + elif isinstance(colors, (tuple, list)): + if len(colors) == 4: + # single color specified as (R, G, B, A) tuple or list + if all([isinstance(c, (float, int)) for c in colors]): + single_color = True + + elif len(colors) == len(data): + # colors passed as list/tuple of colors, such as list of string + single_color = False + + else: + raise ValueError( + "tuple or list colors argument must be a single color represented as [R, G, B, A], " + "or must be a tuple/list of colors represented by a string with the same length as the data" + ) + + if kwargs_lines is None: + kwargs_lines = dict() + + self._set_world_object(pygfx.Group()) + + for i, d in enumerate(data): + if cmap is None: + _cmap = None + + if single_color: + _c = colors + else: + _c = colors[i] + else: + _cmap = cmap[i] + _c = None + + if metadatas is not None: + _m = metadatas[i] + else: + _m = None + + if names is not None: + _name = names[i] + else: + _name = None + + lg = ScatterGraphic( + data=d, + colors=_c, + uniform_color=uniform_colors, + sizes=sizes, + cmap=_cmap, + name=_name, + metadata=_m, + isolated_buffer=isolated_buffer, + **kwargs_lines, + ) + + self.add_graphic(lg) + + def __getitem__(self, item) -> ScatterCollectionIndexer: + return super().__getitem__(item) + + def add_linear_selector( + self, selection: float = None, padding: float = 0.0, axis: str = "x", **kwargs + ) -> LinearSelector: + """ + Adds a linear selector. + + Parameters + ---------- + Parameters + ---------- + selection: float, optional + selected point on the linear selector, computed from data if not provided + + axis: str, default "x" + axis that the selector resides on + + padding: float, default 0.0 + Extra padding to extend the linear selector along the orthogonal axis to make it easier to interact with. + + kwargs + passed to :class:`.LinearSelector` + + Returns + ------- + LinearSelector + + """ + + bounds_init, limits, size, center = self._get_linear_selector_init_args( + axis, padding + ) + + if selection is None: + selection = bounds_init[0] + + selector = LinearSelector( + selection=selection, + limits=limits, + axis=axis, + parent=self, + **kwargs, + ) + + self._plot_area.add_graphic(selector, center=False) + + return selector + + def add_linear_region_selector( + self, + selection: tuple[float, float] = None, + padding: float = 0.0, + axis: str = "x", + **kwargs, + ) -> LinearRegionSelector: + """ + Add a :class:`.LinearRegionSelector`. Selectors are just ``Graphic`` objects, so you can manage, + remove, or delete them from a plot area just like any other ``Graphic``. + + Parameters + ---------- + selection: (float, float), optional + the starting bounds of the linear region selector, computed from data if not provided + + axis: str, default "x" + axis that the selector resides on + + padding: float, default 0.0 + Extra padding to extend the linear region selector along the orthogonal axis to make it easier to interact with. + + kwargs + passed to ``LinearRegionSelector`` + + Returns + ------- + LinearRegionSelector + linear selection graphic + + """ + + bounds_init, limits, size, center = self._get_linear_selector_init_args( + axis, padding + ) + + if selection is None: + selection = bounds_init + + # create selector + selector = LinearRegionSelector( + selection=selection, + limits=limits, + size=size, + center=center, + axis=axis, + parent=self, + **kwargs, + ) + + self._plot_area.add_graphic(selector, center=False) + + # PlotArea manages this for garbage collection etc. just like all other Graphics + # so we should only work with a proxy on the user-end + return selector + + def add_rectangle_selector( + self, + selection: tuple[float, float, float] = None, + **kwargs, + ) -> RectangleSelector: + """ + Add a :class:`.RectangleSelector`. Selectors are just ``Graphic`` objects, so you can manage, + remove, or delete them from a plot area just like any other ``Graphic``. + + Parameters + ---------- + selection: (float, float, float, float), optional + initial (xmin, xmax, ymin, ymax) of the selection + """ + bbox = self.world_object.get_world_bounding_box() + + xdata = np.array(self.data[:, 0]) + xmin, xmax = (np.nanmin(xdata), np.nanmax(xdata)) + value_25px = (xmax - xmin) / 4 + + ydata = np.array(self.data[:, 1]) + ymin = np.floor(ydata.min()).astype(int) + + ymax = np.ptp(bbox[:, 1]) + + if selection is None: + selection = (xmin, value_25px, ymin, ymax) + + limits = (xmin, xmax, ymin - (ymax * 1.5 - ymax), ymax * 1.5) + + selector = RectangleSelector( + selection=selection, + limits=limits, + parent=self, + **kwargs, + ) + + self._plot_area.add_graphic(selector, center=False) + + return selector + + def add_polygon_selector( + self, + selection: List[tuple[float, float]] = None, + **kwargs, + ) -> PolygonSelector: + """ + Add a :class:`.PolygonSelector`. Selectors are just ``Graphic`` objects, so you can manage, + remove, or delete them from a plot area just like any other ``Graphic``. + + Parameters + ---------- + selection: List of positions, optional + Initial points for the polygon. If not given or None, you'll start drawing the selection (clicking adds points to the polygon). + """ + bbox = self.world_object.get_world_bounding_box() + + xdata = np.array(self.data[:, 0]) + xmin, xmax = (np.nanmin(xdata), np.nanmax(xdata)) + + ydata = np.array(self.data[:, 1]) + ymin = np.floor(ydata.min()).astype(int) + + ymax = np.ptp(bbox[:, 1]) + + limits = (xmin, xmax, ymin - (ymax * 1.5 - ymax), ymax * 1.5) + + selector = PolygonSelector( + selection, + limits, + parent=self, + **kwargs, + ) + + self._plot_area.add_graphic(selector, center=False) + + return selector + + def _get_linear_selector_init_args(self, axis, padding): + # use bbox to get size and center + bbox = self.world_object.get_world_bounding_box() + + if axis == "x": + xdata = np.array(self.data[:, 0]) + xmin, xmax = (np.nanmin(xdata), np.nanmax(xdata)) + value_25p = (xmax - xmin) / 4 + + bounds = (xmin, value_25p) + limits = (xmin, xmax) + # size from orthogonal axis + size = np.ptp(bbox[:, 1]) * 1.5 + # center on orthogonal axis + center = bbox[:, 1].mean() + + elif axis == "y": + ydata = np.array(self.data[:, 1]) + xmin, xmax = (np.nanmin(ydata), np.nanmax(ydata)) + value_25p = (xmax - xmin) / 4 + + bounds = (xmin, value_25p) + limits = (xmin, xmax) + + size = np.ptp(bbox[:, 0]) * 1.5 + # center on orthogonal axis + center = bbox[:, 0].mean() + + return bounds, limits, size, center diff --git a/fastplotlib/layouts/_graphic_methods_mixin.py b/fastplotlib/layouts/_graphic_methods_mixin.py index 06a4c7517..3eb018f55 100644 --- a/fastplotlib/layouts/_graphic_methods_mixin.py +++ b/fastplotlib/layouts/_graphic_methods_mixin.py @@ -570,6 +570,88 @@ def add_polygon( PolygonGraphic, data, mode, colors, mapcoords, cmap, clim, **kwargs ) + def add_scatter_collection( + self, + data: Union[numpy.ndarray, List[numpy.ndarray]], + colors: Union[str, Sequence[str], numpy.ndarray, Sequence[numpy.ndarray]] = "w", + uniform_colors: bool = False, + cmap: Union[Sequence[str], str] = None, + cmap_transform: Union[numpy.ndarray, List] = None, + sizes: Union[float, Sequence[float]] = 2.0, + name: str = None, + names: list[str] = None, + metadata: Any = None, + metadatas: Union[Sequence[Any], numpy.ndarray] = None, + isolated_buffer: bool = True, + kwargs_lines: list[dict] = None, + **kwargs, + ) -> ScatterCollection: + """ + + Create a collection of :class:`.ScatterGraphic` + + Parameters + ---------- + data: list of array-like + List or array-like of multiple line data to plot + + | if ``list`` each item in the list must be a 1D, 2D, or 3D numpy array + | if array-like, must be of shape [n_lines, n_points_line, y | xy | xyz] + + colors: str, RGBA array, Iterable of RGBA array, or Iterable of str, default "w" + | if single ``str`` such as "w", "r", "b", etc, represents a single color for all lines + | if single ``RGBA array`` (tuple or list of size 4), represents a single color for all lines + | if ``list`` of ``str``, represents color for each individual line, example ["w", "b", "r",...] + | if ``RGBA array`` of shape [data_size, 4], represents a single RGBA array for each line + + cmap: Iterable of str or str, optional + | if ``str``, single cmap will be used for all lines + | if ``list`` of ``str``, each cmap will apply to the individual lines + + .. note:: + ``cmap`` overrides any arguments passed to ``colors`` + + cmap_transform: 1D array-like of numerical values, optional + if provided, these values are used to map the colors from the cmap + + name: str, optional + name of the line collection as a whole + + names: list[str], optional + names of the individual lines in the collection, ``len(names)`` must equal ``len(data)`` + + metadata: Any + meatadata associated with the collection as a whole + + metadatas: Iterable or array + metadata for each individual line associated with this collection, this is for the user to manage. + ``len(metadata)`` must be same as ``len(data)`` + + kwargs_lines: list[dict], optional + list of kwargs passed to the individual lines, ``len(kwargs_lines)`` must equal ``len(data)`` + + kwargs_collection + kwargs for the collection, passed to GraphicCollection + + + """ + return self._create_graphic( + ScatterCollection, + data, + colors, + uniform_colors, + cmap, + cmap_transform, + sizes, + name, + names, + metadata, + metadatas, + isolated_buffer, + kwargs_lines, + **kwargs, + ) + def add_scatter( self, data: Any, @@ -589,7 +671,7 @@ def add_scatter( image: numpy.ndarray = None, point_rotations: float | numpy.ndarray = 0, point_rotation_mode: Literal["uniform", "vertex", "curve"] = "uniform", - sizes: Union[float, numpy.ndarray, Sequence[float]] = 1, + sizes: Union[float, numpy.ndarray, Sequence[float]] = 5, uniform_size: bool = False, size_space: str = "screen", isolated_buffer: bool = True, From dc30151740ea77414a1b4e8d26009092c3aa4ff0 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Fri, 30 Jan 2026 00:06:37 -0500 Subject: [PATCH 009/101] progress, need to change to other branch so committing --- fastplotlib/graphics/scatter_collection.py | 2 +- .../widgets/nd_widget/_nd_positions.py | 99 +++++++++++++------ 2 files changed, 68 insertions(+), 33 deletions(-) diff --git a/fastplotlib/graphics/scatter_collection.py b/fastplotlib/graphics/scatter_collection.py index 4d671b0ac..b1569cacc 100644 --- a/fastplotlib/graphics/scatter_collection.py +++ b/fastplotlib/graphics/scatter_collection.py @@ -117,7 +117,7 @@ def __init__( uniform_colors: bool = False, cmap: Sequence[str] | str = None, cmap_transform: np.ndarray | List = None, - sizes: float | Sequence[float] = 2.0, + sizes: float | Sequence[float] = 5.0, name: str = None, names: list[str] = None, metadata: Any = None, diff --git a/fastplotlib/widgets/nd_widget/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions.py index dfcb263c5..decd3ec6c 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions.py @@ -13,6 +13,7 @@ LineStack, LineCollection, ScatterGraphic, + ScatterCollection, ) from ._processor_base import NDProcessor @@ -122,7 +123,7 @@ def _apply_window_functions(self, indices: tuple[int, ...]): start = max(0, i - hw) # stop index cannot exceed the bounds of this dimension - stop = min(self.shape[dim_index] - 1, i + hw) + stop = min(self.shape[dim_index], i + hw) s = slice(start, stop, 1) else: @@ -148,23 +149,34 @@ def get(self, indices: tuple[Any, ...]): Note that we do not use __getitem__ here since the index is a tuple specifying a single integer index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. """ - # apply window funcs - # this array should be of shape [n_datapoints, 2 | 3] - window_output = self._apply_window_functions(indices[:-1]).squeeze() + if len(indices) > 1: + # there are dims in addition to the n_datapoints dim + # apply window funcs + # window_output array should be of shape [n_datapoints, 2 | 3] + window_output = self._apply_window_functions(indices[:-1]).squeeze() + else: + window_output = self.data # TODO: window function on the `p` n_datapoints dimension if self.display_window is not None: dw = self.display_window - # half window size - hw = dw // 2 + if dw == 1: + slices = [slice(indices[-1], indices[-1] + 1)] + + else: + # half window size + hw = dw // 2 - # for now assume just a single index provided that indicates x axis value - start = max(indices[-1] - hw, 0) - stop = start + dw + # for now assume just a single index provided that indicates x axis value + start = max(indices[-1] - hw, 0) + stop = start + dw - slices = [slice(start, stop)] + # TODO: uncomment this once we have resizeable buffers!! + # stop = min(indices[-1] + hw, self.shape[-2]) + + slices = [slice(start, stop)] if self.multi: # n - 2 dim is n_lines or n_scatters @@ -177,14 +189,15 @@ class NDPositions: def __init__( self, data, - graphic: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic], + graphic: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic | ScatterCollection | ImageGraphic], multi: bool = False, + display_window: int = 10, ): if issubclass(graphic, LineCollection): multi = True - self._processor = NDPositionsProcessor(data, multi=multi, display_window=100, n_slider_dims=2) - self._indices = tuple([0] * (2 + 1)) + self._processor = NDPositionsProcessor(data, multi=multi, display_window=display_window, n_slider_dims=0) + self._indices = tuple([0] * (0 + 1)) self._create_graphic(graphic) @@ -196,11 +209,19 @@ def processor(self) -> NDPositionsProcessor: def graphic( self, ) -> ( - LineGraphic | LineCollection | LineStack | ScatterGraphic + LineGraphic | LineCollection | LineStack | ScatterGraphic | ScatterCollection | ImageGraphic ): """LineStack or ImageGraphic for heatmaps""" return self._graphic + @graphic.setter + def graphic(self, graphic_type): + plot_area = self._graphic._plot_area + plot_area.delete_graphic(self._graphic) + + self._create_graphic(graphic_type) + plot_area.add_graphic(self._graphic) + @property def indices(self) -> tuple: return self._indices @@ -209,33 +230,47 @@ def indices(self) -> tuple: def indices(self, indices): data_slice = self.processor.get(indices) - if isinstance(self.graphic, list): - # list of scatter - for i in range(len(self.graphic)): - # data_slice shape is [n_scatters, n_datapoints, 2 | 3] - # by using data_slice.shape[-1] it will auto-select if the data is only xy or has xyz - self.graphic[i].data[:, : data_slice.shape[-1]] = data_slice[i] - - elif isinstance(self.graphic, (LineGraphic, ScatterGraphic)): + if isinstance(self.graphic, (LineGraphic, ScatterGraphic)): self.graphic.data[:, : data_slice.shape[-1]] = data_slice - elif isinstance(self.graphic, LineCollection): + elif isinstance(self.graphic, (LineCollection, ScatterCollection)): for i in range(len(self.graphic)): # data_slice shape is [n_lines, n_datapoints, 2 | 3] self.graphic[i].data[:, : data_slice.shape[-1]] = data_slice[i] + elif isinstance(self.graphic, ImageGraphic): + image_data, x0, x_scale = self._create_heatmap_data(data_slice) + self.graphic.data = image_data + self.graphic.offset = (x0, *self.graphic.offset[1:]) + def _create_graphic( self, - graphic_cls: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic], + graphic_cls: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic | ScatterCollection | ImageGraphic], ): - if self.processor.multi and issubclass(graphic_cls, ScatterGraphic): - # make list of scatters - self._graphic = list() - data_slice = self.processor.get(self.indices) - for d in data_slice: - scatter = graphic_cls(d) - self._graphic.append(scatter) + + data_slice = self.processor.get(self.indices) + + if issubclass(graphic_cls, ImageGraphic): + image_data, x0, x_scale = self._create_heatmap_data(data_slice) + self._graphic = graphic_cls(image_data, offset=(x0, 0, -1), scale=(x_scale, 1, 1)) else: - data_slice = self.processor.get(self.indices) self._graphic = graphic_cls(data_slice) + + def _create_heatmap_data(self, data_slice) -> tuple[np.ndarray, float, float]: + if not self.processor.multi: + raise ValueError + + if self.processor.data.shape[-1] != 2: + raise ValueError + + # return [n_rows, n_cols] shape data + + image_data = data_slice[..., 1] + + # assume all x values are the same + x_scale = data_slice[:, -1, 0][0] / data_slice.shape[1] + + x0 = data_slice[0, 0, 0] + + return image_data, x0, x_scale From db98bde60f5b7b8b2bfc6288634616efccd529c0 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Fri, 30 Jan 2026 00:34:37 -0500 Subject: [PATCH 010/101] better --- fastplotlib/widgets/nd_widget/_nd_positions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions.py index decd3ec6c..bc7b5c242 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions.py @@ -216,6 +216,9 @@ def graphic( @graphic.setter def graphic(self, graphic_type): + if isinstance(self.graphic, graphic_type): + return + plot_area = self._graphic._plot_area plot_area.delete_graphic(self._graphic) From 3629f70f8c351feffe8208a0132f9463e63dd146 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Fri, 30 Jan 2026 20:45:47 -0500 Subject: [PATCH 011/101] interpolation for heatmap --- .../widgets/nd_widget/_nd_positions.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions.py index bc7b5c242..f5b13a361 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions.py @@ -261,19 +261,35 @@ def _create_graphic( self._graphic = graphic_cls(data_slice) def _create_heatmap_data(self, data_slice) -> tuple[np.ndarray, float, float]: + """return [n_rows, n_cols] shape data""" if not self.processor.multi: raise ValueError if self.processor.data.shape[-1] != 2: raise ValueError - # return [n_rows, n_cols] shape data + # assumes x vals in every row is the same, otherwise a heatmap representation makes no sense + x = data_slice[0, :, 0] # get x from just the first row - image_data = data_slice[..., 1] + # check if we need to interpolate + norm = np.linalg.norm(np.diff(np.diff(x))) / x.size + + if norm > 1e-6: + # x is not uniform upto float32 precision, must interpolate + x_uniform = np.linspace(x[0], x[-1], num=x.size) + y_interp = np.zeros(shape=data_slice[..., 1].shape, dtype=np.float32) + + # this for loop is actually slightly faster than numpy.apply_along_axis() + for i in range(data_slice.shape[0]): + y_interp[i] = np.interp(x_uniform, x, data_slice[i, :, 1]) + + else: + # x is sufficiently uniform + y_interp = data_slice[..., 1] # assume all x values are the same x_scale = data_slice[:, -1, 0][0] / data_slice.shape[1] x0 = data_slice[0, 0, 0] - return image_data, x0, x_scale + return y_interp, x0, x_scale From 87ea418121114b1bf0617893d804c19baaf70a45 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Fri, 30 Jan 2026 20:47:08 -0500 Subject: [PATCH 012/101] better place for check --- fastplotlib/widgets/nd_widget/_nd_positions.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions.py index f5b13a361..201bbb800 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions.py @@ -250,10 +250,15 @@ def _create_graphic( self, graphic_cls: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic | ScatterCollection | ImageGraphic], ): - data_slice = self.processor.get(self.indices) if issubclass(graphic_cls, ImageGraphic): + if not self.processor.multi: + raise ValueError + + if self.processor.data.shape[-1] != 2: + raise ValueError + image_data, x0, x_scale = self._create_heatmap_data(data_slice) self._graphic = graphic_cls(image_data, offset=(x0, 0, -1), scale=(x_scale, 1, 1)) @@ -262,12 +267,6 @@ def _create_graphic( def _create_heatmap_data(self, data_slice) -> tuple[np.ndarray, float, float]: """return [n_rows, n_cols] shape data""" - if not self.processor.multi: - raise ValueError - - if self.processor.data.shape[-1] != 2: - raise ValueError - # assumes x vals in every row is the same, otherwise a heatmap representation makes no sense x = data_slice[0, :, 0] # get x from just the first row From e5a8d40e7f2a17f6c0effe5fd577f912cf968211 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sun, 1 Feb 2026 03:43:12 -0500 Subject: [PATCH 013/101] window functions working on n_datapoints dim --- .../widgets/nd_widget/_nd_positions.py | 111 +++++++++++++++--- .../widgets/nd_widget/_processor_base.py | 38 +++--- 2 files changed, 118 insertions(+), 31 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions.py index 201bbb800..ec64d4b9f 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions.py @@ -4,6 +4,7 @@ import numpy as np from numpy.typing import ArrayLike +from numpy.lib.stride_tricks import sliding_window_view from ...utils import subsample_array, ArrayProtocol @@ -15,7 +16,7 @@ ScatterGraphic, ScatterCollection, ) -from ._processor_base import NDProcessor +from ._processor_base import NDProcessor, WindowFuncCallable # TODO: Maybe get rid of n_display_dims in NDProcessor, @@ -27,15 +28,21 @@ def __init__( data: ArrayProtocol, multi: bool = False, # TODO: interpret [n - 2] dimension as n_lines or n_points display_window: int | float | None = 100, # window for n_datapoints dim only - n_slider_dims: int = 0, + datapoints_window_func: Callable | None = None, + datapoints_window_size: int | None = None, + **kwargs ): - super().__init__(data=data) self._display_window = display_window + # TOOD: this does data validation twice and is a bit messy, cleanup + self._data = self._validate_data(data) self.multi = multi - self.n_slider_dims = n_slider_dims + super().__init__(data=data, **kwargs) + + self._datapoints_window_func = datapoints_window_func + self._datapoints_window_size = datapoints_window_size def _validate_data(self, data: ArrayProtocol): # TODO: determine right validation shape etc. @@ -70,6 +77,28 @@ def multi(self, m: bool): self._multi = m + @property + def slider_dims(self) -> tuple[int, ...]: + """slider dimensions""" + return tuple(range(self.ndim - 2 - int(self.multi))) + (self.ndim - 2,) + + @property + def n_slider_dims(self) -> int: + return self.ndim - 1 - int(self.multi) + + # TODO: validation for datapoints_window_func and size + @property + def datapoints_window_func(self) -> tuple[Callable, str] | None: + """ + Callable and str indicating which dims to apply window function along: + 'all', 'x', 'y', 'z', 'xyz', 'xy', 'xz', 'yz' + '""" + return self._datapoints_window_func + + @property + def datapoints_window_size(self) -> Callable | None: + return self._datapoints_window_size + def _apply_window_functions(self, indices: tuple[int, ...]): """applies the window functions for each dimension specified""" # window size for each dim @@ -77,15 +106,21 @@ def _apply_window_functions(self, indices: tuple[int, ...]): # window function for each dim funcs = self._window_funcs - if winds is None or funcs is None: - # no window funcs or window sizes, just slice data and return - # clamp to max bounds - indexer = list() - for dim, i in enumerate(indices): - i = min(self.shape[dim] - 1, i) - indexer.append(i) - - return self.data[tuple(indexer)] + # TODO: use tuple of None for window funcs and sizes to indicate all None, instead of just None + # print(winds) + # print(funcs) + # + # if winds is None or funcs is None: + # # no window funcs or window sizes, just slice data and return + # # clamp to max bounds + # indexer = list() + # print(indices) + # print(self.shape) + # for dim, i in enumerate(indices): + # i = min(self.shape[dim] - 1, i) + # indexer.append(i) + # + # return self.data[tuple(indexer)] # order in which window funcs are applied order = self._window_order @@ -172,6 +207,10 @@ def get(self, indices: tuple[Any, ...]): # for now assume just a single index provided that indicates x axis value start = max(indices[-1] - hw, 0) stop = start + dw + # also add window size of `p` dim so window_func output has the same number of datapoints + if self.datapoints_window_func is not None and self.datapoints_window_size is not None: + stop += self.datapoints_window_size - 1 + # TODO: pad with constant if we're using a window func and the index is near the end # TODO: uncomment this once we have resizeable buffers!! # stop = min(indices[-1] + hw, self.shape[-2]) @@ -182,7 +221,38 @@ def get(self, indices: tuple[Any, ...]): # n - 2 dim is n_lines or n_scatters slices.insert(0, slice(None)) - return window_output[tuple(slices)] + # data that will be used for the graphical representation + # a copy is made, if there were no window functions then this is a view of the original data + graphic_data = window_output[tuple(slices)].copy() + + # apply window function on the `p` n_datapoints dim + if self.datapoints_window_func is not None and self.datapoints_window_size is not None: + # get windows + + # graphic_data will be of shape: [n, p + (ws - 1), 2 | 3] + # where: + # n - number of lines, scatters, heatmap rows + # p - number of datapoints/samples + + # windows will be of shape [n, p, 1 | 2 | 3, ws] + wf = self.datapoints_window_func[0] + apply_dims = self.datapoints_window_func[1] + ws = self.datapoints_window_size + + # apply user's window func and return + # result will be of shape [n, p, 2 | 3] + if apply_dims == "all": + windows = sliding_window_view(graphic_data, ws, axis=-2) + return wf(windows, axis=-1) + + # map user dims str to tuple of numerical dims + dims = tuple(map({"x": 0, "y": 1, "z": 2}.get, apply_dims)) + windows = sliding_window_view(graphic_data[..., dims], ws, axis=-2).squeeze() + graphic_data[..., :self.display_window, dims] = wf(windows, axis=-1)[..., None] + + return graphic_data[..., :self.display_window, :] + + return graphic_data class NDPositions: @@ -192,12 +262,21 @@ def __init__( graphic: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic | ScatterCollection | ImageGraphic], multi: bool = False, display_window: int = 10, + window_funcs: tuple[WindowFuncCallable | None] | None = None, + window_sizes: tuple[int | None] | None = None, ): if issubclass(graphic, LineCollection): multi = True - self._processor = NDPositionsProcessor(data, multi=multi, display_window=display_window, n_slider_dims=0) - self._indices = tuple([0] * (0 + 1)) + self._processor = NDPositionsProcessor( + data, + multi=multi, + display_window=display_window, + window_funcs=window_funcs, + window_sizes=window_sizes, + ) + + self._indices = tuple([0] * self._processor.n_slider_dims) self._create_graphic(graphic) diff --git a/fastplotlib/widgets/nd_widget/_processor_base.py b/fastplotlib/widgets/nd_widget/_processor_base.py index 3350fff8f..974677144 100644 --- a/fastplotlib/widgets/nd_widget/_processor_base.py +++ b/fastplotlib/widgets/nd_widget/_processor_base.py @@ -16,14 +16,14 @@ def __init__( self, data, n_display_dims: Literal[2, 3] = 2, - slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, + index_mappings: tuple[Callable[[Any], int] | None, ...] | None = None, window_funcs: tuple[WindowFuncCallable | None] | None = None, window_sizes: tuple[int | None] | None = None, window_order: tuple[int, ...] = None, spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, ): self._data = self._validate_data(data) - self._slider_index_maps = self._validate_slider_index_maps(slider_index_maps) + self._index_mappings = self._validate_index_mappings(index_mappings) self.window_funcs = window_funcs self.window_sizes = window_sizes @@ -43,7 +43,7 @@ def shape(self) -> tuple[int, ...]: @property def ndim(self) -> int: - return int(np.prod(self.shape)) + return len(self.shape) def _validate_data(self, data: ArrayProtocol): if not isinstance(data, ArrayProtocol): @@ -51,6 +51,14 @@ def _validate_data(self, data: ArrayProtocol): return data + @property + def slider_dims(self): + raise NotImplementedError + + @property + def n_slider_dims(self): + raise NotImplementedError + @property def window_funcs( self, @@ -64,21 +72,21 @@ def window_funcs( window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable | None, ): if window_funcs is None: - self._window_funcs = None + self._window_funcs = tuple([None] * self.n_slider_dims) return if callable(window_funcs): window_funcs = (window_funcs,) # if all are None - if all([f is None for f in window_funcs]): - self._window_funcs = None - return + # if all([f is None for f in window_funcs]): + # self._window_funcs = tuple(window_funcs) + # return self._validate_window_func(window_funcs) self._window_funcs = tuple(window_funcs) - self._recompute_histogram() + # self._recompute_histogram() def _validate_window_func(self, funcs): if isinstance(funcs, (tuple, list)): @@ -112,7 +120,7 @@ def window_sizes(self) -> tuple[int | None, ...] | None: @window_sizes.setter def window_sizes(self, window_sizes: tuple[int | None, ...] | int | None): if window_sizes is None: - self._window_sizes = None + self._window_sizes = tuple([None] * self.n_slider_dims) return if isinstance(window_sizes, int): @@ -197,14 +205,14 @@ def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None: # pass @property - def slider_index_maps(self) -> tuple[Callable[[Any], int] | None, ...]: - return self._slider_index_maps + def index_mappings(self) -> tuple[Callable[[Any], int] | None, ...]: + return self._index_mappings - @slider_index_maps.setter - def slider_index_maps(self, maps): - self._maps = self._validate_slider_index_maps(maps) + @index_mappings.setter + def index_mappings(self, maps): + self._index_mappings = self._validate_index_mappings(maps) - def _validate_slider_index_maps(self, maps): + def _validate_index_mappings(self, maps): if maps is not None: if not all([callable(m) or m is None for m in maps]): raise TypeError From 8d050a76a215c1fa78b764a8eb5e80c38a938c76 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sun, 1 Feb 2026 04:00:44 -0500 Subject: [PATCH 014/101] p dim window funcs working for single and multiple dims I think --- fastplotlib/widgets/nd_widget/_nd_positions.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions.py index ec64d4b9f..b20eabb96 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions.py @@ -234,7 +234,6 @@ def get(self, indices: tuple[Any, ...]): # n - number of lines, scatters, heatmap rows # p - number of datapoints/samples - # windows will be of shape [n, p, 1 | 2 | 3, ws] wf = self.datapoints_window_func[0] apply_dims = self.datapoints_window_func[1] ws = self.datapoints_window_size @@ -242,13 +241,18 @@ def get(self, indices: tuple[Any, ...]): # apply user's window func and return # result will be of shape [n, p, 2 | 3] if apply_dims == "all": + # windows will be of shape [n, p, 1 | 2 | 3, ws] windows = sliding_window_view(graphic_data, ws, axis=-2) return wf(windows, axis=-1) # map user dims str to tuple of numerical dims dims = tuple(map({"x": 0, "y": 1, "z": 2}.get, apply_dims)) + + # windows will be of shape [n, p, 1 | 2 | 3, ws] windows = sliding_window_view(graphic_data[..., dims], ws, axis=-2).squeeze() - graphic_data[..., :self.display_window, dims] = wf(windows, axis=-1)[..., None] + + # this reshape is required to reshape wf outputs of shape [n, p] -> [n, p, 1] only when necessary + graphic_data[..., :self.display_window, dims] = wf(windows, axis=-1).reshape(graphic_data.shape[0], self.display_window, len(dims)) return graphic_data[..., :self.display_window, :] From 373199786a7126f2759b59afef03fef6980eb3ba Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sun, 1 Feb 2026 18:20:43 -0500 Subject: [PATCH 015/101] black --- .../widgets/nd_widget/_nd_positions.py | 51 +++++++++++++++---- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions.py index b20eabb96..c39304996 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions.py @@ -30,7 +30,7 @@ def __init__( display_window: int | float | None = 100, # window for n_datapoints dim only datapoints_window_func: Callable | None = None, datapoints_window_size: int | None = None, - **kwargs + **kwargs, ): self._display_window = display_window @@ -208,7 +208,10 @@ def get(self, indices: tuple[Any, ...]): start = max(indices[-1] - hw, 0) stop = start + dw # also add window size of `p` dim so window_func output has the same number of datapoints - if self.datapoints_window_func is not None and self.datapoints_window_size is not None: + if ( + self.datapoints_window_func is not None + and self.datapoints_window_size is not None + ): stop += self.datapoints_window_size - 1 # TODO: pad with constant if we're using a window func and the index is near the end @@ -226,7 +229,10 @@ def get(self, indices: tuple[Any, ...]): graphic_data = window_output[tuple(slices)].copy() # apply window function on the `p` n_datapoints dim - if self.datapoints_window_func is not None and self.datapoints_window_size is not None: + if ( + self.datapoints_window_func is not None + and self.datapoints_window_size is not None + ): # get windows # graphic_data will be of shape: [n, p + (ws - 1), 2 | 3] @@ -249,12 +255,16 @@ def get(self, indices: tuple[Any, ...]): dims = tuple(map({"x": 0, "y": 1, "z": 2}.get, apply_dims)) # windows will be of shape [n, p, 1 | 2 | 3, ws] - windows = sliding_window_view(graphic_data[..., dims], ws, axis=-2).squeeze() + windows = sliding_window_view( + graphic_data[..., dims], ws, axis=-2 + ).squeeze() # this reshape is required to reshape wf outputs of shape [n, p] -> [n, p, 1] only when necessary - graphic_data[..., :self.display_window, dims] = wf(windows, axis=-1).reshape(graphic_data.shape[0], self.display_window, len(dims)) + graphic_data[..., : self.display_window, dims] = wf( + windows, axis=-1 + ).reshape(graphic_data.shape[0], self.display_window, len(dims)) - return graphic_data[..., :self.display_window, :] + return graphic_data[..., : self.display_window, :] return graphic_data @@ -263,7 +273,14 @@ class NDPositions: def __init__( self, data, - graphic: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic | ScatterCollection | ImageGraphic], + graphic: Type[ + LineGraphic + | LineCollection + | LineStack + | ScatterGraphic + | ScatterCollection + | ImageGraphic + ], multi: bool = False, display_window: int = 10, window_funcs: tuple[WindowFuncCallable | None] | None = None, @@ -292,7 +309,12 @@ def processor(self) -> NDPositionsProcessor: def graphic( self, ) -> ( - LineGraphic | LineCollection | LineStack | ScatterGraphic | ScatterCollection | ImageGraphic + LineGraphic + | LineCollection + | LineStack + | ScatterGraphic + | ScatterCollection + | ImageGraphic ): """LineStack or ImageGraphic for heatmaps""" return self._graphic @@ -331,7 +353,14 @@ def indices(self, indices): def _create_graphic( self, - graphic_cls: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic | ScatterCollection | ImageGraphic], + graphic_cls: Type[ + LineGraphic + | LineCollection + | LineStack + | ScatterGraphic + | ScatterCollection + | ImageGraphic + ], ): data_slice = self.processor.get(self.indices) @@ -343,7 +372,9 @@ def _create_graphic( raise ValueError image_data, x0, x_scale = self._create_heatmap_data(data_slice) - self._graphic = graphic_cls(image_data, offset=(x0, 0, -1), scale=(x_scale, 1, 1)) + self._graphic = graphic_cls( + image_data, offset=(x0, 0, -1), scale=(x_scale, 1, 1) + ) else: self._graphic = graphic_cls(data_slice) From 7d4e42024796bc673a5accf575ae469ee1148dc3 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sun, 1 Feb 2026 19:12:43 -0500 Subject: [PATCH 016/101] index_mappings is working I think, lightly tested on p dim --- .../widgets/nd_widget/_nd_positions.py | 16 ++++++---- .../widgets/nd_widget/_processor_base.py | 29 ++++++++++++++----- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions.py index c39304996..1871e027e 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions.py @@ -184,6 +184,9 @@ def get(self, indices: tuple[Any, ...]): Note that we do not use __getitem__ here since the index is a tuple specifying a single integer index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. """ + # apply any slider index mappings + indices = tuple([m(i) for m, i in zip(self.index_mappings, indices)]) + if len(indices) > 1: # there are dims in addition to the n_datapoints dim # apply window funcs @@ -195,7 +198,8 @@ def get(self, indices: tuple[Any, ...]): # TODO: window function on the `p` n_datapoints dimension if self.display_window is not None: - dw = self.display_window + # display window is interpreted using the index mapping for the `p` dim + dw = self.index_mappings[-1](self.display_window) if dw == 1: slices = [slice(indices[-1], indices[-1] + 1)] @@ -244,7 +248,7 @@ def get(self, indices: tuple[Any, ...]): apply_dims = self.datapoints_window_func[1] ws = self.datapoints_window_size - # apply user's window func and return + # apply user's window func # result will be of shape [n, p, 2 | 3] if apply_dims == "all": # windows will be of shape [n, p, 1 | 2 | 3, ws] @@ -260,11 +264,11 @@ def get(self, indices: tuple[Any, ...]): ).squeeze() # this reshape is required to reshape wf outputs of shape [n, p] -> [n, p, 1] only when necessary - graphic_data[..., : self.display_window, dims] = wf( + graphic_data[..., : dw, dims] = wf( windows, axis=-1 - ).reshape(graphic_data.shape[0], self.display_window, len(dims)) + ).reshape(graphic_data.shape[0], dw, len(dims)) - return graphic_data[..., : self.display_window, :] + return graphic_data[..., : dw, :] return graphic_data @@ -285,6 +289,7 @@ def __init__( display_window: int = 10, window_funcs: tuple[WindowFuncCallable | None] | None = None, window_sizes: tuple[int | None] | None = None, + index_mappings: tuple[Callable[[Any], int] | None] | None = None, ): if issubclass(graphic, LineCollection): multi = True @@ -295,6 +300,7 @@ def __init__( display_window=display_window, window_funcs=window_funcs, window_sizes=window_sizes, + index_mappings=index_mappings, ) self._indices = tuple([0] * self._processor.n_slider_dims) diff --git a/fastplotlib/widgets/nd_widget/_processor_base.py b/fastplotlib/widgets/nd_widget/_processor_base.py index 974677144..225608cca 100644 --- a/fastplotlib/widgets/nd_widget/_processor_base.py +++ b/fastplotlib/widgets/nd_widget/_processor_base.py @@ -11,6 +11,10 @@ WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] +def identity(index: int) -> int: + return index + + class NDProcessor: def __init__( self, @@ -23,7 +27,7 @@ def __init__( spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, ): self._data = self._validate_data(data) - self._index_mappings = self._validate_index_mappings(index_mappings) + self._index_mappings = tuple(self._validate_index_mappings(index_mappings)) self.window_funcs = window_funcs self.window_sizes = window_sizes @@ -205,19 +209,30 @@ def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None: # pass @property - def index_mappings(self) -> tuple[Callable[[Any], int] | None, ...]: + def index_mappings(self) -> tuple[Callable[[Any], int]]: return self._index_mappings @index_mappings.setter - def index_mappings(self, maps): - self._index_mappings = self._validate_index_mappings(maps) + def index_mappings(self, maps: tuple[Callable[[Any], int] | None] | None): + self._index_mappings = tuple(self._validate_index_mappings(maps)) def _validate_index_mappings(self, maps): - if maps is not None: - if not all([callable(m) or m is None for m in maps]): + if maps is None: + return tuple([identity] * self.n_slider_dims) + + if len(maps) != self.n_slider_dims: + raise IndexError + + _maps = list() + for m in maps: + if m is None: + _maps.append(identity) + elif callable(m): + _maps.append(identity) + else: raise TypeError - return maps + return tuple(maps) def __getitem__(self, item: tuple[Any, ...]) -> ArrayProtocol: pass From 6cdcb178913874482dd55ef20daf3113879fb3cf Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sun, 1 Feb 2026 21:07:08 -0500 Subject: [PATCH 017/101] remove nd_timeseries since nd_positions is sufficient --- .../widgets/nd_widget/_nd_timeseries.py | 227 ------------------ 1 file changed, 227 deletions(-) delete mode 100644 fastplotlib/widgets/nd_widget/_nd_timeseries.py diff --git a/fastplotlib/widgets/nd_widget/_nd_timeseries.py b/fastplotlib/widgets/nd_widget/_nd_timeseries.py deleted file mode 100644 index 49b9231c3..000000000 --- a/fastplotlib/widgets/nd_widget/_nd_timeseries.py +++ /dev/null @@ -1,227 +0,0 @@ -import inspect -from typing import Literal, Callable, Any -from warnings import warn - -import numpy as np -from numpy.typing import ArrayLike - -from ...utils import subsample_array, ArrayProtocol - -from ...graphics import ImageGraphic, LineStack, LineCollection, ScatterGraphic -from ._processor_base import NDProcessor, WindowFuncCallable - - -VALID_TIMESERIES_Y_DATA_SHAPES = ( - "[n_datapoints] for 1D array of y-values, [n_datapoints, 2] " - "for a 1D array of y and z-values, [n_lines, n_datapoints] for a 2D stack of lines with y-values, " - "or [n_lines, n_datapoints, 2] for a stack of lines with y and z-values." -) - - -# Limitation, no heatmap if z-values present, I don't think you can visualize that -class NDTimeSeriesProcessor(NDProcessor): - def __init__( - self, - data: list[ - ArrayProtocol, ArrayProtocol - ], # list: [x_vals_array, y_vals_and_z_vals_array] - x_values: ArrayProtocol = None, - cmap: str = None, - cmap_transform: ArrayProtocol = None, - display_graphic: Literal["line", "heatmap"] = "line", - n_display_dims: Literal[2, 3] = 2, - slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, - display_window: int | float | None = 100, - window_funcs: tuple[WindowFuncCallable | None] | None = None, - window_sizes: tuple[int | None] | None = None, - spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, - ): - super().__init__( - data=data, - n_display_dims=n_display_dims, - slider_index_maps=slider_index_maps, - ) - - self._display_window = display_window - - self._display_graphic = None - self.display_graphic = display_graphic - - self._uniform_x_values: ArrayProtocol | None = None - self._interp_yz: ArrayProtocol | None = None - - @property - def data(self) -> list[ArrayProtocol, ArrayProtocol]: - return self._data - - @data.setter - def data(self, data: list[ArrayProtocol, ArrayProtocol]): - self._data = self._validate_data(data) - - def _validate_data(self, data: list[ArrayProtocol, ArrayProtocol]): - x_vals, yz_vals = data - - if x_vals.ndim != 1: - raise ("data x values must be 1D") - - if data[1].ndim > 3: - raise ValueError( - f"data yz values must be of shape: {VALID_TIMESERIES_Y_DATA_SHAPES}. You passed data of shape: {yz_vals.shape}" - ) - - return data - - @property - def display_window(self) -> int | float | None: - """display window in the reference units along the x-axis""" - return self._display_window - - @display_window.setter - def display_window(self, dw: int | float | None): - if dw is None: - self._display_window = None - - elif not isinstance(dw, (int, float)): - raise TypeError - - self._display_window = dw - - def __getitem__(self, indices: tuple[Any, ...]) -> ArrayProtocol: - if self.display_window is not None: - # map reference units -> array int indices if necessary - if self.slider_index_maps is not None: - indices_window = self.slider_index_maps(self.display_window) - else: - indices_window = self.display_window - - # half window size - hw = indices_window // 2 - - # for now assume just a single index provided that indicates x axis value - start = max(indices - hw, 0) - stop = start + indices_window - - # slice dim would be ndim - 1 - return self.data[0][start:stop], self.data[1][:, start:stop] - - -class NDTimeSeries: - def __init__(self, processor: NDTimeSeriesProcessor, graphic): - self._processor = processor - - self._indices = 0 - - if graphic == "line": - self._create_line_stack() - elif graphic == "heatmap": - self._create_heatmap() - else: - raise ValueError - - @property - def processor(self) -> NDTimeSeriesProcessor: - return self._processor - - @property - def graphic(self) -> LineStack | ImageGraphic: - """LineStack or ImageGraphic for heatmaps""" - return self._graphic - - @graphic.setter - def graphic(self, g: Literal["line", "heatmap"]): - if g == "line": - # TODO: remove existing graphic - self._create_line_stack() - - elif g == "heatmap": - # make sure "yz" data is only ys and no z values - # can't represent y and z vals in a heatmap - if self.processor.data[1].ndim > 2: - raise ValueError( - "Only y-values are supported for heatmaps, not yz-values" - ) - self._create_heatmap() - - @property - def display_window(self) -> int | float | None: - return self.processor.display_window - - @display_window.setter - def display_window(self, dw: int | float | None): - # create new graphic if it changed - if dw != self.display_window: - create_new_graphic = True - else: - create_new_graphic = False - - self.processor.display_window = dw - - if create_new_graphic: - if isinstance(self.graphic, LineStack): - self.set_index(self._indices) - - def set_index(self, indices: tuple[Any, ...]): - # set the graphic at the given data indices - data_slice = self.processor[indices] - - if isinstance(self.graphic, LineStack): - line_stack_data = self._create_line_stack_data(data_slice) - - for g, line_data in zip(self.graphic.graphics, line_stack_data): - if line_data.shape[1] == 2: - # only x and y values - g.data[:, :-1] = line_data - else: - # has z values too - g.data[:] = line_data - - elif isinstance(self.graphic, ImageGraphic): - hm_data, scale = self._create_heatmap_data(data_slice) - self.graphic.data = hm_data - - self._indices = indices - - def _create_line_stack_data(self, data_slice): - xs = data_slice[0] # 1D - yz = data_slice[ - 1 - ] # [n_lines, n_datapoints] for y-vals or [n_lines, n_datapoints, 2] for yz-vals - - # need to go from x_vals and yz_vals arrays to an array of shape: [n_lines, n_datapoints, 2 | 3] - return np.dstack([np.repeat(xs[None], repeats=yz.shape[0], axis=0), yz]) - - def _create_line_stack(self): - data_slice = self.processor[self._indices] - - ls_data = self._create_line_stack_data(data_slice) - - self._graphic = LineStack(ls_data) - - def _create_heatmap_data(self, data_slice) -> tuple[ArrayProtocol, float]: - """Returns [n_lines, y_values] array and scale factor for x dimension""" - # check if x-vals uniformly spaced - # this is very fast to do on the fly, especially for typical small display windows - x, y = data_slice - norm = np.linalg.norm(np.diff(np.diff(x))) / x.size - if norm > 10**-12: - # need to create evenly spaced x-values - x_uniform = np.linspace(x[0], x[-1], num=x.size) - # yz is [n_lines, n_datapoints] - y_interp = np.zeros(shape=y.shape, dtype=np.float32) - for i in range(y.shape[0]): - y_interp[i] = np.interp(x_uniform, x, y[i]) - - else: - y_interp = y - - x_scale = x[-1] / x.size - - return y_interp, x_scale - - def _create_heatmap(self): - data_slice = self.processor[self._indices] - - hm_data, x_scale = self._create_heatmap_data(data_slice) - - self._graphic = ImageGraphic(hm_data) - self._graphic.world_object.world.scale_x = x_scale From 4748e5939350c9bdf12c8312448c9ce9106dcd68 Mon Sep 17 00:00:00 2001 From: Kushal Kolar Date: Wed, 4 Feb 2026 12:08:55 -0500 Subject: [PATCH 018/101] auto-replace buffers (#974) * remove isolated_buffer * remove isolated_buffer from mixin * basics works for positions data * replaceable buffers for all positions related features * image data buffer can change * resizeable buffers for volume * black * buffer resize condition checked only if new value is an array * gc for buffer managers * uniform colors WIP * switching color modes works! * typo * balck * update tests for color_mode * update examples * backend tests passing * default for all uniforms is True * update examples * forgot * update test * example tests passing * dereferencing test and fixes * simplify texture array tests a bit * image replace buffer tests pass yay * forgot a file * comments, check image graphic * add image reshaping example * add buffer replace imgui thing for manual testing * black * dont call wgpu_obj.destroy(), seems to work and clear VRAM with normal dereferencing * slower changes * update * update example * fixes and tweaks for test * remove unecessary stuff * update * docstrings * fix example * update example * update example * update docs --- docs/source/api/graphics/LineGraphic.rst | 1 + docs/source/api/graphics/ScatterGraphic.rst | 1 + examples/events/cmap_event.py | 2 +- examples/gridplot/multigraphic_gridplot.py | 2 +- examples/guis/imgui_basic.py | 4 +- examples/image/image_reshaping.py | 50 +++++ examples/line/line_cmap.py | 4 +- examples/line/line_cmap_more.py | 25 ++- examples/line/line_colorslice.py | 4 +- .../line_collection_slicing.py | 1 + examples/machine_learning/kmeans.py | 1 + examples/misc/buffer_replace_gc.py | 91 +++++++++ examples/misc/lorenz_animation.py | 7 +- examples/misc/reshape_lines_scatters.py | 92 +++++++++ examples/misc/scatter_animation.py | 2 +- examples/misc/scatter_sizes_animation.py | 2 +- examples/notebooks/quickstart.ipynb | 4 +- examples/scatter/scatter_iris.py | 1 + examples/scatter/scatter_size.py | 2 +- examples/scatter/scatter_validate.py | 2 + examples/scatter/spinning_spiral.py | 9 +- fastplotlib/graphics/_positions_base.py | 175 ++++++++++++++---- fastplotlib/graphics/features/_base.py | 54 +++--- fastplotlib/graphics/features/_image.py | 12 +- fastplotlib/graphics/features/_mesh.py | 8 +- fastplotlib/graphics/features/_positions.py | 99 ++++++++-- fastplotlib/graphics/features/_scatter.py | 134 +++++++++----- fastplotlib/graphics/features/_vectors.py | 2 - fastplotlib/graphics/features/_volume.py | 12 +- fastplotlib/graphics/image.py | 69 +++++-- fastplotlib/graphics/image_volume.py | 38 +++- fastplotlib/graphics/line.py | 25 +-- fastplotlib/graphics/line_collection.py | 11 +- fastplotlib/graphics/mesh.py | 17 +- fastplotlib/graphics/scatter.py | 88 ++++----- fastplotlib/layouts/_graphic_methods_mixin.py | 141 ++++++-------- tests/test_colors_buffer_manager.py | 12 +- tests/test_markers_buffer_manager.py | 8 +- tests/test_point_rotations_buffer_manager.py | 2 +- tests/test_positions_data_buffer_manager.py | 2 +- tests/test_positions_graphics.py | 55 +++--- tests/test_replace_buffer.py | 155 ++++++++++++++++ tests/test_scatter_graphic.py | 2 +- tests/test_texture_array.py | 134 ++++++-------- tests/utils_textures.py | 64 +++++++ 45 files changed, 1160 insertions(+), 466 deletions(-) create mode 100644 examples/image/image_reshaping.py create mode 100644 examples/misc/buffer_replace_gc.py create mode 100644 examples/misc/reshape_lines_scatters.py create mode 100644 tests/test_replace_buffer.py create mode 100644 tests/utils_textures.py diff --git a/docs/source/api/graphics/LineGraphic.rst b/docs/source/api/graphics/LineGraphic.rst index 428e8ef56..867f1bfbb 100644 --- a/docs/source/api/graphics/LineGraphic.rst +++ b/docs/source/api/graphics/LineGraphic.rst @@ -25,6 +25,7 @@ Properties LineGraphic.axes LineGraphic.block_events LineGraphic.cmap + LineGraphic.color_mode LineGraphic.colors LineGraphic.data LineGraphic.deleted diff --git a/docs/source/api/graphics/ScatterGraphic.rst b/docs/source/api/graphics/ScatterGraphic.rst index cf8e1224d..f9dcd2487 100644 --- a/docs/source/api/graphics/ScatterGraphic.rst +++ b/docs/source/api/graphics/ScatterGraphic.rst @@ -25,6 +25,7 @@ Properties ScatterGraphic.axes ScatterGraphic.block_events ScatterGraphic.cmap + ScatterGraphic.color_mode ScatterGraphic.colors ScatterGraphic.data ScatterGraphic.deleted diff --git a/examples/events/cmap_event.py b/examples/events/cmap_event.py index 62913cb29..f01f06d6a 100644 --- a/examples/events/cmap_event.py +++ b/examples/events/cmap_event.py @@ -34,7 +34,7 @@ xs = np.linspace(0, 4 * np.pi, 100) ys = np.sin(xs) -figure["sine"].add_line(np.column_stack([xs, ys])) +figure["sine"].add_line(np.column_stack([xs, ys]), color_mode="vertex") # make a 2D gaussian cloud cloud_data = np.random.normal(0, scale=3, size=1000).reshape(500, 2) diff --git a/examples/gridplot/multigraphic_gridplot.py b/examples/gridplot/multigraphic_gridplot.py index cbf546e2a..0e89efcdc 100644 --- a/examples/gridplot/multigraphic_gridplot.py +++ b/examples/gridplot/multigraphic_gridplot.py @@ -106,7 +106,7 @@ def make_circle(center, radius: float, n_points: int = 75) -> np.ndarray: gaussian_cloud2 = np.random.multivariate_normal(mean, covariance, n_points) # add the scatter graphics to the figure -figure["scatter"].add_scatter(data=gaussian_cloud, sizes=2, cmap="jet") +figure["scatter"].add_scatter(data=gaussian_cloud, sizes=2, cmap="jet", color_mode="vertex") figure["scatter"].add_scatter(data=gaussian_cloud2, colors="r", sizes=2) figure.show() diff --git a/examples/guis/imgui_basic.py b/examples/guis/imgui_basic.py index 26b5603c0..26c2c0fca 100644 --- a/examples/guis/imgui_basic.py +++ b/examples/guis/imgui_basic.py @@ -29,10 +29,10 @@ figure = fpl.Figure(size=(700, 560)) # make some scatter points at every 10th point -figure[0, 0].add_scatter(data[::10], colors="cyan", sizes=15, name="sine-scatter", uniform_color=True) +figure[0, 0].add_scatter(data[::10], colors="cyan", sizes=15, name="sine-scatter") # place a line above the scatter -figure[0, 0].add_line(data, thickness=3, colors="r", name="sine-wave", uniform_color=True) +figure[0, 0].add_line(data, thickness=3, colors="r", name="sine-wave") class ImguiExample(EdgeWindow): diff --git a/examples/image/image_reshaping.py b/examples/image/image_reshaping.py new file mode 100644 index 000000000..23264bda1 --- /dev/null +++ b/examples/image/image_reshaping.py @@ -0,0 +1,50 @@ +""" +Image reshaping +=============== + +An example that shows replacement of the image data with new data of a different shape. Under the hood, this creates a +new buffer and a new array of Textures on the GPU that replace the older Textures. Creating a new buffer and textures +has a performance cost, so you should do this only if you need to or if the performance drawback is not a concern for +your use case. + +Note that the vmin-vmax is reset when you replace the buffers. +""" + +# test_example = false +# sphinx_gallery_pygfx_docs = 'animate' + + +import numpy as np +import fastplotlib as fpl + +# create some data, diagonal sinusoidal bands +xs = np.linspace(0, 2300, 2300, dtype=np.float16) +full_data = np.vstack([np.cos(np.sqrt(xs + (np.pi / 2) * i)) * i for i in range(2_300)]) + +figure = fpl.Figure() + +image = figure[0, 0].add_image(full_data) + +figure.show() + +i, j = 1, 1 + + +def update(): + global i, j + # set the new image data as a subset of the full data + row = np.abs(np.sin(i)) * 2300 + col = np.abs(np.cos(i)) * 2300 + image.data = full_data[: int(row), : int(col)] + + i += 0.01 + j += 0.01 + + +figure.add_animations(update) + +# NOTE: fpl.loop.run() should not be used for interactive sessions +# See the "JupyterLab and IPython" section in the user guide +if __name__ == "__main__": + print(__doc__) + fpl.loop.run() diff --git a/examples/line/line_cmap.py b/examples/line/line_cmap.py index 3d2b5e8c9..6dfc1fe23 100644 --- a/examples/line/line_cmap.py +++ b/examples/line/line_cmap.py @@ -27,7 +27,7 @@ data=sine_data, thickness=10, cmap="plasma", - cmap_transform=sine_data[:, 1] + cmap_transform=sine_data[:, 1], ) # qualitative colormaps, useful for cluster labels or other types of categorical labels @@ -36,7 +36,7 @@ data=cosine_data, thickness=10, cmap="tab10", - cmap_transform=labels + cmap_transform=labels, ) figure.show() diff --git a/examples/line/line_cmap_more.py b/examples/line/line_cmap_more.py index c7c0d80f4..c6e811fb2 100644 --- a/examples/line/line_cmap_more.py +++ b/examples/line/line_cmap_more.py @@ -31,16 +31,35 @@ # set colormap by mapping data using a transform # here we map the color using the y-values of the sine data # i.e., the color is a function of sine(x) -line2 = figure[0, 0].add_line(sine, thickness=10, cmap="jet", cmap_transform=sine[:, 1], offset=(0, 4, 0)) +line2 = figure[0, 0].add_line( + sine, + thickness=10, + cmap="jet", + cmap_transform=sine[:, 1], + offset=(0, 4, 0), +) # make a line and change the cmap afterward, here we are using the cosine instead fot the transform -line3 = figure[0, 0].add_line(sine, thickness=10, cmap="jet", cmap_transform=cosine[:, 1], offset=(0, 6, 0)) +line3 = figure[0, 0].add_line( + sine, + thickness=10, + cmap="jet", + cmap_transform=cosine[:, 1], + offset=(0, 6, 0) +) + # change the cmap line3.cmap = "bwr" # use quantitative colormaps with categorical cmap_transforms labels = [0] * 25 + [1] * 5 + [2] * 50 + [3] * 20 -line4 = figure[0, 0].add_line(sine, thickness=10, cmap="tab10", cmap_transform=labels, offset=(0, 8, 0)) +line4 = figure[0, 0].add_line( + sine, + thickness=10, + cmap="tab10", + cmap_transform=labels, + offset=(0, 8, 0), +) # some text labels for i in range(5): diff --git a/examples/line/line_colorslice.py b/examples/line/line_colorslice.py index b6865eadb..264f944f3 100644 --- a/examples/line/line_colorslice.py +++ b/examples/line/line_colorslice.py @@ -30,7 +30,8 @@ sine = figure[0, 0].add_line( data=sine_data, thickness=5, - colors="magenta" + colors="magenta", + color_mode="vertex", # initialize with same color across vertices, but we will change the per-vertex colors later ) # you can also use colormaps for lines! @@ -56,6 +57,7 @@ data=zeros_data, thickness=8, colors="w", + color_mode="vertex", # initialize with same color across vertices, but we will change the per-vertex colors later offset=(0, 10, 0) ) diff --git a/examples/line_collection/line_collection_slicing.py b/examples/line_collection/line_collection_slicing.py index f829a53c6..98ad97056 100644 --- a/examples/line_collection/line_collection_slicing.py +++ b/examples/line_collection/line_collection_slicing.py @@ -26,6 +26,7 @@ multi_data, thickness=[2, 10, 2, 5, 5, 5, 8, 8, 8, 9, 3, 3, 3, 4, 4], separation=4, + color_mode="vertex", # this will allow us to set per-vertex colors on each line metadatas=list(range(15)), # some metadata names=list("abcdefghijklmno"), # unique name for each line ) diff --git a/examples/machine_learning/kmeans.py b/examples/machine_learning/kmeans.py index f571882ce..4c49844f0 100644 --- a/examples/machine_learning/kmeans.py +++ b/examples/machine_learning/kmeans.py @@ -80,6 +80,7 @@ sizes=5, cmap="tab10", # use a qualitative cmap cmap_transform=kmeans.labels_, # color by the predicted cluster + uniform_size=False, ) # initial index diff --git a/examples/misc/buffer_replace_gc.py b/examples/misc/buffer_replace_gc.py new file mode 100644 index 000000000..e3b0ac104 --- /dev/null +++ b/examples/misc/buffer_replace_gc.py @@ -0,0 +1,91 @@ +""" +Buffer replacement garbage collection test +========================================== + +This is an example that used for a manual test to ensure that GPU VRAM is free when buffers are replaced. + +Use while monitoring VRAM usage with nvidia-smi +""" + +# test_example = false +# sphinx_gallery_pygfx_docs = 'code' + + +from typing import Literal +import numpy as np +import fastplotlib as fpl +from fastplotlib.ui import EdgeWindow +from imgui_bundle import imgui + + +def generate_dataset(size: int) -> dict[str, np.ndarray]: + return { + "data": np.random.rand(size, 3), + "colors": np.random.rand(size, 4), + # TODO: there's a wgpu bind group issue with edge_colors, will figure out later + # "edge_colors": np.random.rand(size, 4), + "markers": np.random.choice(list("osD+x^v<>*"), size=size), + "sizes": np.random.rand(size) * 5, + "point_rotations": np.random.rand(size) * 180, + } + + +datasets = { + "init": generate_dataset(50_000), + "small": generate_dataset(100), + "large": generate_dataset(5_000_000), +} + + +class UI(EdgeWindow): + def __init__(self, figure): + super().__init__(figure=figure, size=200, location="right", title="UI") + init_data = datasets["init"] + self._figure["line"].add_line( + data=init_data["data"], colors=init_data["colors"], name="line" + ) + self._figure["scatter"].add_scatter( + **init_data, + uniform_size=False, + uniform_marker=False, + uniform_edge_color=False, + point_rotation_mode="vertex", + name="scatter", + ) + + def update(self): + for graphic in ["line", "scatter"]: + if graphic == "line": + features = ["data", "colors"] + + elif graphic == "scatter": + features = list(datasets["init"].keys()) + + for size in ["small", "large"]: + for fea in features: + if imgui.button(f"{size} - {graphic} - {fea}"): + self._replace(graphic, fea, size) + + def _replace( + self, + graphic: Literal["line", "scatter", "image"], + feature: Literal["data", "colors", "markers", "sizes", "point_rotations"], + size: Literal["small", "large"], + ): + new_value = datasets[size][feature] + + setattr(self._figure[graphic][graphic], feature, new_value) + + +figure = fpl.Figure(shape=(3, 1), size=(700, 1600), names=["line", "scatter", "image"]) +ui = UI(figure) +figure.add_gui(ui) + +figure.show() + + +# NOTE: fpl.loop.run() should not be used for interactive sessions +# See the "JupyterLab and IPython" section in the user guide +if __name__ == "__main__": + print(__doc__) + fpl.loop.run() diff --git a/examples/misc/lorenz_animation.py b/examples/misc/lorenz_animation.py index 20aee5d83..52a77a243 100644 --- a/examples/misc/lorenz_animation.py +++ b/examples/misc/lorenz_animation.py @@ -60,7 +60,12 @@ def lorenz(xyz, *, s=10, r=28, b=2.667): scatter_markers = list() for graphic in lorenz_line: - marker = figure[0, 0].add_scatter(graphic.data.value[0], sizes=16, colors=graphic.colors[0]) + marker = figure[0, 0].add_scatter( + graphic.data.value[0], + sizes=16, + colors=graphic.colors, + edge_colors="w", + ) scatter_markers.append(marker) # initialize time diff --git a/examples/misc/reshape_lines_scatters.py b/examples/misc/reshape_lines_scatters.py new file mode 100644 index 000000000..db8adb29e --- /dev/null +++ b/examples/misc/reshape_lines_scatters.py @@ -0,0 +1,92 @@ +""" +Change number of points in lines and scatters +============================================= + +This example sets lines and scatters with new data of a different shape, i.e. new data with more or fewer datapoints. +Internally, this creates new buffers for the feature that is being set (data, colors, markers, etc.). Note that there +are performance drawbacks to doing this, so it is recommended to maintain the same number of datapoints in a graphic +when possible. You only want to change the number of datapoints when it's really necessary, and you don't want to do +it constantly (such as tens or hundreds of times per second). + +This example is also useful for manually checking that GPU buffers are freed when they're no longer in use. Run this +example while monitoring VRAM usage with `nvidia-smi` +""" + +# test_example = false +# sphinx_gallery_pygfx_docs = 'animate' + + +import numpy as np +import fastplotlib as fpl + +# create some data to start with +xs = np.linspace(0, 10 * np.pi, 100) +ys = np.sin(xs) + +data = np.column_stack([xs, ys]) + +# create a figure, add a line, scatter and line_stack +figure = fpl.Figure(shape=(3, 1), size=(700, 700)) + +line = figure[0, 0].add_line(data) + +scatter = figure[1, 0].add_scatter( + np.random.rand(100, 3), + colors=np.random.rand(100, 4), + markers=np.random.choice(list("osD+x^v<>*"), size=100), + sizes=(np.random.rand(100) + 1) * 3, + edge_colors=np.random.rand(100, 4), + point_rotations=np.random.rand(100) * 180, + uniform_marker=False, + uniform_size=False, + uniform_edge_color=False, + point_rotation_mode="vertex", +) + +line_stack = figure[2, 0].add_line_stack(np.stack([data] * 10), cmap="viridis") + +text = figure[0, 0].add_text(f"n_points: {100}", offset=(0, 1.5, 0), anchor="middle-left") + +figure.show(maintain_aspect=False) + +i = 0 + + +def update(): + # set a new larger or smaller data array on every render + global i + + # create new data + freq = np.abs(np.sin(i)) * 10 + n_points = int((freq * 20_000) + 10) + + xs = np.linspace(0, 10 * np.pi, n_points) + ys = np.sin(xs * freq) + + new_data = np.column_stack([xs, ys]) + + # update line data + line.data = new_data + + # update scatter data, colors, markers, etc. + scatter.data = np.random.rand(n_points, 3) + scatter.colors = np.random.rand(n_points, 4) + scatter.markers = np.random.choice(list("osD+x^v<>*"), size=n_points) + scatter.edge_colors = np.random.rand(n_points, 4) + scatter.point_rotations = np.random.rand(n_points) * 180 + + # update line stack data + line_stack.data = np.stack([new_data] * 10) + + text.text = f"n_points: {n_points}" + + i += 0.01 + + +figure.add_animations(update) + +# NOTE: fpl.loop.run() should not be used for interactive sessions +# See the "JupyterLab and IPython" section in the user guide +if __name__ == "__main__": + print(__doc__) + fpl.loop.run() diff --git a/examples/misc/scatter_animation.py b/examples/misc/scatter_animation.py index d37aea976..549059b65 100644 --- a/examples/misc/scatter_animation.py +++ b/examples/misc/scatter_animation.py @@ -37,7 +37,7 @@ figure = fpl.Figure(size=(700, 560)) subplot_scatter = figure[0, 0] # use an alpha value since this will be a lot of points -scatter = subplot_scatter.add_scatter(data=cloud, sizes=3, colors=colors, alpha=0.6) +scatter = subplot_scatter.add_scatter(data=cloud, sizes=3, uniform_size=False, colors=colors, alpha=0.6) def update_points(subplot): diff --git a/examples/misc/scatter_sizes_animation.py b/examples/misc/scatter_sizes_animation.py index 53a616a68..2092787f3 100644 --- a/examples/misc/scatter_sizes_animation.py +++ b/examples/misc/scatter_sizes_animation.py @@ -20,7 +20,7 @@ figure = fpl.Figure(size=(700, 560)) -figure[0, 0].add_scatter(data, sizes=sizes, name="sine") +figure[0, 0].add_scatter(data, sizes=sizes, uniform_size=False, name="sine") i = 0 diff --git a/examples/notebooks/quickstart.ipynb b/examples/notebooks/quickstart.ipynb index 7b7551588..61bcb6b06 100644 --- a/examples/notebooks/quickstart.ipynb +++ b/examples/notebooks/quickstart.ipynb @@ -719,8 +719,8 @@ "# we will add all the lines to the same subplot\n", "subplot = fig_lines[0, 0]\n", "\n", - "# plot sine wave, use a single color\n", - "sine = subplot.add_line(data=sine_data, thickness=5, colors=\"magenta\")\n", + "# plot sine wave, use a single color for now, but we will set per-vertex colors later\n", + "sine = subplot.add_line(data=sine_data, thickness=5, colors=\"magenta\", color_mode=\"vertex\")\n", "\n", "# you can also use colormaps for lines!\n", "cosine = subplot.add_line(data=cosine_data, thickness=12, cmap=\"autumn\")\n", diff --git a/examples/scatter/scatter_iris.py b/examples/scatter/scatter_iris.py index b9df16026..fc228e5bf 100644 --- a/examples/scatter/scatter_iris.py +++ b/examples/scatter/scatter_iris.py @@ -35,6 +35,7 @@ cmap="tab10", cmap_transform=clusters_labels, markers=markers, + uniform_marker=False, ) figure.show() diff --git a/examples/scatter/scatter_size.py b/examples/scatter/scatter_size.py index 30d3e6ea3..2b3899dbe 100644 --- a/examples/scatter/scatter_size.py +++ b/examples/scatter/scatter_size.py @@ -35,7 +35,7 @@ ) # add a set of scalar sizes non_scalar_sizes = np.abs((y_values / np.pi)) # ensure minimum size of 5 -figure["array_size"].add_scatter(data=data, sizes=non_scalar_sizes, colors="red") +figure["array_size"].add_scatter(data=data, sizes=non_scalar_sizes, uniform_size=False, colors="red") for graph in figure: graph.auto_scale(maintain_aspect=True) diff --git a/examples/scatter/scatter_validate.py b/examples/scatter/scatter_validate.py index abddffee0..45f0a177c 100644 --- a/examples/scatter/scatter_validate.py +++ b/examples/scatter/scatter_validate.py @@ -41,6 +41,7 @@ uniform_edge_color=False, edge_colors=["w"] * 3 + ["orange"] * 3 + ["blue"] * 3 + ["green"], markers=list("osD+x^v<>*"), + uniform_marker=False, edge_width=2.0, sizes=20, uniform_size=True, @@ -64,6 +65,7 @@ sine, markers="s", sizes=xs * 5, + uniform_size=False, offset=(0, 2, 0) ) diff --git a/examples/scatter/spinning_spiral.py b/examples/scatter/spinning_spiral.py index 89e74eaec..4f947970a 100644 --- a/examples/scatter/spinning_spiral.py +++ b/examples/scatter/spinning_spiral.py @@ -34,7 +34,14 @@ canvas_kwargs={"max_fps": 500, "vsync": False} ) -spiral = figure[0, 0].add_scatter(data, cmap="viridis_r", edge_colors=None, alpha=0.5, sizes=sizes) +spiral = figure[0, 0].add_scatter( + data, + cmap="viridis_r", + edge_colors=None, + alpha=0.5, + sizes=sizes, + uniform_size=False, +) # pre-generate normally distributed data to jitter the points before each render jitter = np.random.normal(scale=0.001, size=n * 3).reshape((n, 3)) diff --git a/fastplotlib/graphics/_positions_base.py b/fastplotlib/graphics/_positions_base.py index af7d7badb..763f5e775 100644 --- a/fastplotlib/graphics/_positions_base.py +++ b/fastplotlib/graphics/_positions_base.py @@ -1,4 +1,6 @@ -from typing import Any, Sequence +from numbers import Real +from typing import Any, Sequence, Literal +from warnings import warn import numpy as np @@ -18,12 +20,20 @@ class PositionsGraphic(Graphic): @property def data(self) -> VertexPositions: - """Get or set the graphic's data""" + """ + Get or set the graphic's data. + + Note that if the number of datapoints does not match the number of + current datapoints a new buffer is automatically allocated. This can + have performance drawbacks when you have a very large number of datapoints. + This is usually fine as long as you don't need to do it hundreds of times + per second. + """ return self._data @data.setter def data(self, value): - self._data[:] = value + self._data.set_value(self, value) @property def colors(self) -> VertexColors | pygfx.Color: @@ -36,11 +46,59 @@ def colors(self) -> VertexColors | pygfx.Color: @colors.setter def colors(self, value: str | np.ndarray | Sequence[float] | Sequence[str]): + self._colors.set_value(self, value) + + @property + def color_mode(self) -> Literal["uniform", "vertex"]: + """ + Get or set the color mode. Note that after setting the color_mode, you will have to set the `colors` + as well for switching between 'uniform' and 'vertex' modes. + """ + return self.world_object.material.color_mode + + @color_mode.setter + def color_mode(self, mode: Literal["uniform", "vertex"]): + valid = ("uniform", "vertex") + if mode not in valid: + raise ValueError(f"`color_mode` must be one of : {valid}") + if mode == "vertex" and isinstance(self._colors, UniformColor): + # uniform -> vertex + # need to make a new vertex buffer and get rid of uniform buffer + new_colors = self._create_colors_buffer(self._colors.value, "vertex") + # we can't clear world_object.material.color so just set the colors buffer on the geometry + # this doesn't really matter anyways since the lingering uniform color takes up just a few bytes + self.world_object.geometry.colors = new_colors._fpl_buffer + + elif mode == "uniform" and isinstance(self._colors, VertexColors): + # vertex -> uniform + # use first vertex color and spit out a warning + warn( + "changing `color_mode` from vertex -> uniform, will use first vertex color " + "for the uniform and discard the remaining color values" + ) + new_colors = self._create_colors_buffer(self._colors.value[0], "uniform") + self.world_object.geometry.colors = None + self.world_object.material.color = new_colors.value + + # clear out cmap + self._cmap.clear_event_handlers() + self._cmap = None + + else: + # no change, return + return + + # restore event handlers onto the new colors feature + new_colors._event_handlers[:] = self._colors._event_handlers + self._colors.clear_event_handlers() + # this should trigger gc + self._colors = new_colors + + # this is created so that cmap can be set later if isinstance(self._colors, VertexColors): - self._colors[:] = value + self._cmap = VertexCmap(self._colors, cmap_name=None, transform=None) - elif isinstance(self._colors, UniformColor): - self._colors.set_value(self, value) + self.world_object.material.color_mode = mode @property def cmap(self) -> VertexCmap: @@ -53,8 +111,8 @@ def cmap(self) -> VertexCmap: @cmap.setter def cmap(self, name: str): - if self._cmap is None: - raise BufferError("Cannot use cmap with uniform_colors=True") + if self.color_mode == "uniform": + raise ValueError("cannot use `cmap` with `color_mode` = 'uniform'") self._cmap[:] = name @@ -71,14 +129,68 @@ def size_space(self): def size_space(self, value: str): self._size_space.set_value(self, value) + def _create_colors_buffer(self, colors, color_mode) -> UniformColor | VertexColors: + # creates either a UniformColor or VertexColors based on the given `colors` and `color_mode` + # if `color_mode` = "auto", returns {UniformColor | VertexColor} based on what the `colors` arg represents + # if `color_mode` = "uniform", it verifies that the user `colors` input represents just 1 color + # if `color_mode` = "vertex", always returns VertexColors regardless of whether `colors` represents >= 1 colors + + if isinstance(colors, VertexColors): + if color_mode == "uniform": + raise ValueError( + "if a `VertexColors` instance is provided for `colors`, " + "`color_mode` must be 'vertex' or 'auto', not 'uniform'" + ) + # share buffer with existing colors instance + new_colors = colors + # blank colormap instance + self._cmap = VertexCmap(new_colors, cmap_name=None, transform=None) + + else: + # determine if a single or multiple colors were passed and decide color mode + if isinstance(colors, (pygfx.Color, str)) or ( + len(colors) in [3, 4] and all(isinstance(v, Real) for v in colors) + ): + # one color specified as a str or pygfx.Color, or one color specified with RGB(A) values + if color_mode in ("auto", "uniform"): + new_colors = UniformColor(colors) + else: + new_colors = VertexColors( + colors, n_colors=self._data.value.shape[0] + ) + + elif all(isinstance(c, (str, pygfx.Color)) for c in colors): + # sequence of colors + if color_mode == "uniform": + raise ValueError( + "You passed `color_mode` = 'uniform', but specified a sequence of multiple colors. Use " + "`color_mode` = 'auto' or 'vertex' for multiple colors." + ) + new_colors = VertexColors(colors, n_colors=self._data.value.shape[0]) + + elif len(colors) > 4: + # sequence of multiple colors, must again ensure color_mode is not uniform + if color_mode == "uniform": + raise ValueError( + "You passed `color_mode` = 'uniform', but specified a sequence of multiple colors. Use " + "`color_mode` = 'auto' or 'vertex' for multiple colors." + ) + new_colors = VertexColors(colors, n_colors=self._data.value.shape[0]) + else: + raise ValueError( + "`colors` must be a str, pygfx.Color, array, list or tuple indicating an RGB(A) color, or a " + "sequence of str, pygfx.Color, or array of shape [n_datapoints, 3 | 4]" + ) + + return new_colors + def __init__( self, data: Any, colors: str | np.ndarray | tuple[float] | list[float] | list[str] = "w", - uniform_color: bool = False, cmap: str | VertexCmap = None, cmap_transform: np.ndarray = None, - isolated_buffer: bool = True, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", size_space: str = "screen", *args, **kwargs, @@ -86,22 +198,31 @@ def __init__( if isinstance(data, VertexPositions): self._data = data else: - self._data = VertexPositions(data, isolated_buffer=isolated_buffer) + self._data = VertexPositions(data) if cmap_transform is not None and cmap is None: raise ValueError("must pass `cmap` if passing `cmap_transform`") + valid = ("auto", "uniform", "vertex") + + # default _cmap is None + self._cmap = None + + if color_mode not in valid: + raise ValueError(f"`color_mode` must be one of {valid}") + if cmap is not None: # if a cmap is specified it overrides colors argument - if uniform_color: - raise TypeError("Cannot use cmap if uniform_color=True") + if color_mode == "uniform": + raise ValueError( + "if a `cmap` is provided, `color_mode` must be 'vertex' or 'auto', not 'uniform'" + ) if isinstance(cmap, str): # make colors from cmap if isinstance(colors, VertexColors): # share buffer with existing colors instance for the cmap self._colors = colors - self._colors._shared += 1 else: # create vertex colors buffer self._colors = VertexColors("w", n_colors=self._data.value.shape[0]) @@ -115,34 +236,18 @@ def __init__( # use existing cmap instance self._cmap = cmap self._colors = cmap._vertex_colors + else: raise TypeError( "`cmap` argument must be a cmap name or an existing `VertexCmap` instance" ) else: # no cmap given - if isinstance(colors, VertexColors): - # share buffer with existing colors instance - self._colors = colors - self._colors._shared += 1 - # blank colormap instance + self._colors = self._create_colors_buffer(colors, color_mode) + + # this is created so that cmap can be set later + if isinstance(self._colors, VertexColors): self._cmap = VertexCmap(self._colors, cmap_name=None, transform=None) - else: - if uniform_color: - if not isinstance(colors, str): # not a single color - if not len(colors) in [3, 4]: # not an RGB(A) array - raise TypeError( - "must pass a single color if using `uniform_colors=True`" - ) - self._colors = UniformColor(colors) - self._cmap = None - else: - self._colors = VertexColors( - colors, n_colors=self._data.value.shape[0] - ) - self._cmap = VertexCmap( - self._colors, cmap_name=None, transform=None - ) self._size_space = SizeSpace(size_space) super().__init__(*args, **kwargs) diff --git a/fastplotlib/graphics/features/_base.py b/fastplotlib/graphics/features/_base.py index 779310476..76352b4ef 100644 --- a/fastplotlib/graphics/features/_base.py +++ b/fastplotlib/graphics/features/_base.py @@ -1,5 +1,6 @@ +import weakref from warnings import warn -from typing import Literal +from typing import Callable import numpy as np from numpy.typing import NDArray @@ -78,7 +79,7 @@ def block_events(self, val: bool): """ self._block_events = val - def add_event_handler(self, handler: callable): + def add_event_handler(self, handler: Callable): """ Add an event handler. All added event handlers are called when this feature changes. @@ -89,7 +90,7 @@ def add_event_handler(self, handler: callable): Parameters ---------- - handler: callable + handler: Callable a function to call when this feature changes """ @@ -102,7 +103,7 @@ def add_event_handler(self, handler: callable): self._event_handlers.append(handler) - def remove_event_handler(self, handler: callable): + def remove_event_handler(self, handler: Callable): """ Remove a registered event ``handler``. @@ -137,32 +138,28 @@ class BufferManager(GraphicFeature): def __init__( self, - data: NDArray | pygfx.Buffer, - buffer_type: Literal["buffer", "texture", "texture-array"] = "buffer", - isolated_buffer: bool = True, + data: NDArray | pygfx.Buffer | None, **kwargs, ): super().__init__(**kwargs) - if isolated_buffer and not isinstance(data, pygfx.Resource): - # useful if data is read-only, example: memmaps - bdata = np.zeros(data.shape, dtype=data.dtype) - bdata[:] = data[:] - else: - # user's input array is used as the buffer - bdata = data - - if isinstance(data, pygfx.Resource): - # already a buffer, probably used for - # managing another BufferManager, example: VertexCmap manages VertexColors - self._buffer = data - elif buffer_type == "buffer": - self._buffer = pygfx.Buffer(bdata) + + # if data is None, then the BufferManager just provides a view into an existing buffer + # example: VertexCmap is basically a view into VertexColors + if data is not None: + if isinstance(data, pygfx.Resource): + # already a buffer, probably used for + # managing another BufferManager, example: VertexCmap manages VertexColors + self._fpl_buffer = data + else: + # create a buffer + bdata = np.empty(data.shape, dtype=data.dtype) + bdata[:] = data[:] + + self._fpl_buffer = pygfx.Buffer(bdata) else: - raise ValueError( - "`data` must be a pygfx.Buffer instance or `buffer_type` must be one of: 'buffer' or 'texture'" - ) + self._fpl_buffer = None - self._event_handlers: list[callable] = list() + self._event_handlers: list[Callable] = list() @property def value(self) -> np.ndarray: @@ -174,9 +171,10 @@ def set_value(self, graphic, value): self[:] = value @property - def buffer(self) -> pygfx.Buffer | pygfx.Texture: - """managed buffer""" - return self._buffer + def buffer(self) -> pygfx.Buffer: + """managed buffer, returns a weakref proxy""" + # the user should never create their own references to the buffer + return weakref.proxy(self._fpl_buffer) @property def __array_interface__(self): diff --git a/fastplotlib/graphics/features/_image.py b/fastplotlib/graphics/features/_image.py index 648f79bc8..cb66bb1ef 100644 --- a/fastplotlib/graphics/features/_image.py +++ b/fastplotlib/graphics/features/_image.py @@ -33,7 +33,7 @@ class TextureArray(GraphicFeature): }, ] - def __init__(self, data, isolated_buffer: bool = True, property_name: str = "data"): + def __init__(self, data, property_name: str = "data"): super().__init__(property_name=property_name) data = self._fix_data(data) @@ -41,13 +41,9 @@ def __init__(self, data, isolated_buffer: bool = True, property_name: str = "dat shared = pygfx.renderers.wgpu.get_shared() self._texture_limit_2d = shared.device.limits["max-texture-dimension-2d"] - if isolated_buffer: - # useful if data is read-only, example: memmaps - self._value = np.zeros(data.shape, dtype=data.dtype) - self.value[:] = data[:] - else: - # user's input array is used as the buffer - self._value = data + # create a new buffer + self._value = np.zeros(data.shape, dtype=data.dtype) + self.value[:] = data[:] # data start indices for each Texture self._row_indices = np.arange( diff --git a/fastplotlib/graphics/features/_mesh.py b/fastplotlib/graphics/features/_mesh.py index 7355acb4e..776d77ce4 100644 --- a/fastplotlib/graphics/features/_mesh.py +++ b/fastplotlib/graphics/features/_mesh.py @@ -51,18 +51,14 @@ class MeshIndices(VertexPositions): }, ] - def __init__( - self, data: Any, isolated_buffer: bool = True, property_name: str = "indices" - ): + def __init__(self, data: Any, property_name: str = "indices"): """ Manages the vertex indices buffer shown in the graphic. Supports fancy indexing if the data array also supports it. """ data = self._fix_data(data) - super().__init__( - data, isolated_buffer=isolated_buffer, property_name=property_name - ) + super().__init__(data, property_name=property_name) def _fix_data(self, data): if data.ndim != 2 or data.shape[1] not in (3, 4): diff --git a/fastplotlib/graphics/features/_positions.py b/fastplotlib/graphics/features/_positions.py index 295d22417..7b67e6bd7 100644 --- a/fastplotlib/graphics/features/_positions.py +++ b/fastplotlib/graphics/features/_positions.py @@ -39,7 +39,6 @@ def __init__( self, colors: str | pygfx.Color | np.ndarray | Sequence[float] | Sequence[str], n_colors: int, - isolated_buffer: bool = True, property_name: str = "colors", ): """ @@ -57,9 +56,56 @@ def __init__( """ data = parse_colors(colors, n_colors) - super().__init__( - data=data, isolated_buffer=isolated_buffer, property_name=property_name - ) + super().__init__(data=data, property_name=property_name) + + def set_value( + self, + graphic, + value: str | pygfx.Color | np.ndarray | Sequence[float] | Sequence[str], + ): + """set the entire array, create new buffer if necessary""" + if isinstance(value, (np.ndarray, list, tuple)): + # TODO: Refactor this triage so it's more elegant + + # first make sure it's not representing one color + skip = False + if isinstance(value, np.ndarray): + if (value.shape in ((3,), (4,))) and ( + np.issubdtype(value.dtype, np.floating) + or np.issubdtype(value.dtype, np.integer) + ): + # represents one color + skip = True + elif isinstance(value, (list, tuple)): + if len(value) in (3, 4) and all( + [isinstance(v, (float, int)) for v in value] + ): + # represents one color + skip = True + + # check if the number of elements matches current buffer size + if not skip and self.buffer.data.shape[0] != len(value): + # parse the new colors + new_colors = parse_colors(value, len(value)) + + # create the new buffer, old buffer should get dereferenced + self._fpl_buffer = pygfx.Buffer(new_colors) + graphic.world_object.geometry.colors = self._fpl_buffer + + if len(self._event_handlers) < 1: + return + + event_info = { + "key": slice(None), + "value": new_colors, + "user_value": value, + } + + event = GraphicFeatureEvent(self._property_name, info=event_info) + self._call_event_handlers(event) + return + + self[:] = value @block_reentrance def __setitem__( @@ -231,18 +277,14 @@ class VertexPositions(BufferManager): }, ] - def __init__( - self, data: Any, isolated_buffer: bool = True, property_name: str = "data" - ): + def __init__(self, data: Any, property_name: str = "data"): """ Manages the vertex positions buffer shown in the graphic. Supports fancy indexing if the data array also supports it. """ data = self._fix_data(data) - super().__init__( - data, isolated_buffer=isolated_buffer, property_name=property_name - ) + super().__init__(data, property_name=property_name) def _fix_data(self, data): if data.ndim == 1: @@ -261,13 +303,42 @@ def _fix_data(self, data): return to_gpu_supported_dtype(data) + def set_value(self, graphic, value): + """Sets the entire array, creates new buffer if necessary""" + if isinstance(value, np.ndarray): + if self.buffer.data.shape[0] != value.shape[0]: + # number of items doesn't match, create a new buffer + + # if data is not 3D + if value.ndim == 1: + # _fix_data creates a new array so we don't need to re-allocate with np.zeros + bdata = self._fix_data(value) + + elif value.shape[1] == 2: + # _fix_data creates a new array so we don't need to re-allocate with np.zeros + bdata = self._fix_data(value) + + elif value.shape[1] == 3: + # need to allocate a buffer to use here + bdata = np.empty(value.shape, dtype=np.float32) + bdata[:] = value[:] + + # create the new buffer, old buffer should get dereferenced + self._fpl_buffer = pygfx.Buffer(bdata) + graphic.world_object.geometry.positions = self._fpl_buffer + + self._emit_event(self._property_name, key=slice(None), value=value) + return + + self[:] = value + @block_reentrance def __setitem__( self, key: int | slice | np.ndarray[int | bool] | tuple[slice, ...], value: np.ndarray | float | list[float], ): - # directly use the key to slice the buffer + # directly use the key to slice the buffer and set the values self.buffer.data[key] = value # _update_range handles parsing the key to @@ -306,7 +377,7 @@ def __init__( provides a way to set colormaps with arbitrary transforms """ - super().__init__(data=vertex_colors.buffer, property_name=property_name) + super().__init__(data=None, property_name=property_name) self._vertex_colors = vertex_colors self._cmap_name = cmap_name @@ -331,6 +402,10 @@ def __init__( # set vertex colors from cmap self._vertex_colors[:] = colors + @property + def buffer(self) -> pygfx.Buffer: + return self._vertex_colors.buffer + @block_reentrance def __setitem__(self, key: slice, cmap_name): if not isinstance(key, slice): diff --git a/fastplotlib/graphics/features/_scatter.py b/fastplotlib/graphics/features/_scatter.py index 16671ef89..36c8527be 100644 --- a/fastplotlib/graphics/features/_scatter.py +++ b/fastplotlib/graphics/features/_scatter.py @@ -100,6 +100,37 @@ def searchsorted_markers_to_int_array(markers_str_array: np.ndarray[str]): return marker_int_searchsorted_vals[indices] +def parse_markers_init(markers: str | Sequence[str] | np.ndarray, n_datapoints: int): + # first validate then allocate buffers + + if isinstance(markers, str): + markers = user_input_to_marker(markers) + + elif isinstance(markers, (tuple, list, np.ndarray)): + validate_user_markers_array(markers) + + # allocate buffers + markers_int_array = np.zeros(n_datapoints, dtype=np.int32) + + marker_str_length = max(map(len, list(pygfx.MarkerShape))) + + markers_readable_array = np.empty(n_datapoints, dtype=f" np.ndarray[str]: @@ -200,6 +200,25 @@ def _set_markers_arrays(self, key, value, n_markers): "new markers value must be a str, Sequence or np.ndarray of new marker values" ) + def set_value(self, graphic, value): + """set all the markers, create new buffer if necessary""" + if isinstance(value, (np.ndarray, list, tuple)): + if self.buffer.data.shape[0] != len(value): + # need to create a new buffer + markers_int_array, self._markers_readable_array = parse_markers_init( + value, len(value) + ) + + # create the new buffer, old buffer should get dereferenced + self._fpl_buffer = pygfx.Buffer(markers_int_array) + graphic.world_object.geometry.markers = self._fpl_buffer + + self._emit_event(self._property_name, key=slice(None), value=value) + + return + + self[:] = value + @block_reentrance def __setitem__( self, @@ -414,18 +433,15 @@ def __init__( self, rotations: int | float | np.ndarray | Sequence[int | float], n_datapoints: int, - isolated_buffer: bool = True, property_name: str = "point_rotations", ): """ Manages rotations buffer of scatter points. """ - sizes = self._fix_sizes(rotations, n_datapoints) - super().__init__( - data=sizes, isolated_buffer=isolated_buffer, property_name=property_name - ) + sizes = self._fix_rotations(rotations, n_datapoints) + super().__init__(data=sizes, property_name=property_name) - def _fix_sizes( + def _fix_rotations( self, sizes: int | float | np.ndarray | Sequence[int | float], n_datapoints: int, @@ -454,6 +470,22 @@ def _fix_sizes( return sizes + def set_value(self, graphic, value): + """set all rotations, create new buffer if necessary""" + if isinstance(value, (np.ndarray, list, tuple)): + if self.buffer.data.shape[0] != value.shape[0]: + # need to create a new buffer + value = self._fix_rotations(value, len(value)) + data = np.empty(shape=(len(value),), dtype=np.float32) + + # create the new buffer, old buffer should get dereferenced + self._fpl_buffer = pygfx.Buffer(data) + graphic.world_object.geometry.rotations = self._fpl_buffer + self._emit_event(self._property_name, key=slice(None), value=value) + return + + self[:] = value + @block_reentrance def __setitem__( self, @@ -488,16 +520,13 @@ def __init__( self, sizes: int | float | np.ndarray | Sequence[int | float], n_datapoints: int, - isolated_buffer: bool = True, property_name: str = "sizes", ): """ Manages sizes buffer of scatter points. """ sizes = self._fix_sizes(sizes, n_datapoints) - super().__init__( - data=sizes, isolated_buffer=isolated_buffer, property_name=property_name - ) + super().__init__(data=sizes, property_name=property_name) def _fix_sizes( self, @@ -533,6 +562,23 @@ def _fix_sizes( return sizes + def set_value(self, graphic, value): + """set all sizes, create new buffer if necessary""" + if isinstance(value, (np.ndarray, list, tuple)): + if self.buffer.data.shape[0] != len(value): + # create new buffer + value = self._fix_sizes(value, len(value)) + data = np.empty(shape=(len(value),), dtype=np.float32) + + # create the new buffer, old buffer should get dereferenced + self._fpl_buffer = pygfx.Buffer(data) + graphic.world_object.geometry.sizes = self._fpl_buffer + + self._emit_event(self._property_name, key=slice(None), value=value) + return + + self[:] = value + @block_reentrance def __setitem__( self, diff --git a/fastplotlib/graphics/features/_vectors.py b/fastplotlib/graphics/features/_vectors.py index 9c86d25fc..729562b06 100644 --- a/fastplotlib/graphics/features/_vectors.py +++ b/fastplotlib/graphics/features/_vectors.py @@ -22,7 +22,6 @@ class VectorPositions(GraphicFeature): def __init__( self, positions: np.ndarray, - isolated_buffer: bool = True, property_name: str = "positions", ): """ @@ -111,7 +110,6 @@ class VectorDirections(GraphicFeature): def __init__( self, directions: np.ndarray, - isolated_buffer: bool = True, property_name: str = "directions", ): """Manages vector field positions by managing the mesh instance buffer's full transform matrix""" diff --git a/fastplotlib/graphics/features/_volume.py b/fastplotlib/graphics/features/_volume.py index ec4c4052a..532065fb7 100644 --- a/fastplotlib/graphics/features/_volume.py +++ b/fastplotlib/graphics/features/_volume.py @@ -34,7 +34,7 @@ class TextureArrayVolume(GraphicFeature): }, ] - def __init__(self, data, isolated_buffer: bool = True): + def __init__(self, data): super().__init__(property_name="data") data = self._fix_data(data) @@ -43,13 +43,9 @@ def __init__(self, data, isolated_buffer: bool = True): self._texture_size_limit = shared.device.limits["max-texture-dimension-3d"] - if isolated_buffer: - # useful if data is read-only, example: memmaps - self._value = np.zeros(data.shape, dtype=data.dtype) - self.value[:] = data[:] - else: - # user's input array is used as the buffer - self._value = data + # create a new buffer that will be used for the texture data + self._value = np.zeros(data.shape, dtype=data.dtype) + self.value[:] = data[:] # data start indices for each Texture self._row_indices = np.arange( diff --git a/fastplotlib/graphics/image.py b/fastplotlib/graphics/image.py index 44bffcedc..760b856d2 100644 --- a/fastplotlib/graphics/image.py +++ b/fastplotlib/graphics/image.py @@ -1,6 +1,7 @@ import math from typing import * +import numpy as np import pygfx from ..utils import quick_min_max @@ -102,7 +103,6 @@ def __init__( cmap: str = "plasma", interpolation: str = "nearest", cmap_interpolation: str = "linear", - isolated_buffer: bool = True, **kwargs, ): """ @@ -130,12 +130,6 @@ def __init__( cmap_interpolation: str, optional, default "linear" colormap interpolation method, one of "nearest" or "linear" - isolated_buffer: bool, default True - If True, initialize a buffer with the same shape as the input data and then - set the data, useful if the data arrays are ready-only such as memmaps. - If False, the input array is itself used as the buffer - useful if the - array is large. - kwargs: additional keyword arguments passed to :class:`.Graphic` @@ -143,7 +137,7 @@ def __init__( super().__init__(**kwargs) - world_object = pygfx.Group() + group = pygfx.Group() if isinstance(data, TextureArray): # share buffer @@ -151,7 +145,7 @@ def __init__( else: # create new texture array to manage buffer # texture array that manages the multiple textures on the GPU that represent this image - self._data = TextureArray(data, isolated_buffer=isolated_buffer) + self._data = TextureArray(data) if (vmin is None) or (vmax is None): _vmin, _vmax = quick_min_max(self.data.value) @@ -165,6 +159,7 @@ def __init__( self._vmax = ImageVmax(vmax) self._interpolation = ImageInterpolation(interpolation) + self._cmap_interpolation = ImageCmapInterpolation(cmap_interpolation) # set map to None for RGB images if self._data.value.ndim > 2: @@ -173,7 +168,6 @@ def __init__( else: # use TextureMap for grayscale images self._cmap = ImageCmap(cmap) - self._cmap_interpolation = ImageCmapInterpolation(cmap_interpolation) _map = pygfx.TextureMap( self._cmap.texture, @@ -189,6 +183,14 @@ def __init__( pick_write=True, ) + # create the _ImageTile world objects, add to group + for tile in self._create_tiles(): + group.add(tile) + + self._set_world_object(group) + + def _create_tiles(self) -> list[_ImageTile]: + tiles = list() # iterate through each texture chunk and create # an _ImageTile, offset the tile using the data indices for texture, chunk_index, data_slice in self._data: @@ -209,17 +211,58 @@ def __init__( img.world.x = data_col_start img.world.y = data_row_start - world_object.add(img) + tiles.append(img) - self._set_world_object(world_object) + return tiles @property def data(self) -> TextureArray: - """Get or set the image data""" + """ + Get or set the image data. + + Note that if the shape of the new data array does not equal the shape of + current data array, a new set of GPU Textures are automatically created. + This can have performance drawbacks when you have a ver large images. + This is usually fine as long as you don't need to do it hundreds of times + per second. + """ return self._data @data.setter def data(self, data): + if isinstance(data, np.ndarray): + # check if a new buffer is required + if self._data.value.shape != data.shape: + # create new TextureArray + self._data = TextureArray(data) + + # cmap based on if rgb or grayscale + if self._data.value.ndim > 2: + self._cmap = None + + # must be None if RGB(A) + self._material.map = None + else: + if self.cmap is None: # have switched from RGBA -> grayscale image + # create default cmap + self._cmap = ImageCmap("plasma") + self._material.map = pygfx.TextureMap( + self._cmap.texture, + filter=self._cmap_interpolation.value, + wrap="clamp-to-edge", + ) + + self._material.clim = quick_min_max(self.data.value) + + # clear image tiles + self.world_object.clear() + + # create new tiles + for tile in self._create_tiles(): + self.world_object.add(tile) + + return + self._data[:] = data @property diff --git a/fastplotlib/graphics/image_volume.py b/fastplotlib/graphics/image_volume.py index db8f29eaa..a3b379492 100644 --- a/fastplotlib/graphics/image_volume.py +++ b/fastplotlib/graphics/image_volume.py @@ -113,7 +113,6 @@ def __init__( substep_size: float = 0.1, emissive: str | tuple | np.ndarray = (0, 0, 0), shininess: int = 30, - isolated_buffer: bool = True, **kwargs, ): """ @@ -170,11 +169,6 @@ def __init__( How shiny the specular highlight is; a higher value gives a sharper highlight. Used only if `mode` = "iso" - isolated_buffer: bool, default True - If True, initialize a buffer with the same shape as the input data and then set the data, useful if the - data arrays are ready-only such as memmaps. If False, the input array is itself used as the - buffer - useful if the array is large. - kwargs additional keyword arguments passed to :class:`.Graphic` @@ -188,7 +182,7 @@ def __init__( super().__init__(**kwargs) - world_object = pygfx.Group() + group = pygfx.Group() if isinstance(data, TextureArrayVolume): # share existing buffer @@ -196,7 +190,7 @@ def __init__( else: # create new texture array to manage buffer # texture array that manages the textures on the GPU that represent this image volume - self._data = TextureArrayVolume(data, isolated_buffer=isolated_buffer) + self._data = TextureArrayVolume(data) if (vmin is None) or (vmax is None): _vmin, _vmax = quick_min_max(self.data.value) @@ -237,6 +231,15 @@ def __init__( self._mode = VolumeRenderMode(mode) + # create tiles + for tile in self._create_tiles(): + group.add(tile) + + self._set_world_object(group) + + def _create_tiles(self) -> list[_VolumeTile]: + tiles = list() + # iterate through each texture chunk and create # a _VolumeTile, offset the tile using the data indices for texture, chunk_index, data_slice in self._data: @@ -259,9 +262,9 @@ def __init__( vol.world.x = data_col_start vol.world.y = data_row_start - world_object.add(vol) + tiles.append(vol) - self._set_world_object(world_object) + return tiles @property def data(self) -> TextureArrayVolume: @@ -270,6 +273,21 @@ def data(self) -> TextureArrayVolume: @data.setter def data(self, data): + if isinstance(data, np.ndarray): + # check if a new buffer is required + if self._data.value.shape != data.shape: + # create new TextureArray + self._data = TextureArrayVolume(data) + + # clear image tiles + self.world_object.clear() + + # create new tiles + for tile in self._create_tiles(): + self.world_object.add(tile) + + return + self._data[:] = data @property diff --git a/fastplotlib/graphics/line.py b/fastplotlib/graphics/line.py index a4f42704f..bba10b10f 100644 --- a/fastplotlib/graphics/line.py +++ b/fastplotlib/graphics/line.py @@ -18,6 +18,7 @@ UniformColor, VertexCmap, SizeSpace, + UniformRotations, ) from ..utils import quick_min_max @@ -36,10 +37,9 @@ def __init__( data: Any, thickness: float = 2.0, colors: str | np.ndarray | Sequence = "w", - uniform_color: bool = False, cmap: str = None, cmap_transform: np.ndarray | Sequence = None, - isolated_buffer: bool = True, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", size_space: str = "screen", **kwargs, ): @@ -61,15 +61,19 @@ def __init__( specify colors as a single human-readable string, a single RGBA array, or a Sequence (array, tuple, or list) of strings or RGBA arrays - uniform_color: bool, default ``False`` - if True, uses a uniform buffer for the line color, - basically saves GPU VRAM when the entire line has a single color - cmap: str, optional Apply a colormap to the line instead of assigning colors manually, this overrides any argument passed to "colors". For supported colormaps see the ``cmap`` library catalogue: https://cmap-docs.readthedocs.io/en/stable/catalog/ + color_mode: one of "auto", "uniform", "vertex", default "auto" + "uniform" restricts to a single color for all line datapoints. + "vertex" allows independent colors per vertex. + For most cases you can keep it as "auto" and the `color_mode` is determineed automatically based on the + argument passed to `colors`. if `colors` represents a single color, then the mode is set to "uniform". + If `colors` represents a unique color per-datapoint, or if a cmap is provided, then `color_mode` is set to + "vertex". You can switch between "uniform" and "vertex" `color_mode` after creating the graphic. + cmap_transform: 1D array-like of numerical values, optional if provided, these values are used to map the colors from the cmap @@ -84,10 +88,9 @@ def __init__( super().__init__( data=data, colors=colors, - uniform_color=uniform_color, cmap=cmap, cmap_transform=cmap_transform, - isolated_buffer=isolated_buffer, + color_mode=color_mode, size_space=size_space, **kwargs, ) @@ -102,8 +105,8 @@ def __init__( aa = kwargs.get("alpha_mode", "auto") in ("blend", "weighted_blend") - if uniform_color: - geometry = pygfx.Geometry(positions=self._data.buffer) + if isinstance(self._colors, UniformColor): + geometry = pygfx.Geometry(positions=self._data._fpl_buffer) material = MaterialCls( aa=aa, thickness=self.thickness, @@ -123,7 +126,7 @@ def __init__( depth_compare="<=", ) geometry = pygfx.Geometry( - positions=self._data.buffer, colors=self._colors.buffer + positions=self._data._fpl_buffer, colors=self._colors._fpl_buffer ) world_object: pygfx.Line = pygfx.Line(geometry=geometry, material=material) diff --git a/fastplotlib/graphics/line_collection.py b/fastplotlib/graphics/line_collection.py index d08231f7d..5ec56777e 100644 --- a/fastplotlib/graphics/line_collection.py +++ b/fastplotlib/graphics/line_collection.py @@ -128,14 +128,13 @@ def __init__( data: np.ndarray | List[np.ndarray], thickness: float | Sequence[float] = 2.0, colors: str | Sequence[str] | np.ndarray | Sequence[np.ndarray] = "w", - uniform_colors: bool = False, cmap: Sequence[str] | str = None, cmap_transform: np.ndarray | List = None, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", name: str = None, names: list[str] = None, metadata: Any = None, metadatas: Sequence[Any] | np.ndarray = None, - isolated_buffer: bool = True, kwargs_lines: list[dict] = None, **kwargs, ): @@ -170,6 +169,9 @@ def __init__( cmap_transform: 1D array-like of numerical values, optional if provided, these values are used to map the colors from the cmap + color_mode: one of "auto", "uniform", "vertex", default "auto" + The color mode for each line in the collection. See `color_mode` in :class:`.LineGraphic` for details. + name: str, optional name of the line collection as a whole @@ -320,11 +322,10 @@ def __init__( data=d, thickness=_s, colors=_c, - uniform_color=uniform_colors, cmap=_cmap, + color_mode=color_mode, name=_name, metadata=_m, - isolated_buffer=isolated_buffer, **kwargs_lines, ) @@ -560,7 +561,6 @@ def __init__( names: list[str] = None, metadata: Any = None, metadatas: Sequence[Any] | np.ndarray = None, - isolated_buffer: bool = True, separation: float = 10.0, separation_axis: str = "y", kwargs_lines: list[dict] = None, @@ -634,7 +634,6 @@ def __init__( names=names, metadata=metadata, metadatas=metadatas, - isolated_buffer=isolated_buffer, kwargs_lines=kwargs_lines, **kwargs, ) diff --git a/fastplotlib/graphics/mesh.py b/fastplotlib/graphics/mesh.py index 0e1ac42a3..efe03c57b 100644 --- a/fastplotlib/graphics/mesh.py +++ b/fastplotlib/graphics/mesh.py @@ -38,7 +38,6 @@ def __init__( mapcoords: Any = None, cmap: str | dict | pygfx.Texture | pygfx.TextureMap | np.ndarray = None, clim: tuple[float, float] = None, - isolated_buffer: bool = True, **kwargs, ): """ @@ -77,12 +76,6 @@ def __init__( Both 1D and 2D colormaps are supported, though the mapcoords has to match the dimensionality. An image can also be used, this is basically a 2D colormap. - isolated_buffer: bool, default True - If True, initialize a buffer with the same shape as the input data and then - set the data, useful if the data arrays are ready-only such as memmaps. - If False, the input array is itself used as the buffer - useful if the - array is large. In almost all cases this should be ``True``. - **kwargs passed to :class:`.Graphic` @@ -93,16 +86,12 @@ def __init__( if isinstance(positions, VertexPositions): self._positions = positions else: - self._positions = VertexPositions( - positions, isolated_buffer=isolated_buffer, property_name="positions" - ) + self._positions = VertexPositions(positions, property_name="positions") if isinstance(positions, MeshIndices): self._indices = indices else: - self._indices = MeshIndices( - indices, isolated_buffer=isolated_buffer, property_name="indices" - ) + self._indices = MeshIndices(indices, property_name="indices") self._cmap = MeshCmap(cmap) @@ -139,7 +128,7 @@ def __init__( ) geometry = pygfx.Geometry( - positions=self._positions.buffer, indices=self._indices._buffer + positions=self._positions.buffer, indices=self._indices._fpl_buffer ) valid_modes = ["basic", "phong", "slice"] diff --git a/fastplotlib/graphics/scatter.py b/fastplotlib/graphics/scatter.py index 5268dcc51..b9cacf908 100644 --- a/fastplotlib/graphics/scatter.py +++ b/fastplotlib/graphics/scatter.py @@ -40,12 +40,12 @@ def __init__( self, data: Any, colors: str | np.ndarray | Sequence[float] | Sequence[str] = "w", - uniform_color: bool = False, cmap: str = None, cmap_transform: np.ndarray = None, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", mode: Literal["markers", "simple", "gaussian", "image"] = "markers", markers: str | np.ndarray | Sequence[str] = "o", - uniform_marker: bool = False, + uniform_marker: bool = True, custom_sdf: str = None, edge_colors: str | np.ndarray | pygfx.Color | Sequence[float] = "black", uniform_edge_color: bool = True, @@ -54,9 +54,8 @@ def __init__( point_rotations: float | np.ndarray = 0, point_rotation_mode: Literal["uniform", "vertex", "curve"] = "uniform", sizes: float | np.ndarray | Sequence[float] = 5, - uniform_size: bool = False, + uniform_size: bool = True, size_space: str = "screen", - isolated_buffer: bool = True, **kwargs, ): """ @@ -72,18 +71,23 @@ def __init__( specify colors as a single human-readable string, a single RGBA array, or a Sequence (array, tuple, or list) of strings or RGBA arrays - uniform_color: bool, default False - if True, uses a uniform buffer for the scatter point colors. Useful if you need to - save GPU VRAM when all points have the same color. - cmap: str, optional apply a colormap to the scatter instead of assigning colors manually, this - overrides any argument passed to "colors". For supported colormaps see the - ``cmap`` library catalogue: https://cmap-docs.readthedocs.io/en/stable/catalog/ + overrides any argument passed to "colors". + For supported colormaps see the ``cmap`` library catalogue: + https://cmap-docs.readthedocs.io/en/stable/catalog/ cmap_transform: 1D array-like or list of numerical values, optional if provided, these values are used to map the colors from the cmap + color_mode: one of "auto", "uniform", "vertex", default "auto" + "uniform" restricts to a single color for all line datapoints. + "vertex" allows independent colors per vertex. + For most cases you can keep it as "auto" and the `color_mode` is determineed automatically based on the + argument passed to `colors`. if `colors` represents a single color, then the mode is set to "uniform". + If `colors` represents a unique color per-datapoint, or if a cmap is provided, then `color_mode` is set to + "vertex". You can switch between "uniform" and "vertex" `color_mode` after creating the graphic. + mode: one of: "markers", "simple", "gaussian", "image", default "markers" The scatter points mode, cannot be changed after the graphic has been created. @@ -103,9 +107,10 @@ def __init__( * Emojis: "❤️♠️♣️♦️💎💍✳️📍". * A string containing the value "custom". In this case, WGSL code defined by ``custom_sdf`` will be used. - uniform_marker: bool, default False - Use the same marker for all points. Only valid when `mode` is "markers". Useful if you need to use - the same marker for all points and want to save GPU RAM. + uniform_marker: bool, default ``True`` + If ``True``, use the same marker for all points. Only valid when `mode` is "markers". + Useful if you need to use the same marker for all points and want to save GPU RAM. If ``False``, you can + set per-vertex markers. custom_sdf: str = None, The SDF code for the marker shape when the marker is set to custom. @@ -125,8 +130,9 @@ def __init__( edge_colors: str | np.ndarray | pygfx.Color | Sequence[float], default "black" edge color of the markers, used when `mode` is "markers" - uniform_edge_color: bool, default True - Set the same edge color for all markers. Useful for saving GPU RAM. + uniform_edge_color: bool, default ``True`` + Set the same edge color for all markers. Useful for saving GPU RAM. Set to ``False`` for per-vertex edge + colors edge_width: float = 1.0, Width of the marker edges. used when `mode` is "markers". @@ -147,17 +153,13 @@ def __init__( sizes: float or iterable of float, optional, default 1.0 sizes of the scatter points - uniform_size: bool, default False - if True, uses a uniform buffer for the scatter point sizes. Useful if you need to - save GPU VRAM when all points have the same size. + uniform_size: bool, default ``False`` + if ``True``, uses a uniform buffer for the scatter point sizes. Useful if you need to + save GPU VRAM when all points have the same size. Set to ``False`` if you need per-vertex sizes. size_space: str, default "screen" coordinate space in which the size is expressed, one of ("screen", "world", "model") - isolated_buffer: bool, default True - whether the buffers should be isolated from the user input array. - Generally always ``True``, ``False`` is for rare advanced use if you have large arrays. - kwargs passed to :class:`.Graphic` @@ -166,17 +168,16 @@ def __init__( super().__init__( data=data, colors=colors, - uniform_color=uniform_color, cmap=cmap, cmap_transform=cmap_transform, - isolated_buffer=isolated_buffer, + color_mode=color_mode, size_space=size_space, **kwargs, ) n_datapoints = self.data.value.shape[0] - geo_kwargs = {"positions": self._data.buffer} + geo_kwargs = {"positions": self._data._fpl_buffer} aa = kwargs.get("alpha_mode", "auto") in ("blend", "weighted_blend") @@ -214,7 +215,7 @@ def __init__( self._markers = VertexMarkers(markers, n_datapoints) - geo_kwargs["markers"] = self._markers.buffer + geo_kwargs["markers"] = self._markers._fpl_buffer if edge_colors is None: # interpret as no edge color @@ -237,7 +238,7 @@ def __init__( edge_colors, n_datapoints, property_name="edge_colors" ) material_kwargs["edge_color_mode"] = pygfx.ColorMode.vertex - geo_kwargs["edge_colors"] = self._edge_colors.buffer + geo_kwargs["edge_colors"] = self._edge_colors._fpl_buffer self._edge_width = EdgeWidth(edge_width) material_kwargs["edge_width"] = self._edge_width.value @@ -274,12 +275,12 @@ def __init__( self._size_space = SizeSpace(size_space) - if uniform_color: + if isinstance(self._colors, UniformColor): material_kwargs["color_mode"] = pygfx.ColorMode.uniform material_kwargs["color"] = self.colors else: material_kwargs["color_mode"] = pygfx.ColorMode.vertex - geo_kwargs["colors"] = self.colors.buffer + geo_kwargs["colors"] = self.colors._fpl_buffer if uniform_size: material_kwargs["size_mode"] = pygfx.SizeMode.uniform @@ -288,14 +289,14 @@ def __init__( else: material_kwargs["size_mode"] = pygfx.SizeMode.vertex self._sizes = VertexPointSizes(sizes, n_datapoints=n_datapoints) - geo_kwargs["sizes"] = self.sizes.buffer + geo_kwargs["sizes"] = self.sizes._fpl_buffer match point_rotation_mode: case pygfx.enums.RotationMode.vertex: self._point_rotations = VertexRotations( point_rotations, n_datapoints=n_datapoints ) - geo_kwargs["rotations"] = self._point_rotations.buffer + geo_kwargs["rotations"] = self._point_rotations._fpl_buffer case pygfx.enums.RotationMode.uniform: self._point_rotations = UniformRotations(point_rotations) @@ -338,10 +339,8 @@ def markers(self, value: str | np.ndarray[str] | Sequence[str]): raise AttributeError( f"scatter plot is: {self.mode}. The mode must be 'markers' to set the markers" ) - if isinstance(self._markers, VertexMarkers): - self._markers[:] = value - elif isinstance(self._markers, UniformMarker): - self._markers.set_value(self, value) + + self._markers.set_value(self, value) @property def edge_colors(self) -> str | pygfx.Color | VertexColors | None: @@ -359,12 +358,7 @@ def edge_colors(self, value: str | np.ndarray | Sequence[str] | Sequence[float]) raise AttributeError( f"scatter plot is: {self.mode}. The mode must be 'markers' to set the edge_colors" ) - - if isinstance(self._edge_colors, VertexColors): - self._edge_colors[:] = value - - elif isinstance(self._edge_colors, UniformEdgeColor): - self._edge_colors.set_value(self, value) + self._edge_colors.set_value(self, value) @property def edge_width(self) -> float | None: @@ -406,11 +400,7 @@ def point_rotations(self, value: float | np.ndarray[float]): f"it be 'uniform' or 'vertex' to set the `point_rotations`" ) - if isinstance(self._point_rotations, VertexRotations): - self._point_rotations[:] = value - - elif isinstance(self._point_rotations, UniformRotations): - self._point_rotations.set_value(self, value) + self._point_rotations.set_value(self, value) @property def image(self) -> TextureArray | None: @@ -437,8 +427,4 @@ def sizes(self) -> VertexPointSizes | float: @sizes.setter def sizes(self, value): - if isinstance(self._sizes, VertexPointSizes): - self._sizes[:] = value - - elif isinstance(self._sizes, UniformSize): - self._sizes.set_value(self, value) + self._sizes.set_value(self, value) diff --git a/fastplotlib/layouts/_graphic_methods_mixin.py b/fastplotlib/layouts/_graphic_methods_mixin.py index 3eb018f55..eda7b1492 100644 --- a/fastplotlib/layouts/_graphic_methods_mixin.py +++ b/fastplotlib/layouts/_graphic_methods_mixin.py @@ -33,8 +33,7 @@ def add_image( cmap: str = "plasma", interpolation: str = "nearest", cmap_interpolation: str = "linear", - isolated_buffer: bool = True, - **kwargs, + **kwargs ) -> ImageGraphic: """ @@ -62,12 +61,6 @@ def add_image( cmap_interpolation: str, optional, default "linear" colormap interpolation method, one of "nearest" or "linear" - isolated_buffer: bool, default True - If True, initialize a buffer with the same shape as the input data and then - set the data, useful if the data arrays are ready-only such as memmaps. - If False, the input array is itself used as the buffer - useful if the - array is large. - kwargs: additional keyword arguments passed to :class:`.Graphic` @@ -81,8 +74,7 @@ def add_image( cmap, interpolation, cmap_interpolation, - isolated_buffer, - **kwargs, + **kwargs ) def add_image_volume( @@ -100,8 +92,7 @@ def add_image_volume( substep_size: float = 0.1, emissive: str | tuple | numpy.ndarray = (0, 0, 0), shininess: int = 30, - isolated_buffer: bool = True, - **kwargs, + **kwargs ) -> ImageVolumeGraphic: """ @@ -158,11 +149,6 @@ def add_image_volume( How shiny the specular highlight is; a higher value gives a sharper highlight. Used only if `mode` = "iso" - isolated_buffer: bool, default True - If True, initialize a buffer with the same shape as the input data and then set the data, useful if the - data arrays are ready-only such as memmaps. If False, the input array is itself used as the - buffer - useful if the array is large. - kwargs additional keyword arguments passed to :class:`.Graphic` @@ -183,8 +169,7 @@ def add_image_volume( substep_size, emissive, shininess, - isolated_buffer, - **kwargs, + **kwargs ) def add_line_collection( @@ -192,16 +177,15 @@ def add_line_collection( data: Union[numpy.ndarray, List[numpy.ndarray]], thickness: Union[float, Sequence[float]] = 2.0, colors: Union[str, Sequence[str], numpy.ndarray, Sequence[numpy.ndarray]] = "w", - uniform_colors: bool = False, cmap: Union[Sequence[str], str] = None, cmap_transform: Union[numpy.ndarray, List] = None, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", name: str = None, names: list[str] = None, metadata: Any = None, metadatas: Union[Sequence[Any], numpy.ndarray] = None, - isolated_buffer: bool = True, kwargs_lines: list[dict] = None, - **kwargs, + **kwargs ) -> LineCollection: """ @@ -235,6 +219,9 @@ def add_line_collection( cmap_transform: 1D array-like of numerical values, optional if provided, these values are used to map the colors from the cmap + color_mode: one of "auto", "uniform", "vertex", default "auto" + The color mode for each line in the collection. See `color_mode` in :class:`.LineGraphic` for details. + name: str, optional name of the line collection as a whole @@ -261,16 +248,15 @@ def add_line_collection( data, thickness, colors, - uniform_colors, cmap, cmap_transform, + color_mode, name, names, metadata, metadatas, - isolated_buffer, kwargs_lines, - **kwargs, + **kwargs ) def add_line( @@ -278,12 +264,11 @@ def add_line( data: Any, thickness: float = 2.0, colors: Union[str, numpy.ndarray, Sequence] = "w", - uniform_color: bool = False, cmap: str = None, cmap_transform: Union[numpy.ndarray, Sequence] = None, - isolated_buffer: bool = True, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", size_space: str = "screen", - **kwargs, + **kwargs ) -> LineGraphic: """ @@ -304,15 +289,19 @@ def add_line( specify colors as a single human-readable string, a single RGBA array, or a Sequence (array, tuple, or list) of strings or RGBA arrays - uniform_color: bool, default ``False`` - if True, uses a uniform buffer for the line color, - basically saves GPU VRAM when the entire line has a single color - cmap: str, optional Apply a colormap to the line instead of assigning colors manually, this overrides any argument passed to "colors". For supported colormaps see the ``cmap`` library catalogue: https://cmap-docs.readthedocs.io/en/stable/catalog/ + color_mode: one of "auto", "uniform", "vertex", default "auto" + "uniform" restricts to a single color for all line datapoints. + "vertex" allows independent colors per vertex. + For most cases you can keep it as "auto" and the `color_mode` is determineed automatically based on the + argument passed to `colors`. if `colors` represents a single color, then the mode is set to "uniform". + If `colors` represents a unique color per-datapoint, or if a cmap is provided, then `color_mode` is set to + "vertex". You can switch between "uniform" and "vertex" `color_mode` after creating the graphic. + cmap_transform: 1D array-like of numerical values, optional if provided, these values are used to map the colors from the cmap @@ -329,12 +318,11 @@ def add_line( data, thickness, colors, - uniform_color, cmap, cmap_transform, - isolated_buffer, + color_mode, size_space, - **kwargs, + **kwargs ) def add_line_stack( @@ -348,11 +336,10 @@ def add_line_stack( names: list[str] = None, metadata: Any = None, metadatas: Union[Sequence[Any], numpy.ndarray] = None, - isolated_buffer: bool = True, separation: float = 10.0, separation_axis: str = "y", kwargs_lines: list[dict] = None, - **kwargs, + **kwargs ) -> LineStack: """ @@ -425,11 +412,10 @@ def add_line_stack( names, metadata, metadatas, - isolated_buffer, separation, separation_axis, kwargs_lines, - **kwargs, + **kwargs ) def add_mesh( @@ -448,8 +434,7 @@ def add_mesh( | numpy.ndarray ) = None, clim: tuple[float, float] = None, - isolated_buffer: bool = True, - **kwargs, + **kwargs ) -> MeshGraphic: """ @@ -488,12 +473,6 @@ def add_mesh( Both 1D and 2D colormaps are supported, though the mapcoords has to match the dimensionality. An image can also be used, this is basically a 2D colormap. - isolated_buffer: bool, default True - If True, initialize a buffer with the same shape as the input data and then - set the data, useful if the data arrays are ready-only such as memmaps. - If False, the input array is itself used as the buffer - useful if the - array is large. In almost all cases this should be ``True``. - **kwargs passed to :class:`.Graphic` @@ -509,8 +488,7 @@ def add_mesh( mapcoords, cmap, clim, - isolated_buffer, - **kwargs, + **kwargs ) def add_polygon( @@ -527,7 +505,7 @@ def add_polygon( | numpy.ndarray ) = None, clim: tuple[float, float] | None = None, - **kwargs, + **kwargs ) -> PolygonGraphic: """ @@ -656,12 +634,12 @@ def add_scatter( self, data: Any, colors: Union[str, numpy.ndarray, Sequence[float], Sequence[str]] = "w", - uniform_color: bool = False, cmap: str = None, cmap_transform: numpy.ndarray = None, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", mode: Literal["markers", "simple", "gaussian", "image"] = "markers", markers: Union[str, numpy.ndarray, Sequence[str]] = "o", - uniform_marker: bool = False, + uniform_marker: bool = True, custom_sdf: str = None, edge_colors: Union[ str, pygfx.utils.color.Color, numpy.ndarray, Sequence[float] @@ -672,10 +650,9 @@ def add_scatter( point_rotations: float | numpy.ndarray = 0, point_rotation_mode: Literal["uniform", "vertex", "curve"] = "uniform", sizes: Union[float, numpy.ndarray, Sequence[float]] = 5, - uniform_size: bool = False, + uniform_size: bool = True, size_space: str = "screen", - isolated_buffer: bool = True, - **kwargs, + **kwargs ) -> ScatterGraphic: """ @@ -691,18 +668,23 @@ def add_scatter( specify colors as a single human-readable string, a single RGBA array, or a Sequence (array, tuple, or list) of strings or RGBA arrays - uniform_color: bool, default False - if True, uses a uniform buffer for the scatter point colors. Useful if you need to - save GPU VRAM when all points have the same color. - cmap: str, optional apply a colormap to the scatter instead of assigning colors manually, this - overrides any argument passed to "colors". For supported colormaps see the - ``cmap`` library catalogue: https://cmap-docs.readthedocs.io/en/stable/catalog/ + overrides any argument passed to "colors". + For supported colormaps see the ``cmap`` library catalogue: + https://cmap-docs.readthedocs.io/en/stable/catalog/ cmap_transform: 1D array-like or list of numerical values, optional if provided, these values are used to map the colors from the cmap + color_mode: one of "auto", "uniform", "vertex", default "auto" + "uniform" restricts to a single color for all line datapoints. + "vertex" allows independent colors per vertex. + For most cases you can keep it as "auto" and the `color_mode` is determineed automatically based on the + argument passed to `colors`. if `colors` represents a single color, then the mode is set to "uniform". + If `colors` represents a unique color per-datapoint, or if a cmap is provided, then `color_mode` is set to + "vertex". You can switch between "uniform" and "vertex" `color_mode` after creating the graphic. + mode: one of: "markers", "simple", "gaussian", "image", default "markers" The scatter points mode, cannot be changed after the graphic has been created. @@ -722,9 +704,10 @@ def add_scatter( * Emojis: "❤️♠️♣️♦️💎💍✳️📍". * A string containing the value "custom". In this case, WGSL code defined by ``custom_sdf`` will be used. - uniform_marker: bool, default False - Use the same marker for all points. Only valid when `mode` is "markers". Useful if you need to use - the same marker for all points and want to save GPU RAM. + uniform_marker: bool, default ``True`` + If ``True``, use the same marker for all points. Only valid when `mode` is "markers". + Useful if you need to use the same marker for all points and want to save GPU RAM. If ``False``, you can + set per-vertex markers. custom_sdf: str = None, The SDF code for the marker shape when the marker is set to custom. @@ -744,8 +727,9 @@ def add_scatter( edge_colors: str | np.ndarray | pygfx.Color | Sequence[float], default "black" edge color of the markers, used when `mode` is "markers" - uniform_edge_color: bool, default True - Set the same edge color for all markers. Useful for saving GPU RAM. + uniform_edge_color: bool, default ``True`` + Set the same edge color for all markers. Useful for saving GPU RAM. Set to ``False`` for per-vertex edge + colors edge_width: float = 1.0, Width of the marker edges. used when `mode` is "markers". @@ -766,17 +750,13 @@ def add_scatter( sizes: float or iterable of float, optional, default 1.0 sizes of the scatter points - uniform_size: bool, default False - if True, uses a uniform buffer for the scatter point sizes. Useful if you need to - save GPU VRAM when all points have the same size. + uniform_size: bool, default ``False`` + if ``True``, uses a uniform buffer for the scatter point sizes. Useful if you need to + save GPU VRAM when all points have the same size. Set to ``False`` if you need per-vertex sizes. size_space: str, default "screen" coordinate space in which the size is expressed, one of ("screen", "world", "model") - isolated_buffer: bool, default True - whether the buffers should be isolated from the user input array. - Generally always ``True``, ``False`` is for rare advanced use if you have large arrays. - kwargs passed to :class:`.Graphic` @@ -786,9 +766,9 @@ def add_scatter( ScatterGraphic, data, colors, - uniform_color, cmap, cmap_transform, + color_mode, mode, markers, uniform_marker, @@ -802,8 +782,7 @@ def add_scatter( sizes, uniform_size, size_space, - isolated_buffer, - **kwargs, + **kwargs ) def add_surface( @@ -820,7 +799,7 @@ def add_surface( | numpy.ndarray ) = None, clim: tuple[float, float] | None = None, - **kwargs, + **kwargs ) -> SurfaceGraphic: """ @@ -874,7 +853,7 @@ def add_text( screen_space: bool = True, offset: tuple[float] = (0, 0, 0), anchor: str = "middle-center", - **kwargs, + **kwargs ) -> TextGraphic: """ @@ -925,7 +904,7 @@ def add_text( screen_space, offset, anchor, - **kwargs, + **kwargs ) def add_vectors( @@ -935,7 +914,7 @@ def add_vectors( color: Union[str, Sequence[float], numpy.ndarray] = "w", size: float = None, vector_shape_options: dict = None, - **kwargs, + **kwargs ) -> VectorsGraphic: """ @@ -980,5 +959,5 @@ def add_vectors( color, size, vector_shape_options, - **kwargs, + **kwargs ) diff --git a/tests/test_colors_buffer_manager.py b/tests/test_colors_buffer_manager.py index 7b1aef16a..f9d56189e 100644 --- a/tests/test_colors_buffer_manager.py +++ b/tests/test_colors_buffer_manager.py @@ -48,10 +48,10 @@ def test_int(test_graphic): data = generate_positions_spiral_data("xyz") if test_graphic == "line": - graphic = fig[0, 0].add_line(data=data) + graphic = fig[0, 0].add_line(data=data, color_mode="vertex") elif test_graphic == "scatter": - graphic = fig[0, 0].add_scatter(data=data) + graphic = fig[0, 0].add_scatter(data=data, color_mode="vertex") colors = graphic.colors global EVENT_RETURN_VALUE @@ -98,10 +98,10 @@ def test_tuple(test_graphic, slice_method): data = generate_positions_spiral_data("xyz") if test_graphic == "line": - graphic = fig[0, 0].add_line(data=data) + graphic = fig[0, 0].add_line(data=data, color_mode="vertex") elif test_graphic == "scatter": - graphic = fig[0, 0].add_scatter(data=data) + graphic = fig[0, 0].add_scatter(data=data, color_mode="vertex") colors = graphic.colors global EVENT_RETURN_VALUE @@ -190,10 +190,10 @@ def test_slice(color_input, slice_method: dict, test_graphic: bool): data = generate_positions_spiral_data("xyz") if test_graphic == "line": - graphic = fig[0, 0].add_line(data=data) + graphic = fig[0, 0].add_line(data=data, color_mode="vertex") elif test_graphic == "scatter": - graphic = fig[0, 0].add_scatter(data=data) + graphic = fig[0, 0].add_scatter(data=data, color_mode="vertex") colors = graphic.colors diff --git a/tests/test_markers_buffer_manager.py b/tests/test_markers_buffer_manager.py index 65ead392e..488bed194 100644 --- a/tests/test_markers_buffer_manager.py +++ b/tests/test_markers_buffer_manager.py @@ -46,10 +46,10 @@ def test_create_buffer(test_graphic): if test_graphic: fig = fpl.Figure() - scatter = fig[0, 0].add_scatter(data, markers=MARKERS1) + scatter = fig[0, 0].add_scatter(data, markers=MARKERS1, uniform_marker=False) vertex_markers = scatter.markers assert isinstance(vertex_markers, VertexMarkers) - assert vertex_markers.buffer is scatter.world_object.geometry.markers + assert vertex_markers._fpl_buffer is scatter.world_object.geometry.markers else: vertex_markers = VertexMarkers(MARKERS1, len(data)) @@ -68,7 +68,7 @@ def test_int(test_graphic, index: int): if test_graphic: fig = fpl.Figure() - scatter = fig[0, 0].add_scatter(data, markers=MARKERS1) + scatter = fig[0, 0].add_scatter(data, markers=MARKERS1, uniform_marker=False) scatter.add_event_handler(event_handler, "markers") vertex_markers = scatter.markers else: @@ -108,7 +108,7 @@ def test_slice(test_graphic, slice_method): if test_graphic: fig = fpl.Figure() - scatter = fig[0, 0].add_scatter(data, markers=MARKERS1) + scatter = fig[0, 0].add_scatter(data, markers=MARKERS1, uniform_marker=False) scatter.add_event_handler(event_handler, "markers") vertex_markers = scatter.markers diff --git a/tests/test_point_rotations_buffer_manager.py b/tests/test_point_rotations_buffer_manager.py index ec5fdbe0f..50ee88984 100644 --- a/tests/test_point_rotations_buffer_manager.py +++ b/tests/test_point_rotations_buffer_manager.py @@ -35,7 +35,7 @@ def test_create_buffer(test_graphic): scatter = fig[0, 0].add_scatter(data, point_rotation_mode="vertex", point_rotations=ROTATIONS1) vertex_rotations = scatter.point_rotations assert isinstance(vertex_rotations, VertexRotations) - assert vertex_rotations.buffer is scatter.world_object.geometry.rotations + assert vertex_rotations._fpl_buffer is scatter.world_object.geometry.rotations else: vertex_rotations = VertexRotations(ROTATIONS1, len(data)) diff --git a/tests/test_positions_data_buffer_manager.py b/tests/test_positions_data_buffer_manager.py index e2582d4ba..cc550abf0 100644 --- a/tests/test_positions_data_buffer_manager.py +++ b/tests/test_positions_data_buffer_manager.py @@ -57,7 +57,7 @@ def test_int(test_graphic): graphic = fig[0, 0].add_scatter(data=data) points = graphic.data - assert graphic.data.buffer is graphic.world_object.geometry.positions + assert graphic.data._fpl_buffer is graphic.world_object.geometry.positions global EVENT_RETURN_VALUE graphic.add_event_handler(event_handler, "data") else: diff --git a/tests/test_positions_graphics.py b/tests/test_positions_graphics.py index 31c001888..4bc93b626 100644 --- a/tests/test_positions_graphics.py +++ b/tests/test_positions_graphics.py @@ -37,12 +37,12 @@ def test_sizes_slice(): @pytest.mark.parametrize("graphic_type", ["line", "scatter"]) @pytest.mark.parametrize("colors", [None, *generate_color_inputs("b")]) -@pytest.mark.parametrize("uniform_color", [True, False]) -def test_uniform_color(graphic_type, colors, uniform_color): +@pytest.mark.parametrize("color_mode", ["uniform", "vertex"]) +def test_color_mode(graphic_type, colors, color_mode): fig = fpl.Figure() kwargs = dict() - for kwarg in ["colors", "uniform_color"]: + for kwarg in ["colors", "color_mode"]: if locals()[kwarg] is not None: # add to dict of arguments that will be passed kwargs[kwarg] = locals()[kwarg] @@ -54,7 +54,7 @@ def test_uniform_color(graphic_type, colors, uniform_color): elif graphic_type == "scatter": graphic = fig[0, 0].add_scatter(data=data, **kwargs) - if uniform_color: + if color_mode == "uniform": assert isinstance(graphic._colors, UniformColor) assert isinstance(graphic.colors, pygfx.Color) if colors is None: @@ -130,17 +130,17 @@ def test_positions_graphics_data( @pytest.mark.parametrize("graphic_type", ["line", "scatter"]) @pytest.mark.parametrize("colors", [None, *generate_color_inputs("r")]) -@pytest.mark.parametrize("uniform_color", [None, False]) +@pytest.mark.parametrize("color_mode", ["vertex"]) def test_positions_graphic_vertex_colors( graphic_type, colors, - uniform_color, + color_mode, ): # test different ways of passing vertex colors fig = fpl.Figure() kwargs = dict() - for kwarg in ["colors", "uniform_color"]: + for kwarg in ["colors", "color_mode"]: if locals()[kwarg] is not None: # add to dict of arguments that will be passed kwargs[kwarg] = locals()[kwarg] @@ -153,10 +153,9 @@ def test_positions_graphic_vertex_colors( graphic = fig[0, 0].add_scatter(data=data, **kwargs) # color per vertex - # uniform colors is default False, or set to False - assert isinstance(graphic._colors, VertexColors) - assert isinstance(graphic.colors, VertexColors) - assert len(graphic.colors) == len(graphic.data) + assert isinstance(graphic._colors, VertexColors) + assert isinstance(graphic.colors, VertexColors) + assert len(graphic.colors) == len(graphic.data) if colors is None: # default @@ -179,7 +178,7 @@ def test_positions_graphic_vertex_colors( @pytest.mark.parametrize("graphic_type", ["line", "scatter"]) @pytest.mark.parametrize("colors", [None, *generate_color_inputs("r")]) -@pytest.mark.parametrize("uniform_color", [None, False]) +@pytest.mark.parametrize("color_mode", ["auto", "vertex"]) @pytest.mark.parametrize("cmap", ["jet"]) @pytest.mark.parametrize( "cmap_transform", [None, [3, 5, 2, 1, 0, 6, 9, 7, 4, 8], np.arange(9, -1, -1)] @@ -187,7 +186,7 @@ def test_positions_graphic_vertex_colors( def test_cmap( graphic_type, colors, - uniform_color, + color_mode, cmap, cmap_transform, ): @@ -195,7 +194,7 @@ def test_cmap( fig = fpl.Figure() kwargs = dict() - for kwarg in ["cmap", "cmap_transform", "colors", "uniform_color"]: + for kwarg in ["cmap", "cmap_transform", "colors", "color_mode"]: if locals()[kwarg] is not None: # add to dict of arguments that will be passed kwargs[kwarg] = locals()[kwarg] @@ -220,7 +219,8 @@ def test_cmap( # make sure buffer is identical # cmap overrides colors argument - assert graphic.colors.buffer is graphic.cmap.buffer + # use __repr__.__self__ to get the real reference from the cmap feature instead of the weakref proxy + assert graphic.colors._fpl_buffer is graphic.cmap.buffer.__repr__.__self__ npt.assert_almost_equal(graphic.cmap.value, truth) npt.assert_almost_equal(graphic.colors.value, truth) @@ -261,14 +261,14 @@ def test_cmap( "colors", [None, *generate_color_inputs("multi")] ) # cmap arg overrides colors @pytest.mark.parametrize( - "uniform_color", [True] # none of these will work with a uniform buffer + "color_mode", ["uniform"] # none of these will work with a uniform buffer ) -def test_incompatible_cmap_color_args(graphic_type, cmap, colors, uniform_color): +def test_incompatible_cmap_color_args(graphic_type, cmap, colors, color_mode): # test incompatible cmap args fig = fpl.Figure() kwargs = dict() - for kwarg in ["cmap", "colors", "uniform_color"]: + for kwarg in ["cmap", "colors", "color_mode"]: if locals()[kwarg] is not None: # add to dict of arguments that will be passed kwargs[kwarg] = locals()[kwarg] @@ -276,24 +276,24 @@ def test_incompatible_cmap_color_args(graphic_type, cmap, colors, uniform_color) data = generate_positions_spiral_data("xy") if graphic_type == "line": - with pytest.raises(TypeError): + with pytest.raises(ValueError): graphic = fig[0, 0].add_line(data=data, **kwargs) elif graphic_type == "scatter": - with pytest.raises(TypeError): + with pytest.raises(ValueError): graphic = fig[0, 0].add_scatter(data=data, **kwargs) @pytest.mark.parametrize("graphic_type", ["line", "scatter"]) @pytest.mark.parametrize("colors", [*generate_color_inputs("multi")]) @pytest.mark.parametrize( - "uniform_color", [True] # none of these will work with a uniform buffer + "color_mode", ["uniform"] # none of these will work with a uniform buffer ) -def test_incompatible_color_args(graphic_type, colors, uniform_color): +def test_incompatible_color_args(graphic_type, colors, color_mode): # test incompatible color args fig = fpl.Figure() kwargs = dict() - for kwarg in ["colors", "uniform_color"]: + for kwarg in ["colors", "color_mode"]: if locals()[kwarg] is not None: # add to dict of arguments that will be passed kwargs[kwarg] = locals()[kwarg] @@ -301,16 +301,15 @@ def test_incompatible_color_args(graphic_type, colors, uniform_color): data = generate_positions_spiral_data("xy") if graphic_type == "line": - with pytest.raises(TypeError): + with pytest.raises(ValueError): graphic = fig[0, 0].add_line(data=data, **kwargs) elif graphic_type == "scatter": - with pytest.raises(TypeError): + with pytest.raises(ValueError): graphic = fig[0, 0].add_scatter(data=data, **kwargs) @pytest.mark.parametrize("sizes", [None, 5.0, np.linspace(3, 8, 10, dtype=np.float32)]) -@pytest.mark.parametrize("uniform_size", [None, False]) -def test_sizes(sizes, uniform_size): +def test_sizes(sizes): # test scatter sizes fig = fpl.Figure() @@ -322,7 +321,7 @@ def test_sizes(sizes, uniform_size): data = generate_positions_spiral_data("xy") - graphic = fig[0, 0].add_scatter(data=data, **kwargs) + graphic = fig[0, 0].add_scatter(data=data, uniform_size=False, **kwargs) assert isinstance(graphic.sizes, VertexPointSizes) assert isinstance(graphic._sizes, VertexPointSizes) diff --git a/tests/test_replace_buffer.py b/tests/test_replace_buffer.py new file mode 100644 index 000000000..a9d0ffe41 --- /dev/null +++ b/tests/test_replace_buffer.py @@ -0,0 +1,155 @@ +import gc +import weakref + +import pytest +import numpy as np +from itertools import product + +import fastplotlib as fpl +from .utils_textures import MAX_TEXTURE_SIZE, check_texture_array, check_image_graphic + +# These are only de-referencing tests for positions graphics, and ImageGraphic +# they do not test that VRAM gets free, for now this can only be checked manually +# with the tests in examples/misc/buffer_replace_gc.py + + +@pytest.mark.parametrize("graphic_type", ["line", "scatter"]) +@pytest.mark.parametrize("new_buffer_size", [50, 150]) +def test_replace_positions_buffer(graphic_type, new_buffer_size): + fig = fpl.Figure() + + # create some data with an initial shape + orig_datapoints = 100 + + xs = np.linspace(0, 2 * np.pi, orig_datapoints) + ys = np.sin(xs) + zs = np.cos(xs) + + data = np.column_stack([xs, ys, zs]) + + # add add_line or add_scatter method + adder = getattr(fig[0, 0], f"add_{graphic_type}") + + if graphic_type == "scatter": + kwargs = { + "markers": np.random.choice(list("osD+x^v<>*"), size=orig_datapoints), + "uniform_marker": False, + "sizes": np.abs(ys), + "uniform_size": False, + # TODO: skipping edge_colors for now since that causes a WGPU bind group error that we will figure out later + # anyways I think changing buffer sizes in combination with per-vertex edge colors is a literal edge-case + "point_rotations": zs * 180, + "point_rotation_mode": "vertex", + } + else: + kwargs = dict() + + # add a line or scatter graphic + graphic = adder(data=data, colors=np.random.rand(orig_datapoints, 4), **kwargs) + + fig.show() + + # weakrefs to the original buffers + # these should raise a ReferenceError when the corresponding feature is replaced with data of a different shape + orig_data_buffer = weakref.proxy(graphic.data._fpl_buffer) + orig_colors_buffer = weakref.proxy(graphic.colors._fpl_buffer) + + buffers = [orig_data_buffer, orig_colors_buffer] + + # extra buffers for the scatters + if graphic_type == "scatter": + for attr in ["markers", "sizes", "point_rotations"]: + buffers.append(weakref.proxy(getattr(graphic, attr)._fpl_buffer)) + + # create some new data that requires a different buffer shape + xs = np.linspace(0, 15 * np.pi, new_buffer_size) + ys = np.sin(xs) + zs = np.cos(xs) + + new_data = np.column_stack([xs, ys, zs]) + + # set data that requires a larger buffer and check that old buffer is no longer referenced + graphic.data = new_data + graphic.colors = np.random.rand(new_buffer_size, 4) + + if graphic_type == "scatter": + # changes values so that new larger buffers must be allocated + graphic.markers = np.random.choice(list("osD+x^v<>*"), size=new_buffer_size) + graphic.sizes = np.abs(zs) + graphic.point_rotations = ys * 180 + + # make sure old original buffers are de-referenced + for i in range(len(buffers)): + with pytest.raises(ReferenceError) as fail: + buffers[i] + pytest.fail( + f"GC failed for buffer: {buffers[i]}, " + f"with referrers: {gc.get_referrers(buffers[i].__repr__.__self__)}" + ) + + +# test all combination of dims that require TextureArrays of shapes 1x1, 1x2, 1x3, 2x3, 3x3 etc. +@pytest.mark.parametrize( + "new_buffer_size", list(product(*[[(500, 1), (1200, 2), (2200, 3)]] * 2)) +) +def test_replace_image_buffer(new_buffer_size): + # make an image with some starting shape + orig_size = (1_500, 1_500) + + data = np.random.rand(*orig_size) + + fig = fpl.Figure() + image = fig[0, 0].add_image(data) + + # the original Texture buffers that represent the individual image tiles + orig_buffers = [ + weakref.proxy(image.data.buffer.ravel()[i]) + for i in range(image.data.buffer.size) + ] + orig_shape = image.data.buffer.shape + + fig.show() + + # dimensions for a new image + new_dims = [v[0] for v in new_buffer_size] + + # the number of tiles required in each dim/shape of the TextureArray + new_shape = tuple(v[1] for v in new_buffer_size) + + # make the new data and set the image + new_data = np.random.rand(*new_dims) + image.data = new_data + + # test that old Texture buffers are de-referenced + for i in range(len(orig_buffers)): + with pytest.raises(ReferenceError) as fail: + orig_buffers[i] + pytest.fail( + f"GC failed for buffer: {orig_buffers[i]}, of shape: {orig_shape}" + f"with referrers: {gc.get_referrers(orig_buffers[i].__repr__.__self__)}" + ) + + # check new texture array + check_texture_array( + data=new_data, + ta=image.data, + buffer_size=np.prod(new_shape), + buffer_shape=new_shape, + row_indices_size=new_shape[0], + col_indices_size=new_shape[1], + row_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (new_data.shape[0] - 1) // MAX_TEXTURE_SIZE) + ] + ), + col_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (new_data.shape[1] - 1) // MAX_TEXTURE_SIZE) + ] + ), + ) + + # check that new image tiles are arranged correctly + check_image_graphic(image.data, image) diff --git a/tests/test_scatter_graphic.py b/tests/test_scatter_graphic.py index a61681f24..930d8c495 100644 --- a/tests/test_scatter_graphic.py +++ b/tests/test_scatter_graphic.py @@ -133,7 +133,7 @@ def test_edge_colors(edge_colors): npt.assert_almost_equal(scatter.edge_colors.value, MULTI_COLORS_TRUTH) assert ( - scatter.edge_colors.buffer is scatter.world_object.geometry.edge_colors + scatter.edge_colors._fpl_buffer is scatter.world_object.geometry.edge_colors ) # test changes, don't need to test extensively here since it's tested in the main VertexColors test diff --git a/tests/test_texture_array.py b/tests/test_texture_array.py index 6220f2fe5..01abb9a97 100644 --- a/tests/test_texture_array.py +++ b/tests/test_texture_array.py @@ -2,14 +2,9 @@ from numpy import testing as npt import pytest -import pygfx - import fastplotlib as fpl from fastplotlib.graphics.features import TextureArray -from fastplotlib.graphics.image import _ImageTile - - -MAX_TEXTURE_SIZE = 1024 +from .utils_textures import MAX_TEXTURE_SIZE, check_texture_array, check_image_graphic def make_data(n_rows: int, n_cols: int) -> np.ndarray: @@ -25,50 +20,6 @@ def make_data(n_rows: int, n_cols: int) -> np.ndarray: return np.vstack([sine * i for i in range(n_rows)]).astype(np.float32) -def check_texture_array( - data: np.ndarray, - ta: TextureArray, - buffer_size: int, - buffer_shape: tuple[int, int], - row_indices_size: int, - col_indices_size: int, - row_indices_values: np.ndarray, - col_indices_values: np.ndarray, -): - - npt.assert_almost_equal(ta.value, data) - - assert ta.buffer.size == buffer_size - assert ta.buffer.shape == buffer_shape - - assert all([isinstance(texture, pygfx.Texture) for texture in ta.buffer.ravel()]) - - assert ta.row_indices.size == row_indices_size - assert ta.col_indices.size == col_indices_size - npt.assert_array_equal(ta.row_indices, row_indices_values) - npt.assert_array_equal(ta.col_indices, col_indices_values) - - # make sure chunking is correct - for texture, chunk_index, data_slice in ta: - assert ta.buffer[chunk_index] is texture - chunk_row, chunk_col = chunk_index - - data_row_start_index = chunk_row * MAX_TEXTURE_SIZE - data_col_start_index = chunk_col * MAX_TEXTURE_SIZE - - data_row_stop_index = min( - data.shape[0], data_row_start_index + MAX_TEXTURE_SIZE - ) - data_col_stop_index = min( - data.shape[1], data_col_start_index + MAX_TEXTURE_SIZE - ) - - row_slice = slice(data_row_start_index, data_row_stop_index) - col_slice = slice(data_col_start_index, data_col_stop_index) - - assert data_slice == (row_slice, col_slice) - - def check_set_slice(data, ta, row_slice, col_slice): ta[row_slice, col_slice] = 1 npt.assert_almost_equal(ta[row_slice, col_slice], 1) @@ -85,17 +36,6 @@ def make_image_graphic(data) -> fpl.ImageGraphic: return fig[0, 0].add_image(data) -def check_image_graphic(texture_array, graphic): - # make sure each ImageTile has the right texture - for (texture, chunk_index, data_slice), img in zip( - texture_array, graphic.world_object.children - ): - assert isinstance(img, _ImageTile) - assert img.geometry.grid is texture - assert img.world.x == data_slice[1].start - assert img.world.y == data_slice[0].start - - @pytest.mark.parametrize("test_graphic", [False, True]) def test_small_texture(test_graphic): # tests TextureArray with dims that requires only 1 texture @@ -162,15 +102,27 @@ def test_wide(test_graphic): else: ta = TextureArray(data) + ta_shape = (2, 3) + check_texture_array( data, ta=ta, - buffer_size=6, - buffer_shape=(2, 3), - row_indices_size=2, - col_indices_size=3, - row_indices_values=np.array([0, MAX_TEXTURE_SIZE]), - col_indices_values=np.array([0, MAX_TEXTURE_SIZE, 2 * MAX_TEXTURE_SIZE]), + buffer_size=np.prod(ta_shape), + buffer_shape=ta_shape, + row_indices_size=ta_shape[0], + col_indices_size=ta_shape[1], + row_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (data.shape[0] - 1) // MAX_TEXTURE_SIZE) + ] + ), + col_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (data.shape[1] - 1) // MAX_TEXTURE_SIZE) + ] + ), ) if test_graphic: @@ -189,15 +141,27 @@ def test_tall(test_graphic): else: ta = TextureArray(data) + ta_shape = (3, 2) + check_texture_array( data, ta=ta, - buffer_size=6, - buffer_shape=(3, 2), - row_indices_size=3, - col_indices_size=2, - row_indices_values=np.array([0, MAX_TEXTURE_SIZE, 2 * MAX_TEXTURE_SIZE]), - col_indices_values=np.array([0, MAX_TEXTURE_SIZE]), + buffer_size=np.prod(ta_shape), + buffer_shape=ta_shape, + row_indices_size=ta_shape[0], + col_indices_size=ta_shape[1], + row_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (data.shape[0] - 1) // MAX_TEXTURE_SIZE) + ] + ), + col_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (data.shape[1] - 1) // MAX_TEXTURE_SIZE) + ] + ), ) if test_graphic: @@ -216,15 +180,27 @@ def test_square(test_graphic): else: ta = TextureArray(data) + ta_shape = (3, 3) + check_texture_array( data, ta=ta, - buffer_size=9, - buffer_shape=(3, 3), - row_indices_size=3, - col_indices_size=3, - row_indices_values=np.array([0, MAX_TEXTURE_SIZE, 2 * MAX_TEXTURE_SIZE]), - col_indices_values=np.array([0, MAX_TEXTURE_SIZE, 2 * MAX_TEXTURE_SIZE]), + buffer_size=np.prod(ta_shape), + buffer_shape=ta_shape, + row_indices_size=ta_shape[0], + col_indices_size=ta_shape[1], + row_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (data.shape[0] - 1) // MAX_TEXTURE_SIZE) + ] + ), + col_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (data.shape[1] - 1) // MAX_TEXTURE_SIZE) + ] + ), ) if test_graphic: diff --git a/tests/utils_textures.py b/tests/utils_textures.py new file mode 100644 index 000000000..f40a7371c --- /dev/null +++ b/tests/utils_textures.py @@ -0,0 +1,64 @@ +import numpy as np +import pygfx +from numpy import testing as npt + +from fastplotlib.graphics.features import TextureArray +from fastplotlib.graphics.image import _ImageTile + + +MAX_TEXTURE_SIZE = 1024 + + +def check_texture_array( + data: np.ndarray, + ta: TextureArray, + buffer_size: int, + buffer_shape: tuple[int, int], + row_indices_size: int, + col_indices_size: int, + row_indices_values: np.ndarray, + col_indices_values: np.ndarray, +): + + npt.assert_almost_equal(ta.value, data) + + assert ta.buffer.size == buffer_size + assert ta.buffer.shape == buffer_shape + + assert all([isinstance(texture, pygfx.Texture) for texture in ta.buffer.ravel()]) + + assert ta.row_indices.size == row_indices_size + assert ta.col_indices.size == col_indices_size + npt.assert_array_equal(ta.row_indices, row_indices_values) + npt.assert_array_equal(ta.col_indices, col_indices_values) + + # make sure chunking is correct + for texture, chunk_index, data_slice in ta: + assert ta.buffer[chunk_index] is texture + chunk_row, chunk_col = chunk_index + + data_row_start_index = chunk_row * MAX_TEXTURE_SIZE + data_col_start_index = chunk_col * MAX_TEXTURE_SIZE + + data_row_stop_index = min( + data.shape[0], data_row_start_index + MAX_TEXTURE_SIZE + ) + data_col_stop_index = min( + data.shape[1], data_col_start_index + MAX_TEXTURE_SIZE + ) + + row_slice = slice(data_row_start_index, data_row_stop_index) + col_slice = slice(data_col_start_index, data_col_stop_index) + + assert data_slice == (row_slice, col_slice) + + +def check_image_graphic(texture_array, graphic): + # make sure each ImageTile has the right texture + for (texture, chunk_index, data_slice), img in zip( + texture_array, graphic.world_object.children + ): + assert isinstance(img, _ImageTile) + assert img.geometry.grid is texture + assert img.world.x == data_slice[1].start + assert img.world.y == data_slice[0].start From aefe418192709223a6850ec37932f176664e24a8 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 4 Feb 2026 18:44:42 -0500 Subject: [PATCH 019/101] some basic OOC working --- .../widgets/nd_widget/_nd_positions.py | 33 ++++++++++++++++--- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions.py index 1871e027e..65d1f59c5 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions.py @@ -9,6 +9,7 @@ from ...utils import subsample_array, ArrayProtocol from ...graphics import ( + Graphic, ImageGraphic, LineGraphic, LineStack, @@ -264,13 +265,14 @@ def get(self, indices: tuple[Any, ...]): ).squeeze() # this reshape is required to reshape wf outputs of shape [n, p] -> [n, p, 1] only when necessary + # we need to slice upto dw since we add the `datapoints_window_size` above graphic_data[..., : dw, dims] = wf( windows, axis=-1 ).reshape(graphic_data.shape[0], dw, len(dims)) - return graphic_data[..., : dw, :] + return graphic_data[..., : dw : max(1, dw // self.p_max), :] - return graphic_data + return graphic_data[..., : graphic_data.shape[-2] : max(1, graphic_data.shape[-2] // self.p_max), :] class NDPositions: @@ -303,6 +305,8 @@ def __init__( index_mappings=index_mappings, ) + self._processor.p_max = 1_000 + self._indices = tuple([0] * self._processor.n_slider_dims) self._create_graphic(graphic) @@ -348,15 +352,21 @@ def indices(self, indices): self.graphic.data[:, : data_slice.shape[-1]] = data_slice elif isinstance(self.graphic, (LineCollection, ScatterCollection)): - for i in range(len(self.graphic)): - # data_slice shape is [n_lines, n_datapoints, 2 | 3] - self.graphic[i].data[:, : data_slice.shape[-1]] = data_slice[i] + for g, new_data in zip(self.graphic.graphics, data_slice): + if g.data.value.shape[0] != new_data.shape[0]: + # will replace buffer internally + g.data = new_data + else: + # if data are only xy, set only xy + g.data[:, :new_data.shape[1]] = new_data elif isinstance(self.graphic, ImageGraphic): image_data, x0, x_scale = self._create_heatmap_data(data_slice) self.graphic.data = image_data self.graphic.offset = (x0, *self.graphic.offset[1:]) + self._indices = indices + def _create_graphic( self, graphic_cls: Type[ @@ -368,6 +378,9 @@ def _create_graphic( | ImageGraphic ], ): + if not issubclass(graphic_cls, Graphic): + raise TypeError + data_slice = self.processor.get(self.indices) if issubclass(graphic_cls, ImageGraphic): @@ -412,3 +425,13 @@ def _create_heatmap_data(self, data_slice) -> tuple[np.ndarray, float, float]: x0 = data_slice[0, 0, 0] return y_interp, x0, x_scale + + @property + def display_window(self) -> int | float | None: + """display window in the reference units for the n_datapoints dim""" + return self.processor.display_window + + @display_window.setter + def display_window(self, dw: int | float | None): + self.processor.display_window = dw + self.indices = self.indices From b6d6e62d0f2a55ff4152c29db23acd6e578d391f Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Fri, 6 Feb 2026 12:27:51 -0500 Subject: [PATCH 020/101] max num of dipslay datapoints --- .../widgets/nd_widget/_nd_positions.py | 53 ++++++++++++++++--- 1 file changed, 45 insertions(+), 8 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions.py index 65d1f59c5..6cc29d92a 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions.py @@ -29,12 +29,14 @@ def __init__( data: ArrayProtocol, multi: bool = False, # TODO: interpret [n - 2] dimension as n_lines or n_points display_window: int | float | None = 100, # window for n_datapoints dim only + max_display_datapoints: int = 1_000, datapoints_window_func: Callable | None = None, datapoints_window_size: int | None = None, **kwargs, ): self._display_window = display_window + self._max_display_datapoints = max_display_datapoints # TOOD: this does data validation twice and is a bit messy, cleanup self._data = self._validate_data(data) @@ -64,6 +66,19 @@ def display_window(self, dw: int | float | None): self._display_window = dw + @property + def max_display_datapoints(self) -> int: + return self._max_display_datapoints + + @max_display_datapoints.setter + def max_display_datapoints(self, n: int): + if not isinstance(n, (int, np.integer)): + raise TypeError + if n < 2: + raise ValueError + + self._max_display_datapoints = n + @property def multi(self) -> bool: return self._multi @@ -231,12 +246,15 @@ def get(self, indices: tuple[Any, ...]): # data that will be used for the graphical representation # a copy is made, if there were no window functions then this is a view of the original data - graphic_data = window_output[tuple(slices)].copy() + graphic_data = window_output[tuple(slices)] # apply window function on the `p` n_datapoints dim if ( self.datapoints_window_func is not None and self.datapoints_window_size is not None + # if there are too many points to efficiently compute the window func + # applying a window func also requires making a copy so that's a further performance hit + and (dw < self.max_display_datapoints * 2) ): # get windows @@ -264,18 +282,30 @@ def get(self, indices: tuple[Any, ...]): graphic_data[..., dims], ws, axis=-2 ).squeeze() + # make a copy because we need to modify it + graphic_data = graphic_data.copy() + # this reshape is required to reshape wf outputs of shape [n, p] -> [n, p, 1] only when necessary # we need to slice upto dw since we add the `datapoints_window_size` above - graphic_data[..., : dw, dims] = wf( - windows, axis=-1 - ).reshape(graphic_data.shape[0], dw, len(dims)) + graphic_data[..., :dw, dims] = wf(windows, axis=-1).reshape( + graphic_data.shape[0], dw, len(dims) + ) - return graphic_data[..., : dw : max(1, dw // self.p_max), :] + return graphic_data[ + ..., : dw : max(1, dw // self.max_display_datapoints), : + ] - return graphic_data[..., : graphic_data.shape[-2] : max(1, graphic_data.shape[-2] // self.p_max), :] + return graphic_data[ + ..., + : graphic_data.shape[-2] : max( + 1, graphic_data.shape[-2] // self.max_display_datapoints + ), + :, + ] class NDPositions: + def __init__( self, data, @@ -292,6 +322,8 @@ def __init__( window_funcs: tuple[WindowFuncCallable | None] | None = None, window_sizes: tuple[int | None] | None = None, index_mappings: tuple[Callable[[Any], int] | None] | None = None, + max_display_datapoints: int = 1_000, + graphic_kwargs: dict = None, ): if issubclass(graphic, LineCollection): multi = True @@ -300,6 +332,7 @@ def __init__( data, multi=multi, display_window=display_window, + max_display_datapoints=max_display_datapoints, window_funcs=window_funcs, window_sizes=window_sizes, index_mappings=index_mappings, @@ -358,7 +391,7 @@ def indices(self, indices): g.data = new_data else: # if data are only xy, set only xy - g.data[:, :new_data.shape[1]] = new_data + g.data[:, : new_data.shape[1]] = new_data elif isinstance(self.graphic, ImageGraphic): image_data, x0, x_scale = self._create_heatmap_data(data_slice) @@ -396,7 +429,11 @@ def _create_graphic( ) else: - self._graphic = graphic_cls(data_slice) + if issubclass(graphic_cls, LineStack): + kwargs = {"separation": 0.0} + else: + kwargs = dict() + self._graphic = graphic_cls(data_slice, **kwargs) def _create_heatmap_data(self, data_slice) -> tuple[np.ndarray, float, float]: """return [n_rows, n_cols] shape data""" From 976459b662d6a2d133009721e1c8c3bf1457fef5 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 11 Feb 2026 03:40:12 -0500 Subject: [PATCH 021/101] scatter stack, not tested --- fastplotlib/graphics/scatter_collection.py | 123 +++++++++++++++++++++ 1 file changed, 123 insertions(+) diff --git a/fastplotlib/graphics/scatter_collection.py b/fastplotlib/graphics/scatter_collection.py index b1569cacc..ac8cc307e 100644 --- a/fastplotlib/graphics/scatter_collection.py +++ b/fastplotlib/graphics/scatter_collection.py @@ -515,3 +515,126 @@ def _get_linear_selector_init_args(self, axis, padding): center = bbox[:, 0].mean() return bounds, limits, size, center + + +axes = {"x": 0, "y": 1, "z": 2} + + +class ScatterStack(ScatterCollection): + def __init__( + self, + data: List[np.ndarray], + thickness: float | Iterable[float] = 2.0, + colors: str | Iterable[str] | np.ndarray | Iterable[np.ndarray] = "w", + cmap: Iterable[str] | str = None, + cmap_transform: np.ndarray | List = None, + name: str = None, + names: list[str] = None, + metadata: Any = None, + metadatas: Sequence[Any] | np.ndarray = None, + isolated_buffer: bool = True, + separation: float = 0.0, + separation_axis: str = "y", + kwargs_lines: list[dict] = None, + **kwargs, + ): + """ + Create a stack of :class:`.LineGraphic` that are separated along the "x" or "y" axis. + + Parameters + ---------- + data: list of array-like + List or array-like of multiple line data to plot + + | if ``list`` each item in the list must be a 1D, 2D, or 3D numpy array + | if array-like, must be of shape [n_lines, n_points_line, y | xy | xyz] + + thickness: float or Iterable of float, default 2.0 + | if ``float``, single thickness will be used for all lines + | if ``list`` of ``float``, each value will apply to the individual lines + + colors: str, RGBA array, Iterable of RGBA array, or Iterable of str, default "w" + | if single ``str`` such as "w", "r", "b", etc, represents a single color for all lines + | if single ``RGBA array`` (tuple or list of size 4), represents a single color for all lines + | if ``list`` of ``str``, represents color for each individual line, example ["w", "b", "r",...] + | if ``RGBA array`` of shape [data_size, 4], represents a single RGBA array for each line + + cmap: Iterable of str or str, optional + | if ``str``, single cmap will be used for all lines + | if ``list`` of ``str``, each cmap will apply to the individual lines + + .. note:: + ``cmap`` overrides any arguments passed to ``colors`` + + cmap_transform: 1D array-like of numerical values, optional + if provided, these values are used to map the colors from the cmap + + name: str, optional + name of the line collection as a whole + + names: list[str], optional + names of the individual lines in the collection, ``len(names)`` must equal ``len(data)`` + + metadata: Any + metadata associated with the collection as a whole + + metadatas: Iterable or array + metadata for each individual line associated with this collection, this is for the user to manage. + ``len(metadata)`` must be same as ``len(data)`` + + separation: float, default 0.0 + space in between each line graphic in the stack + + separation_axis: str, default "y" + axis in which the line graphics in the stack should be separated + + + kwargs_lines: list[dict], optional + list of kwargs passed to the individual lines, ``len(kwargs_lines)`` must equal ``len(data)`` + + kwargs_collection + kwargs for the collection, passed to GraphicCollection + + """ + super().__init__( + data=data, + thickness=thickness, + colors=colors, + cmap=cmap, + cmap_transform=cmap_transform, + name=name, + names=names, + metadata=metadata, + metadatas=metadatas, + isolated_buffer=isolated_buffer, + kwargs_lines=kwargs_lines, + **kwargs, + ) + + self._sepration_axis = separation_axis + self._separation = separation + + self.separation = separation + + @property + def separation(self) -> float: + """distance between each line in the stack, in world space""" + return self._separation + + @separation.setter + def separation(self, value: float): + separation = float(value) + + axis_zero = 0 + for i, line in enumerate(self.graphics): + if self._sepration_axis == "x": + line.offset = (axis_zero, *line.offset[1:]) + + elif self._sepration_axis == "y": + line.offset = (line.offset[0], axis_zero, line.offset[2]) + + axis_zero = ( + axis_zero + line.data.value[:, axes[self._sepration_axis]].max() + separation + ) + + self._separation = value From a9bfa4480bc385a32223644d29c3564ee5aef6ac Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 11 Feb 2026 16:47:02 -0500 Subject: [PATCH 022/101] progress --- .../widgets/nd_widget/_nd_positions.py | 131 +++++++++++++----- 1 file changed, 98 insertions(+), 33 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions.py index 6cc29d92a..8d30fe37a 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions.py @@ -26,7 +26,7 @@ class NDPositionsProcessor(NDProcessor): def __init__( self, - data: ArrayProtocol, + data: Any, multi: bool = False, # TODO: interpret [n - 2] dimension as n_lines or n_points display_window: int | float | None = 100, # window for n_datapoints dim only max_display_datapoints: int = 1_000, @@ -34,7 +34,6 @@ def __init__( datapoints_window_size: int | None = None, **kwargs, ): - self._display_window = display_window self._max_display_datapoints = max_display_datapoints @@ -193,6 +192,66 @@ def _apply_window_functions(self, indices: tuple[int, ...]): return data_sliced + def _get_dw_slices(self, indices) -> tuple[slice] | tuple[slice, slice]: + # given indices, return slice using display window + + # display window is interpreted using the index mapping for the `p` dim + dw = self.display_window + + if dw is None: + # just map p dimension at this index and return + index_p = self.index_mappings[-1](indices[-1]) + return (slice(index_p, index_p + 1),) + + # display window is in reference units, apply display window and then map to array indices + # clamp w.r.t. 0 and processor shape `p` dim + hw = dw / 2 + index_p_start = max(self.index_mappings[-1](indices[-1] - hw), 0) + index_p_stop = min(self.index_mappings[-1](indices[-1] + hw), self.shape[-2]) + if index_p_start >= index_p_stop: + index_p_stop = index_p_start + 1 + + slices = [slice(index_p_start, index_p_stop)] + + if self.multi: + slices.insert(0, slice(None)) + + return tuple(slices) + + # + # # clamp w.r.t. processor shape + # + # dw = self.index_mappings[-1](self.display_window) + # + # if dw == 1: + # slices = [slice(index_p, index_p + 1)] + # + # else: + # # half window size + # hw = dw // 2 + # + # # for now assume just a single index provided that indicates x axis value + # start = max(index_p - hw, 0) + # stop = start + dw + # # also add window size of `p` dim so window_func output has the same number of datapoints + # if ( + # self.datapoints_window_func is not None + # and self.datapoints_window_size is not None + # ): + # stop += self.datapoints_window_size - 1 + # # TODO: pad with constant if we're using a window func and the index is near the end + # + # # TODO: uncomment this once we have resizeable buffers!! + # # stop = min(index_p + hw, self.shape[-2]) + # + # slices = [slice(start, stop)] + # + # if self.multi: + # # n - 2 dim is n_lines or n_scatters + # slices.insert(0, slice(None)) + # + # return tuple(slices) + def get(self, indices: tuple[Any, ...]): """ slices through all slider dims and outputs an array that can be used to set graphic data @@ -214,40 +273,45 @@ def get(self, indices: tuple[Any, ...]): # TODO: window function on the `p` n_datapoints dimension if self.display_window is not None: - # display window is interpreted using the index mapping for the `p` dim - dw = self.index_mappings[-1](self.display_window) - - if dw == 1: - slices = [slice(indices[-1], indices[-1] + 1)] - - else: - # half window size - hw = dw // 2 - - # for now assume just a single index provided that indicates x axis value - start = max(indices[-1] - hw, 0) - stop = start + dw - # also add window size of `p` dim so window_func output has the same number of datapoints - if ( - self.datapoints_window_func is not None - and self.datapoints_window_size is not None - ): - stop += self.datapoints_window_size - 1 - # TODO: pad with constant if we're using a window func and the index is near the end - - # TODO: uncomment this once we have resizeable buffers!! - # stop = min(indices[-1] + hw, self.shape[-2]) + slices = self._get_dw_slices(indices) - slices = [slice(start, stop)] - - if self.multi: - # n - 2 dim is n_lines or n_scatters - slices.insert(0, slice(None)) + # if self.display_window is not None: + # # display window is interpreted using the index mapping for the `p` dim + # dw = self.index_mappings[-1](self.display_window) + # + # if dw == 1: + # slices = [slice(indices[-1], indices[-1] + 1)] + # + # else: + # # half window size + # hw = dw // 2 + # + # # for now assume just a single index provided that indicates x axis value + # start = max(indices[-1] - hw, 0) + # stop = start + dw + # # also add window size of `p` dim so window_func output has the same number of datapoints + # if ( + # self.datapoints_window_func is not None + # and self.datapoints_window_size is not None + # ): + # stop += self.datapoints_window_size - 1 + # # TODO: pad with constant if we're using a window func and the index is near the end + # + # # TODO: uncomment this once we have resizeable buffers!! + # # stop = min(indices[-1] + hw, self.shape[-2]) + # + # slices = [slice(start, stop)] + # + # if self.multi: + # # n - 2 dim is n_lines or n_scatters + # slices.insert(0, slice(None)) # data that will be used for the graphical representation # a copy is made, if there were no window functions then this is a view of the original data graphic_data = window_output[tuple(slices)] + dw = self.index_mappings[-1](self.display_window) + # apply window function on the `p` n_datapoints dim if ( self.datapoints_window_func is not None @@ -308,7 +372,7 @@ class NDPositions: def __init__( self, - data, + data: Any, graphic: Type[ LineGraphic | LineCollection @@ -317,6 +381,7 @@ def __init__( | ScatterCollection | ImageGraphic ], + processor: type[NDPositionsProcessor] = NDPositionsProcessor, multi: bool = False, display_window: int = 10, window_funcs: tuple[WindowFuncCallable | None] | None = None, @@ -328,7 +393,7 @@ def __init__( if issubclass(graphic, LineCollection): multi = True - self._processor = NDPositionsProcessor( + self._processor = processor( data, multi=multi, display_window=display_window, @@ -420,7 +485,7 @@ def _create_graphic( if not self.processor.multi: raise ValueError - if self.processor.data.shape[-1] != 2: + if self.processor.shape[-1] != 2: raise ValueError image_data, x0, x_scale = self._create_heatmap_data(data_slice) From 596b8e76227eb7cb80cdb8364b0041f9f6b7a83c Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Fri, 13 Feb 2026 23:15:33 -0500 Subject: [PATCH 023/101] scatter collection updates --- fastplotlib/graphics/scatter_collection.py | 4 -- fastplotlib/layouts/_graphic_methods_mixin.py | 46 +++++++++---------- 2 files changed, 21 insertions(+), 29 deletions(-) diff --git a/fastplotlib/graphics/scatter_collection.py b/fastplotlib/graphics/scatter_collection.py index ac8cc307e..b8e7556ad 100644 --- a/fastplotlib/graphics/scatter_collection.py +++ b/fastplotlib/graphics/scatter_collection.py @@ -114,7 +114,6 @@ def __init__( self, data: np.ndarray | List[np.ndarray], colors: str | Sequence[str] | np.ndarray | Sequence[np.ndarray] = "w", - uniform_colors: bool = False, cmap: Sequence[str] | str = None, cmap_transform: np.ndarray | List = None, sizes: float | Sequence[float] = 5.0, @@ -122,7 +121,6 @@ def __init__( names: list[str] = None, metadata: Any = None, metadatas: Sequence[Any] | np.ndarray = None, - isolated_buffer: bool = True, kwargs_lines: list[dict] = None, **kwargs, ): @@ -291,12 +289,10 @@ def __init__( lg = ScatterGraphic( data=d, colors=_c, - uniform_color=uniform_colors, sizes=sizes, cmap=_cmap, name=_name, metadata=_m, - isolated_buffer=isolated_buffer, **kwargs_lines, ) diff --git a/fastplotlib/layouts/_graphic_methods_mixin.py b/fastplotlib/layouts/_graphic_methods_mixin.py index eda7b1492..bd01855bd 100644 --- a/fastplotlib/layouts/_graphic_methods_mixin.py +++ b/fastplotlib/layouts/_graphic_methods_mixin.py @@ -33,7 +33,7 @@ def add_image( cmap: str = "plasma", interpolation: str = "nearest", cmap_interpolation: str = "linear", - **kwargs + **kwargs, ) -> ImageGraphic: """ @@ -74,7 +74,7 @@ def add_image( cmap, interpolation, cmap_interpolation, - **kwargs + **kwargs, ) def add_image_volume( @@ -92,7 +92,7 @@ def add_image_volume( substep_size: float = 0.1, emissive: str | tuple | numpy.ndarray = (0, 0, 0), shininess: int = 30, - **kwargs + **kwargs, ) -> ImageVolumeGraphic: """ @@ -169,7 +169,7 @@ def add_image_volume( substep_size, emissive, shininess, - **kwargs + **kwargs, ) def add_line_collection( @@ -185,7 +185,7 @@ def add_line_collection( metadata: Any = None, metadatas: Union[Sequence[Any], numpy.ndarray] = None, kwargs_lines: list[dict] = None, - **kwargs + **kwargs, ) -> LineCollection: """ @@ -256,7 +256,7 @@ def add_line_collection( metadata, metadatas, kwargs_lines, - **kwargs + **kwargs, ) def add_line( @@ -268,7 +268,7 @@ def add_line( cmap_transform: Union[numpy.ndarray, Sequence] = None, color_mode: Literal["auto", "uniform", "vertex"] = "auto", size_space: str = "screen", - **kwargs + **kwargs, ) -> LineGraphic: """ @@ -322,7 +322,7 @@ def add_line( cmap_transform, color_mode, size_space, - **kwargs + **kwargs, ) def add_line_stack( @@ -339,7 +339,7 @@ def add_line_stack( separation: float = 10.0, separation_axis: str = "y", kwargs_lines: list[dict] = None, - **kwargs + **kwargs, ) -> LineStack: """ @@ -415,7 +415,7 @@ def add_line_stack( separation, separation_axis, kwargs_lines, - **kwargs + **kwargs, ) def add_mesh( @@ -434,7 +434,7 @@ def add_mesh( | numpy.ndarray ) = None, clim: tuple[float, float] = None, - **kwargs + **kwargs, ) -> MeshGraphic: """ @@ -488,7 +488,7 @@ def add_mesh( mapcoords, cmap, clim, - **kwargs + **kwargs, ) def add_polygon( @@ -505,7 +505,7 @@ def add_polygon( | numpy.ndarray ) = None, clim: tuple[float, float] | None = None, - **kwargs + **kwargs, ) -> PolygonGraphic: """ @@ -552,15 +552,13 @@ def add_scatter_collection( self, data: Union[numpy.ndarray, List[numpy.ndarray]], colors: Union[str, Sequence[str], numpy.ndarray, Sequence[numpy.ndarray]] = "w", - uniform_colors: bool = False, cmap: Union[Sequence[str], str] = None, cmap_transform: Union[numpy.ndarray, List] = None, - sizes: Union[float, Sequence[float]] = 2.0, + sizes: Union[float, Sequence[float]] = 5.0, name: str = None, names: list[str] = None, metadata: Any = None, metadatas: Union[Sequence[Any], numpy.ndarray] = None, - isolated_buffer: bool = True, kwargs_lines: list[dict] = None, **kwargs, ) -> ScatterCollection: @@ -617,7 +615,6 @@ def add_scatter_collection( ScatterCollection, data, colors, - uniform_colors, cmap, cmap_transform, sizes, @@ -625,7 +622,6 @@ def add_scatter_collection( names, metadata, metadatas, - isolated_buffer, kwargs_lines, **kwargs, ) @@ -652,7 +648,7 @@ def add_scatter( sizes: Union[float, numpy.ndarray, Sequence[float]] = 5, uniform_size: bool = True, size_space: str = "screen", - **kwargs + **kwargs, ) -> ScatterGraphic: """ @@ -782,7 +778,7 @@ def add_scatter( sizes, uniform_size, size_space, - **kwargs + **kwargs, ) def add_surface( @@ -799,7 +795,7 @@ def add_surface( | numpy.ndarray ) = None, clim: tuple[float, float] | None = None, - **kwargs + **kwargs, ) -> SurfaceGraphic: """ @@ -853,7 +849,7 @@ def add_text( screen_space: bool = True, offset: tuple[float] = (0, 0, 0), anchor: str = "middle-center", - **kwargs + **kwargs, ) -> TextGraphic: """ @@ -904,7 +900,7 @@ def add_text( screen_space, offset, anchor, - **kwargs + **kwargs, ) def add_vectors( @@ -914,7 +910,7 @@ def add_vectors( color: Union[str, Sequence[float], numpy.ndarray] = "w", size: float = None, vector_shape_options: dict = None, - **kwargs + **kwargs, ) -> VectorsGraphic: """ @@ -959,5 +955,5 @@ def add_vectors( color, size, vector_shape_options, - **kwargs + **kwargs, ) From db2431f47b5a2cdb21b917c36967aa0e9b7bacbe Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Fri, 13 Feb 2026 23:16:03 -0500 Subject: [PATCH 024/101] tootip handlers for ndpositions --- fastplotlib/widgets/nd_widget/_nd_positions.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions.py index 8d30fe37a..9a2d25048 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions.py @@ -1,3 +1,4 @@ +from functools import partial import inspect from typing import Literal, Callable, Any, Type from warnings import warn @@ -465,6 +466,13 @@ def indices(self, indices): self._indices = indices + def _tooltip_handler(self, graphic, pick_info): + if isinstance(self.graphic, (LineCollection, ScatterCollection)): + # get graphic within the collection + n_index = np.argwhere(self.graphic.graphics == graphic).item() + p_index = pick_info["vertex_index"] + return self.processor.format_tooltip(n_index, p_index) + def _create_graphic( self, graphic_cls: Type[ @@ -500,6 +508,11 @@ def _create_graphic( kwargs = dict() self._graphic = graphic_cls(data_slice, **kwargs) + if hasattr(self.processor, "format_tooltip"): + if isinstance(self._graphic, (LineCollection, ScatterCollection)): + for g in self._graphic.graphics: + g.tooltip_format = partial(self._tooltip_handler, g) + def _create_heatmap_data(self, data_slice) -> tuple[np.ndarray, float, float]: """return [n_rows, n_cols] shape data""" # assumes x vals in every row is the same, otherwise a heatmap representation makes no sense From 57d9a6ba1e3cfacdac5b2055e12b05b05ccac957 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 14 Feb 2026 03:34:03 -0500 Subject: [PATCH 025/101] refactoring, general NDPP_Pandas processor for any dataframe data --- fastplotlib/widgets/nd_widget/__init__.py | 2 + .../nd_widget/_nd_positions/__init__.py | 23 +++++ .../nd_widget/_nd_positions/_pandas.py | 94 +++++++++++++++++++ .../widgets/nd_widget/_nd_positions/_zarr.py | 4 + .../core.py} | 54 +++-------- .../nd_widget/{_nd_image.py => nd_image.py} | 2 +- .../{_processor_base.py => processor_base.py} | 13 +++ 7 files changed, 149 insertions(+), 43 deletions(-) create mode 100644 fastplotlib/widgets/nd_widget/_nd_positions/__init__.py create mode 100644 fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py create mode 100644 fastplotlib/widgets/nd_widget/_nd_positions/_zarr.py rename fastplotlib/widgets/nd_widget/{_nd_positions.py => _nd_positions/core.py} (92%) rename fastplotlib/widgets/nd_widget/{_nd_image.py => nd_image.py} (87%) rename fastplotlib/widgets/nd_widget/{_processor_base.py => processor_base.py} (96%) diff --git a/fastplotlib/widgets/nd_widget/__init__.py b/fastplotlib/widgets/nd_widget/__init__.py index e69de29bb..70c2e7621 100644 --- a/fastplotlib/widgets/nd_widget/__init__.py +++ b/fastplotlib/widgets/nd_widget/__init__.py @@ -0,0 +1,2 @@ +from .processor_base import NDProcessor +from ._nd_positions import NDPositions, NDPositionsProcessor, ndp_extras diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/__init__.py b/fastplotlib/widgets/nd_widget/_nd_positions/__init__.py new file mode 100644 index 000000000..03bb0e8f7 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_positions/__init__.py @@ -0,0 +1,23 @@ +import importlib + +from .core import NDPositions, NDPositionsProcessor + +class Extras: + pass + +ndp_extras = Extras() + + +for optional in ["pandas", "zarr"]: + try: + importlib.import_module(optional) + except ImportError: + pass + else: + module = importlib.import_module(f"._{optional}", "fastplotlib.widgets.nd_widget._nd_positions") + cls = getattr(module, f"NDPP_{optional.capitalize()}") + setattr( + ndp_extras, + f"NDPP_{optional.capitalize()}", + cls + ) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py new file mode 100644 index 000000000..de26c8a9d --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py @@ -0,0 +1,94 @@ +import numpy as np +import pandas as pd + +from .core import NDPositionsProcessor + + +class NDPP_Pandas(NDPositionsProcessor): + def __init__( + self, + data: pd.DataFrame, + columns: list[tuple[str, str] | tuple[str, str, str]], + tooltip_columns: list[str] = None, + max_display_datapoints: int = 1_000, + **kwargs, + ): + data = data + + self._columns = columns + + if tooltip_columns is not None: + if len(tooltip_columns) != len(self.columns): + raise ValueError + self._tooltip_columns = tooltip_columns + self._tooltip = True + else: + self._tooltip_columns = None + self._tooltip = False + + super().__init__( + data=data, + max_display_datapoints=max_display_datapoints, + **kwargs, + ) + + @property + def data(self) -> pd.DataFrame: + return self._data + + def _validate_data(self, data: pd.DataFrame): + if not isinstance(data, pd.DataFrame): + raise TypeError + + return data + + @property + def columns(self) -> list[tuple[str, str] | tuple[str, str, str]]: + return self._columns + + @property + def multi(self) -> bool: + return True + + @multi.setter + def multi(self, v): + pass + + @property + def shape(self) -> tuple[int, ...]: + # n_graphical_elements, n_timepoints, 2 + return len(self.columns), self.data.index.size, 2 + + @property + def ndim(self) -> int: + return len(self.shape) + + @property + def n_slider_dims(self) -> int: + return 1 + + @property + def tooltip(self) -> bool: + return self._tooltip + + def tooltip_format(self, n: int, p: int): + # datapoint index w.r.t. full data + p += self._slices[-1].start + return str(self.data[self._tooltip_columns[n]][p]) + + def get(self, indices: tuple[float | int, ...]) -> np.ndarray: + if not isinstance(indices, tuple): + raise TypeError(".get() must receive a tuple of float | int indices") + # assume no additional slider dims, only time slider dim + self._slices = self._get_dw_slices(indices) + + + gdata_shape = len(self.columns), self._slices[-1].stop - self._slices[-1].start, 3 + gdata = np.zeros(shape=gdata_shape, dtype=np.float32) + + for i, col in enumerate(self.columns): + gdata[i, :, :len(col)] = np.column_stack( + [self.data[c][self._slices[-1]] for c in col] + ) + + return gdata diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_zarr.py b/fastplotlib/widgets/nd_widget/_nd_positions/_zarr.py new file mode 100644 index 000000000..fb3bb7015 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_zarr.py @@ -0,0 +1,4 @@ +# placeholder + +class NDPP_Zarr: + pass diff --git a/fastplotlib/widgets/nd_widget/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions/core.py similarity index 92% rename from fastplotlib/widgets/nd_widget/_nd_positions.py rename to fastplotlib/widgets/nd_widget/_nd_positions/core.py index 9a2d25048..b95916ce8 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/core.py @@ -1,15 +1,13 @@ from functools import partial -import inspect from typing import Literal, Callable, Any, Type from warnings import warn import numpy as np -from numpy.typing import ArrayLike from numpy.lib.stride_tricks import sliding_window_view -from ...utils import subsample_array, ArrayProtocol +from ....utils import subsample_array, ArrayProtocol -from ...graphics import ( +from ....graphics import ( Graphic, ImageGraphic, LineGraphic, @@ -18,7 +16,7 @@ ScatterGraphic, ScatterCollection, ) -from ._processor_base import NDProcessor, WindowFuncCallable +from ..processor_base import NDProcessor, WindowFuncCallable # TODO: Maybe get rid of n_display_dims in NDProcessor, @@ -219,40 +217,6 @@ def _get_dw_slices(self, indices) -> tuple[slice] | tuple[slice, slice]: return tuple(slices) - # - # # clamp w.r.t. processor shape - # - # dw = self.index_mappings[-1](self.display_window) - # - # if dw == 1: - # slices = [slice(index_p, index_p + 1)] - # - # else: - # # half window size - # hw = dw // 2 - # - # # for now assume just a single index provided that indicates x axis value - # start = max(index_p - hw, 0) - # stop = start + dw - # # also add window size of `p` dim so window_func output has the same number of datapoints - # if ( - # self.datapoints_window_func is not None - # and self.datapoints_window_size is not None - # ): - # stop += self.datapoints_window_size - 1 - # # TODO: pad with constant if we're using a window func and the index is near the end - # - # # TODO: uncomment this once we have resizeable buffers!! - # # stop = min(index_p + hw, self.shape[-2]) - # - # slices = [slice(start, stop)] - # - # if self.multi: - # # n - 2 dim is n_lines or n_scatters - # slices.insert(0, slice(None)) - # - # return tuple(slices) - def get(self, indices: tuple[Any, ...]): """ slices through all slider dims and outputs an array that can be used to set graphic data @@ -370,10 +334,10 @@ def get(self, indices: tuple[Any, ...]): class NDPositions: - def __init__( self, data: Any, + *args, graphic: Type[ LineGraphic | LineCollection @@ -390,18 +354,24 @@ def __init__( index_mappings: tuple[Callable[[Any], int] | None] | None = None, max_display_datapoints: int = 1_000, graphic_kwargs: dict = None, + processor_kwargs: dict = None, ): if issubclass(graphic, LineCollection): multi = True + if processor_kwargs is None: + processor_kwargs = dict() + self._processor = processor( data, + *args, multi=multi, display_window=display_window, max_display_datapoints=max_display_datapoints, window_funcs=window_funcs, window_sizes=window_sizes, index_mappings=index_mappings, + **processor_kwargs, ) self._processor.p_max = 1_000 @@ -471,7 +441,7 @@ def _tooltip_handler(self, graphic, pick_info): # get graphic within the collection n_index = np.argwhere(self.graphic.graphics == graphic).item() p_index = pick_info["vertex_index"] - return self.processor.format_tooltip(n_index, p_index) + return self.processor.tooltip_format(n_index, p_index) def _create_graphic( self, @@ -508,7 +478,7 @@ def _create_graphic( kwargs = dict() self._graphic = graphic_cls(data_slice, **kwargs) - if hasattr(self.processor, "format_tooltip"): + if self.processor.tooltip: if isinstance(self._graphic, (LineCollection, ScatterCollection)): for g in self._graphic.graphics: g.tooltip_format = partial(self._tooltip_handler, g) diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/nd_image.py similarity index 87% rename from fastplotlib/widgets/nd_widget/_nd_image.py rename to fastplotlib/widgets/nd_widget/nd_image.py index f115e146e..4972db9d5 100644 --- a/fastplotlib/widgets/nd_widget/_nd_image.py +++ b/fastplotlib/widgets/nd_widget/nd_image.py @@ -1,6 +1,6 @@ from typing import Literal -from ._processor_base import NDProcessor +from .processor_base import NDProcessor class NDImageProcessor(NDProcessor): diff --git a/fastplotlib/widgets/nd_widget/_processor_base.py b/fastplotlib/widgets/nd_widget/processor_base.py similarity index 96% rename from fastplotlib/widgets/nd_widget/_processor_base.py rename to fastplotlib/widgets/nd_widget/processor_base.py index 225608cca..a1cd5311c 100644 --- a/fastplotlib/widgets/nd_widget/_processor_base.py +++ b/fastplotlib/widgets/nd_widget/processor_base.py @@ -55,6 +55,19 @@ def _validate_data(self, data: ArrayProtocol): return data + @property + def tooltip(self) -> bool: + """ + whether or not a custom tooltip formatter method exists + """ + return False + + def tooltip_format(self, *args) -> str | None: + """ + Override in subclass to format custom tooltips + """ + return None + @property def slider_dims(self): raise NotImplementedError From 47ec02add5b7904add70cfc4fbaa618a702e6a0b Mon Sep 17 00:00:00 2001 From: Kushal Kolar Date: Mon, 16 Feb 2026 05:45:18 -0500 Subject: [PATCH 026/101] separate array logic and graphic logic in `ImageWidget` (#868) * start separating iw plotting and array logic * some more basics down * comment * collapse into just having a window function, no frame_function * progress * placeholder for computing histogram * formatting * remove spaghetti * more progress * basics working :D * black * most of the basics work in iw * fix * progress * progress but still broken * flippin display dims works * camera scale must be positive for MIP rendering * a very difficult to encounter iterator bug! * patch iterator caveats * mostly worksgit status * add ArrayProtocol * rename * fixes * set camera orthogonal to xy plane when going from 3d -> 2d * naming, cleaning * cleanup, correct way to push and pop dims * quality of life improvements * new histogram lut tool * new hlut tool * imagewidget rgb toggle works * more progress * support rgb(a) image volumes * ImageGraphic cleanup * cleanup, docs * fix * updates * new per-data array properties work * black formatting * fixes and other things * typing tweaks * better iterator, fix bugs * fixes * show tooltips in right clck menu * ignore nans and inf for histogram * histogram of zeros * docstrings * fix imgui pixels * iw indices event handlers only get a tuple of the indices * bugfix * fix cmap setter * spatial_func better name * bugfix * hist specify quantile * fix typos (#991) * fix typos * add rendercanvas to intersphinx_mapping * nd-iw backup * correct ImageGraphic w.r.t. ndw * last fixes in ndi --- .github/workflows/docs-deploy.yml | 2 +- fastplotlib/graphics/_base.py | 9 + fastplotlib/graphics/features/_base.py | 4 +- .../graphics/features/_selection_features.py | 6 +- fastplotlib/graphics/image.py | 15 +- fastplotlib/graphics/image_volume.py | 14 +- .../graphics/selectors/_linear_region.py | 4 +- fastplotlib/graphics/utils.py | 15 +- fastplotlib/layouts/_figure.py | 34 +- fastplotlib/layouts/_plot_area.py | 5 +- fastplotlib/tools/_histogram_lut.py | 588 +++++----- fastplotlib/ui/_base.py | 4 +- .../ui/right_click_menus/_colormap_picker.py | 3 +- fastplotlib/utils/_protocols.py | 3 + fastplotlib/widgets/image_widget/__init__.py | 1 + .../widgets/image_widget/_nd_iw_backup.py | 1007 +++++++++++++++++ .../widgets/image_widget/_processor.py | 519 +++++++++ .../widgets/image_widget/_properties.py | 139 +++ fastplotlib/widgets/image_widget/_sliders.py | 91 +- 19 files changed, 2089 insertions(+), 374 deletions(-) create mode 100644 fastplotlib/widgets/image_widget/_nd_iw_backup.py create mode 100644 fastplotlib/widgets/image_widget/_processor.py create mode 100644 fastplotlib/widgets/image_widget/_properties.py diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 470e2e5a5..f17941405 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -49,7 +49,7 @@ jobs: - name: build docs run: | cd docs - RTD_BUILD=1 make html SPHINXOPTS="-W --keep-going" + DOCS_BUILD=1 make html SPHINXOPTS="-W --keep-going" # set environment variable `DOCS_VERSION_DIR` to either the pr-branch name, "dev", or the release version tag - name: set output pr diff --git a/fastplotlib/graphics/_base.py b/fastplotlib/graphics/_base.py index 47673cbc0..e0602e4e3 100644 --- a/fastplotlib/graphics/_base.py +++ b/fastplotlib/graphics/_base.py @@ -178,6 +178,7 @@ def __init__( self._alpha_mode = AlphaMode(alpha_mode) self._visible = Visible(visible) self._block_events = False + self._block_handlers = list() self._axes: Axes = None @@ -274,6 +275,11 @@ def block_events(self) -> bool: def block_events(self, value: bool): self._block_events = value + @property + def block_handlers(self) -> list: + """Used to block event handlers for a graphic and prevent recursion.""" + return self._block_handlers + @property def world_object(self) -> pygfx.WorldObject: """Associated pygfx WorldObject. Always returns a proxy, real object cannot be accessed directly.""" @@ -440,6 +446,9 @@ def _handle_event(self, callback, event: pygfx.Event): if self.block_events: return + if callback in self._block_handlers: + return + if event.type in self._features: # for feature events event._target = self.world_object diff --git a/fastplotlib/graphics/features/_base.py b/fastplotlib/graphics/features/_base.py index 76352b4ef..68fe54c33 100644 --- a/fastplotlib/graphics/features/_base.py +++ b/fastplotlib/graphics/features/_base.py @@ -318,7 +318,7 @@ def __repr__(self): def block_reentrance(set_value): # decorator to block re-entrant set_value methods # useful when creating complex, circular, bidirectional event graphs - def set_value_wrapper(self: GraphicFeature, graphic_or_key, value): + def set_value_wrapper(self: GraphicFeature, graphic_or_key, value, **kwargs): """ wraps GraphicFeature.set_value @@ -334,7 +334,7 @@ def set_value_wrapper(self: GraphicFeature, graphic_or_key, value): try: # block re-execution of set_value until it has *fully* finished executing self._reentrant_block = True - set_value(self, graphic_or_key, value) + set_value(self, graphic_or_key, value, **kwargs) except Exception as exc: # raise original exception raise exc # set_value has raised. The line above and the lines 2+ steps below are probably more relevant! diff --git a/fastplotlib/graphics/features/_selection_features.py b/fastplotlib/graphics/features/_selection_features.py index 9b30dd70c..1f049f0cb 100644 --- a/fastplotlib/graphics/features/_selection_features.py +++ b/fastplotlib/graphics/features/_selection_features.py @@ -118,7 +118,7 @@ def axis(self) -> str: return self._axis @block_reentrance - def set_value(self, selector, value: Sequence[float]): + def set_value(self, selector, value: Sequence[float], *, change: str = "full"): """ Set start, stop range of selector @@ -182,7 +182,9 @@ def set_value(self, selector, value: Sequence[float]): if len(self._event_handlers) < 1: return - event = GraphicFeatureEvent(self._property_name, {"value": self.value}) + event = GraphicFeatureEvent( + self._property_name, {"value": self.value, "change": change} + ) event.get_selected_indices = selector.get_selected_indices event.get_selected_data = selector.get_selected_data diff --git a/fastplotlib/graphics/image.py b/fastplotlib/graphics/image.py index 760b856d2..7b670d531 100644 --- a/fastplotlib/graphics/image.py +++ b/fastplotlib/graphics/image.py @@ -162,10 +162,11 @@ def __init__( self._cmap_interpolation = ImageCmapInterpolation(cmap_interpolation) # set map to None for RGB images - if self._data.value.ndim > 2: + if self._data.value.ndim == 3: self._cmap = None _map = None - else: + + elif self._data.value.ndim == 2: # use TextureMap for grayscale images self._cmap = ImageCmap(cmap) @@ -174,6 +175,12 @@ def __init__( filter=self._cmap_interpolation.value, wrap="clamp-to-edge", ) + else: + raise ValueError( + f"ImageGraphic `data` must have 2 dimensions for grayscale images, or 3 dimensions for RGB(A) images.\n" + f"You have passed a a data array with: {self._data.value.ndim} dimensions, " + f"and of shape: {self._data.value.shape}" + ) # one common material is used for every Texture chunk self._material = pygfx.ImageBasicMaterial( @@ -275,8 +282,6 @@ def cmap(self) -> str | None: if self._cmap is not None: return self._cmap.value - return None - @cmap.setter def cmap(self, name: str): if self.data.value.ndim > 2: @@ -312,7 +317,7 @@ def interpolation(self, value: str): @property def cmap_interpolation(self) -> str: - """cmap interpolation method""" + """cmap interpolation method, 'linear' or 'nearest'. Used only for grayscale images""" return self._cmap_interpolation.value @cmap_interpolation.setter diff --git a/fastplotlib/graphics/image_volume.py b/fastplotlib/graphics/image_volume.py index a3b379492..3d2d064e8 100644 --- a/fastplotlib/graphics/image_volume.py +++ b/fastplotlib/graphics/image_volume.py @@ -204,18 +204,24 @@ def __init__( self._vmax = ImageVmax(vmax) self._interpolation = ImageInterpolation(interpolation) + self._cmap_interpolation = ImageCmapInterpolation(cmap_interpolation) - # TODO: I'm assuming RGB volume images aren't supported??? # use TextureMap for grayscale images self._cmap = ImageCmap(cmap) - self._cmap_interpolation = ImageCmapInterpolation(cmap_interpolation) - self._texture_map = pygfx.TextureMap( self._cmap.texture, filter=self._cmap_interpolation.value, wrap="clamp-to-edge", ) + if self._data.value.ndim not in (3, 4): + raise ValueError( + f"ImageVolumeGraphic `data` must have 3 dimensions for grayscale images, " + f"or 4 dimensions for RGB(A) images.\n" + f"You have passed a a data array with: {self._data.value.ndim} dimensions, " + f"and of shape: {self._data.value.shape}" + ) + self._plane = VolumeSlicePlane(plane) self._threshold = VolumeIsoThreshold(threshold) self._step_size = VolumeIsoStepSize(step_size) @@ -301,7 +307,7 @@ def mode(self, mode: str): @property def cmap(self) -> str: - """Get or set colormap name""" + """Get or set colormap name, only used for grayscale images""" return self._cmap.value @cmap.setter diff --git a/fastplotlib/graphics/selectors/_linear_region.py b/fastplotlib/graphics/selectors/_linear_region.py index 70a8dffa8..8a8583ae9 100644 --- a/fastplotlib/graphics/selectors/_linear_region.py +++ b/fastplotlib/graphics/selectors/_linear_region.py @@ -472,9 +472,9 @@ def _move_graphic(self, move_info: MoveInfo): if move_info.source == self._edges[0]: # change only left or bottom bound new_min = min(cur_min + delta, cur_max) - self._selection.set_value(self, (new_min, cur_max)) + self._selection.set_value(self, (new_min, cur_max), change="min") elif move_info.source == self._edges[1]: # change only right or top bound new_max = max(cur_max + delta, cur_min) - self._selection.set_value(self, (cur_min, new_max)) + self._selection.set_value(self, (cur_min, new_max), change="max") diff --git a/fastplotlib/graphics/utils.py b/fastplotlib/graphics/utils.py index 6be5aefc4..f32d80809 100644 --- a/fastplotlib/graphics/utils.py +++ b/fastplotlib/graphics/utils.py @@ -1,13 +1,16 @@ from contextlib import contextmanager +from typing import Callable, Iterable from ._base import Graphic @contextmanager -def pause_events(*graphics: Graphic): +def pause_events(*graphics: Graphic, event_handlers: Iterable[Callable] = None): """ Context manager for pausing Graphic events. + Optionally pass in only specific event handlers which are blocked. Other events for the graphic will not be blocked. + Examples -------- @@ -30,8 +33,14 @@ def pause_events(*graphics: Graphic): original_vals = [g.block_events for g in graphics] for g in graphics: - g.block_events = True + if event_handlers is not None: + g.block_handlers.extend([e for e in event_handlers]) + else: + g.block_events = True yield for g, value in zip(graphics, original_vals): - g.block_events = value + if event_handlers is not None: + g.block_handlers.clear() + else: + g.block_events = value diff --git a/fastplotlib/layouts/_figure.py b/fastplotlib/layouts/_figure.py index 79b5be3a8..00b915b1f 100644 --- a/fastplotlib/layouts/_figure.py +++ b/fastplotlib/layouts/_figure.py @@ -539,7 +539,7 @@ def _render(self, draw=True): # call the animation functions before render self._call_animate_functions(self._animate_funcs_pre) - for subplot in self: + for subplot in self._subplots.ravel(): subplot._render() # overlay render pass @@ -606,14 +606,14 @@ def show( sidecar_kwargs = dict() # flip y-axis if ImageGraphics are present - for subplot in self: + for subplot in self._subplots.ravel(): for g in subplot.graphics: if isinstance(g, ImageGraphic): subplot.camera.local.scale_y *= -1 break if autoscale: - for subplot in self: + for subplot in self._subplots.ravel(): if maintain_aspect is None: _maintain_aspect = subplot.camera.maintain_aspect else: @@ -622,7 +622,7 @@ def show( # set axes visibility if False if not axes_visible: - for subplot in self: + for subplot in self._subplots.ravel(): subplot.axes.visible = False # parse based on canvas type @@ -646,15 +646,15 @@ def show( elif self.canvas.__class__.__name__ == "OffscreenRenderCanvas": # for test and docs gallery screenshots self._fpl_reset_layout() - for subplot in self: + for subplot in self._subplots.ravel(): subplot.axes.update_using_camera() # render call is blocking only on github actions for some reason, # but not for rtd build, this is a workaround # for CI tests, the render call works if it's in test_examples # but it is necessary for the gallery images too so that's why this check is here - if "RTD_BUILD" in os.environ.keys(): - if os.environ["RTD_BUILD"] == "1": + if "DOCS_BUILD" in os.environ.keys(): + if os.environ["DOCS_BUILD"] == "1": self._render() else: # assume GLFW @@ -770,7 +770,7 @@ def clear_animations(self, removal: str = None): def clear(self): """Clear all Subplots""" - for subplot in self: + for subplot in self._subplots.ravel(): subplot.clear() def export_numpy(self, rgb: bool = False) -> np.ndarray: @@ -929,18 +929,20 @@ def __getitem__(self, index: str | int | tuple[int, int]) -> Subplot: return subplot raise IndexError(f"no subplot with given name: {index}") + if isinstance(index, (int, np.integer)): + return self._subplots.ravel()[index] + if isinstance(self.layout, GridLayout): return self._subplots[index[0], index[1]] - return self._subplots[index] + raise TypeError( + f"Can index figure using subplot name, numerical subplot index, or a " + f"tuple[int, int] if the layout is a grid" + ) def __iter__(self): - self._current_iter = iter(range(len(self))) - return self - - def __next__(self) -> Subplot: - pos = self._current_iter.__next__() - return self._subplots.ravel()[pos] + for subplot in self._subplots.ravel(): + yield subplot def __len__(self): """number of subplots""" @@ -955,6 +957,6 @@ def __repr__(self): return ( f"fastplotlib.{self.__class__.__name__}" f" Subplots:\n" - f"\t{newline.join(subplot.__str__() for subplot in self)}" + f"\t{newline.join(subplot.__str__() for subplot in self._subplots.ravel())}" f"\n" ) diff --git a/fastplotlib/layouts/_plot_area.py b/fastplotlib/layouts/_plot_area.py index f83dcfbcb..405a01546 100644 --- a/fastplotlib/layouts/_plot_area.py +++ b/fastplotlib/layouts/_plot_area.py @@ -233,7 +233,10 @@ def controller(self, new_controller: str | pygfx.Controller): # pygfx plans on refactoring viewports anyways if self.parent is not None: if self.parent.__class__.__name__.endswith("Figure"): - for subplot in self.parent: + # always use figure._subplots.ravel() in internal fastplotlib code + # otherwise if we use `for subplot in figure`, this could conflict + # with a user's iterator where they are doing `for subplot in figure` !!! + for subplot in self.parent._subplots.ravel(): if subplot.camera in cameras_list: new_controller.register_events(subplot.viewport) subplot._controller = new_controller diff --git a/fastplotlib/tools/_histogram_lut.py b/fastplotlib/tools/_histogram_lut.py index d651137da..8edfb046b 100644 --- a/fastplotlib/tools/_histogram_lut.py +++ b/fastplotlib/tools/_histogram_lut.py @@ -6,424 +6,412 @@ import pygfx -from ..utils import subsample_array +from ..utils import subsample_array, RenderQueue from ..graphics import LineGraphic, ImageGraphic, ImageVolumeGraphic, TextGraphic from ..graphics.utils import pause_events from ..graphics._base import Graphic +from ..graphics.features import GraphicFeatureEvent from ..graphics.selectors import LinearRegionSelector -def _get_image_graphic_events(image_graphic: ImageGraphic) -> list[str]: - """Small helper function to return the relevant events for an ImageGraphic""" - events = ["vmin", "vmax"] +def _format_value(value: float): + abs_val = abs(value) + if abs_val < 0.01 or abs_val > 9_999: + return f"{value:.2e}" + else: + return f"{value:.2f}" - if not image_graphic.data.value.ndim > 2: - events.append("cmap") - # if RGB(A), do not add cmap - - return events - - -# TODO: This is a widget, we can think about a BaseWidget class later if necessary class HistogramLUTTool(Graphic): _fpl_support_tooltip = False def __init__( self, - data: np.ndarray, - images: ( - ImageGraphic - | ImageVolumeGraphic - | Sequence[ImageGraphic | ImageVolumeGraphic] - ), - nbins: int = 100, - flank_divisor: float = 5.0, + histogram: tuple[np.ndarray, np.ndarray], + images: ImageGraphic | ImageVolumeGraphic | Sequence[ImageGraphic | ImageVolumeGraphic] | None = None, **kwargs, ): """ - HistogramLUT tool that can be used to control the vmin, vmax of ImageGraphics or ImageVolumeGraphics. - If used to control multiple images or image volumes it is assumed that they share a representation of - the same data, and that their histogram, vmin, and vmax are identical. For example, displaying a - ImageVolumeGraphic and several images that represent slices of the same volume data. + A histogram tool that allows adjusting the vmin, vmax of images. + Also allows changing the cmap LUT for grayscale images and displays a colorbar. Parameters ---------- - data: np.ndarray - - images: ImageGraphic | ImageVolumeGraphic | tuple[ImageGraphic | ImageVolumeGraphic] - - nbins: int, defaut 100. - Total number of bins used in the histogram + histogram: tuple[np.ndarray, np.ndarray] + [frequency, bin_edges], must be 100 bins - flank_divisor: float, default 5.0. - Fraction of empty histogram bins on the tails of the distribution set `np.inf` for no flanks + images: ImageGraphic | ImageVolumeGraphic | Sequence[ImageGraphic | ImageVolumeGraphic] + the images that are managed by the histogram tool - kwargs: passed to ``Graphic`` + kwargs: + passed to ``Graphic`` """ - super().__init__(**kwargs) - - self._nbins = nbins - self._flank_divisor = flank_divisor - - if isinstance(images, (ImageGraphic, ImageVolumeGraphic)): - images = (images,) - elif isinstance(images, Sequence): - if not all( - [isinstance(ig, (ImageGraphic, ImageVolumeGraphic)) for ig in images] - ): - raise TypeError( - f"`images` argument must be an ImageGraphic, ImageVolumeGraphic, or a " - f"tuple or list or ImageGraphic | ImageVolumeGraphic" - ) - else: - raise TypeError( - f"`images` argument must be an ImageGraphic, ImageVolumeGraphic, or a " - f"tuple or list or ImageGraphic | ImageVolumeGraphic" - ) - self._images = images + super().__init__(**kwargs) - self._data = weakref.proxy(data) + if len(histogram) != 2: + raise TypeError - self._scale_factor: float = 1.0 + self._block_reentrance = False + self._images = list() - hist, edges, hist_scaled, edges_flanked = self._calculate_histogram(data) + self._bin_centers_flanked = np.zeros(120, dtype=np.float64) + self._freq_flanked = np.zeros(120, dtype=np.float32) - line_data = np.column_stack([hist_scaled, edges_flanked]) + # 100 points for the histogram, 10 points on each side for the flank + line_data = np.column_stack( + [np.zeros(120, dtype=np.float32), np.arange(0, 120)] + ) - self._histogram_line = LineGraphic( - line_data, colors=(0.8, 0.8, 0.8), alpha_mode="solid", offset=(0, 0, -1) + # line that displays the histogram + self._line = LineGraphic( + line_data, colors=(0.8, 0.8, 0.8), alpha_mode="solid", offset=(1, 0, 0) + ) + self._line.world_object.local.scale_x = -1 + + # vmin, vmax selector + self._selector = LinearRegionSelector( + selection=(10, 110), + limits=(0, 119), + size=1.5, + center=0.5, # frequency data are normalized between 0-1 + axis="y", + parent=self._line, ) - bounds = (edges[0] * self._scale_factor, edges[-1] * self._scale_factor) - limits = (edges_flanked[0], edges_flanked[-1]) - size = 120 # since it's scaled to 100 - origin = (hist_scaled.max() / 2, 0) + self._selector.add_event_handler(self._selector_event_handler, "selection") - self._linear_region_selector = LinearRegionSelector( - selection=bounds, - limits=limits, - size=size, - center=origin[0], - axis="y", - parent=self._histogram_line, + self._colorbar = ImageGraphic( + data=np.zeros([120, 2]), interpolation="linear", offset=(1.5, 0, 0) ) - self._vmin = self.images[0].vmin - self._vmax = self.images[0].vmax + # make the colorbar thin + self._colorbar.world_object.local.scale_x = 0.15 + self._colorbar.add_event_handler(self._open_cmap_picker, "click") - # there will be a small difference with the histogram edges so this makes them both line up exactly - self._linear_region_selector.selection = ( - self._vmin * self._scale_factor, - self._vmax * self._scale_factor, + # colorbar ruler + self._ruler = pygfx.Ruler( + end_pos=(0, 119, 0), + alpha_mode="solid", + render_queue=RenderQueue.axes, + tick_side="right", + tick_marker="tick_right", + tick_format=self._ruler_tick_map, + min_tick_distance=10, ) + self._ruler.local.x = 1.75 - vmin_str, vmax_str = self._get_vmin_vmax_str() + # TODO: need to auto-scale using the text so it appears nicely, will do later + self._ruler.visible = False self._text_vmin = TextGraphic( - text=vmin_str, + text="", font_size=16, - offset=(0, 0, 0), anchor="top-left", outline_color="black", outline_thickness=0.5, alpha_mode="solid", ) - + # this is to make sure clicking text doesn't conflict with the selector tool + # since the text appears near the selector tool self._text_vmin.world_object.material.pick_write = False self._text_vmax = TextGraphic( - text=vmax_str, + text="", font_size=16, - offset=(0, 0, 0), anchor="bottom-left", outline_color="black", outline_thickness=0.5, alpha_mode="solid", ) - self._text_vmax.world_object.material.pick_write = False - widget_wo = pygfx.Group() - widget_wo.add( - self._histogram_line.world_object, - self._linear_region_selector.world_object, + # add all the world objects to a pygfx.Group + wo = pygfx.Group() + wo.add( + self._line.world_object, + self._selector.world_object, + self._colorbar.world_object, + self._ruler, self._text_vmin.world_object, self._text_vmax.world_object, ) + self._set_world_object(wo) - self._set_world_object(widget_wo) + # for convenience, a list that stores all the graphics managed by the histogram LUT tool + self._children = [ + self._line, + self._selector, + self._colorbar, + self._text_vmin, + self._text_vmax, + ] - self.world_object.local.scale_x *= -1 + # set histogram + self.histogram = histogram - self._text_vmin.offset = (-120, self._linear_region_selector.selection[0], 0) + # set the images + self.images = images - self._text_vmax.offset = (-120, self._linear_region_selector.selection[1], 0) + def _fpl_add_plot_area_hook(self, plot_area): + self._plot_area = plot_area - self._linear_region_selector.add_event_handler( - self._linear_region_handler, "selection" - ) + for child in self._children: + # need all of them to call the add_plot_area_hook so that events are connected correctly + # example, the linear region selector needs all the canvas events to be connected + child._fpl_add_plot_area_hook(plot_area) - ig_events = _get_image_graphic_events(self.images[0]) + if hasattr(self._plot_area, "size"): + # if it's in a dock area + self._plot_area.size = 80 - for ig in self.images: - ig.add_event_handler(self._image_cmap_handler, *ig_events) + # disable the controller in this plot area + self._plot_area.controller.enabled = False + self._plot_area.auto_scale(maintain_aspect=False) - # colorbar for grayscale images - if self.images[0].cmap is not None: - self._colorbar: ImageGraphic = self._make_colorbar(edges_flanked) - self._colorbar.add_event_handler(self._open_cmap_picker, "click") + # tick text for colorbar ruler doesn't show without this call + self._ruler.update(plot_area.camera, plot_area.canvas.get_logical_size()) - self.world_object.add(self._colorbar.world_object) - else: - self._colorbar = None - self._cmap = None + def _ruler_tick_map(self, bin_index, *args): + return f"{self._bin_centers_flanked[int(bin_index)]:.2f}" - def _make_colorbar(self, edges_flanked) -> ImageGraphic: - # use the histogram edge values as data for an - # image with 2 columns, this will be our colorbar! - colorbar_data = np.column_stack( - [ - np.linspace( - edges_flanked[0], edges_flanked[-1], ceil(np.ptp(edges_flanked)) - ) - ] - * 2 - ).astype(np.float32) - - colorbar_data /= self._scale_factor - - cbar = ImageGraphic( - data=colorbar_data, - vmin=self.vmin, - vmax=self.vmax, - cmap=self.images[0].cmap, - interpolation="linear", - offset=(-55, edges_flanked[0], -1), - ) + @property + def histogram(self) -> tuple[np.ndarray, np.ndarray]: + """histogram [frequency, bin_centers]. Frequency is flanked by 10 zeros on both sides""" + return self._freq_flanked, self._bin_centers_flanked - cbar.world_object.world.scale_x = 20 - self._cmap = self.images[0].cmap + @histogram.setter + def histogram( + self, histogram: tuple[np.ndarray, np.ndarray], limits: tuple[int, int] = None + ): + """set histogram with pre-compuated [frequency, edges], must have exactly 100 bins""" - return cbar + freq, edges = histogram - def _get_vmin_vmax_str(self) -> tuple[str, str]: - if self.vmin < 0.001 or self.vmin > 99_999: - vmin_str = f"{self.vmin:.2e}" - else: - vmin_str = f"{self.vmin:.2f}" + if freq.max() > 0: + # if the histogram is made from an empty array, then the max freq will be 0 + # we don't want to divide by 0 because then we just get nans + freq = freq / freq.max() - if self.vmax < 0.001 or self.vmax > 99_999: - vmax_str = f"{self.vmax:.2e}" - else: - vmax_str = f"{self.vmax:.2f}" + bin_centers = 0.5 * (edges[1:] + edges[:-1]) - return vmin_str, vmax_str + step = bin_centers[1] - bin_centers[0] - def _fpl_add_plot_area_hook(self, plot_area): - self._plot_area = plot_area - self._linear_region_selector._fpl_add_plot_area_hook(plot_area) - self._histogram_line._fpl_add_plot_area_hook(plot_area) + under_flank = np.linspace(bin_centers[0] - step * 10, bin_centers[0] - step, 10) + over_flank = np.linspace( + bin_centers[-1] + step, bin_centers[-1] + step * 10, 10 + ) + self._bin_centers_flanked[:] = np.concatenate( + [under_flank, bin_centers, over_flank] + ) + + self._freq_flanked[10:110] = freq - self._plot_area.auto_scale() - self._plot_area.controller.enabled = True + self._line.data[:, 0] = self._freq_flanked + self._colorbar.data = np.column_stack( + [self._bin_centers_flanked, self._bin_centers_flanked] + ) - def _calculate_histogram(self, data): + # self.vmin, self.vmax = bin_centers[0], bin_centers[-1] - # get a subsampled view of this array - data_ss = subsample_array(data, max_size=int(1e6)) # 1e6 is default - hist, edges = np.histogram(data_ss, bins=self._nbins) + if hasattr(self, "plot_area"): + self._ruler.update( + self._plot_area.camera, self._plot_area.canvas.get_logical_size() + ) - # used if data ptp <= 10 because event things get weird - # with tiny world objects due to floating point error - # so if ptp <= 10, scale up by a factor - data_interval = edges[-1] - edges[0] - self._scale_factor: int = max(1, 100 * int(10 / data_interval)) + @property + def images(self) -> tuple[ImageGraphic | ImageVolumeGraphic, ...] | None: + """get or set the managed images""" + return tuple(self._images) - edges = edges * self._scale_factor + @images.setter + def images(self, new_images: ImageGraphic | ImageVolumeGraphic | Sequence[ImageGraphic | ImageVolumeGraphic] | None): + self._disconnect_images() + self._images.clear() - bin_width = edges[1] - edges[0] + if new_images is None: + return - flank_nbins = int(self._nbins / self._flank_divisor) - flank_size = flank_nbins * bin_width + if isinstance(new_images, (ImageGraphic, ImageVolumeGraphic)): + new_images = [new_images] - flank_left = np.arange(edges[0] - flank_size, edges[0], bin_width) - flank_right = np.arange( - edges[-1] + bin_width, edges[-1] + flank_size, bin_width - ) + if not all( + [ + isinstance(image, (ImageGraphic, ImageVolumeGraphic)) + for image in new_images + ] + ): + raise TypeError - edges_flanked = np.concatenate((flank_left, edges, flank_right)) + for image in new_images: + if image.cmap is not None: + self._colorbar.visible = True + break + else: + self._colorbar.visible = False - hist_flanked = np.concatenate( - (np.zeros(flank_nbins), hist, np.zeros(flank_nbins)) - ) + self._images = list(new_images) - # scale 0-100 to make it easier to see - # float32 data can produce unnecessarily high values - hist_scale_value = hist_flanked.max() - if np.allclose(hist_scale_value, 0): - hist_scale_value = 1 - hist_scaled = hist_flanked / (hist_scale_value / 100) + # reset vmin, vmax using first image + self.vmin = self._images[0].vmin + self.vmax = self._images[0].vmax - if edges_flanked.size > hist_scaled.size: - # we don't care about accuracy here so if it's off by 1-2 bins that's fine - edges_flanked = edges_flanked[: hist_scaled.size] + if self._images[0].cmap is not None: + self._colorbar.cmap = self._images[0].cmap - return hist, edges, hist_scaled, edges_flanked + # connect event handlers + for image in self._images: + image.add_event_handler(self._image_event_handler, "vmin", "vmax") + image.add_event_handler(self._disconnect_images, "deleted") + if image.cmap is not None: + image.add_event_handler( + self._image_event_handler, "vmin", "vmax", "cmap" + ) - def _linear_region_handler(self, ev): - # must use world coordinate values directly from selection() - # otherwise the linear region bounds jump to the closest bin edges - selected_ixs = self._linear_region_selector.selection - vmin, vmax = selected_ixs[0], selected_ixs[1] - vmin, vmax = vmin / self._scale_factor, vmax / self._scale_factor - self.vmin, self.vmax = vmin, vmax + def _disconnect_images(self, *args): + """disconnect event handlers of the managed images""" + for image in self._images: + for ev, handlers in image.event_handlers: + if self._image_event_handler in handlers: + image.remove_event_handler(self._image_event_handler, ev) - def _image_cmap_handler(self, ev): - setattr(self, ev.type, ev.info["value"]) + def _image_event_handler(self, ev): + """when the image vmin, vmax, or cmap changes it will update the HistogramLUTTool""" + new_value = ev.info["value"] + setattr(self, ev.type, new_value) @property def cmap(self) -> str: - return self._cmap + """get or set the colormap, only for grayscale images""" + return self._colorbar.cmap @cmap.setter def cmap(self, name: str): - if self._colorbar is None: + if self._block_reentrance: return - with pause_events(*self.images): - for ig in self.images: - ig.cmap = name + if name is None: + return - self._cmap = name + self._block_reentrance = True + try: self._colorbar.cmap = name + with pause_events( + *self._images, event_handlers=[self._image_event_handler] + ): + for image in self._images: + if image.cmap is None: + # rgb(a) images have no cmap + continue + + image.cmap = name + except Exception as exc: + # raise original exception + raise exc # vmax setter has raised. The lines above below are probably more relevant! + finally: + # set_value has finished executing, now allow future executions + self._block_reentrance = False + @property def vmin(self) -> float: - return self._vmin + """get or set the vmin, the lower contrast limit""" + # no offset or rotation so we can directly use the world space selection value + index = int(self._selector.selection[0]) + return self._bin_centers_flanked[index] @vmin.setter def vmin(self, value: float): - with pause_events(self._linear_region_selector, *self.images): - # must use world coordinate values directly from selection() - # otherwise the linear region bounds jump to the closest bin edges - self._linear_region_selector.selection = ( - value * self._scale_factor, - self._linear_region_selector.selection[1], - ) - for ig in self.images: - ig.vmin = value + if self._block_reentrance: + return + self._block_reentrance = True + try: + index_min = np.searchsorted(self._bin_centers_flanked, value) + with pause_events( + self._selector, + *self._images, + event_handlers=[ + self._selector_event_handler, + self._image_event_handler, + ], + ): + self._selector.selection = (index_min, self._selector.selection[1]) - self._vmin = value - if self._colorbar is not None: - self._colorbar.vmin = value + self._colorbar.vmin = value - vmin_str, vmax_str = self._get_vmin_vmax_str() - self._text_vmin.offset = (-120, self._linear_region_selector.selection[0], 0) - self._text_vmin.text = vmin_str + self._text_vmin.text = _format_value(value) + self._text_vmin.offset = (-0.45, self._selector.selection[0], 0) + + for image in self._images: + image.vmin = value + + except Exception as exc: + # raise original exception + raise exc # vmax setter has raised. The lines above below are probably more relevant! + finally: + # set_value has finished executing, now allow future executions + self._block_reentrance = False @property def vmax(self) -> float: - return self._vmax + """get or set the vmax, the upper contrast limit""" + # no offset or rotation so we can directly use the world space selection value + index = int(self._selector.selection[1]) + return self._bin_centers_flanked[index] @vmax.setter def vmax(self, value: float): - with pause_events(self._linear_region_selector, *self.images): - # must use world coordinate values directly from selection() - # otherwise the linear region bounds jump to the closest bin edges - self._linear_region_selector.selection = ( - self._linear_region_selector.selection[0], - value * self._scale_factor, - ) - - for ig in self.images: - ig.vmax = value - - self._vmax = value - if self._colorbar is not None: - self._colorbar.vmax = value - - vmin_str, vmax_str = self._get_vmin_vmax_str() - self._text_vmax.offset = (-120, self._linear_region_selector.selection[1], 0) - self._text_vmax.text = vmax_str - - def set_data(self, data, reset_vmin_vmax: bool = True): - hist, edges, hist_scaled, edges_flanked = self._calculate_histogram(data) - - line_data = np.column_stack([hist_scaled, edges_flanked]) - - # set x and y vals - self._histogram_line.data[:, :2] = line_data - - bounds = (edges[0], edges[-1]) - limits = (edges_flanked[0], edges_flanked[-11]) - origin = (hist_scaled.max() / 2, 0) - - if reset_vmin_vmax: - # reset according to the new data - self._linear_region_selector.limits = limits - self._linear_region_selector.selection = bounds - else: - with pause_events(self._linear_region_selector, *self.images): - # don't change the current selection - self._linear_region_selector.limits = limits - - self._data = weakref.proxy(data) - - if self._colorbar is not None: - self._colorbar.clear_event_handlers() - self.world_object.remove(self._colorbar.world_object) - - if self.images[0].cmap is not None: - self._colorbar: ImageGraphic = self._make_colorbar(edges_flanked) - self._colorbar.add_event_handler(self._open_cmap_picker, "click") + if self._block_reentrance: + return - self.world_object.add(self._colorbar.world_object) - else: - self._colorbar = None - self._cmap = None + self._block_reentrance = True + try: + index_max = np.searchsorted(self._bin_centers_flanked, value) + with pause_events( + self._selector, + *self._images, + event_handlers=[ + self._selector_event_handler, + self._image_event_handler, + ], + ): + self._selector.selection = (self._selector.selection[0], index_max) - # reset plotarea dims - self._plot_area.auto_scale() + self._colorbar.vmax = value - @property - def images(self) -> tuple[ImageGraphic | ImageVolumeGraphic]: - return self._images + self._text_vmax.text = _format_value(value) + self._text_vmax.offset = (-0.45, self._selector.selection[1], 0) - @images.setter - def images(self, images): - if isinstance(images, (ImageGraphic, ImageVolumeGraphic)): - images = (images,) - elif isinstance(images, Sequence): - if not all( - [isinstance(ig, (ImageGraphic, ImageVolumeGraphic)) for ig in images] - ): - raise TypeError( - f"`images` argument must be an ImageGraphic, ImageVolumeGraphic, or a " - f"tuple or list or ImageGraphic | ImageVolumeGraphic" - ) - else: - raise TypeError( - f"`images` argument must be an ImageGraphic, ImageVolumeGraphic, or a " - f"tuple or list or ImageGraphic | ImageVolumeGraphic" - ) + for image in self._images: + image.vmax = value - if self._images is not None: - for ig in self._images: - # cleanup events from current image graphics - ig_events = _get_image_graphic_events(ig) - ig.remove_event_handler(self._image_cmap_handler, *ig_events) + except Exception as exc: + # raise original exception + raise exc # vmax setter has raised. The lines above below are probably more relevant! + finally: + # set_value has finished executing, now allow future executions + self._block_reentrance = False - self._images = images + def _selector_event_handler(self, ev: GraphicFeatureEvent): + """when the selector's selctor has changed, it will update the vmin, vmax, or both""" + selection = ev.info["value"] + index_min = int(selection[0]) + vmin = self._bin_centers_flanked[index_min] - ig_events = _get_image_graphic_events(self._images[0]) + index_max = int(selection[1]) + vmax = self._bin_centers_flanked[index_max] - for ig in self.images: - ig.add_event_handler(self._image_cmap_handler, *ig_events) + match ev.info["change"]: + case "min": + self.vmin = vmin + case "max": + self.vmax = vmax + case _: + self.vmin, self.vmax = vmin, vmax def _open_cmap_picker(self, ev): + """open imgui cmap picker""" # check if right click if ev.button != 2: return @@ -433,7 +421,11 @@ def _open_cmap_picker(self, ev): self._plot_area.get_figure().open_popup("colormap-picker", pos, lut_tool=self) def _fpl_prepare_del(self): - self._linear_region_selector._fpl_prepare_del() - self._histogram_line._fpl_prepare_del() - del self._histogram_line - del self._linear_region_selector + """cleanup, need to disconnect events and remove image references for proper garbage collection""" + self._disconnect_images() + self._images.clear() + + for i in range(len(self._children)): + g = self._children.pop(0) + g._fpl_prepare_del() + del g diff --git a/fastplotlib/ui/_base.py b/fastplotlib/ui/_base.py index 3e763e08c..9767cf76f 100644 --- a/fastplotlib/ui/_base.py +++ b/fastplotlib/ui/_base.py @@ -123,8 +123,9 @@ def size(self) -> int | None: @size.setter def size(self, value): if not isinstance(value, int): - raise TypeError + raise TypeError(f"{self.__class__.__name__}.size must be an ") self._size = value + self._set_rect() @property def location(self) -> str: @@ -153,6 +154,7 @@ def height(self) -> int: def _set_rect(self, *args): self._x, self._y, self._width, self._height = self.get_rect() + self._figure._fpl_reset_layout() def get_rect(self) -> tuple[int, int, int, int]: """ diff --git a/fastplotlib/ui/right_click_menus/_colormap_picker.py b/fastplotlib/ui/right_click_menus/_colormap_picker.py index a80e5b2aa..9df26dcdc 100644 --- a/fastplotlib/ui/right_click_menus/_colormap_picker.py +++ b/fastplotlib/ui/right_click_menus/_colormap_picker.py @@ -154,7 +154,8 @@ def update(self): self._texture_height = (imgui.get_font_size()) - 2 if imgui.menu_item("Reset vmin-vmax", "", False)[0]: - self._lut_tool.images[0].reset_vmin_vmax() + for image in self._lut_tool.images: + image.reset_vmin_vmax() # add all the cmap options for cmap_type in COLORMAP_NAMES.keys(): diff --git a/fastplotlib/utils/_protocols.py b/fastplotlib/utils/_protocols.py index c168ecfa4..7ae63ed67 100644 --- a/fastplotlib/utils/_protocols.py +++ b/fastplotlib/utils/_protocols.py @@ -1,6 +1,9 @@ from typing import Protocol, runtime_checkable +ARRAY_LIKE_ATTRS = ["shape", "ndim", "__getitem__"] + + @runtime_checkable class ArrayProtocol(Protocol): @property diff --git a/fastplotlib/widgets/image_widget/__init__.py b/fastplotlib/widgets/image_widget/__init__.py index 70a1aa8ae..dc5daea55 100644 --- a/fastplotlib/widgets/image_widget/__init__.py +++ b/fastplotlib/widgets/image_widget/__init__.py @@ -2,6 +2,7 @@ if IMGUI: from ._widget import ImageWidget + from ._processor import NDImageProcessor else: diff --git a/fastplotlib/widgets/image_widget/_nd_iw_backup.py b/fastplotlib/widgets/image_widget/_nd_iw_backup.py new file mode 100644 index 000000000..7db265c0c --- /dev/null +++ b/fastplotlib/widgets/image_widget/_nd_iw_backup.py @@ -0,0 +1,1007 @@ +from typing import Callable, Sequence, Literal +from warnings import warn + +import numpy as np + +from rendercanvas import BaseRenderCanvas + +from ...layouts import ImguiFigure as Figure +from ...graphics import ImageGraphic, ImageVolumeGraphic +from ...utils import calculate_figure_shape, quick_min_max, ArrayProtocol +from ...tools import HistogramLUTTool +from ._sliders import ImageWidgetSliders +from ._processor import NDImageProcessor, WindowFuncCallable +from ._properties import ImageWidgetProperty, Indices + + +IMGUI_SLIDER_HEIGHT = 49 + + +class ImageWidget: + def __init__( + self, + data: ArrayProtocol | Sequence[ArrayProtocol | None] | None, + processors: NDImageProcessor | Sequence[NDImageProcessor] = NDImageProcessor, + n_display_dims: Literal[2, 3] | Sequence[Literal[2, 3]] = 2, + slider_dim_names: Sequence[str] | None = None, # dim names left -> right + rgb: bool | Sequence[bool] = False, + cmap: str | Sequence[str] = "plasma", + window_funcs: ( + tuple[WindowFuncCallable | None, ...] + | WindowFuncCallable + | None + | Sequence[ + tuple[WindowFuncCallable | None, ...] | WindowFuncCallable | None + ] + ) = None, + window_sizes: ( + tuple[int | None, ...] | Sequence[tuple[int | None, ...] | None] + ) = None, + window_order: tuple[int, ...] | Sequence[tuple[int, ...] | None] = None, + spatial_func: ( + Callable[[ArrayProtocol], ArrayProtocol] + | Sequence[Callable[[ArrayProtocol], ArrayProtocol]] + | None + ) = None, + sliders_dim_order: Literal["right", "left"] = "right", + figure_shape: tuple[int, int] = None, + names: Sequence[str] = None, + figure_kwargs: dict = None, + histogram_widget: bool = True, + histogram_init_quantile: int = (0, 100), + graphic_kwargs: dict | Sequence[dict] = None, + ): + """ + This widget facilitates high-level navigation through image stacks, which are arrays containing one or more + images. It includes sliders for key dimensions such as "t" (time) and "z", enabling users to smoothly navigate + through one or multiple image stacks simultaneously. + + Allowed dimensions orders for each image stack: Note that each has a an optional (c) channel which refers to + RGB(A) a channel. So this channel should be either 3 or 4. + + Parameters + ---------- + data: ArrayProtocol | Sequence[ArrayProtocol | None] | None + array-like or a list of array-like, each array must have a minimum of 2 dimensions + + processors: NDImageProcessor | Sequence[NDImageProcessor], default NDImageProcessor + The image processors used for each n-dimensional data array + + n_display_dims: Literal[2, 3] | Sequence[Literal[2, 3]], default 2 + number of display dimensions + + slider_dim_names: Sequence[str], optional + optional list/tuple of names for each slider dim + + rgb: bool | Sequence[bool], default + whether or not each data array represents RGB(A) images + + figure_shape: Optional[Tuple[int, int]] + manually provide the shape for the Figure, otherwise the number of rows and columns is estimated + + figure_kwargs: dict, optional + passed to ``Figure`` + + names: Optional[str] + gives names to the subplots + + histogram_widget: bool, default False + make histogram LUT widget for each subplot + + rgb: bool | list[bool], default None + bool or list of bool for each input data array in the ImageWidget, indicating whether the corresponding + data arrays are grayscale or RGB(A). + + graphic_kwargs: Any + passed to each ImageGraphic in the ImageWidget figure subplots + + """ + + if figure_kwargs is None: + figure_kwargs = dict() + + if isinstance(data, ArrayProtocol) or (data is None): + data = [data] + + elif isinstance(data, (list, tuple)): + # verify that it's a list of np.ndarray + if not all([isinstance(d, ArrayProtocol) or d is None for d in data]): + raise TypeError( + f"`data` must be an array-like type or a list/tuple of array-like or None. " + f"You have passed the following type {type(data)}" + ) + + else: + raise TypeError( + f"`data` must be an array-like type or a list/tuple of array-like or None. " + f"You have passed the following type {type(data)}" + ) + + if issubclass(processors, NDImageProcessor): + processors = [processors] * len(data) + + elif isinstance(processors, (tuple, list)): + if not all([issubclass(p, NDImageProcessor) for p in processors]): + raise TypeError( + f"`processors` must be a `NDImageProcess` class, a subclass of `NDImageProcessor`, or a " + f"list/tuple of `NDImageProcess` subclasses. You have passed: {processors}" + ) + + else: + raise TypeError( + f"`processors` must be a `NDImageProcess` class, a subclass of `NDImageProcessor`, or a " + f"list/tuple of `NDImageProcess` subclasses. You have passed: {processors}" + ) + + # subplot layout + if figure_shape is None: + if "shape" in figure_kwargs: + figure_shape = figure_kwargs["shape"] + else: + figure_shape = calculate_figure_shape(len(data)) + + # Regardless of how figure_shape is computed, below code + # verifies that figure shape is large enough for the number of image arrays passed + if figure_shape[0] * figure_shape[1] < len(data): + original_shape = (figure_shape[0], figure_shape[1]) + figure_shape = calculate_figure_shape(len(data)) + warn( + f"Original `figure_shape` was: {original_shape} " + f" but data length is {len(data)}" + f" Resetting figure shape to: {figure_shape}" + ) + + elif isinstance(rgb, bool): + rgb = [rgb] * len(data) + + if not all([isinstance(v, bool) for v in rgb]): + raise TypeError( + f"`rgb` parameter must be a bool or a Sequence of bool, you have passed: {rgb}" + ) + + if not len(rgb) == len(data): + raise ValueError( + f"len(rgb) != len(data), {len(rgb)} != {len(data)}. These must be equal" + ) + + if names is not None: + if not all([isinstance(n, str) for n in names]): + raise TypeError("optional argument `names` must be a Sequence of str") + + if len(names) != len(data): + raise ValueError( + "number of `names` for subplots must be same as the number of data arrays" + ) + + # verify window funcs + if window_funcs is None: + win_funcs = [None] * len(data) + + elif callable(window_funcs) or all( + [callable(f) or f is None for f in window_funcs] + ): + # across all data arrays + # one window function defined for all dims, or window functions defined per-dim + win_funcs = [window_funcs] * len(data) + + # if the above two clauses didn't trigger, then window_funcs defined per-dim, per data array + elif len(window_funcs) != len(data): + raise IndexError + else: + win_funcs = window_funcs + + # verify window sizes + if window_sizes is None: + win_sizes = [window_sizes] * len(data) + + elif isinstance(window_sizes, int): + win_sizes = [window_sizes] * len(data) + + elif all([isinstance(size, int) or size is None for size in window_sizes]): + # window sizes defined per-dim across all data arrays + win_sizes = [window_sizes] * len(data) + + elif len(window_sizes) != len(data): + # window sizes defined per-dim, per data array + raise IndexError + else: + win_sizes = window_sizes + + # verify window orders + if window_order is None: + win_order = [None] * len(data) + + elif all([isinstance(o, int) for o in order]): + # window order defined per-dim across all data arrays + win_order = [window_order] * len(data) + + elif len(window_order) != len(data): + raise IndexError + + else: + win_order = window_order + + # verify spatial_func + if spatial_func is None: + spatial_func = [None] * len(data) + + elif callable(spatial_func): + # same spatial_func for all data arrays + spatial_func = [spatial_func] * len(data) + + elif len(spatial_func) != len(data): + raise IndexError + + else: + spatial_func = spatial_func + + # verify number of display dims + if isinstance(n_display_dims, (int, np.integer)): + n_display_dims = [n_display_dims] * len(data) + + elif isinstance(n_display_dims, (tuple, list)): + if not all([isinstance(n, (int, np.integer)) for n in n_display_dims]): + raise TypeError + + if len(n_display_dims) != len(data): + raise IndexError + else: + raise TypeError + + n_display_dims = tuple(n_display_dims) + + if sliders_dim_order not in ("right",): + raise ValueError( + f"Only 'right' slider dims order is currently supported, you passed: {sliders_dim_order}" + ) + self._sliders_dim_order = sliders_dim_order + + self._slider_dim_names = None + self.slider_dim_names = slider_dim_names + + self._histogram_widget = histogram_widget + + # make NDImageArrays + self._image_processors: list[NDImageProcessor] = list() + for i in range(len(data)): + cls = processors[i] + image_processor = cls( + data=data[i], + rgb=rgb[i], + n_display_dims=n_display_dims[i], + window_funcs=win_funcs[i], + window_sizes=win_sizes[i], + window_order=win_order[i], + spatial_func=spatial_func[i], + compute_histogram=self._histogram_widget, + ) + + self._image_processors.append(image_processor) + + self._data = ImageWidgetProperty(self, "data") + self._rgb = ImageWidgetProperty(self, "rgb") + self._n_display_dims = ImageWidgetProperty(self, "n_display_dims") + self._window_funcs = ImageWidgetProperty(self, "window_funcs") + self._window_sizes = ImageWidgetProperty(self, "window_sizes") + self._window_order = ImageWidgetProperty(self, "window_order") + self._spatial_func = ImageWidgetProperty(self, "spatial_func") + + if len(set(n_display_dims)) > 1: + # assume user wants one controller for 2D images and another for 3D image volumes + n_subplots = np.prod(figure_shape) + controller_ids = [0] * n_subplots + controller_types = ["panzoom"] * n_subplots + + for i in range(len(data)): + if n_display_dims[i] == 2: + controller_ids[i] = 1 + else: + controller_ids[i] = 2 + controller_types[i] = "orbit" + + # needs to be a list of list + controller_ids = [controller_ids] + + else: + controller_ids = "sync" + controller_types = None + + figure_kwargs_default = { + "controller_ids": controller_ids, + "controller_types": controller_types, + "names": names, + } + + # update the default kwargs with any user-specified kwargs + # user specified kwargs will overwrite the defaults + figure_kwargs_default.update(figure_kwargs) + figure_kwargs_default["shape"] = figure_shape + + if graphic_kwargs is None: + graphic_kwargs = [dict()] * len(data) + + elif isinstance(graphic_kwargs, dict): + graphic_kwargs = [graphic_kwargs] * len(data) + + elif len(graphic_kwargs) != len(data): + raise IndexError + + if cmap is None: + cmap = [None] * len(data) + + elif isinstance(cmap, str): + cmap = [cmap] * len(data) + + elif not all([isinstance(c, str) for c in cmap]): + raise TypeError(f"`cmap` must be a or a list/tuple of ") + + self._figure: Figure = Figure(**figure_kwargs_default) + + self._indices = Indices(list(0 for i in range(self.n_sliders)), self) + + for i, subplot in zip(range(len(self._image_processors)), self.figure): + image_data = self._get_image( + self._image_processors[i], tuple(self._indices) + ) + + if image_data is None: + # this subplot/data array is blank, skip + continue + + # next 20 lines are just vmin, vmax parsing + vmin_specified, vmax_specified = None, None + if "vmin" in graphic_kwargs[i].keys(): + vmin_specified = graphic_kwargs[i].pop("vmin") + if "vmax" in graphic_kwargs[i].keys(): + vmax_specified = graphic_kwargs[i].pop("vmax") + + if (vmin_specified is None) or (vmax_specified is None): + # if either vmin or vmax are not specified, calculate an estimate by subsampling + vmin_estimate, vmax_estimate = quick_min_max( + self._image_processors[i].data + ) + + # decide vmin, vmax passed to ImageGraphic constructor based on whether it's user specified or now + if vmin_specified is None: + # user hasn't specified vmin, use estimated value + vmin = vmin_estimate + else: + # user has provided a specific value, use that + vmin = vmin_specified + + if vmax_specified is None: + vmax = vmax_estimate + else: + vmax = vmax_specified + else: + # both vmin and vmax are specified + vmin, vmax = vmin_specified, vmax_specified + + graphic_kwargs[i]["cmap"] = cmap[i] + + if self._image_processors[i].n_display_dims == 2: + # create an Image + graphic = ImageGraphic( + data=image_data, + name="image_widget_managed", + vmin=vmin, + vmax=vmax, + **graphic_kwargs[i], + ) + elif self._image_processors[i].n_display_dims == 3: + # create an ImageVolume + graphic = ImageVolumeGraphic( + data=image_data, + name="image_widget_managed", + vmin=vmin, + vmax=vmax, + **graphic_kwargs[i], + ) + subplot.camera.fov = 50 + + subplot.add_graphic(graphic) + + self._reset_histogram(subplot, self._image_processors[i]) + + self._sliders_ui = ImageWidgetSliders( + figure=self.figure, + size=57 + (IMGUI_SLIDER_HEIGHT * self.n_sliders), + location="bottom", + title="ImageWidget Controls", + image_widget=self, + ) + + self.figure.add_gui(self._sliders_ui) + + self._indices_changed_handlers = set() + + self._reentrant_block = False + + @property + def data(self) -> ImageWidgetProperty[ArrayProtocol | None]: + """get or set the nd-image data arrays""" + return self._data + + @data.setter + def data(self, new_data: Sequence[ArrayProtocol | None]): + if isinstance(new_data, ArrayProtocol) or new_data is None: + new_data = [new_data] * len(self._image_processors) + + if len(new_data) != len(self._image_processors): + raise IndexError + + # if the data array hasn't been changed + # graphics will not be reset for this data index + skip_indices = list() + + for i, (new_data, image_processor) in enumerate( + zip(new_data, self._image_processors) + ): + if new_data is image_processor.data: + skip_indices.append(i) + continue + + image_processor.data = new_data + + self._reset(skip_indices) + + @property + def rgb(self) -> ImageWidgetProperty[bool]: + """get or set the rgb toggle for each data array""" + return self._rgb + + @rgb.setter + def rgb(self, rgb: Sequence[bool]): + if isinstance(rgb, bool): + rgb = [rgb] * len(self._image_processors) + + if len(rgb) != len(self._image_processors): + raise IndexError + + # if the rgb option hasn't been changed + # graphics will not be reset for this data index + skip_indices = list() + + for i, (new, image_processor) in enumerate(zip(rgb, self._image_processors)): + if image_processor.rgb == new: + skip_indices.append(i) + continue + + image_processor.rgb = new + + self._reset(skip_indices) + + @property + def n_display_dims(self) -> ImageWidgetProperty[Literal[2, 3]]: + """Get or set the number of display dimensions for each data array, 2 is a 2D image, 3 is a 3D volume image""" + return self._n_display_dims + + @n_display_dims.setter + def n_display_dims(self, new_ndd: Sequence[Literal[2, 3]] | Literal[2, 3]): + if isinstance(new_ndd, (int, np.integer)): + if new_ndd == 2 or new_ndd == 3: + new_ndd = [new_ndd] * len(self._image_processors) + else: + raise ValueError + + if len(new_ndd) != len(self._image_processors): + raise IndexError + + if not all([(n == 2) or (n == 3) for n in new_ndd]): + raise ValueError + + # if the n_display_dims hasn't been changed for this data array + # graphics will not be reset for this data array index + skip_indices = list() + + # first update image arrays + for i, (image_processor, new) in enumerate( + zip(self._image_processors, new_ndd) + ): + if new > image_processor.max_n_display_dims: + raise IndexError( + f"number of display dims exceeds maximum number of possible " + f"display dimensions: {image_processor.max_n_display_dims}, for array at index: " + f"{i} with shape: {image_processor.shape}, and rgb set to: {image_processor.rgb}" + ) + + if image_processor.n_display_dims == new: + skip_indices.append(i) + else: + image_processor.n_display_dims = new + + self._reset(skip_indices) + + @property + def window_funcs(self) -> ImageWidgetProperty[tuple[WindowFuncCallable | None] | None]: + """get or set the window functions""" + return self._window_funcs + + @window_funcs.setter + def window_funcs(self, new_funcs: Sequence[WindowFuncCallable | None] | None): + if callable(new_funcs) or new_funcs is None: + new_funcs = [new_funcs] * len(self._image_processors) + + if len(new_funcs) != len(self._image_processors): + raise IndexError + + self._set_image_processor_funcs("window_funcs", new_funcs) + + @property + def window_sizes(self) -> ImageWidgetProperty[tuple[int | None, ...] | None]: + """get or set the window sizes""" + return self._window_sizes + + @window_sizes.setter + def window_sizes( + self, new_sizes: Sequence[tuple[int | None, ...] | int | None] | int | None + ): + if isinstance(new_sizes, int) or new_sizes is None: + # same window for all data arrays + new_sizes = [new_sizes] * len(self._image_processors) + + if len(new_sizes) != len(self._image_processors): + raise IndexError + + self._set_image_processor_funcs("window_sizes", new_sizes) + + @property + def window_order(self) -> ImageWidgetProperty[tuple[int, ...] | None]: + """get or set order in which window functions are applied over dimensions""" + return self._window_order + + @window_order.setter + def window_order(self, new_order: Sequence[tuple[int, ...]]): + if new_order is None: + new_order = [new_order] * len(self._image_processors) + + if all([isinstance(order, (int, np.integer))] for order in new_order): + # same order specified across all data arrays + new_order = [new_order] * len(self._image_processors) + + if len(new_order) != len(self._image_processors): + raise IndexError + + self._set_image_processor_funcs("window_order", new_order) + + @property + def spatial_func(self) -> ImageWidgetProperty[Callable | None]: + """Get or set a spatial_func that operates on the spatial dimensions of the 2D or 3D image""" + return self._spatial_func + + @spatial_func.setter + def spatial_func(self, funcs: Callable | Sequence[Callable] | None): + if callable(funcs) or funcs is None: + funcs = [funcs] * len(self._image_processors) + + if len(funcs) != len(self._image_processors): + raise IndexError + + self._set_image_processor_funcs("spatial_func", funcs) + + def _set_image_processor_funcs(self, attr, new_values): + """sets window_funcs, window_sizes, window_order, or spatial_func and updates displayed data and histograms""" + for new, image_processor, subplot in zip( + new_values, self._image_processors, self.figure + ): + if getattr(image_processor, attr) == new: + continue + + setattr(image_processor, attr, new) + + # window functions and spatial functions will only change the histogram + # they do not change the collections of dimensions, so we don't need to call _reset_dimensions + # they also do not change the image graphic, so we do not need to call _reset_image_graphics + self._reset_histogram(subplot, image_processor) + + # update the displayed image data in the graphics + self.indices = self.indices + + @property + def indices(self) -> ImageWidgetProperty[int]: + """ + Get or set the current indices. + + Returns + ------- + indices: ImageWidgetProperty[int] + integer index for each slider dimension + + """ + return self._indices + + @indices.setter + def indices(self, new_indices: Sequence[int]): + if self._reentrant_block: + return + + try: + self._reentrant_block = True # block re-execution until new_indices has *fully* completed execution + + if len(new_indices) != self.n_sliders: + raise IndexError( + f"len(new_indices) != ImageWidget.n_sliders, {len(new_indices)} != {self.n_sliders}. " + f"The length of the new_indices must be the same as the number of sliders" + ) + + if any([i < 0 for i in new_indices]): + raise IndexError( + f"only positive index values are supported, you have passed: {new_indices}" + ) + + for image_processor, graphic in zip(self._image_processors, self.graphics): + new_data = self._get_image(image_processor, indices=new_indices) + if new_data is None: + continue + + graphic.data = new_data + + self._indices._fpl_set(new_indices) + + # call any event handlers + for handler in self._indices_changed_handlers: + handler(tuple(self.indices)) + + except Exception as exc: + # raise original exception + raise exc # indices setter has raised. The lines above below are probably more relevant! + finally: + # set_value has finished executing, now allow future executions + self._reentrant_block = False + + @property + def histogram_widget(self) -> bool: + """show or hide the histograms""" + return self._histogram_widget + + @histogram_widget.setter + def histogram_widget(self, show_histogram: bool): + if not isinstance(show_histogram, bool): + raise TypeError( + f"`histogram_widget` can be set with a bool, you have passed: {show_histogram}" + ) + + for subplot, image_processor in zip(self.figure, self._image_processors): + image_processor.compute_histogram = show_histogram + self._reset_histogram(subplot, image_processor) + + @property + def n_sliders(self) -> int: + """number of sliders""" + return max([a.n_slider_dims for a in self._image_processors]) + + @property + def bounds(self) -> tuple[int, ...]: + """The max bound across all dimensions across all data arrays""" + # initialize with 0 + bounds = [0] * self.n_sliders + + # TODO: implement left -> right slider dims ordering, right now it's only right -> left + # in reverse because dims go left <- right + for i, dim in enumerate(range(-1, -self.n_sliders - 1, -1)): + # across each dim + for array in self._image_processors: + if i > array.n_slider_dims - 1: + continue + # across each data array + # dims go left <- right + bounds[dim] = max(array.slider_dims_shape[dim], bounds[dim]) + + return bounds + + @property + def slider_dim_names(self) -> tuple[str, ...]: + return self._slider_dim_names + + @slider_dim_names.setter + def slider_dim_names(self, names: Sequence[str]): + if names is None: + self._slider_dim_names = None + return + + if not all([isinstance(n, str) for n in names]): + raise TypeError(f"`slider_dim_names` must be set with a list/tuple of , you passed: {names}") + + if len(set(names)) != len(names): + raise ValueError( + f"`slider_dim_names` must be unique, you passed: {names}" + ) + + self._slider_dim_names = tuple(names) + + def _get_image( + self, image_processor: NDImageProcessor, indices: Sequence[int] + ) -> ArrayProtocol: + """Get a processed 2d or 3d image from the NDImage at the given indices""" + n = image_processor.n_slider_dims + + if self._sliders_dim_order == "right": + return image_processor.get(indices[-n:]) + + elif self._sliders_dim_order == "left": + # TODO: left -> right is not fully implemented yet in ImageWidget + return image_processor.get(indices[:n]) + + def _reset_dimensions(self): + """reset the dimensions w.r.t. current collection of NDImageProcessors""" + # TODO: implement left -> right slider dims ordering, right now it's only right -> left + # add or remove dims from indices + # trim any excess dimensions + while len(self._indices) > self.n_sliders: + # remove outer most dims first + self._indices.pop_dim() + self._sliders_ui.pop_dim() + + # add any new dimensions that aren't present + while len(self.indices) < self.n_sliders: + # insert right -> left + self._indices.push_dim() + self._sliders_ui.push_dim() + + self._sliders_ui.size = 57 + (IMGUI_SLIDER_HEIGHT * self.n_sliders) + + def _reset_image_graphics(self, subplot, image_processor): + """delete and create a new image graphic if necessary""" + new_image = self._get_image(image_processor, indices=tuple(self.indices)) + if new_image is None: + if "image_widget_managed" in subplot: + # delete graphic from this subplot if present + subplot.delete_graphic(subplot["image_widget_managed"]) + # skip this subplot + return + + # check if a graphic exists + if "image_widget_managed" in subplot: + # create a new graphic only if the Texture buffer shape doesn't match + if subplot["image_widget_managed"].data.value.shape == new_image.shape: + return + + # keep cmap + cmap = subplot["image_widget_managed"].cmap + if cmap is None: + # ex: going from rgb -> grayscale + cmap = "plasma" + # delete graphic since it will be replaced + subplot.delete_graphic(subplot["image_widget_managed"]) + else: + # default cmap + cmap = "plasma" + + if image_processor.n_display_dims == 2: + g = subplot.add_image( + data=new_image, cmap=cmap, name="image_widget_managed" + ) + + # set camera orthogonal to the xy plane, flip y axis + subplot.camera.set_state( + { + "position": [0, 0, -1], + "rotation": [0, 0, 0, 1], + "scale": [1, -1, 1], + "reference_up": [0, 1, 0], + "fov": 0, + "depth_range": None, + } + ) + + subplot.controller = "panzoom" + subplot.axes.intersection = None + subplot.auto_scale() + + elif image_processor.n_display_dims == 3: + g = subplot.add_image_volume( + data=new_image, cmap=cmap, name="image_widget_managed" + ) + subplot.camera.fov = 50 + subplot.controller = "orbit" + + # make sure all 3D dimension camera scales are positive + # MIP rendering doesn't work with negative camera scales + for dim in ["x", "y", "z"]: + if getattr(subplot.camera.local, f"scale_{dim}") < 0: + setattr(subplot.camera.local, f"scale_{dim}", 1) + + subplot.auto_scale() + + def _reset_histogram(self, subplot, image_processor): + """reset the histogram""" + if not self._histogram_widget: + subplot.docks["right"].size = 0 + return + + if image_processor.histogram is None: + # no histogram available for this processor + # either there is no data array in this subplot, + # or a histogram routine does not exist for this processor + subplot.docks["right"].size = 0 + return + + if "image_widget_managed" not in subplot: + # no image in this subplot + subplot.docks["right"].size = 0 + return + + image = subplot["image_widget_managed"] + + if "histogram_lut" in subplot.docks["right"]: + hlut: HistogramLUTTool = subplot.docks["right"]["histogram_lut"] + hlut.histogram = image_processor.histogram + hlut.images = image + if subplot.docks["right"].size < 1: + subplot.docks["right"].size = 80 + + else: + # need to make one + hlut = HistogramLUTTool( + histogram=image_processor.histogram, + images=image, + name="histogram_lut", + ) + + subplot.docks["right"].add_graphic(hlut) + subplot.docks["right"].size = 80 + + self.reset_vmin_vmax() + + def _reset(self, skip_data_indices: tuple[int, ...] = None): + if skip_data_indices is None: + skip_data_indices = tuple() + + # reset the slider indices according to the new collection of dimensions + self._reset_dimensions() + # update graphics where display dims have changed accordings to indices + for i, (subplot, image_processor) in enumerate( + zip(self.figure, self._image_processors) + ): + if i in skip_data_indices: + continue + + self._reset_image_graphics(subplot, image_processor) + self._reset_histogram(subplot, image_processor) + + # force an update + self.indices = self.indices + + @property + def figure(self) -> Figure: + """ + ``Figure`` used by `ImageWidget`. + """ + return self._figure + + @property + def graphics(self) -> list[ImageGraphic]: + """List of ``ImageWidget`` managed graphics.""" + iw_managed = list() + for subplot in self.figure: + if "image_widget_managed" in subplot: + iw_managed.append(subplot["image_widget_managed"]) + else: + iw_managed.append(None) + return tuple(iw_managed) + + @property + def cmap(self) -> tuple[str | None, ...]: + """get the cmaps, or set the cmap across all images""" + return tuple(g.cmap for g in self.graphics) + + @cmap.setter + def cmap(self, name: str): + for g in self.graphics: + if g is None: + # no data at this index + continue + + if g.cmap is None: + # if rgb + continue + + g.cmap = name + + def add_event_handler(self, handler: callable, event: str = "indices"): + """ + Register an event handler. + + Currently the only event that ImageWidget supports is "indices". This event is + emitted whenever the indices of the ImageWidget changes. + + Parameters + ---------- + handler: callable + callback function, must take a tuple of int as the only argument. This tuple will be the `indices` + + event: str, "indices" + the only supported event is "indices" + + Example + ------- + + .. code-block:: py + + def my_handler(indices): + print(indices) + # example prints: (100, 15) if the data has 2 slider dimensions with sliders at positions 100, 15 + + # create an image widget + iw = ImageWidget(...) + + # add event handler + iw.add_event_handler(my_handler) + + """ + if event != "indices": + raise ValueError("`indices` is the only event supported by `ImageWidget`") + + self._indices_changed_handlers.add(handler) + + def remove_event_handler(self, handler: callable): + """Remove a registered event handler""" + self._indices_changed_handlers.remove(handler) + + def clear_event_handlers(self): + """Clear all registered event handlers""" + self._indices_changed_handlers.clear() + + def reset_vmin_vmax(self): + """ + Reset the vmin and vmax w.r.t. the full data + """ + for image_processor, subplot in zip(self._image_processors, self.figure): + if "histogram_lut" not in subplot.docks["right"]: + continue + + if image_processor.histogram is None: + continue + + hlut = subplot.docks["right"]["histogram_lut"] + hlut.histogram = image_processor.histogram + + edges = image_processor.histogram[1] + + hlut.vmin, hlut.vmax = edges[0], edges[-1] + + def reset_vmin_vmax_frame(self): + """ + Resets the vmin vmax and HistogramLUT widgets w.r.t. the current data shown in the + ImageGraphic instead of the data in the full data array. For example, if a post-processing + function is used, the range of values in the ImageGraphic can be very different from the + range of values in the full data array. + """ + + for subplot, image_processor in zip(self.figure, self._image_processors): + if "histogram_lut" not in subplot.docks["right"]: + continue + + if image_processor.histogram is None: + continue + + hlut = subplot.docks["right"]["histogram_lut"] + # set the data using the current image graphic data + image = subplot["image_widget_managed"] + freqs, edges = np.histogram(image.data.value, bins=100) + hlut.histogram = (freqs, edges) + hlut.vmin, hlut.vmax = edges[0], edges[-1] + + def show(self, **kwargs): + """ + Show the widget. + + Parameters + ---------- + + kwargs: Any + passed to `Figure.show()`t + + Returns + ------- + BaseRenderCanvas + In Qt or GLFW, the canvas window containing the Figure will be shown. + In jupyter, it will display the plot in the output cell or sidecar. + + """ + + return self.figure.show(**kwargs) + + def close(self): + """Close Widget""" + self.figure.close() diff --git a/fastplotlib/widgets/image_widget/_processor.py b/fastplotlib/widgets/image_widget/_processor.py new file mode 100644 index 000000000..0dce84a5e --- /dev/null +++ b/fastplotlib/widgets/image_widget/_processor.py @@ -0,0 +1,519 @@ +import inspect +from typing import Literal, Callable +from warnings import warn + +import numpy as np +from numpy.typing import ArrayLike + +from ...utils import subsample_array, ArrayProtocol, ARRAY_LIKE_ATTRS + + +# must take arguments: array-like, `axis`: int, `keepdims`: bool +WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] + + +class NDImageProcessor: + def __init__( + self, + data: ArrayLike | None, + n_display_dims: Literal[2, 3] = 2, + rgb: bool = False, + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, + window_sizes: tuple[int | None, ...] | int = None, + window_order: tuple[int, ...] = None, + spatial_func: Callable[[ArrayLike], ArrayLike] = None, + compute_histogram: bool = True, + ): + """ + An ND image that supports computing window functions, and functions over spatial dimensions. + + Parameters + ---------- + data: ArrayLike + array-like data, must have 2 or more dimensions + + n_display_dims: int, 2 or 3, default 2 + number of display dimensions + + rgb: bool, default False + whether the image data is RGB(A) or not + + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable, optional + A function or a ``tuple`` of functions that are applied to a rolling window of the data. + + You can provide unique window functions for each dimension. If you want to apply a window function + only to a subset of the dimensions, put ``None`` to indicate no window function for a given dimension. + + A "window function" must take ``axis`` argument, which is an ``int`` that specifies the axis along which + the window function is applied. It must also take a ``keepdims`` argument which is a ``bool``. The window + function **must** return an array that has the same number of dimensions as the original ``data`` array, + therefore the size of the dimension along which the window was applied will reduce to ``1``. + + The output array-like type from a window function **must** support a ``.squeeze()`` method, but the + function itself should NOT squeeze the output array. + + window_sizes: tuple[int | None, ...], optional + ``tuple`` of ``int`` that specifies the window size for each dimension. + + window_order: tuple[int, ...] | None, optional + order in which to apply the window functions, by default just applies it from the left-most dim to the + right-most slider dim. + + spatial_func: Callable[[ArrayLike], ArrayLike] | None, optional + A function that is applied on the _spatial_ dimensions of the data array, i.e. the last 2 or 3 dimensions. + This function is applied after the window functions (if present). + + compute_histogram: bool, default True + Compute a histogram of the data, auto re-computes if window function propties or spatial_func changes. + Disable if slow. + + """ + # set as False until data, window funcs stuff and spatial func is all set + self._compute_histogram = False + + self.data = data + self.n_display_dims = n_display_dims + self.rgb = rgb + + self.window_funcs = window_funcs + self.window_sizes = window_sizes + self.window_order = window_order + + self._spatial_func = spatial_func + + self._compute_histogram = compute_histogram + self._recompute_histogram() + + @property + def data(self) -> ArrayLike | None: + """get or set the data array""" + return self._data + + @data.setter + def data(self, data: ArrayLike): + # check that all array-like attributes are present + if data is None: + self._data = None + return + + if not isinstance(data, ArrayProtocol): + raise TypeError( + f"`data` arrays must have all of the following attributes to be sufficiently array-like:\n" + f"{ARRAY_LIKE_ATTRS}, or they must be `None`" + ) + + if data.ndim < 2: + raise IndexError( + f"Image data must have a minimum of 2 dimensions, you have passed an array of shape: {data.shape}" + ) + + self._data = data + self._recompute_histogram() + + @property + def ndim(self) -> int: + if self.data is None: + return 0 + + return self.data.ndim + + @property + def shape(self) -> tuple[int, ...]: + if self._data is None: + return tuple() + + return self.data.shape + + @property + def rgb(self) -> bool: + """whether or not the data is rgb(a)""" + return self._rgb + + @rgb.setter + def rgb(self, rgb: bool): + if not isinstance(rgb, bool): + raise TypeError + + if rgb and self.ndim < 3: + raise IndexError( + f"require 3 or more dims for RGB, you have: {self.ndim} dims" + ) + + self._rgb = rgb + + @property + def n_slider_dims(self) -> int: + """number of slider dimensions""" + if self._data is None: + return 0 + + return self.ndim - self.n_display_dims - int(self.rgb) + + @property + def slider_dims(self) -> tuple[int, ...] | None: + """tuple indicating the slider dimension indices""" + if self.n_slider_dims == 0: + return None + + return tuple(range(self.n_slider_dims)) + + @property + def slider_dims_shape(self) -> tuple[int, ...] | None: + if self.n_slider_dims == 0: + return None + + return tuple(self.shape[i] for i in self.slider_dims) + + @property + def n_display_dims(self) -> Literal[2, 3]: + """get or set the number of display dimensions, `2` for 2D image and `3` for volume images""" + return self._n_display_dims + + # TODO: make n_display_dims settable, requires thinking about inserting and poping indices in ImageWidget + @n_display_dims.setter + def n_display_dims(self, n: Literal[2, 3]): + if not (n == 2 or n == 3): + raise ValueError( + f"`n_display_dims` must be an with a value of 2 or 3, you have passed: {n}" + ) + self._n_display_dims = n + self._recompute_histogram() + + @property + def max_n_display_dims(self) -> int: + """maximum number of possible display dims""" + # min 2, max 3, accounts for if data is None and ndim is 0 + return max(2, min(3, self.ndim - int(self.rgb))) + + @property + def display_dims(self) -> tuple[int, int] | tuple[int, int, int]: + """tuple indicating the display dimension indices""" + return tuple(range(self.data.ndim))[self.n_slider_dims :] + + @property + def window_funcs( + self, + ) -> tuple[WindowFuncCallable | None, ...] | None: + """get or set window functions, see docstring for details""" + return self._window_funcs + + @window_funcs.setter + def window_funcs( + self, + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable | None, + ): + if window_funcs is None: + self._window_funcs = None + return + + if callable(window_funcs): + window_funcs = (window_funcs,) + + # if all are None + if all([f is None for f in window_funcs]): + self._window_funcs = None + return + + self._validate_window_func(window_funcs) + + self._window_funcs = tuple(window_funcs) + self._recompute_histogram() + + def _validate_window_func(self, funcs): + if isinstance(funcs, (tuple, list)): + for f in funcs: + if f is None: + pass + elif callable(f): + sig = inspect.signature(f) + + if "axis" not in sig.parameters or "keepdims" not in sig.parameters: + raise TypeError( + f"Each window function must take an `axis` and `keepdims` argument, " + f"you passed: {f} with the following function signature: {sig}" + ) + else: + raise TypeError( + f"`window_funcs` must be of type: tuple[Callable | None, ...], you have passed: {funcs}" + ) + + if not (len(funcs) == self.n_slider_dims or self.n_slider_dims == 0): + raise IndexError( + f"number of `window_funcs` must be the same as the number of slider dims: {self.n_slider_dims}, " + f"and you passed {len(funcs)} `window_funcs`: {funcs}" + ) + + @property + def window_sizes(self) -> tuple[int | None, ...] | None: + """get or set window sizes used for the corresponding window functions, see docstring for details""" + return self._window_sizes + + @window_sizes.setter + def window_sizes(self, window_sizes: tuple[int | None, ...] | int | None): + if window_sizes is None: + self._window_sizes = None + return + + if isinstance(window_sizes, int): + window_sizes = (window_sizes,) + + # if all are None + if all([w is None for w in window_sizes]): + self._window_sizes = None + return + + if not all([isinstance(w, (int)) or w is None for w in window_sizes]): + raise TypeError( + f"`window_sizes` must be of type: tuple[int | None, ...] | int | None, you have passed: {window_sizes}" + ) + + if not (len(window_sizes) == self.n_slider_dims or self.n_slider_dims == 0): + raise IndexError( + f"number of `window_sizes` must be the same as the number of slider dims, " + f"i.e. `data.ndim` - n_display_dims, your data array has {self.ndim} dimensions " + f"and you passed {len(window_sizes)} `window_sizes`: {window_sizes}" + ) + + # make all window sizes are valid numbers + _window_sizes = list() + for i, w in enumerate(window_sizes): + if w is None: + _window_sizes.append(None) + continue + + if w < 0: + raise ValueError( + f"negative window size passed, all `window_sizes` must be positive " + f"integers or `None`, you passed: {_window_sizes}" + ) + + if w == 0 or w == 1: + # this is not a real window, set as None + w = None + + elif w % 2 == 0: + # odd window sizes makes most sense + warn( + f"provided even window size: {w} in dim: {i}, adding `1` to make it odd" + ) + w += 1 + + _window_sizes.append(w) + + self._window_sizes = tuple(_window_sizes) + self._recompute_histogram() + + @property + def window_order(self) -> tuple[int, ...] | None: + """get or set dimension order in which window functions are applied""" + return self._window_order + + @window_order.setter + def window_order(self, order: tuple[int] | None): + if order is None: + self._window_order = None + return + + if order is not None: + if not all([d <= self.n_slider_dims for d in order]): + raise IndexError( + f"all `window_order` entries must be <= n_slider_dims\n" + f"`n_slider_dims` is: {self.n_slider_dims}, you have passed `window_order`: {order}" + ) + + if not all([d >= 0 for d in order]): + raise IndexError( + f"all `window_order` entires must be >= 0, you have passed: {order}" + ) + + self._window_order = tuple(order) + self._recompute_histogram() + + @property + def spatial_func(self) -> Callable[[ArrayLike], ArrayLike] | None: + """get or set a spatial_func function, see docstring for details""" + return self._spatial_func + + @spatial_func.setter + def spatial_func(self, func: Callable[[ArrayLike], ArrayLike] | None): + if not (callable(func) or func is not None): + raise TypeError( + f"`spatial_func` must be a callable or `None`, you have passed: {func}" + ) + + self._spatial_func = func + self._recompute_histogram() + + @property + def compute_histogram(self) -> bool: + return self._compute_histogram + + @compute_histogram.setter + def compute_histogram(self, compute: bool): + if compute: + if self._compute_histogram is False: + # compute a histogram + self._recompute_histogram() + self._compute_histogram = True + else: + self._compute_histogram = False + self._histogram = None + + @property + def histogram(self) -> tuple[np.ndarray, np.ndarray] | None: + """ + an estimate of the histogram of the data, (histogram_values, bin_edges). + + returns `None` if `compute_histogram` is `False` + """ + return self._histogram + + def _apply_window_function(self, indices: tuple[int, ...]) -> ArrayLike: + """applies the window functions for each dimension specified""" + # window size for each dim + winds = self._window_sizes + # window function for each dim + funcs = self._window_funcs + + if winds is None or funcs is None: + # no window funcs or window sizes, just slice data and return + # clamp to max bounds + indexer = list() + for dim, i in enumerate(indices): + i = min(self.shape[dim] - 1, i) + indexer.append(i) + + return self.data[tuple(indexer)] + + # order in which window funcs are applied + order = self._window_order + + if order is not None: + # remove any entries in `window_order` where the specified dim + # has a window function or window size specified as `None` + # example: + # window_sizes = (3, 2) + # window_funcs = (np.mean, None) + # order = (0, 1) + # `1` is removed from the order since that window_func is `None` + order = tuple( + d for d in order if winds[d] is not None and funcs[d] is not None + ) + else: + # sequential order + order = list() + for d in range(self.n_slider_dims): + if winds[d] is not None and funcs[d] is not None: + order.append(d) + + # the final indexer which will be used on the data array + indexer = list() + + for dim_index, (i, w, f) in enumerate(zip(indices, winds, funcs)): + # clamp i within the max bounds + i = min(self.shape[dim_index] - 1, i) + + if (w is not None) and (f is not None): + # specify slice window if both window size and function for this dim are not None + hw = int((w - 1) / 2) # half window + + # start index cannot be less than 0 + start = max(0, i - hw) + + # stop index cannot exceed the bounds of this dimension + stop = min(self.shape[dim_index] - 1, i + hw) + + s = slice(start, stop, 1) + else: + s = slice(i, i + 1, 1) + + indexer.append(s) + + # apply indexer to slice data with the specified windows + data_sliced = self.data[tuple(indexer)] + + # finally apply the window functions in the specified order + for dim in order: + f = funcs[dim] + + data_sliced = f(data_sliced, axis=dim, keepdims=True) + + return data_sliced + + def get(self, indices: tuple[int, ...]) -> ArrayLike | None: + """ + Get the data at the given index, process data through the window functions. + + Note that we do not use __getitem__ here since the index is a tuple specifying a single integer + index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. + + Parameters + ---------- + indices: tuple[int, ...] + Get the processed data at this index. Must provide a value for each dimension. + Example: get((100, 5)) + + """ + if self.data is None: + return None + + if self.n_slider_dims != 0: + if len(indices) != self.n_slider_dims: + raise IndexError( + f"Must specify index for every slider dim, you have specified an index: {indices}\n" + f"But there are: {self.n_slider_dims} slider dims." + ) + # get output after processing through all window funcs + # squeeze to remove all dims of size 1 + window_output = self._apply_window_function(indices).squeeze() + else: + # data is a static image or volume + window_output = self.data + + # apply spatial_func + if self.spatial_func is not None: + final_output = self.spatial_func(window_output) + if final_output.ndim != (self.n_display_dims + int(self.rgb)): + raise IndexError( + f"Final output after of the `spatial_func` must match the number of display dims." + f"Output after `spatial_func` returned an array with {final_output.ndim} dims and " + f"of shape: {final_output.shape}, expected {self.n_display_dims} dims" + ) + else: + # check that output ndim after window functions matches display dims + final_output = window_output + if final_output.ndim != (self.n_display_dims + int(self.rgb)): + raise IndexError( + f"Final output after of the `window_funcs` must match the number of display dims." + f"Output after `window_funcs` returned an array with {window_output.ndim} dims and " + f"of shape: {window_output.shape}{' with rgb(a) channels' if self.rgb else ''}, " + f"expected {self.n_display_dims} dims" + ) + + return final_output + + def _recompute_histogram(self): + """ + + Returns + ------- + (histogram_values, bin_edges) + + """ + if not self._compute_histogram or self.data is None: + self._histogram = None + return + + if self.spatial_func is not None: + # don't subsample spatial dims if a spatial function is used + # spatial functions often operate on the spatial dims, ex: a gaussian kernel + # so their results require the full spatial resolution, the histogram of a + # spatially subsampled image will be very different + ignore_dims = self.display_dims + else: + ignore_dims = None + + sub = subsample_array(self.data, ignore_dims=ignore_dims) + sub_real = sub[~(np.isnan(sub) | np.isinf(sub))] + + self._histogram = np.histogram(sub_real, bins=100) diff --git a/fastplotlib/widgets/image_widget/_properties.py b/fastplotlib/widgets/image_widget/_properties.py new file mode 100644 index 000000000..060314439 --- /dev/null +++ b/fastplotlib/widgets/image_widget/_properties.py @@ -0,0 +1,139 @@ +from pprint import pformat +from typing import Iterable + +import numpy as np + +from ._processor import NDImageProcessor + + +class ImageWidgetProperty: + __class_getitem__ = classmethod(type(list[int])) + + def __init__( + self, + image_widget, + attribute: str, + ): + self._image_widget = image_widget + self._image_processors: list[NDImageProcessor] = image_widget._image_processors + self._attribute = attribute + + def _get_key(self, key: slice | int | np.integer | str) -> int | slice: + if not isinstance(key, (slice | int, np.integer, str)): + raise TypeError( + f"can index `{self._attribute}` only with a , , or a indicating the subplot name." + f"You tried to index with: {key}" + ) + + if isinstance(key, str): + for i, subplot in enumerate(self._image_widget.figure): + if subplot.name == key: + key = i + break + else: + raise IndexError(f"No subplot with given name: {key}") + + return key + + def __getitem__(self, key): + key = self._get_key(key) + # return image processor attribute at this index + if isinstance(key, (int, np.integer)): + return getattr(self._image_processors[key], self._attribute) + + # if it's a slice + processors = self._image_processors[key] + + return tuple(getattr(p, self._attribute) for p in processors) + + def __setitem__(self, key, value): + key = self._get_key(key) + + # get the values from the ImageWidget property + new_values = list(getattr(p, self._attribute) for p in self._image_processors) + + # set the new value at this slice + new_values[key] = value + + # call the setter + setattr(self._image_widget, self._attribute, new_values) + + def __iter__(self): + for image_processor in self._image_processors: + yield getattr(image_processor, self._attribute) + + def __repr__(self): + return f"{self._attribute}: {pformat(self[:])}" + + def __eq__(self, other): + return self[:] == other + + +class Indices: + def __init__( + self, + indices: list[int], + image_widget, + ): + self._data = indices + + self._image_widget = image_widget + + def __iter__(self): + for i in self._data: + yield i + + def _parse_key(self, key: int | np.integer | str) -> int: + if not isinstance(key, (int, np.integer, str)): + raise TypeError( + f"indices can only be indexed with or types, you have used: {key}" + ) + + if isinstance(key, str): + # get integer index from user's names + names = self._image_widget._slider_dim_names + if key not in names: + raise KeyError( + f"dim with name: {key} not found in slider_dim_names, current names are: {names}" + ) + + key = names.index(key) + + return key + + def __getitem__(self, key: int | np.integer | str) -> int | tuple[int]: + if isinstance(key, str): + key = self._parse_key(key) + + return self._data[key] + + def __setitem__(self, key, value): + key = self._parse_key(key) + + if not isinstance(value, (int, np.integer)): + raise TypeError( + f"indices values can only be set with integers, you have tried to set the value: {value}" + ) + + new_indices = list(self._data) + new_indices[key] = value + + self._image_widget.indices = new_indices + + def _fpl_set(self, values): + self._data[:] = values + + def pop_dim(self): + self._data.pop(0) + + def push_dim(self): + self._data.insert(0, 0) + + def __len__(self): + return len(self._data) + + def __eq__(self, other): + return self._data == other + + def __repr__(self): + return f"indices: {self._data}" diff --git a/fastplotlib/widgets/image_widget/_sliders.py b/fastplotlib/widgets/image_widget/_sliders.py index 393b13273..1945b8cfb 100644 --- a/fastplotlib/widgets/image_widget/_sliders.py +++ b/fastplotlib/widgets/image_widget/_sliders.py @@ -11,50 +11,66 @@ def __init__(self, figure, size, location, title, image_widget): super().__init__(figure=figure, size=size, location=location, title=title) self._image_widget = image_widget + n_sliders = self._image_widget.n_sliders + # whether or not a dimension is in play mode - self._playing: dict[str, bool] = {"t": False, "z": False} + self._playing: list[bool] = [False] * n_sliders # approximate framerate for playing - self._fps: dict[str, int] = {"t": 20, "z": 20} + self._fps: list[int] = [20] * n_sliders + # framerate converted to frame time - self._frame_time: dict[str, float] = {"t": 1 / 20, "z": 1 / 20} + self._frame_time: list[float] = [1 / 20] * n_sliders # last timepoint that a frame was displayed from a given dimension - self._last_frame_time: dict[str, float] = {"t": 0, "z": 0} + self._last_frame_time: list[float] = [perf_counter()] * n_sliders + # loop playback self._loop = False - if "RTD_BUILD" in os.environ.keys(): - if os.environ["RTD_BUILD"] == "1": - self._playing["t"] = True + # auto-plays the ImageWidget's left-most dimension in docs galleries + if "DOCS_BUILD" in os.environ.keys(): + if os.environ["DOCS_BUILD"] == "1": + self._playing[0] = True self._loop = True - def set_index(self, dim: str, index: int): - """set the current_index of the ImageWidget""" + self.pause = False + + def pop_dim(self): + """pop right most dim""" + i = 0 # len(self._image_widget.indices) - 1 + for l in [self._playing, self._fps, self._frame_time, self._last_frame_time]: + l.pop(i) + + def push_dim(self): + """push a new dim""" + self._playing.insert(0, False) + self._fps.insert(0, 20) + self._frame_time.insert(0, 1 / 20) + self._last_frame_time.insert(0, perf_counter()) + + def set_index(self, dim: int, new_index: int): + """set the index of the ImageWidget""" # make sure the max index for this dim is not exceeded - max_index = self._image_widget._dims_max_bounds[dim] - 1 - if index > max_index: + max_index = self._image_widget.bounds[dim] - 1 + if new_index > max_index: if self._loop: # loop back to index zero if looping is enabled - index = 0 + new_index = 0 else: # if looping not enabled, stop playing this dimension self._playing[dim] = False return - # set current_index - self._image_widget.current_index = {dim: min(index, max_index)} + # set new index + new_indices = list(self._image_widget.indices) + new_indices[dim] = new_index + self._image_widget.indices = new_indices def update(self): """called on every render cycle to update the GUI elements""" - # store the new index of the image widget ("t" and "z") - new_index = dict() - - # flag if the index changed - flag_index_changed = False - # reset vmin-vmax using full orig data if imgui.button(label=fa.ICON_FA_CIRCLE_HALF_STROKE + fa.ICON_FA_FILM): self._image_widget.reset_vmin_vmax() @@ -72,7 +88,7 @@ def update(self): now = perf_counter() # buttons and slider UI elements for each dim - for dim in self._image_widget.slider_dims: + for dim in range(self._image_widget.n_sliders): imgui.push_id(f"{self._id_counter}_{dim}") if self._playing[dim]: @@ -83,7 +99,7 @@ def update(self): # if in play mode and enough time has elapsed w.r.t. the desired framerate, increment the index if now - self._last_frame_time[dim] >= self._frame_time[dim]: - self.set_index(dim, self._image_widget.current_index[dim] + 1) + self.set_index(dim, self._image_widget.indices[dim] + 1) self._last_frame_time[dim] = now else: @@ -97,12 +113,12 @@ def update(self): imgui.same_line() # step back one frame button if imgui.button(label=fa.ICON_FA_BACKWARD_STEP) and not self._playing[dim]: - self.set_index(dim, self._image_widget.current_index[dim] - 1) + self.set_index(dim, self._image_widget.indices[dim] - 1) imgui.same_line() # step forward one frame button if imgui.button(label=fa.ICON_FA_FORWARD_STEP) and not self._playing[dim]: - self.set_index(dim, self._image_widget.current_index[dim] + 1) + self.set_index(dim, self._image_widget.indices[dim] + 1) imgui.same_line() # stop button @@ -137,10 +153,15 @@ def update(self): self._fps[dim] = value self._frame_time[dim] = 1 / value - val = self._image_widget.current_index[dim] - vmax = self._image_widget._dims_max_bounds[dim] - 1 + val = self._image_widget.indices[dim] + vmax = self._image_widget.bounds[dim] - 1 + + dim_name = dim + if self._image_widget._slider_dim_names is not None: + if dim < len(self._image_widget._slider_dim_names): + dim_name = self._image_widget._slider_dim_names[dim] - imgui.text(f"{dim}: ") + imgui.text(f"dim '{dim_name}:' ") imgui.same_line() # so that slider occupies full width imgui.set_next_item_width(self.width * 0.85) @@ -154,18 +175,12 @@ def update(self): # slider for this dimension changed, index = imgui.slider_int( - f"{dim}", v=val, v_min=0, v_max=vmax, flags=flags + f"d: {dim}", v=val, v_min=0, v_max=vmax, flags=flags ) - new_index[dim] = index - - # if the slider value changed for this dimension - flag_index_changed |= changed + if changed: + new_indices = list(self._image_widget.indices) + new_indices[dim] = index + self._image_widget.indices = new_indices imgui.pop_id() - - if flag_index_changed: - # if any slider dim changed set the new index of the image widget - self._image_widget.current_index = new_index - - self.size = int(imgui.get_window_height()) From 777a1d507e1995b4aaa603fc7eaa5166f5bebb1d Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Mon, 16 Feb 2026 08:42:03 -0500 Subject: [PATCH 027/101] update --- fastplotlib/utils/__init__.py | 2 +- fastplotlib/widgets/nd_widget/__init__.py | 1 + fastplotlib/widgets/nd_widget/_nd_image.py | 624 +++++++++++++++++++++ fastplotlib/widgets/nd_widget/nd_image.py | 13 - 4 files changed, 626 insertions(+), 14 deletions(-) create mode 100644 fastplotlib/widgets/nd_widget/_nd_image.py delete mode 100644 fastplotlib/widgets/nd_widget/nd_image.py diff --git a/fastplotlib/utils/__init__.py b/fastplotlib/utils/__init__.py index 8001ae375..6f0059f6a 100644 --- a/fastplotlib/utils/__init__.py +++ b/fastplotlib/utils/__init__.py @@ -6,7 +6,7 @@ from .gpu import enumerate_adapters, select_adapter, print_wgpu_report from ._plot_helpers import * from .enums import * -from ._protocols import ArrayProtocol +from ._protocols import ArrayProtocol, ARRAY_LIKE_ATTRS @dataclass diff --git a/fastplotlib/widgets/nd_widget/__init__.py b/fastplotlib/widgets/nd_widget/__init__.py index 70c2e7621..352df09a8 100644 --- a/fastplotlib/widgets/nd_widget/__init__.py +++ b/fastplotlib/widgets/nd_widget/__init__.py @@ -1,2 +1,3 @@ from .processor_base import NDProcessor from ._nd_positions import NDPositions, NDPositionsProcessor, ndp_extras +from ._nd_image import NDImageProcessor, NDImage diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py new file mode 100644 index 000000000..e3a3a4f80 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -0,0 +1,624 @@ +import inspect +from typing import Literal, Callable, Type, Any +from warnings import warn + +import numpy as np +from numpy.typing import ArrayLike + +from ...utils import subsample_array, ArrayProtocol, ARRAY_LIKE_ATTRS +from ...graphics import ImageGraphic, ImageVolumeGraphic +from .processor_base import NDProcessor + +# must take arguments: array-like, `axis`: int, `keepdims`: bool +WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] + + +class NDImageProcessor(NDProcessor): + def __init__( + self, + data: ArrayLike | None, + n_display_dims: Literal[2, 3] = 2, + rgb: bool = False, + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, + window_sizes: tuple[int | None, ...] | int = None, + window_order: tuple[int, ...] = None, + spatial_func: Callable[[ArrayLike], ArrayLike] = None, + compute_histogram: bool = True, + index_mappings = None, + ): + """ + An ND image that supports computing window functions, and functions over spatial dimensions. + + Parameters + ---------- + data: ArrayLike + array-like data, must have 2 or more dimensions + + n_display_dims: int, 2 or 3, default 2 + number of display dimensions + + rgb: bool, default False + whether the image data is RGB(A) or not + + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable, optional + A function or a ``tuple`` of functions that are applied to a rolling window of the data. + + You can provide unique window functions for each dimension. If you want to apply a window function + only to a subset of the dimensions, put ``None`` to indicate no window function for a given dimension. + + A "window function" must take ``axis`` argument, which is an ``int`` that specifies the axis along which + the window function is applied. It must also take a ``keepdims`` argument which is a ``bool``. The window + function **must** return an array that has the same number of dimensions as the original ``data`` array, + therefore the size of the dimension along which the window was applied will reduce to ``1``. + + The output array-like type from a window function **must** support a ``.squeeze()`` method, but the + function itself should NOT squeeze the output array. + + window_sizes: tuple[int | None, ...], optional + ``tuple`` of ``int`` that specifies the window size for each dimension. + + window_order: tuple[int, ...] | None, optional + order in which to apply the window functions, by default just applies it from the left-most dim to the + right-most slider dim. + + spatial_func: Callable[[ArrayLike], ArrayLike] | None, optional + A function that is applied on the _spatial_ dimensions of the data array, i.e. the last 2 or 3 dimensions. + This function is applied after the window functions (if present). + + compute_histogram: bool, default True + Compute a histogram of the data, auto re-computes if window function propties or spatial_func changes. + Disable if slow. + + """ + # set as False until data, window funcs stuff and spatial func is all set + self._compute_histogram = False + + self.data = data + self.n_display_dims = n_display_dims + self.rgb = rgb + + self.window_funcs = window_funcs + self.window_sizes = window_sizes + self.window_order = window_order + + self._spatial_func = spatial_func + + self._compute_histogram = compute_histogram + self._recompute_histogram() + + self._index_mappings = self._validate_index_mappings(index_mappings) + + @property + def data(self) -> ArrayLike | None: + """get or set the data array""" + return self._data + + @data.setter + def data(self, data: ArrayLike): + # check that all array-like attributes are present + if data is None: + self._data = None + return + + if not isinstance(data, ArrayProtocol): + raise TypeError( + f"`data` arrays must have all of the following attributes to be sufficiently array-like:\n" + f"{ARRAY_LIKE_ATTRS}, or they must be `None`" + ) + + if data.ndim < 2: + raise IndexError( + f"Image data must have a minimum of 2 dimensions, you have passed an array of shape: {data.shape}" + ) + + self._data = data + self._recompute_histogram() + + @property + def ndim(self) -> int: + if self.data is None: + return 0 + + return self.data.ndim + + @property + def shape(self) -> tuple[int, ...]: + if self._data is None: + return tuple() + + return self.data.shape + + @property + def rgb(self) -> bool: + """whether or not the data is rgb(a)""" + return self._rgb + + @rgb.setter + def rgb(self, rgb: bool): + if not isinstance(rgb, bool): + raise TypeError + + if rgb and self.ndim < 3: + raise IndexError( + f"require 3 or more dims for RGB, you have: {self.ndim} dims" + ) + + self._rgb = rgb + + @property + def n_slider_dims(self) -> int: + """number of slider dimensions""" + if self._data is None: + return 0 + + return self.ndim - self.n_display_dims - int(self.rgb) + + @property + def slider_dims(self) -> tuple[int, ...] | None: + """tuple indicating the slider dimension indices""" + if self.n_slider_dims == 0: + return None + + return tuple(range(self.n_slider_dims)) + + @property + def slider_dims_shape(self) -> tuple[int, ...] | None: + if self.n_slider_dims == 0: + return None + + return tuple(self.shape[i] for i in self.slider_dims) + + @property + def n_display_dims(self) -> Literal[2, 3]: + """get or set the number of display dimensions, `2` for 2D image and `3` for volume images""" + return self._n_display_dims + + # TODO: make n_display_dims settable, requires thinking about inserting and poping indices in ImageWidget + @n_display_dims.setter + def n_display_dims(self, n: Literal[2, 3]): + if not (n == 2 or n == 3): + raise ValueError( + f"`n_display_dims` must be an with a value of 2 or 3, you have passed: {n}" + ) + self._n_display_dims = n + self._recompute_histogram() + + @property + def max_n_display_dims(self) -> int: + """maximum number of possible display dims""" + # min 2, max 3, accounts for if data is None and ndim is 0 + return max(2, min(3, self.ndim - int(self.rgb))) + + @property + def display_dims(self) -> tuple[int, int] | tuple[int, int, int]: + """tuple indicating the display dimension indices""" + return tuple(range(self.data.ndim))[self.n_slider_dims :] + + @property + def window_funcs( + self, + ) -> tuple[WindowFuncCallable | None, ...] | None: + """get or set window functions, see docstring for details""" + return self._window_funcs + + @window_funcs.setter + def window_funcs( + self, + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable | None, + ): + if window_funcs is None: + self._window_funcs = None + return + + if callable(window_funcs): + window_funcs = (window_funcs,) + + # if all are None + if all([f is None for f in window_funcs]): + self._window_funcs = None + return + + self._validate_window_func(window_funcs) + + self._window_funcs = tuple(window_funcs) + self._recompute_histogram() + + def _validate_window_func(self, funcs): + if isinstance(funcs, (tuple, list)): + for f in funcs: + if f is None: + pass + elif callable(f): + sig = inspect.signature(f) + + if "axis" not in sig.parameters or "keepdims" not in sig.parameters: + raise TypeError( + f"Each window function must take an `axis` and `keepdims` argument, " + f"you passed: {f} with the following function signature: {sig}" + ) + else: + raise TypeError( + f"`window_funcs` must be of type: tuple[Callable | None, ...], you have passed: {funcs}" + ) + + if not (len(funcs) == self.n_slider_dims or self.n_slider_dims == 0): + raise IndexError( + f"number of `window_funcs` must be the same as the number of slider dims: {self.n_slider_dims}, " + f"and you passed {len(funcs)} `window_funcs`: {funcs}" + ) + + @property + def window_sizes(self) -> tuple[int | None, ...] | None: + """get or set window sizes used for the corresponding window functions, see docstring for details""" + return self._window_sizes + + @window_sizes.setter + def window_sizes(self, window_sizes: tuple[int | None, ...] | int | None): + if window_sizes is None: + self._window_sizes = None + return + + if isinstance(window_sizes, int): + window_sizes = (window_sizes,) + + # if all are None + if all([w is None for w in window_sizes]): + self._window_sizes = None + return + + if not all([isinstance(w, (int)) or w is None for w in window_sizes]): + raise TypeError( + f"`window_sizes` must be of type: tuple[int | None, ...] | int | None, you have passed: {window_sizes}" + ) + + if not (len(window_sizes) == self.n_slider_dims or self.n_slider_dims == 0): + raise IndexError( + f"number of `window_sizes` must be the same as the number of slider dims, " + f"i.e. `data.ndim` - n_display_dims, your data array has {self.ndim} dimensions " + f"and you passed {len(window_sizes)} `window_sizes`: {window_sizes}" + ) + + # make all window sizes are valid numbers + _window_sizes = list() + for i, w in enumerate(window_sizes): + if w is None: + _window_sizes.append(None) + continue + + if w < 0: + raise ValueError( + f"negative window size passed, all `window_sizes` must be positive " + f"integers or `None`, you passed: {_window_sizes}" + ) + + if w == 0 or w == 1: + # this is not a real window, set as None + w = None + + elif w % 2 == 0: + # odd window sizes makes most sense + warn( + f"provided even window size: {w} in dim: {i}, adding `1` to make it odd" + ) + w += 1 + + _window_sizes.append(w) + + self._window_sizes = tuple(_window_sizes) + self._recompute_histogram() + + @property + def window_order(self) -> tuple[int, ...] | None: + """get or set dimension order in which window functions are applied""" + return self._window_order + + @window_order.setter + def window_order(self, order: tuple[int] | None): + if order is None: + self._window_order = None + return + + if order is not None: + if not all([d <= self.n_slider_dims for d in order]): + raise IndexError( + f"all `window_order` entries must be <= n_slider_dims\n" + f"`n_slider_dims` is: {self.n_slider_dims}, you have passed `window_order`: {order}" + ) + + if not all([d >= 0 for d in order]): + raise IndexError( + f"all `window_order` entires must be >= 0, you have passed: {order}" + ) + + self._window_order = tuple(order) + self._recompute_histogram() + + @property + def spatial_func(self) -> Callable[[ArrayLike], ArrayLike] | None: + """get or set a spatial_func function, see docstring for details""" + return self._spatial_func + + @spatial_func.setter + def spatial_func(self, func: Callable[[ArrayLike], ArrayLike] | None): + if not (callable(func) or func is not None): + raise TypeError( + f"`spatial_func` must be a callable or `None`, you have passed: {func}" + ) + + self._spatial_func = func + self._recompute_histogram() + + @property + def compute_histogram(self) -> bool: + return self._compute_histogram + + @compute_histogram.setter + def compute_histogram(self, compute: bool): + if compute: + if self._compute_histogram is False: + # compute a histogram + self._recompute_histogram() + self._compute_histogram = True + else: + self._compute_histogram = False + self._histogram = None + + @property + def histogram(self) -> tuple[np.ndarray, np.ndarray] | None: + """ + an estimate of the histogram of the data, (histogram_values, bin_edges). + + returns `None` if `compute_histogram` is `False` + """ + return self._histogram + + def _apply_window_function(self, indices: tuple[int, ...]) -> ArrayLike: + """applies the window functions for each dimension specified""" + # window size for each dim + winds = self._window_sizes + # window function for each dim + funcs = self._window_funcs + + if winds is None or funcs is None: + # no window funcs or window sizes, just slice data and return + # clamp to max bounds + indexer = list() + for dim, i in enumerate(indices): + i = min(self.shape[dim] - 1, i) + indexer.append(i) + + return self.data[tuple(indexer)] + + # order in which window funcs are applied + order = self._window_order + + if order is not None: + # remove any entries in `window_order` where the specified dim + # has a window function or window size specified as `None` + # example: + # window_sizes = (3, 2) + # window_funcs = (np.mean, None) + # order = (0, 1) + # `1` is removed from the order since that window_func is `None` + order = tuple( + d for d in order if winds[d] is not None and funcs[d] is not None + ) + else: + # sequential order + order = list() + for d in range(self.n_slider_dims): + if winds[d] is not None and funcs[d] is not None: + order.append(d) + + # the final indexer which will be used on the data array + indexer = list() + + for dim_index, (i, w, f) in enumerate(zip(indices, winds, funcs)): + # clamp i within the max bounds + i = min(self.shape[dim_index] - 1, i) + + if (w is not None) and (f is not None): + # specify slice window if both window size and function for this dim are not None + hw = int((w - 1) / 2) # half window + + # start index cannot be less than 0 + start = max(0, i - hw) + + # stop index cannot exceed the bounds of this dimension + stop = min(self.shape[dim_index] - 1, i + hw) + + s = slice(start, stop, 1) + else: + s = slice(i, i + 1, 1) + + indexer.append(s) + + # apply indexer to slice data with the specified windows + data_sliced = self.data[tuple(indexer)] + + # finally apply the window functions in the specified order + for dim in order: + f = funcs[dim] + + data_sliced = f(data_sliced, axis=dim, keepdims=True) + + return data_sliced + + def get(self, indices: tuple[int, ...]) -> ArrayLike | None: + """ + Get the data at the given index, process data through the window functions. + + Note that we do not use __getitem__ here since the index is a tuple specifying a single integer + index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. + + Parameters + ---------- + indices: tuple[int, ...] + Get the processed data at this index. Must provide a value for each dimension. + Example: get((100, 5)) + + """ + if self.data is None: + return None + + # apply any slider index mappings + indices = tuple([m(i) for m, i in zip(self.index_mappings, indices)]) + + if self.n_slider_dims != 0: + if len(indices) != self.n_slider_dims: + raise IndexError( + f"Must specify index for every slider dim, you have specified an index: {indices}\n" + f"But there are: {self.n_slider_dims} slider dims." + ) + # get output after processing through all window funcs + # squeeze to remove all dims of size 1 + window_output = self._apply_window_function(indices).squeeze() + else: + # data is a static image or volume + window_output = self.data + + # apply spatial_func + if self.spatial_func is not None: + final_output = self.spatial_func(window_output) + if final_output.ndim != (self.n_display_dims + int(self.rgb)): + raise IndexError( + f"Final output after of the `spatial_func` must match the number of display dims." + f"Output after `spatial_func` returned an array with {final_output.ndim} dims and " + f"of shape: {final_output.shape}, expected {self.n_display_dims} dims" + ) + else: + # check that output ndim after window functions matches display dims + final_output = window_output + if final_output.ndim != (self.n_display_dims + int(self.rgb)): + raise IndexError( + f"Final output after of the `window_funcs` must match the number of display dims." + f"Output after `window_funcs` returned an array with {window_output.ndim} dims and " + f"of shape: {window_output.shape}{' with rgb(a) channels' if self.rgb else ''}, " + f"expected {self.n_display_dims} dims" + ) + + return final_output + + def _recompute_histogram(self): + """ + + Returns + ------- + (histogram_values, bin_edges) + + """ + if not self._compute_histogram or self.data is None: + self._histogram = None + return + + if self.spatial_func is not None: + # don't subsample spatial dims if a spatial function is used + # spatial functions often operate on the spatial dims, ex: a gaussian kernel + # so their results require the full spatial resolution, the histogram of a + # spatially subsampled image will be very different + ignore_dims = self.display_dims + else: + ignore_dims = None + + sub = subsample_array(self.data, ignore_dims=ignore_dims) + sub_real = sub[~(np.isnan(sub) | np.isinf(sub))] + + self._histogram = np.histogram(sub_real, bins=100) + + +class NDImage: + def __init__( + self, + data: Any, + *args, + graphic: type[ImageGraphic, ImageVolumeGraphic] = None, + processor: type[NDImageProcessor] = NDImageProcessor, + window_funcs: tuple[WindowFuncCallable | None] | None = None, + window_sizes: tuple[int | None] | None = None, + index_mappings: tuple[Callable[[Any], int] | None] | None = None, + graphic_kwargs: dict = None, + processor_kwargs: dict = None, + ): + if processor_kwargs is None: + processor_kwargs = dict() + + self._processor = processor( + data, + *args, + window_funcs=window_funcs, + window_sizes=window_sizes, + index_mappings=index_mappings, + **processor_kwargs, + ) + + self._indices = tuple([0] * self._processor.n_slider_dims) + + self._graphic = None + + self._create_graphic() + + @property + def processor(self) -> NDImageProcessor: + return self._processor + + @property + def graphic( + self, + ) -> ( + ImageGraphic | ImageVolumeGraphic + ): + """LineStack or ImageGraphic for heatmaps""" + return self._graphic + + @graphic.setter + def graphic(self, graphic_type): + # TODO implement if graphic type changes to custom user subclass + pass + + def _create_graphic(self): + match self.processor.n_display_dims: + case 2: + cls = ImageGraphic + case 3: + cls = ImageVolumeGraphic + + data_slice = self.processor.get(self.indices) + + old_graphic = self._graphic + new_graphic = cls(data_slice) + + if old_graphic is not None: + g = self._graphic + plot_area = g._plot_area + self._graphic._plot_area.delete_graphic(g) + plot_area.add_graphic(self._graphic) + + self._graphic = new_graphic + + @property + def n_display_dims(self) -> Literal[2, 3]: + return self.processor.n_display_dims + + @n_display_dims.setter + def n_display_dims(self, n: Literal[2 , 3]): + self.processor.n_display_dims = n + + self._create_graphic() + + @property + def indices(self) -> tuple: + return self._indices + + @indices.setter + def indices(self, indices): + data_slice = self.processor.get(indices) + + self.graphic.data = data_slice + + self._indices = indices + + def _tooltip_handler(self, graphic, pick_info): + # get graphic within the collection + n_index = np.argwhere(self.graphic.graphics == graphic).item() + p_index = pick_info["vertex_index"] + return self.processor.tooltip_format(n_index, p_index) diff --git a/fastplotlib/widgets/nd_widget/nd_image.py b/fastplotlib/widgets/nd_widget/nd_image.py deleted file mode 100644 index 4972db9d5..000000000 --- a/fastplotlib/widgets/nd_widget/nd_image.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Literal - -from .processor_base import NDProcessor - - -class NDImageProcessor(NDProcessor): - @property - def n_display_dims(self) -> Literal[2, 3]: - pass - - def _validate_n_display_dims(self, n_display_dims): - if n_display_dims not in (2, 3): - raise ValueError("`n_display_dims` must be") From 13557336a13d6f0ec23cc0342644192f2ad04d99 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 17 Feb 2026 01:29:19 -0500 Subject: [PATCH 028/101] basic minimal ndw orchestration working --- fastplotlib/__init__.py | 2 +- fastplotlib/widgets/__init__.py | 3 +- fastplotlib/widgets/nd_widget/__init__.py | 17 +- fastplotlib/widgets/nd_widget/_nd_image.py | 10 +- .../widgets/nd_widget/_nd_positions/core.py | 14 +- .../nd_widget/{processor_base.py => base.py} | 23 ++ fastplotlib/widgets/nd_widget/ndwidget.py | 214 ++++++++++++++++++ 7 files changed, 266 insertions(+), 17 deletions(-) rename fastplotlib/widgets/nd_widget/{processor_base.py => base.py} (94%) create mode 100644 fastplotlib/widgets/nd_widget/ndwidget.py diff --git a/fastplotlib/__init__.py b/fastplotlib/__init__.py index 6dab91605..bde2c89e3 100644 --- a/fastplotlib/__init__.py +++ b/fastplotlib/__init__.py @@ -19,7 +19,7 @@ else: from .layouts import Figure -from .widgets import ImageWidget +from .widgets import NDWidget, ImageWidget from .utils import config, enumerate_adapters, select_adapter, print_wgpu_report diff --git a/fastplotlib/widgets/__init__.py b/fastplotlib/widgets/__init__.py index 766620ea6..04102dbdf 100644 --- a/fastplotlib/widgets/__init__.py +++ b/fastplotlib/widgets/__init__.py @@ -1,3 +1,4 @@ +from .nd_widget import NDWidget from .image_widget import ImageWidget -__all__ = ["ImageWidget"] +__all__ = ["NDWidget", "ImageWidget"] diff --git a/fastplotlib/widgets/nd_widget/__init__.py b/fastplotlib/widgets/nd_widget/__init__.py index 352df09a8..7855327d9 100644 --- a/fastplotlib/widgets/nd_widget/__init__.py +++ b/fastplotlib/widgets/nd_widget/__init__.py @@ -1,3 +1,14 @@ -from .processor_base import NDProcessor -from ._nd_positions import NDPositions, NDPositionsProcessor, ndp_extras -from ._nd_image import NDImageProcessor, NDImage +from ...layouts import IMGUI + +if IMGUI: + from .base import NDProcessor + from ._nd_positions import NDPositions, NDPositionsProcessor, ndp_extras + from ._nd_image import NDImageProcessor, NDImage + from .ndwidget import NDWidget +else: + class NDWidget: + def __init__(self, *args, **kwargs): + raise ModuleNotFoundError( + "NDWidget requires `imgui-bundle` to be installed.\n" + "pip install imgui-bundle" + ) diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py index e3a3a4f80..3e54814b2 100644 --- a/fastplotlib/widgets/nd_widget/_nd_image.py +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -7,10 +7,7 @@ from ...utils import subsample_array, ArrayProtocol, ARRAY_LIKE_ATTRS from ...graphics import ImageGraphic, ImageVolumeGraphic -from .processor_base import NDProcessor - -# must take arguments: array-like, `axis`: int, `keepdims`: bool -WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] +from .base import NDProcessor, NDGraphic, WindowFuncCallable class NDImageProcessor(NDProcessor): @@ -526,7 +523,7 @@ def _recompute_histogram(self): self._histogram = np.histogram(sub_real, bins=100) -class NDImage: +class NDImage(NDGraphic): def __init__( self, data: Any, @@ -538,6 +535,7 @@ def __init__( index_mappings: tuple[Callable[[Any], int] | None] | None = None, graphic_kwargs: dict = None, processor_kwargs: dict = None, + name: str = None, ): if processor_kwargs is None: processor_kwargs = dict() @@ -557,6 +555,8 @@ def __init__( self._create_graphic() + self._name = name + @property def processor(self) -> NDImageProcessor: return self._processor diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/core.py b/fastplotlib/widgets/nd_widget/_nd_positions/core.py index b95916ce8..cd19bf2a5 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/core.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/core.py @@ -16,7 +16,7 @@ ScatterGraphic, ScatterCollection, ) -from ..processor_base import NDProcessor, WindowFuncCallable +from ..base import NDProcessor, NDGraphic, WindowFuncCallable # TODO: Maybe get rid of n_display_dims in NDProcessor, @@ -210,7 +210,8 @@ def _get_dw_slices(self, indices) -> tuple[slice] | tuple[slice, slice]: if index_p_start >= index_p_stop: index_p_stop = index_p_start + 1 - slices = [slice(index_p_start, index_p_stop)] + # round to the nearest integer since to use as arra indices + slices = [slice(round(index_p_start), round(index_p_stop))] if self.multi: slices.insert(0, slice(None)) @@ -225,19 +226,18 @@ def get(self, indices: tuple[Any, ...]): index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. """ # apply any slider index mappings - indices = tuple([m(i) for m, i in zip(self.index_mappings, indices)]) + array_indices = tuple([m(i) for m, i in zip(self.index_mappings, indices)]) - if len(indices) > 1: + if len(array_indices) > 1: # there are dims in addition to the n_datapoints dim # apply window funcs # window_output array should be of shape [n_datapoints, 2 | 3] - window_output = self._apply_window_functions(indices[:-1]).squeeze() + window_output = self._apply_window_functions(array_indices[:-1]).squeeze() else: window_output = self.data - # TODO: window function on the `p` n_datapoints dimension - if self.display_window is not None: + # display_window is in reference units slices = self._get_dw_slices(indices) # if self.display_window is not None: diff --git a/fastplotlib/widgets/nd_widget/processor_base.py b/fastplotlib/widgets/nd_widget/base.py similarity index 94% rename from fastplotlib/widgets/nd_widget/processor_base.py rename to fastplotlib/widgets/nd_widget/base.py index a1cd5311c..e46386e93 100644 --- a/fastplotlib/widgets/nd_widget/processor_base.py +++ b/fastplotlib/widgets/nd_widget/base.py @@ -6,6 +6,7 @@ from numpy.typing import ArrayLike from ...utils import subsample_array, ArrayProtocol +from ...graphics import Graphic # must take arguments: array-like, `axis`: int, `keepdims`: bool WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] @@ -249,3 +250,25 @@ def _validate_index_mappings(self, maps): def __getitem__(self, item: tuple[Any, ...]) -> ArrayProtocol: pass + + +class NDGraphic: + @property + def name(self) -> str: + return self._name + + @property + def processor(self) -> NDProcessor: + raise NotImplementedError + + @property + def graphic(self) -> Graphic: + raise NotImplementedError + + @property + def indices(self) -> tuple[Any]: + raise NotImplementedError + + @indices.setter + def indices(self, new: tuple): + raise NotImplementedError diff --git a/fastplotlib/widgets/nd_widget/ndwidget.py b/fastplotlib/widgets/nd_widget/ndwidget.py new file mode 100644 index 000000000..2932fa18d --- /dev/null +++ b/fastplotlib/widgets/nd_widget/ndwidget.py @@ -0,0 +1,214 @@ +from dataclasses import dataclass +import os +from time import perf_counter +from typing import Any, Sequence + +from imgui_bundle import imgui, icons_fontawesome_6 as fa +import numpy as np + +from ...layouts import ImguiFigure, Subplot +from ...graphics import ScatterCollection, LineCollection, LineStack, ImageGraphic +from ...ui import EdgeWindow +from .base import NDGraphic, NDProcessor +from ._nd_image import NDImage, NDImageProcessor +from ._nd_positions import NDPositions, NDPositionsProcessor + + +@dataclass +class ReferenceRangeContinuous: + start: int | float + stop: int | float + step: int | float + unit: str + + def __getitem__(self, index: int): + """return the value at the index w.r.t. the step size""" + # if index is negative, turn to positive index + if index < 0: + raise ValueError("negative indexing not supported") + + val = self.start + (self.step * index) + if not self.start <= val <= self.stop: + raise IndexError(f"index: {index} value: {val} out of bounds: [{self.start}, {self.stop}]") + + return val + + +@dataclass +class ReferenceRangeDiscrete: + options: Sequence[Any] + unit: str + + def __getitem__(self, index: int): + if index > len(self.options): + raise IndexError + + return self.options[index] + + def __len__(self): + return len(self.options) + + +class NDWSubplot: + def __init__(self, ndw, subplot: Subplot): + self.ndw = ndw + self._subplot = subplot + + self._nd_graphics = list() + + @property + def nd_graphics(self) -> list[NDGraphic]: + return self._nd_graphics + + def __getitem__(self, key): + if isinstance(key, (int, np.integer)): + return self.nd_graphics[key] + + for g in self.nd_graphics: + if g.name == key: + return g + + else: + raise KeyError(f"NDGraphc with given key not found: {key}") + + def add_nd_image(self, *args, **kwargs): + nd = NDImage(*args, **kwargs) + self._nd_graphics.append(nd) + self._subplot.add_graphic(nd.graphic) + return nd + + def add_nd_scatter(self, *args, **kwargs): + nd = NDPositions(*args, graphic=ScatterCollection, multi=True, **kwargs) + self._nd_graphics.append(nd) + self._subplot.add_graphic(nd.graphic) + + return nd + + def add_nd_timeseries(self, *args, graphic: type[LineCollection | LineStack | ImageGraphic] = LineStack, **kwargs): + nd = NDPositions(*args, graphic=LineStack, multi=True, **kwargs) + self._nd_graphics.append(nd) + self._subplot.add_graphic(nd.graphic) + # TODO: think about auto-xrange for subplot camera + return nd + + def add_nd_lines(self, *args, **kwargs): + nd = NDPositions(*args, graphic=LineCollection, multi=True, **kwargs) + self._nd_graphics.append(nd) + self._subplot.add_graphic(nd.graphic) + return nd + + # def __repr__(self): + # return "NDWidget Subplot" + # + # def __str__(self): + # return "NDWidget Subplot" + + +class NDWSliders(EdgeWindow): + def __init__(self, figure, size, ndwidget): + super().__init__(figure=figure, size=size, title="NDWidget controls", location="bottom") + self._ndwidget = ndwidget + + # n_sliders = self._image_widget.n_sliders + # + # # whether or not a dimension is in play mode + # self._playing: list[bool] = [False] * n_sliders + # + # # approximate framerate for playing + # self._fps: list[int] = [20] * n_sliders + # + # # framerate converted to frame time + # self._frame_time: list[float] = [1 / 20] * n_sliders + # + # # last timepoint that a frame was displayed from a given dimension + # self._last_frame_time: list[float] = [perf_counter()] * n_sliders + # + # # loop playback + # self._loop = False + # + # # auto-plays the ImageWidget's left-most dimension in docs galleries + # if "DOCS_BUILD" in os.environ.keys(): + # if os.environ["DOCS_BUILD"] == "1": + # self._playing[0] = True + # self._loop = True + # + # self.pause = False + + def update(self): + indices_changed = False + + for dim_index, (current_index, refr) in enumerate(zip(self._ndwidget.indices, self._ndwidget.ref_ranges)): + if isinstance(refr, ReferenceRangeContinuous): + changed, val = imgui.slider_float( + v=current_index, + v_min=refr.start, + v_max=refr.stop, + label=refr.unit + ) + + if changed: + new_indices = list(self._ndwidget.indices) + new_indices[dim_index] = val + + indices_changed = True + + if indices_changed: + self._ndwidget.indices = tuple(new_indices) + + +class NDWidget: + def __init__(self, ref_ranges: list[tuple], **kwargs): + self._ref_ranges = list() + + for r in ref_ranges: + if len(r) == 4: + # assume start, stop, step, unit + refr = ReferenceRangeContinuous(*r) + elif len(r) == 2: + refr = ReferenceRangeDiscrete(*r) + else: + raise ValueError + + self._ref_ranges.append(refr) + + self._figure = ImguiFigure(**kwargs) + + self._subplots: dict[Subplot, NDWSubplot] = dict() + for subplot in self.figure: + self._subplots[subplot] = NDWSubplot(self, subplot) + + # starting index for all dims + self._indices = tuple(refr[0] for refr in self.ref_ranges) + + # hard code the expected height so that the first render looks right in tests, docs etc. + ui_size = 57 + (50 * len(self.indices)) + + self._sliders_ui = NDWSliders(self.figure, ui_size, self) + self.figure.add_gui(self._sliders_ui) + + @property + def figure(self) -> ImguiFigure: + return self._figure + + @property + def ref_ranges(self) -> tuple[ReferenceRangeContinuous | ReferenceRangeDiscrete]: + return tuple(self._ref_ranges) + + @property + def indices(self) -> tuple: + return self._indices + + @indices.setter + def indices(self, new_indices: tuple[Any]): + for subplot in self._subplots.values(): + for ndg in subplot.nd_graphics: + ndg.indices = new_indices + + self._indices = new_indices + + def __getitem__(self, key): + subplot = self.figure[key] + return self._subplots[subplot] + + def show(self, **kwargs): + return self.figure.show(**kwargs) \ No newline at end of file From 78878b1b072d37d4e87695d24c17546a91f9225a Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 17 Feb 2026 18:10:26 -0500 Subject: [PATCH 029/101] implement auto-x for timeseries --- fastplotlib/layouts/_plot_area.py | 36 +++++++++++++++++++ .../widgets/nd_widget/_nd_positions/core.py | 17 +++++++-- fastplotlib/widgets/nd_widget/ndwidget.py | 2 +- 3 files changed, 51 insertions(+), 4 deletions(-) diff --git a/fastplotlib/layouts/_plot_area.py b/fastplotlib/layouts/_plot_area.py index 405a01546..8ca914717 100644 --- a/fastplotlib/layouts/_plot_area.py +++ b/fastplotlib/layouts/_plot_area.py @@ -860,6 +860,42 @@ def _auto_scale_scene( camera.zoom = zoom + @property + def x_range(self) -> tuple[float, float]: + """ + Get or set the x-range currently in view. + Only valid for orthographic projections of the xy plane. + Use camera.set_state() to set the camera position for arbitrary projections. + """ + hw = self.camera.width / 2 + x = self.camera.local.x + return x - hw, x + hw + + @x_range.setter + def x_range(self, xr: tuple[float, float]): + width = xr[1] - xr[0] + x_mid = xr[0] + (width / 2) + self.camera.width = width + self.camera.local.x = x_mid + + @property + def y_range(self) -> tuple[float, float]: + """ + Get or set the y-range currently in view. + Only valid for orthographic projections of the xy plane. + Use camera.set_state() to set the camera position for arbitrary projections. + """ + hh = self.camera.width / 2 + y = self.camera.local.y + return y - hh, y + hh + + @y_range.setter + def y_range(self, yr: tuple[float, float]): + width = yr[1] - yr[0] + y_mid = yr[0] + (width / 2) + self.camera.width = width + self.camera.local.y = y_mid + def remove_graphic(self, graphic: Graphic): """ Remove a ``Graphic`` from the scene. Note: This does not garbage collect the graphic, diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/core.py b/fastplotlib/widgets/nd_widget/_nd_positions/core.py index cd19bf2a5..6717bccd2 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/core.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/core.py @@ -353,6 +353,7 @@ def __init__( window_sizes: tuple[int | None] | None = None, index_mappings: tuple[Callable[[Any], int] | None] | None = None, max_display_datapoints: int = 1_000, + auto_x_range: bool = False, graphic_kwargs: dict = None, processor_kwargs: dict = None, ): @@ -378,6 +379,8 @@ def __init__( self._indices = tuple([0] * self._processor.n_slider_dims) + self._auto_x_range = auto_x_range + self._create_graphic(graphic) @property @@ -433,6 +436,12 @@ def indices(self, indices): image_data, x0, x_scale = self._create_heatmap_data(data_slice) self.graphic.data = image_data self.graphic.offset = (x0, *self.graphic.offset[1:]) + self.graphic.scale = (x_scale, *self.graphic.scale[1:]) + + # x range of the data + xr = data_slice[0, 0, 0], data_slice[0, -1, 0] + if self._auto_x_range: + self.graphic._plot_area.x_range = xr self._indices = indices @@ -504,11 +513,13 @@ def _create_heatmap_data(self, data_slice) -> tuple[np.ndarray, float, float]: # x is sufficiently uniform y_interp = data_slice[..., 1] - # assume all x values are the same - x_scale = data_slice[:, -1, 0][0] / data_slice.shape[1] - x0 = data_slice[0, 0, 0] + # assume all x values are the same across all lines + # otherwise a heatmap representation makes no sense anyways + x_stop = data_slice[:, -1, 0][0] + x_scale = (x_stop - x0) / data_slice.shape[1] + return y_interp, x0, x_scale @property diff --git a/fastplotlib/widgets/nd_widget/ndwidget.py b/fastplotlib/widgets/nd_widget/ndwidget.py index 2932fa18d..dd8610849 100644 --- a/fastplotlib/widgets/nd_widget/ndwidget.py +++ b/fastplotlib/widgets/nd_widget/ndwidget.py @@ -85,7 +85,7 @@ def add_nd_scatter(self, *args, **kwargs): return nd def add_nd_timeseries(self, *args, graphic: type[LineCollection | LineStack | ImageGraphic] = LineStack, **kwargs): - nd = NDPositions(*args, graphic=LineStack, multi=True, **kwargs) + nd = NDPositions(*args, graphic=graphic, multi=True, auto_x_range=True,**kwargs) self._nd_graphics.append(nd) self._subplot.add_graphic(nd.graphic) # TODO: think about auto-xrange for subplot camera From 3ad64fa8908e219e6ad885fb6b1113226dd707cc Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 18 Feb 2026 02:12:09 -0500 Subject: [PATCH 030/101] bugfix update worldobject -> graphic map for image tiles --- fastplotlib/graphics/_base.py | 36 ++++++++++++++++++++++++++--------- fastplotlib/graphics/image.py | 6 ++++++ 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/fastplotlib/graphics/_base.py b/fastplotlib/graphics/_base.py index e0602e4e3..abc3c4cad 100644 --- a/fastplotlib/graphics/_base.py +++ b/fastplotlib/graphics/_base.py @@ -291,15 +291,8 @@ def _set_world_object(self, wo: pygfx.WorldObject): # add to world object -> graphic mapping if isinstance(wo, pygfx.Group): - for child in wo.children: - if isinstance( - child, (pygfx.Image, pygfx.Volume, pygfx.Points, pygfx.Line) - ): - # unique 32 bit integer id for each world object - global_id = child.id - WORLD_OBJECT_TO_GRAPHIC[global_id] = self - # store id to pop from dict when graphic is deleted - self._world_object_ids.append(global_id) + # for Graphics which use a pygfx.Group, ImageGraphic and graphic collections + self._add_group_graphic_map(wo) else: global_id = wo.id WORLD_OBJECT_TO_GRAPHIC[global_id] = self @@ -324,6 +317,31 @@ def _set_world_object(self, wo: pygfx.WorldObject): if not all(wo.world.rotation == self.rotation): self.rotation = self.rotation + def _add_group_graphic_map(self, wo: pygfx.Group): + # add the children of the group to the WorldObject -> Graphic map + # used by images since they create new WorldObject ImageTiles when a different buffer size is required + # also used by GraphicCollections inititally, but not used for reseting like images + for child in wo.children: + if isinstance( + child, (pygfx.Image, pygfx.Volume, pygfx.Points, pygfx.Line) + ): + # unique 32 bit integer id for each world object + global_id = child.id + WORLD_OBJECT_TO_GRAPHIC[global_id] = self + # store id to pop from dict when graphic is deleted + self._world_object_ids.append(global_id) + + def _remove_group_graphic_map(self, wo: pygfx.Group): + # remove the children of the group to the WorldObject -> Graphic map + for child in wo.children: + if isinstance( + child, (pygfx.Image, pygfx.Volume, pygfx.Points, pygfx.Line) + ): + # unique 32 bit integer id for each world object + global_id = child.id + WORLD_OBJECT_TO_GRAPHIC.pop(global_id) + self._world_object_ids.remove(global_id) + @property def tooltip_format(self) -> Callable[[dict], str] | None: """ diff --git a/fastplotlib/graphics/image.py b/fastplotlib/graphics/image.py index 7b670d531..6dfb52238 100644 --- a/fastplotlib/graphics/image.py +++ b/fastplotlib/graphics/image.py @@ -261,6 +261,9 @@ def data(self, data): self._material.clim = quick_min_max(self.data.value) + # remove tiles from the WorldObject -> Graphic map + self._remove_group_graphic_map(self.world_object) + # clear image tiles self.world_object.clear() @@ -268,6 +271,9 @@ def data(self, data): for tile in self._create_tiles(): self.world_object.add(tile) + # add new tiles to WorldObject -> Graphic map + self._add_group_graphic_map(self.world_object) + return self._data[:] = data From 75361c0de940e88388113ade492f87053a73d499 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 18 Feb 2026 02:12:26 -0500 Subject: [PATCH 031/101] bugfix linear selector set limits --- fastplotlib/graphics/selectors/_linear.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/fastplotlib/graphics/selectors/_linear.py b/fastplotlib/graphics/selectors/_linear.py index 0c956d57b..4ea454ee8 100644 --- a/fastplotlib/graphics/selectors/_linear.py +++ b/fastplotlib/graphics/selectors/_linear.py @@ -45,10 +45,8 @@ def limits(self, values: tuple[float, float]): # using `Real` here allows it to work with builtin `int` and `float` types, and numpy scaler types if len(values) != 2 or not all(map(lambda v: isinstance(v, Real), values)): raise TypeError("limits must be an iterable of two numeric values") - self._limits = tuple( - map(round, values) - ) # if values are close to zero things get weird so round them - self.selection._limits = self._limits + self._limits = np.asarray(values) # if values are close to zero things get weird so round them + self._selection._limits = self._limits @property def edge_color(self) -> pygfx.Color: From 38481c0ff783eec95d49a8a989cfb458d64b822c Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 18 Feb 2026 02:13:33 -0500 Subject: [PATCH 032/101] linear selector for timeseries --- .../widgets/nd_widget/_nd_positions/core.py | 23 +++++++++++++++ fastplotlib/widgets/nd_widget/ndwidget.py | 29 ++++++++++++++++--- 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/core.py b/fastplotlib/widgets/nd_widget/_nd_positions/core.py index 6717bccd2..c763f9100 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/core.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/core.py @@ -16,6 +16,8 @@ ScatterGraphic, ScatterCollection, ) +from ....graphics.utils import pause_events +from ....graphics.selectors import LinearSelector from ..base import NDProcessor, NDGraphic, WindowFuncCallable @@ -354,6 +356,7 @@ def __init__( index_mappings: tuple[Callable[[Any], int] | None] | None = None, max_display_datapoints: int = 1_000, auto_x_range: bool = False, + linear_selector: bool = False, graphic_kwargs: dict = None, processor_kwargs: dict = None, ): @@ -383,6 +386,13 @@ def __init__( self._create_graphic(graphic) + if linear_selector: + self._linear_selector = LinearSelector(0, limits=(-np.inf, np.inf), edge_color="cyan") + else: + self._linear_selector = None + + self._pause = False + @property def processor(self) -> NDPositionsProcessor: return self._processor @@ -418,6 +428,9 @@ def indices(self) -> tuple: @indices.setter def indices(self, indices): + if self._pause: + return + data_slice = self.processor.get(indices) if isinstance(self.graphic, (LineGraphic, ScatterGraphic)): @@ -443,8 +456,18 @@ def indices(self, indices): if self._auto_x_range: self.graphic._plot_area.x_range = xr + if self._linear_selector is not None: + with pause_events(self._linear_selector):#, event_handlers=[self._set_indices_from_selector]): + self._linear_selector.limits = xr + self._linear_selector.selection = indices[-1] + # self._set_linear_selector(x_mid, limits=xr) + self._indices = indices + # def _set_linear_selector(self, x_mid, limits): + # self._linear_selector.selection = x_mid + # self._linear_selector.limits = limits + def _tooltip_handler(self, graphic, pick_info): if isinstance(self.graphic, (LineCollection, ScatterCollection)): # get graphic within the collection diff --git a/fastplotlib/widgets/nd_widget/ndwidget.py b/fastplotlib/widgets/nd_widget/ndwidget.py index dd8610849..fd2491be7 100644 --- a/fastplotlib/widgets/nd_widget/ndwidget.py +++ b/fastplotlib/widgets/nd_widget/ndwidget.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from functools import partial import os from time import perf_counter from typing import Any, Sequence @@ -84,11 +85,18 @@ def add_nd_scatter(self, *args, **kwargs): return nd - def add_nd_timeseries(self, *args, graphic: type[LineCollection | LineStack | ImageGraphic] = LineStack, **kwargs): - nd = NDPositions(*args, graphic=graphic, multi=True, auto_x_range=True,**kwargs) + def add_nd_timeseries( + self, + *args, + graphic: type[LineCollection | LineStack | ImageGraphic] = LineStack, + **kwargs + ): + nd = NDPositions(*args, graphic=graphic, multi=True, auto_x_range=True, linear_selector=True, **kwargs) self._nd_graphics.append(nd) self._subplot.add_graphic(nd.graphic) - # TODO: think about auto-xrange for subplot camera + self._subplot.add_graphic(nd._linear_selector) + nd._linear_selector.add_event_handler(partial(self._set_indices_from_selector, nd), "selection") + return nd def add_nd_lines(self, *args, **kwargs): @@ -97,6 +105,19 @@ def add_nd_lines(self, *args, **kwargs): self._subplot.add_graphic(nd.graphic) return nd + def _set_indices_from_selector(self, skip_graphic: NDGraphic, ev): + # skip the NDPosition object which has the linear selector that triggered this event + skip_graphic._pause = True + + x = ev.info["value"] + indices_new = list(self.ndw.indices) + # linear selector for NDPositions always acts on the `p` dim + indices_new[-1] = x + self.ndw.indices = tuple(indices_new) + + # restore + skip_graphic._pause = False + # def __repr__(self): # return "NDWidget Subplot" # @@ -211,4 +232,4 @@ def __getitem__(self, key): return self._subplots[subplot] def show(self, **kwargs): - return self.figure.show(**kwargs) \ No newline at end of file + return self.figure.show(**kwargs) From 5e67318285d42de557ef2367f138dd72eb573aac Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 18 Feb 2026 02:39:58 -0500 Subject: [PATCH 033/101] return full data if display_window is Noen --- .../widgets/nd_widget/_nd_positions/_pandas.py | 14 +++++++++++--- .../widgets/nd_widget/_nd_positions/core.py | 4 ++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py index de26c8a9d..3e03b9c2d 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py @@ -79,11 +79,19 @@ def tooltip_format(self, n: int, p: int): def get(self, indices: tuple[float | int, ...]) -> np.ndarray: if not isinstance(indices, tuple): raise TypeError(".get() must receive a tuple of float | int indices") - # assume no additional slider dims, only time slider dim - self._slices = self._get_dw_slices(indices) + # TODO: LOD by using a step size according to max_p + # TODO: Also what to do if display_window is None and data + # hasn't changed when indices keeps getting set, cache? + + # assume no additional slider dims, only time slider dim + if self.display_window is not None: + self._slices = self._get_dw_slices(indices) + gdata_shape = len(self.columns), self._slices[-1].stop - self._slices[-1].start, 3 + else: + gdata_shape = len(self.columns), self.data.shape[0], 3 + self._slices = (slice(None),) - gdata_shape = len(self.columns), self._slices[-1].stop - self._slices[-1].start, 3 gdata = np.zeros(shape=gdata_shape, dtype=np.float32) for i, col in enumerate(self.columns): diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/core.py b/fastplotlib/widgets/nd_widget/_nd_positions/core.py index c763f9100..b83b4dd4c 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/core.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/core.py @@ -200,6 +200,10 @@ def _get_dw_slices(self, indices) -> tuple[slice] | tuple[slice, slice]: dw = self.display_window if dw is None: + # just return everything + return (slice(None),) + + if dw == 0: # just map p dimension at this index and return index_p = self.index_mappings[-1](indices[-1]) return (slice(index_p, index_p + 1),) From 9d7328a891b5581d7a29ad30afc1196232acb3da Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 18 Feb 2026 23:09:43 -0500 Subject: [PATCH 034/101] arrow key to step indices --- fastplotlib/widgets/nd_widget/ndwidget.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/ndwidget.py b/fastplotlib/widgets/nd_widget/ndwidget.py index fd2491be7..475987e0f 100644 --- a/fastplotlib/widgets/nd_widget/ndwidget.py +++ b/fastplotlib/widgets/nd_widget/ndwidget.py @@ -160,19 +160,35 @@ def update(self): for dim_index, (current_index, refr) in enumerate(zip(self._ndwidget.indices, self._ndwidget.ref_ranges)): if isinstance(refr, ReferenceRangeContinuous): - changed, val = imgui.slider_float( + changed, new_index = imgui.slider_float( v=current_index, v_min=refr.start, v_max=refr.stop, label=refr.unit ) + # TODO: refactor all this stuff, make fully fledged UI if changed: new_indices = list(self._ndwidget.indices) - new_indices[dim_index] = val + new_indices[dim_index] = new_index indices_changed = True + elif imgui.is_item_hovered(): + if imgui.is_key_pressed(imgui.Key.right_arrow): + new_index = current_index + refr.step + new_indices = list(self._ndwidget.indices) + new_indices[dim_index] = new_index + + indices_changed = True + + if imgui.is_key_pressed(imgui.Key.left_arrow): + new_index = current_index - refr.step + new_indices = list(self._ndwidget.indices) + new_indices[dim_index] = new_index + + indices_changed = True + if indices_changed: self._ndwidget.indices = tuple(new_indices) From 05a38ec65215514b509222d956a5a5a1c63e48ae Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Mon, 23 Feb 2026 02:53:34 -0500 Subject: [PATCH 035/101] imgui separator --- fastplotlib/ui/_base.py | 193 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 190 insertions(+), 3 deletions(-) diff --git a/fastplotlib/ui/_base.py b/fastplotlib/ui/_base.py index 9767cf76f..bc0280a4a 100644 --- a/fastplotlib/ui/_base.py +++ b/fastplotlib/ui/_base.py @@ -44,7 +44,7 @@ def __init__( location: Literal["bottom", "right"], title: str, window_flags: enum.IntFlag = imgui.WindowFlags_.no_collapse - | imgui.WindowFlags_.no_resize, + | imgui.WindowFlags_.no_resize | imgui.WindowFlags_.no_title_bar, *args, **kwargs, ): @@ -111,6 +111,15 @@ def __init__( self._title = title self._window_flags = window_flags + self._resize_cursor_set = False + self._resize_blocked = False + self._right_gui_resizing = False + + self._separator_thickness = 14.0 + + self._collapsed = False + self._old_size = self.size + self._x, self._y, self._width, self._height = self.get_rect() self._figure.canvas.add_event_handler(self._set_rect, "resize") @@ -184,25 +193,203 @@ def get_rect(self) -> tuple[int, int, int, int]: return x_pos, y_pos, width, height + def _draw_resize_handle(self): + if self._location == "bottom": + imgui.set_cursor_pos((0, 0)) + imgui.invisible_button("##resize_handle", imgui.ImVec2(imgui.get_window_width(), self._separator_thickness)) + + hovered = imgui.is_item_hovered() + active = imgui.is_item_active() + + # Get the actual screen rect of the button after it's been laid out + rect_min = imgui.get_item_rect_min() + rect_max = imgui.get_item_rect_max() + + elif self._location == "right": + imgui.set_cursor_pos((0, 0)) + screen_pos = imgui.get_cursor_screen_pos() + win_height = imgui.get_window_height() + mouse_pos = imgui.get_mouse_pos() + + rect_min = imgui.ImVec2(screen_pos.x, screen_pos.y) + rect_max = imgui.ImVec2(screen_pos.x + self._separator_thickness, screen_pos.y + win_height) + + hovered = ( + rect_min.x <= mouse_pos.x <= rect_max.x + and rect_min.y <= mouse_pos.y <= rect_max.y + ) + + if hovered and imgui.is_mouse_clicked(0): + self._right_gui_resizing = True + + if not imgui.is_mouse_down(0): + self._right_gui_resizing = False + + active = self._right_gui_resizing + + imgui.set_cursor_pos((self._separator_thickness, 0)) + + if hovered and imgui.is_mouse_double_clicked(0): + if not self._collapsed: + self._old_size = self.size + if self._location == "bottom": + self.size = int(self._separator_thickness) + elif self._location == "right": + self.size = int(self._separator_thickness) + self._collapsed = True + else: + self.size = self._old_size + self._collapsed = False + + if hovered or active: + if not self._resize_cursor_set: + if self._location == "bottom": + self._figure.canvas.set_cursor("ns_resize") + + elif self._location == "right": + self._figure.canvas.set_cursor("ew_resize") + + self._resize_cursor_set = True + imgui.set_tooltip("Drag to resize, double click to expand/collapse") + + elif self._resize_cursor_set: + self._figure.canvas.set_cursor("default") + self._resize_cursor_set = False + + if active and imgui.is_mouse_dragging(0): + if self._location == "bottom": + delta = imgui.get_mouse_drag_delta(0).y + + elif self._location == "right": + delta = imgui.get_mouse_drag_delta(0).x + + imgui.reset_mouse_drag_delta(0) + px, py, pw, ph = self._figure.get_pygfx_render_area() + + if self._location == "bottom": + new_render_size = ph + delta + elif self._location == "right": + new_render_size = pw + delta + + # check if the new size would make the pygfx render area too small + if (delta < 0) and (new_render_size < 150): + print("not enough render area") + self._resize_blocked = True + + if self._resize_blocked: + # check if cursor has returned + if self._location == "bottom": + _min, pos, _max = rect_min.y, imgui.get_mouse_pos().y, rect_max.y + + elif self._location == "right": + _min, pos, _max = rect_min.x, imgui.get_mouse_pos().x, rect_max.x + + if ((_min - 5) <= pos <= (_max + 5)) and delta > 0: + # if the mouse cursor is back on the bar and the delta > 0, i.e. render area increasing + self._resize_blocked = False + + if not self._resize_blocked: + self.size = max(30, round(self.size - delta)) + self._collapsed = False + + draw_list = imgui.get_window_draw_list() + + line_color = ( + imgui.get_color_u32(imgui.ImVec4(0.9, 0.9, 0.9, 1.0)) + if (hovered or active) + else imgui.get_color_u32(imgui.ImVec4(0.5, 0.5, 0.5, 0.8)) + ) + bg_color = ( + imgui.get_color_u32(imgui.ImVec4(0.2, 0.2, 0.2, 0.8)) + if (hovered or active) + else imgui.get_color_u32(imgui.ImVec4(0.15, 0.15, 0.15, 0.6)) + ) + + # Background bar + draw_list.add_rect_filled( + imgui.ImVec2(rect_min.x, rect_min.y), + imgui.ImVec2(rect_max.x, rect_max.y), + bg_color, + ) + + # Three grip dots centered on the line + dot_spacing = 7.0 + dot_radius = 2 + if self._location == "bottom": + mid_y = (rect_min.y + rect_max.y) * 0.5 + center_x = (rect_min.x + rect_max.x) * 0.5 + for i in (-1, 0, 1): + cx = center_x + i * dot_spacing + draw_list.add_circle_filled(imgui.ImVec2(cx, mid_y), dot_radius, line_color) + + imgui.set_cursor_pos((0, imgui.get_cursor_pos_y() - imgui.get_style().item_spacing.y)) + + elif self._location == "right": + mid_x = (rect_min.x + rect_max.x) * 0.5 + center_y = (rect_min.y + rect_max.y) * 0.5 + for i in (-1, 0, 1): + cy = center_y + i * dot_spacing + draw_list.add_circle_filled( + imgui.ImVec2(mid_x, cy), dot_radius, line_color + ) + + def _draw_title(self, title: str): + padding = imgui.ImVec2(10, 4) + text_size = imgui.calc_text_size(title) + win_width = imgui.get_window_width() + box_size = imgui.ImVec2(win_width, text_size.y + padding.y * 2) + + box_screen_pos = imgui.get_cursor_screen_pos() + + draw_list = imgui.get_window_draw_list() + + # Background — use imgui's default title bar color + draw_list.add_rect_filled( + imgui.ImVec2(box_screen_pos.x, box_screen_pos.y), + imgui.ImVec2(box_screen_pos.x + box_size.x, box_screen_pos.y + box_size.y), + imgui.get_color_u32(imgui.Col_.title_bg_active), + ) + + # Centered text + text_pos = imgui.ImVec2( + box_screen_pos.x + (win_width - text_size.x) * 0.5, + box_screen_pos.y + padding.y, + ) + draw_list.add_text( + text_pos, imgui.get_color_u32(imgui.ImVec4(1, 1, 1, 1)), title + ) + + imgui.dummy(imgui.ImVec2(win_width, box_size.y)) + def draw_window(self): """helps simplify using imgui by managing window creation & position, and pushing/popping the ID""" # window position & size x, y, w, h = self.get_rect() imgui.set_next_window_size((self.width, self.height)) imgui.set_next_window_pos((self.x, self.y)) - # imgui.set_next_window_pos((x, y)) - # imgui.set_next_window_size((w, h)) flags = self._window_flags # begin window imgui.begin(self._title, p_open=None, flags=flags) + self._draw_resize_handle() + # push ID to prevent conflict between multiple figs with same UI imgui.push_id(self._id_counter) + # collapse the UI if the separator state is collapsed + # otherwise the UI renders partially on the separator for "right" guis and it looks weird + main_height = 1.0 if self._collapsed else 0.0 + imgui.begin_child("##main_ui", imgui.ImVec2(0, main_height)) + + self._draw_title(self._title) + + imgui.indent(6.0) # draw stuff from subclass into window self.update() + imgui.end_child() + # pop ID imgui.pop_id() From 0433058bd447f25afc4a09dcb9318477d275fa00 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Mon, 23 Feb 2026 03:24:56 -0500 Subject: [PATCH 036/101] fix and ui stuff --- .../widgets/nd_widget/_nd_positions/core.py | 2 +- fastplotlib/widgets/nd_widget/ndwidget.py | 166 +++++++++++++----- 2 files changed, 126 insertions(+), 42 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/core.py b/fastplotlib/widgets/nd_widget/_nd_positions/core.py index b83b4dd4c..3a43c5a03 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/core.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/core.py @@ -417,7 +417,7 @@ def graphic( @graphic.setter def graphic(self, graphic_type): - if isinstance(self.graphic, graphic_type): + if type(self.graphic) is graphic_type: return plot_area = self._graphic._plot_area diff --git a/fastplotlib/widgets/nd_widget/ndwidget.py b/fastplotlib/widgets/nd_widget/ndwidget.py index 475987e0f..313781cd5 100644 --- a/fastplotlib/widgets/nd_widget/ndwidget.py +++ b/fastplotlib/widgets/nd_widget/ndwidget.py @@ -8,13 +8,17 @@ import numpy as np from ...layouts import ImguiFigure, Subplot -from ...graphics import ScatterCollection, LineCollection, LineStack, ImageGraphic +from ...graphics import ScatterCollection, LineCollection, LineStack, ImageGraphic, ImageVolumeGraphic from ...ui import EdgeWindow from .base import NDGraphic, NDProcessor from ._nd_image import NDImage, NDImageProcessor from ._nd_positions import NDPositions, NDPositionsProcessor +position_graphics = [ScatterCollection, LineCollection, LineStack, ImageGraphic] +image_graphics = [ImageGraphic, ImageVolumeGraphic] + + @dataclass class ReferenceRangeContinuous: start: int | float @@ -30,7 +34,9 @@ def __getitem__(self, index: int): val = self.start + (self.step * index) if not self.start <= val <= self.stop: - raise IndexError(f"index: {index} value: {val} out of bounds: [{self.start}, {self.stop}]") + raise IndexError( + f"index: {index} value: {val} out of bounds: [{self.start}, {self.stop}]" + ) return val @@ -86,16 +92,25 @@ def add_nd_scatter(self, *args, **kwargs): return nd def add_nd_timeseries( - self, - *args, - graphic: type[LineCollection | LineStack | ImageGraphic] = LineStack, - **kwargs + self, + *args, + graphic: type[LineCollection | LineStack | ImageGraphic] = LineStack, + **kwargs, ): - nd = NDPositions(*args, graphic=graphic, multi=True, auto_x_range=True, linear_selector=True, **kwargs) + nd = NDPositions( + *args, + graphic=graphic, + multi=True, + auto_x_range=True, + linear_selector=True, + **kwargs, + ) self._nd_graphics.append(nd) self._subplot.add_graphic(nd.graphic) self._subplot.add_graphic(nd._linear_selector) - nd._linear_selector.add_event_handler(partial(self._set_indices_from_selector, nd), "selection") + nd._linear_selector.add_event_handler( + partial(self._set_indices_from_selector, nd), "selection" + ) return nd @@ -127,7 +142,11 @@ def _set_indices_from_selector(self, skip_graphic: NDGraphic, ev): class NDWSliders(EdgeWindow): def __init__(self, figure, size, ndwidget): - super().__init__(figure=figure, size=size, title="NDWidget controls", location="bottom") + super().__init__( + figure=figure, size=size, title="NDWidget controls", location="bottom", + window_flags=imgui.WindowFlags_.no_collapse + | imgui.WindowFlags_.no_resize | imgui.WindowFlags_.no_title_bar + ) self._ndwidget = ndwidget # n_sliders = self._image_widget.n_sliders @@ -155,42 +174,106 @@ def __init__(self, figure, size, ndwidget): # # self.pause = False + self._selected_subplot = self._ndwidget.figure[0, 0].name + self._selected_nd_graphic = 0 + + self._max_display_windows: dict[NDGraphic, float | int] = dict() + def update(self): indices_changed = False - for dim_index, (current_index, refr) in enumerate(zip(self._ndwidget.indices, self._ndwidget.ref_ranges)): - if isinstance(refr, ReferenceRangeContinuous): - changed, new_index = imgui.slider_float( - v=current_index, - v_min=refr.start, - v_max=refr.stop, - label=refr.unit - ) + if imgui.begin_tab_bar("NDWidget Controls"): + + if imgui.begin_tab_item("Indices")[0]: + for dim_index, (current_index, refr) in enumerate( + zip(self._ndwidget.indices, self._ndwidget.ref_ranges) + ): + if isinstance(refr, ReferenceRangeContinuous): + changed, new_index = imgui.slider_float( + v=current_index, + v_min=refr.start, + v_max=refr.stop, + label=refr.unit, + ) + + # TODO: refactor all this stuff, make fully fledged UI + if changed: + new_indices = list(self._ndwidget.indices) + new_indices[dim_index] = new_index + + indices_changed = True + + elif imgui.is_item_hovered(): + if imgui.is_key_pressed(imgui.Key.right_arrow): + new_index = current_index + refr.step + new_indices = list(self._ndwidget.indices) + new_indices[dim_index] = new_index + + indices_changed = True - # TODO: refactor all this stuff, make fully fledged UI - if changed: - new_indices = list(self._ndwidget.indices) - new_indices[dim_index] = new_index + if imgui.is_key_pressed(imgui.Key.left_arrow): + new_index = current_index - refr.step + new_indices = list(self._ndwidget.indices) + new_indices[dim_index] = new_index - indices_changed = True + indices_changed = True - elif imgui.is_item_hovered(): - if imgui.is_key_pressed(imgui.Key.right_arrow): - new_index = current_index + refr.step - new_indices = list(self._ndwidget.indices) - new_indices[dim_index] = new_index + if indices_changed: + self._ndwidget.indices = tuple(new_indices) - indices_changed = True + imgui.end_tab_item() - if imgui.is_key_pressed(imgui.Key.left_arrow): - new_index = current_index - refr.step - new_indices = list(self._ndwidget.indices) - new_indices[dim_index] = new_index + if imgui.begin_tab_item("NDGraphic properties")[0]: + imgui.text("Subplots:") - indices_changed = True + self._draw_nd_graphics_props_tab() - if indices_changed: - self._ndwidget.indices = tuple(new_indices) + imgui.end_tab_item() + + imgui.end_tab_bar() + + def _draw_nd_graphics_props_tab(self): + for subplot in self._ndwidget.figure: + if imgui.tree_node(subplot.name): + self._draw_ndgraphics_node(subplot) + imgui.tree_pop() + + def _draw_ndgraphics_node(self, subplot: Subplot): + for ng in self._ndwidget[subplot].nd_graphics: + if imgui.tree_node(str(ng)): + if isinstance(ng, NDPositions): + self._draw_nd_pos_ui(subplot, ng) + imgui.tree_pop() + + def _draw_nd_pos_ui(self, subplot: Subplot, nd_graphic: NDPositions): + for i, cls in enumerate(position_graphics): + if imgui.radio_button(cls.__name__, type(nd_graphic.graphic) is cls): + nd_graphic.graphic = cls + subplot.auto_scale() + if i < len(position_graphics) - 1: + imgui.same_line() + + + if isinstance( + nd_graphic.display_window, (int, np.integer) + ): + slider = imgui.slider_int + input_ = imgui.input_int + type_ = int + else: + slider = imgui.slider_float + input_ = imgui.input_float + type_ = float + + changed, new = slider( + "display window", + v=nd_graphic.display_window, + v_min=type_(0), + v_max=type_(self._ndwidget.ref_ranges[0].stop * 0.25), + ) + + if changed: + nd_graphic.display_window = new class NDWidget: @@ -210,9 +293,9 @@ def __init__(self, ref_ranges: list[tuple], **kwargs): self._figure = ImguiFigure(**kwargs) - self._subplots: dict[Subplot, NDWSubplot] = dict() + self._subplots_nd: dict[Subplot, NDWSubplot] = dict() for subplot in self.figure: - self._subplots[subplot] = NDWSubplot(self, subplot) + self._subplots_nd[subplot] = NDWSubplot(self, subplot) # starting index for all dims self._indices = tuple(refr[0] for refr in self.ref_ranges) @@ -237,15 +320,16 @@ def indices(self) -> tuple: @indices.setter def indices(self, new_indices: tuple[Any]): - for subplot in self._subplots.values(): + for subplot in self._subplots_nd.values(): for ndg in subplot.nd_graphics: ndg.indices = new_indices self._indices = new_indices - def __getitem__(self, key): - subplot = self.figure[key] - return self._subplots[subplot] + def __getitem__(self, key: str | tuple[int, int] | Subplot): + if not isinstance(key, Subplot): + key = self.figure[key] + return self._subplots_nd[key] def show(self, **kwargs): return self.figure.show(**kwargs) From f6322eea168b50a89c5f92c19a952477d92a410b Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Mon, 23 Feb 2026 04:31:47 -0500 Subject: [PATCH 037/101] both auto x range modes working --- .../widgets/nd_widget/_nd_positions/core.py | 42 +++++++++++-- fastplotlib/widgets/nd_widget/ndwidget.py | 59 +++++++++++++------ 2 files changed, 78 insertions(+), 23 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/core.py b/fastplotlib/widgets/nd_widget/_nd_positions/core.py index 3a43c5a03..6a62a939b 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/core.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/core.py @@ -359,7 +359,7 @@ def __init__( window_sizes: tuple[int | None] | None = None, index_mappings: tuple[Callable[[Any], int] | None] | None = None, max_display_datapoints: int = 1_000, - auto_x_range: bool = False, + x_range_mode: Literal[None, "fixed-window", "view-range"] = None, linear_selector: bool = False, graphic_kwargs: dict = None, processor_kwargs: dict = None, @@ -386,10 +386,12 @@ def __init__( self._indices = tuple([0] * self._processor.n_slider_dims) - self._auto_x_range = auto_x_range - self._create_graphic(graphic) + self._x_range_mode = None + self._last_x_range = [0, 0] + self._block_auto_x = False + if linear_selector: self._linear_selector = LinearSelector(0, limits=(-np.inf, np.inf), edge_color="cyan") else: @@ -457,8 +459,9 @@ def indices(self, indices): # x range of the data xr = data_slice[0, 0, 0], data_slice[0, -1, 0] - if self._auto_x_range: + if self._x_range_mode is not None: self.graphic._plot_area.x_range = xr + self._last_x_range = xr # if the update_from_view is polling, prevents it if self._linear_selector is not None: with pause_events(self._linear_selector):#, event_handlers=[self._set_indices_from_selector]): @@ -558,3 +561,34 @@ def display_window(self) -> int | float | None: def display_window(self, dw: int | float | None): self.processor.display_window = dw self.indices = self.indices + + @property + def x_range_mode(self) -> Literal[None, "fixed-window", "view-range"]: + """x-range using a fixed window from the display window, or by polling the camera (view-range)""" + return self._x_range_mode + + @x_range_mode.setter + def x_range_mode(self, mode: Literal[None, "fixed-window", "view-range"]): + if self._x_range_mode == "view-range": + # old mode was view-range + self.graphic._plot_area.remove_animation( + self._update_from_view_range + ) + + if mode == "view-range": + self.graphic._plot_area.add_animations(self._update_from_view_range) + + self._x_range_mode = mode + + def _update_from_view_range(self): + xr = self.graphic._plot_area.x_range + if xr == self._last_x_range: + return + + self._last_x_range = self.graphic._plot_area.x_range + + self.display_window = xr[1] - xr[0] + indices = list(self.indices) + indices[-1] = (xr[0] + xr[1]) / 2 + + self.indices = indices diff --git a/fastplotlib/widgets/nd_widget/ndwidget.py b/fastplotlib/widgets/nd_widget/ndwidget.py index 313781cd5..982d120b0 100644 --- a/fastplotlib/widgets/nd_widget/ndwidget.py +++ b/fastplotlib/widgets/nd_widget/ndwidget.py @@ -40,6 +40,10 @@ def __getitem__(self, index: int): return val + @property + def range(self) -> int | float: + return self.stop - self.start + @dataclass class ReferenceRangeDiscrete: @@ -95,13 +99,14 @@ def add_nd_timeseries( self, *args, graphic: type[LineCollection | LineStack | ImageGraphic] = LineStack, + x_range_mode="fixed-window", **kwargs, ): nd = NDPositions( *args, graphic=graphic, multi=True, - auto_x_range=True, + x_range_mode=x_range_mode, linear_selector=True, **kwargs, ) @@ -112,6 +117,8 @@ def add_nd_timeseries( partial(self._set_indices_from_selector, nd), "selection" ) + nd.x_range_mode = x_range_mode + return nd def add_nd_lines(self, *args, **kwargs): @@ -122,6 +129,7 @@ def add_nd_lines(self, *args, **kwargs): def _set_indices_from_selector(self, skip_graphic: NDGraphic, ev): # skip the NDPosition object which has the linear selector that triggered this event + print("setting from selector") skip_graphic._pause = True x = ev.info["value"] @@ -253,27 +261,40 @@ def _draw_nd_pos_ui(self, subplot: Subplot, nd_graphic: NDPositions): if i < len(position_graphics) - 1: imgui.same_line() + changed, val = imgui.checkbox("use display window", nd_graphic.display_window is not None) + if changed: + if not val: + nd_graphic.display_window = None + else: + # pick a value 10% of the reference range + nd_graphic.display_window = self._ndwidget.ref_ranges[0].range * 0.1 + + if nd_graphic.display_window is not None: + if isinstance( + nd_graphic.display_window, (int, np.integer) + ): + slider = imgui.slider_int + input_ = imgui.input_int + type_ = int + else: + slider = imgui.slider_float + input_ = imgui.input_float + type_ = float + + changed, new = slider( + "display window", + v=nd_graphic.display_window, + v_min=type_(0), + v_max=type_(self._ndwidget.ref_ranges[0].stop * 0.25), + ) - if isinstance( - nd_graphic.display_window, (int, np.integer) - ): - slider = imgui.slider_int - input_ = imgui.input_int - type_ = int - else: - slider = imgui.slider_float - input_ = imgui.input_float - type_ = float - - changed, new = slider( - "display window", - v=nd_graphic.display_window, - v_min=type_(0), - v_max=type_(self._ndwidget.ref_ranges[0].stop * 0.25), - ) + if changed: + nd_graphic.display_window = new + options = [None, "fixed-window", "view-range"] + changed, option = imgui.combo("x-range mode", options.index(nd_graphic.x_range_mode), [str(o) for o in options]) if changed: - nd_graphic.display_window = new + nd_graphic.x_range_mode = options[option] class NDWidget: From a71c9318c9af529ebf51e681aa4b01568caa4211 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Mon, 23 Feb 2026 05:08:19 -0500 Subject: [PATCH 038/101] progress --- .../widgets/nd_widget/_nd_positions/core.py | 30 ++++++++++++++----- fastplotlib/widgets/nd_widget/ndwidget.py | 5 ++-- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/core.py b/fastplotlib/widgets/nd_widget/_nd_positions/core.py index 6a62a939b..77871e71b 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/core.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/core.py @@ -389,8 +389,8 @@ def __init__( self._create_graphic(graphic) self._x_range_mode = None - self._last_x_range = [0, 0] - self._block_auto_x = False + self._last_x_range = np.array([0.0, 0.0], dtype=np.float32) + self._block_update_indices = False if linear_selector: self._linear_selector = LinearSelector(0, limits=(-np.inf, np.inf), edge_color="cyan") @@ -434,9 +434,12 @@ def indices(self) -> tuple: @indices.setter def indices(self, indices): - if self._pause: + if self._block_update_indices: return + # this update must be non-reentrant + self._block_update_indices = True + data_slice = self.processor.get(indices) if isinstance(self.graphic, (LineGraphic, ScatterGraphic)): @@ -461,16 +464,19 @@ def indices(self, indices): xr = data_slice[0, 0, 0], data_slice[0, -1, 0] if self._x_range_mode is not None: self.graphic._plot_area.x_range = xr - self._last_x_range = xr # if the update_from_view is polling, prevents it + + self._last_x_range[:] = xr # if the update_from_view is polling, prevents it if self._linear_selector is not None: - with pause_events(self._linear_selector):#, event_handlers=[self._set_indices_from_selector]): + with pause_events(self._linear_selector): self._linear_selector.limits = xr self._linear_selector.selection = indices[-1] # self._set_linear_selector(x_mid, limits=xr) self._indices = indices + self._block_update_indices = False + # def _set_linear_selector(self, x_mid, limits): # self._linear_selector.selection = x_mid # self._linear_selector.limits = limits @@ -582,13 +588,21 @@ def x_range_mode(self, mode: Literal[None, "fixed-window", "view-range"]): def _update_from_view_range(self): xr = self.graphic._plot_area.x_range - if xr == self._last_x_range: + + # the floating point error near zero gets nasty here + if np.allclose(xr, self._last_x_range, atol=1e-14): return - self._last_x_range = self.graphic._plot_area.x_range + self._last_x_range[:] = xr self.display_window = xr[1] - xr[0] + new_index = (xr[0] + xr[1]) / 2 + indices = list(self.indices) - indices[-1] = (xr[0] + xr[1]) / 2 + if indices[-1] == new_index: + return + + indices[-1] = new_index self.indices = indices + diff --git a/fastplotlib/widgets/nd_widget/ndwidget.py b/fastplotlib/widgets/nd_widget/ndwidget.py index 982d120b0..dd99a7c72 100644 --- a/fastplotlib/widgets/nd_widget/ndwidget.py +++ b/fastplotlib/widgets/nd_widget/ndwidget.py @@ -129,8 +129,7 @@ def add_nd_lines(self, *args, **kwargs): def _set_indices_from_selector(self, skip_graphic: NDGraphic, ev): # skip the NDPosition object which has the linear selector that triggered this event - print("setting from selector") - skip_graphic._pause = True + skip_graphic._block_update_indices = True x = ev.info["value"] indices_new = list(self.ndw.indices) @@ -139,7 +138,7 @@ def _set_indices_from_selector(self, skip_graphic: NDGraphic, ev): self.ndw.indices = tuple(indices_new) # restore - skip_graphic._pause = False + skip_graphic._block_update_indices = False # def __repr__(self): # return "NDWidget Subplot" From a2529cc83be2a56036071029b1aa5f187d6fa466 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 24 Feb 2026 01:58:35 -0500 Subject: [PATCH 039/101] moving stuff --- fastplotlib/widgets/nd_widget/_index.py | 113 +++++++ fastplotlib/widgets/nd_widget/_ndw_subplot.py | 95 ++++++ fastplotlib/widgets/nd_widget/_ui.py | 161 ++++++++++ fastplotlib/widgets/nd_widget/ndwidget.py | 301 +----------------- 4 files changed, 375 insertions(+), 295 deletions(-) create mode 100644 fastplotlib/widgets/nd_widget/_index.py create mode 100644 fastplotlib/widgets/nd_widget/_ndw_subplot.py create mode 100644 fastplotlib/widgets/nd_widget/_ui.py diff --git a/fastplotlib/widgets/nd_widget/_index.py b/fastplotlib/widgets/nd_widget/_index.py new file mode 100644 index 000000000..dc3bdfec5 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_index.py @@ -0,0 +1,113 @@ +from dataclasses import dataclass +from typing import Sequence, Any + + +@dataclass +class ReferenceRangeContinuous: + start: int | float + stop: int | float + step: int | float + unit: str + + def __getitem__(self, index: int): + """return the value at the index w.r.t. the step size""" + # if index is negative, turn to positive index + if index < 0: + raise ValueError("negative indexing not supported") + + val = self.start + (self.step * index) + if not self.start <= val <= self.stop: + raise IndexError( + f"index: {index} value: {val} out of bounds: [{self.start}, {self.stop}]" + ) + + return val + + @property + def range(self) -> int | float: + return self.stop - self.start + + +@dataclass +class ReferenceRangeDiscrete: + options: Sequence[Any] + unit: str + + def __getitem__(self, index: int): + if index > len(self.options): + raise IndexError + + return self.options[index] + + def __len__(self): + return len(self.options) + + +class GlobalIndexVector: + def __init__(self): + self._ndgraphics = list() + self._index = list() + self._ref_ranges = list() + + @property + def ndgraphics(self): + return tuple(self._ndgraphics) + + @property + def index(self) -> tuple[Any]: + # TODO: clamp index to given range here + # graphics will clamp according to their own array sizes? + pass + + @property + def dims(self) -> tuple[str]: + return tuple(ref.unit for ref in self.ref_ranges) + + @property + def ref_ranges(self) -> tuple[ReferenceRangeContinuous]: + pass + + def __getitem__(self, item): + if isinstance(item, int): + # integer index in the ordered dict + return self.ref_ranges[item] + + for rr in self.ref_ranges: + if rr.unit == item: + return rr + + raise KeyError + + def __setitem__(self, key, value): + # TODO: set the index for the given dimension only + if isinstance(key, str): + for i, rr in enumerate(self.ref_ranges): + if rr.unit == key: + key = i + break + else: + raise KeyError + + index = list(self.index) + + # set index for given dim + index[key] = value + + def __repr__(self): + return "\n".join([f"{d}: {i}" for d, i in zip(self.dims, self.index)]) + + +class SelectionVector: + @property + def selection(self): + pass + + @property + def graphics(self): + pass + + def add_graphic(self): + pass + + def remove_graphic(self): + pass diff --git a/fastplotlib/widgets/nd_widget/_ndw_subplot.py b/fastplotlib/widgets/nd_widget/_ndw_subplot.py new file mode 100644 index 000000000..f28c88a50 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_ndw_subplot.py @@ -0,0 +1,95 @@ +from functools import partial + +import numpy as np + +from ... import ScatterCollection, LineCollection, LineStack, ImageGraphic +from ...layouts import Subplot +from . import NDImage, NDPositions +from .base import NDGraphic + + +class NDWSubplot: + def __init__(self, ndw, subplot: Subplot): + self.ndw = ndw + self._subplot = subplot + + self._nd_graphics = list() + + @property + def nd_graphics(self) -> list[NDGraphic]: + return self._nd_graphics + + def __getitem__(self, key): + if isinstance(key, (int, np.integer)): + return self.nd_graphics[key] + + for g in self.nd_graphics: + if g.name == key: + return g + + else: + raise KeyError(f"NDGraphc with given key not found: {key}") + + def add_nd_image(self, *args, **kwargs): + nd = NDImage(*args, **kwargs) + self._nd_graphics.append(nd) + self._subplot.add_graphic(nd.graphic) + return nd + + def add_nd_scatter(self, *args, **kwargs): + nd = NDPositions(*args, graphic=ScatterCollection, multi=True, **kwargs) + self._nd_graphics.append(nd) + self._subplot.add_graphic(nd.graphic) + + return nd + + def add_nd_timeseries( + self, + *args, + graphic: type[LineCollection | LineStack | ImageGraphic] = LineStack, + x_range_mode="fixed-window", + **kwargs, + ): + nd = NDPositions( + *args, + graphic=graphic, + multi=True, + x_range_mode=x_range_mode, + linear_selector=True, + **kwargs, + ) + self._nd_graphics.append(nd) + self._subplot.add_graphic(nd.graphic) + self._subplot.add_graphic(nd._linear_selector) + nd._linear_selector.add_event_handler( + partial(self._set_indices_from_selector, nd), "selection" + ) + + nd.x_range_mode = x_range_mode + + return nd + + def add_nd_lines(self, *args, **kwargs): + nd = NDPositions(*args, graphic=LineCollection, multi=True, **kwargs) + self._nd_graphics.append(nd) + self._subplot.add_graphic(nd.graphic) + return nd + + def _set_indices_from_selector(self, skip_graphic: NDGraphic, ev): + # skip the NDPosition object which has the linear selector that triggered this event + skip_graphic._block_update_indices = True + + x = ev.info["value"] + indices_new = list(self.ndw.indices) + # linear selector for NDPositions always acts on the `p` dim + indices_new[-1] = x + self.ndw.indices = tuple(indices_new) + + # restore + skip_graphic._block_update_indices = False + + # def __repr__(self): + # return "NDWidget Subplot" + # + # def __str__(self): + # return "NDWidget Subplot" diff --git a/fastplotlib/widgets/nd_widget/_ui.py b/fastplotlib/widgets/nd_widget/_ui.py new file mode 100644 index 000000000..a9777e87f --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_ui.py @@ -0,0 +1,161 @@ +import numpy as np +from imgui_bundle import imgui + +from ...graphics import ScatterCollection, LineCollection, LineStack, ImageGraphic, ImageVolumeGraphic +from ...layouts import Subplot +from ...ui import EdgeWindow +from . import NDPositions +from ._index import ReferenceRangeContinuous +from .base import NDGraphic + +position_graphics = [ScatterCollection, LineCollection, LineStack, ImageGraphic] +image_graphics = [ImageGraphic, ImageVolumeGraphic] + + +class NDWidgetUI(EdgeWindow): + def __init__(self, figure, size, ndwidget): + super().__init__( + figure=figure, size=size, title="NDWidget controls", location="bottom", + window_flags=imgui.WindowFlags_.no_collapse + | imgui.WindowFlags_.no_resize | imgui.WindowFlags_.no_title_bar + ) + self._ndwidget = ndwidget + + # n_sliders = self._image_widget.n_sliders + # + # # whether or not a dimension is in play mode + # self._playing: list[bool] = [False] * n_sliders + # + # # approximate framerate for playing + # self._fps: list[int] = [20] * n_sliders + # + # # framerate converted to frame time + # self._frame_time: list[float] = [1 / 20] * n_sliders + # + # # last timepoint that a frame was displayed from a given dimension + # self._last_frame_time: list[float] = [perf_counter()] * n_sliders + # + # # loop playback + # self._loop = False + # + # # auto-plays the ImageWidget's left-most dimension in docs galleries + # if "DOCS_BUILD" in os.environ.keys(): + # if os.environ["DOCS_BUILD"] == "1": + # self._playing[0] = True + # self._loop = True + # + # self.pause = False + + self._selected_subplot = self._ndwidget.figure[0, 0].name + self._selected_nd_graphic = 0 + + self._max_display_windows: dict[NDGraphic, float | int] = dict() + + def update(self): + indices_changed = False + + if imgui.begin_tab_bar("NDWidget Controls"): + + if imgui.begin_tab_item("Indices")[0]: + for dim_index, (current_index, refr) in enumerate( + zip(self._ndwidget.indices, self._ndwidget.ref_ranges) + ): + if isinstance(refr, ReferenceRangeContinuous): + changed, new_index = imgui.slider_float( + v=current_index, + v_min=refr.start, + v_max=refr.stop, + label=refr.unit, + ) + + # TODO: refactor all this stuff, make fully fledged UI + if changed: + new_indices = list(self._ndwidget.indices) + new_indices[dim_index] = new_index + + indices_changed = True + + elif imgui.is_item_hovered(): + if imgui.is_key_pressed(imgui.Key.right_arrow): + new_index = current_index + refr.step + new_indices = list(self._ndwidget.indices) + new_indices[dim_index] = new_index + + indices_changed = True + + if imgui.is_key_pressed(imgui.Key.left_arrow): + new_index = current_index - refr.step + new_indices = list(self._ndwidget.indices) + new_indices[dim_index] = new_index + + indices_changed = True + + if indices_changed: + self._ndwidget.indices = tuple(new_indices) + + imgui.end_tab_item() + + if imgui.begin_tab_item("NDGraphic properties")[0]: + imgui.text("Subplots:") + + self._draw_nd_graphics_props_tab() + + imgui.end_tab_item() + + imgui.end_tab_bar() + + def _draw_nd_graphics_props_tab(self): + for subplot in self._ndwidget.figure: + if imgui.tree_node(subplot.name): + self._draw_ndgraphics_node(subplot) + imgui.tree_pop() + + def _draw_ndgraphics_node(self, subplot: Subplot): + for ng in self._ndwidget[subplot].nd_graphics: + if imgui.tree_node(str(ng)): + if isinstance(ng, NDPositions): + self._draw_nd_pos_ui(subplot, ng) + imgui.tree_pop() + + def _draw_nd_pos_ui(self, subplot: Subplot, nd_graphic: NDPositions): + for i, cls in enumerate(position_graphics): + if imgui.radio_button(cls.__name__, type(nd_graphic.graphic) is cls): + nd_graphic.graphic = cls + subplot.auto_scale() + if i < len(position_graphics) - 1: + imgui.same_line() + + changed, val = imgui.checkbox("use display window", nd_graphic.display_window is not None) + if changed: + if not val: + nd_graphic.display_window = None + else: + # pick a value 10% of the reference range + nd_graphic.display_window = self._ndwidget.ref_ranges[0].range * 0.1 + + if nd_graphic.display_window is not None: + if isinstance( + nd_graphic.display_window, (int, np.integer) + ): + slider = imgui.slider_int + input_ = imgui.input_int + type_ = int + else: + slider = imgui.slider_float + input_ = imgui.input_float + type_ = float + + changed, new = slider( + "display window", + v=nd_graphic.display_window, + v_min=type_(0), + v_max=type_(self._ndwidget.ref_ranges[0].stop * 0.25), + ) + + if changed: + nd_graphic.display_window = new + + options = [None, "fixed-window", "view-range"] + changed, option = imgui.combo("x-range mode", options.index(nd_graphic.x_range_mode), [str(o) for o in options]) + if changed: + nd_graphic.x_range_mode = options[option] diff --git a/fastplotlib/widgets/nd_widget/ndwidget.py b/fastplotlib/widgets/nd_widget/ndwidget.py index dd99a7c72..427abba4e 100644 --- a/fastplotlib/widgets/nd_widget/ndwidget.py +++ b/fastplotlib/widgets/nd_widget/ndwidget.py @@ -1,303 +1,14 @@ -from dataclasses import dataclass -from functools import partial -import os -from time import perf_counter -from typing import Any, Sequence - -from imgui_bundle import imgui, icons_fontawesome_6 as fa -import numpy as np +from typing import Any +from ._index import ReferenceRangeContinuous, ReferenceRangeDiscrete +from ._ndw_subplot import NDWSubplot +from ._ui import NDWidgetUI from ...layouts import ImguiFigure, Subplot -from ...graphics import ScatterCollection, LineCollection, LineStack, ImageGraphic, ImageVolumeGraphic -from ...ui import EdgeWindow -from .base import NDGraphic, NDProcessor -from ._nd_image import NDImage, NDImageProcessor -from ._nd_positions import NDPositions, NDPositionsProcessor - - -position_graphics = [ScatterCollection, LineCollection, LineStack, ImageGraphic] -image_graphics = [ImageGraphic, ImageVolumeGraphic] - - -@dataclass -class ReferenceRangeContinuous: - start: int | float - stop: int | float - step: int | float - unit: str - - def __getitem__(self, index: int): - """return the value at the index w.r.t. the step size""" - # if index is negative, turn to positive index - if index < 0: - raise ValueError("negative indexing not supported") - - val = self.start + (self.step * index) - if not self.start <= val <= self.stop: - raise IndexError( - f"index: {index} value: {val} out of bounds: [{self.start}, {self.stop}]" - ) - - return val - - @property - def range(self) -> int | float: - return self.stop - self.start - - -@dataclass -class ReferenceRangeDiscrete: - options: Sequence[Any] - unit: str - - def __getitem__(self, index: int): - if index > len(self.options): - raise IndexError - - return self.options[index] - - def __len__(self): - return len(self.options) - - -class NDWSubplot: - def __init__(self, ndw, subplot: Subplot): - self.ndw = ndw - self._subplot = subplot - - self._nd_graphics = list() - - @property - def nd_graphics(self) -> list[NDGraphic]: - return self._nd_graphics - - def __getitem__(self, key): - if isinstance(key, (int, np.integer)): - return self.nd_graphics[key] - - for g in self.nd_graphics: - if g.name == key: - return g - - else: - raise KeyError(f"NDGraphc with given key not found: {key}") - - def add_nd_image(self, *args, **kwargs): - nd = NDImage(*args, **kwargs) - self._nd_graphics.append(nd) - self._subplot.add_graphic(nd.graphic) - return nd - - def add_nd_scatter(self, *args, **kwargs): - nd = NDPositions(*args, graphic=ScatterCollection, multi=True, **kwargs) - self._nd_graphics.append(nd) - self._subplot.add_graphic(nd.graphic) - - return nd - - def add_nd_timeseries( - self, - *args, - graphic: type[LineCollection | LineStack | ImageGraphic] = LineStack, - x_range_mode="fixed-window", - **kwargs, - ): - nd = NDPositions( - *args, - graphic=graphic, - multi=True, - x_range_mode=x_range_mode, - linear_selector=True, - **kwargs, - ) - self._nd_graphics.append(nd) - self._subplot.add_graphic(nd.graphic) - self._subplot.add_graphic(nd._linear_selector) - nd._linear_selector.add_event_handler( - partial(self._set_indices_from_selector, nd), "selection" - ) - - nd.x_range_mode = x_range_mode - - return nd - - def add_nd_lines(self, *args, **kwargs): - nd = NDPositions(*args, graphic=LineCollection, multi=True, **kwargs) - self._nd_graphics.append(nd) - self._subplot.add_graphic(nd.graphic) - return nd - - def _set_indices_from_selector(self, skip_graphic: NDGraphic, ev): - # skip the NDPosition object which has the linear selector that triggered this event - skip_graphic._block_update_indices = True - - x = ev.info["value"] - indices_new = list(self.ndw.indices) - # linear selector for NDPositions always acts on the `p` dim - indices_new[-1] = x - self.ndw.indices = tuple(indices_new) - - # restore - skip_graphic._block_update_indices = False - - # def __repr__(self): - # return "NDWidget Subplot" - # - # def __str__(self): - # return "NDWidget Subplot" - - -class NDWSliders(EdgeWindow): - def __init__(self, figure, size, ndwidget): - super().__init__( - figure=figure, size=size, title="NDWidget controls", location="bottom", - window_flags=imgui.WindowFlags_.no_collapse - | imgui.WindowFlags_.no_resize | imgui.WindowFlags_.no_title_bar - ) - self._ndwidget = ndwidget - - # n_sliders = self._image_widget.n_sliders - # - # # whether or not a dimension is in play mode - # self._playing: list[bool] = [False] * n_sliders - # - # # approximate framerate for playing - # self._fps: list[int] = [20] * n_sliders - # - # # framerate converted to frame time - # self._frame_time: list[float] = [1 / 20] * n_sliders - # - # # last timepoint that a frame was displayed from a given dimension - # self._last_frame_time: list[float] = [perf_counter()] * n_sliders - # - # # loop playback - # self._loop = False - # - # # auto-plays the ImageWidget's left-most dimension in docs galleries - # if "DOCS_BUILD" in os.environ.keys(): - # if os.environ["DOCS_BUILD"] == "1": - # self._playing[0] = True - # self._loop = True - # - # self.pause = False - - self._selected_subplot = self._ndwidget.figure[0, 0].name - self._selected_nd_graphic = 0 - - self._max_display_windows: dict[NDGraphic, float | int] = dict() - - def update(self): - indices_changed = False - - if imgui.begin_tab_bar("NDWidget Controls"): - - if imgui.begin_tab_item("Indices")[0]: - for dim_index, (current_index, refr) in enumerate( - zip(self._ndwidget.indices, self._ndwidget.ref_ranges) - ): - if isinstance(refr, ReferenceRangeContinuous): - changed, new_index = imgui.slider_float( - v=current_index, - v_min=refr.start, - v_max=refr.stop, - label=refr.unit, - ) - - # TODO: refactor all this stuff, make fully fledged UI - if changed: - new_indices = list(self._ndwidget.indices) - new_indices[dim_index] = new_index - - indices_changed = True - - elif imgui.is_item_hovered(): - if imgui.is_key_pressed(imgui.Key.right_arrow): - new_index = current_index + refr.step - new_indices = list(self._ndwidget.indices) - new_indices[dim_index] = new_index - - indices_changed = True - - if imgui.is_key_pressed(imgui.Key.left_arrow): - new_index = current_index - refr.step - new_indices = list(self._ndwidget.indices) - new_indices[dim_index] = new_index - - indices_changed = True - - if indices_changed: - self._ndwidget.indices = tuple(new_indices) - - imgui.end_tab_item() - - if imgui.begin_tab_item("NDGraphic properties")[0]: - imgui.text("Subplots:") - - self._draw_nd_graphics_props_tab() - - imgui.end_tab_item() - - imgui.end_tab_bar() - - def _draw_nd_graphics_props_tab(self): - for subplot in self._ndwidget.figure: - if imgui.tree_node(subplot.name): - self._draw_ndgraphics_node(subplot) - imgui.tree_pop() - - def _draw_ndgraphics_node(self, subplot: Subplot): - for ng in self._ndwidget[subplot].nd_graphics: - if imgui.tree_node(str(ng)): - if isinstance(ng, NDPositions): - self._draw_nd_pos_ui(subplot, ng) - imgui.tree_pop() - - def _draw_nd_pos_ui(self, subplot: Subplot, nd_graphic: NDPositions): - for i, cls in enumerate(position_graphics): - if imgui.radio_button(cls.__name__, type(nd_graphic.graphic) is cls): - nd_graphic.graphic = cls - subplot.auto_scale() - if i < len(position_graphics) - 1: - imgui.same_line() - - changed, val = imgui.checkbox("use display window", nd_graphic.display_window is not None) - if changed: - if not val: - nd_graphic.display_window = None - else: - # pick a value 10% of the reference range - nd_graphic.display_window = self._ndwidget.ref_ranges[0].range * 0.1 - - if nd_graphic.display_window is not None: - if isinstance( - nd_graphic.display_window, (int, np.integer) - ): - slider = imgui.slider_int - input_ = imgui.input_int - type_ = int - else: - slider = imgui.slider_float - input_ = imgui.input_float - type_ = float - - changed, new = slider( - "display window", - v=nd_graphic.display_window, - v_min=type_(0), - v_max=type_(self._ndwidget.ref_ranges[0].stop * 0.25), - ) - - if changed: - nd_graphic.display_window = new - - options = [None, "fixed-window", "view-range"] - changed, option = imgui.combo("x-range mode", options.index(nd_graphic.x_range_mode), [str(o) for o in options]) - if changed: - nd_graphic.x_range_mode = options[option] class NDWidget: def __init__(self, ref_ranges: list[tuple], **kwargs): + # TODO: this should maybe be an ordered dict?? self._ref_ranges = list() for r in ref_ranges: @@ -323,7 +34,7 @@ def __init__(self, ref_ranges: list[tuple], **kwargs): # hard code the expected height so that the first render looks right in tests, docs etc. ui_size = 57 + (50 * len(self.indices)) - self._sliders_ui = NDWSliders(self.figure, ui_size, self) + self._sliders_ui = NDWidgetUI(self.figure, ui_size, self) self.figure.add_gui(self._sliders_ui) @property From d53cb8f08479b5f4aba0739bf3c70ad28d74352d Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 24 Feb 2026 02:57:44 -0500 Subject: [PATCH 040/101] much much better organization of things --- fastplotlib/widgets/nd_widget/_index.py | 76 ++++++++++++++----- .../widgets/nd_widget/_nd_positions/core.py | 49 ++++++------ fastplotlib/widgets/nd_widget/_ndw_subplot.py | 30 +++----- fastplotlib/widgets/nd_widget/_ui.py | 30 ++++++-- fastplotlib/widgets/nd_widget/base.py | 64 +++++++++++++++- fastplotlib/widgets/nd_widget/ndwidget.py | 41 ++++------ 6 files changed, 192 insertions(+), 98 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_index.py b/fastplotlib/widgets/nd_widget/_index.py index dc3bdfec5..4cf1e0bd6 100644 --- a/fastplotlib/widgets/nd_widget/_index.py +++ b/fastplotlib/widgets/nd_widget/_index.py @@ -1,5 +1,7 @@ from dataclasses import dataclass -from typing import Sequence, Any +from typing import Sequence, Any, Callable + +from .base import NDGraphic @dataclass @@ -44,20 +46,39 @@ def __len__(self): class GlobalIndexVector: - def __init__(self): - self._ndgraphics = list() - self._index = list() + def __init__(self, ref_ranges: list, get_ndgraphics: Callable): self._ref_ranges = list() - @property - def ndgraphics(self): - return tuple(self._ndgraphics) + for r in ref_ranges: + if len(r) == 4: + # assume start, stop, step, unit + refr = ReferenceRangeContinuous(*r) + elif len(r) == 2: + refr = ReferenceRangeDiscrete(*r) + else: + raise ValueError + + self._ref_ranges.append(refr) + + self._get_ndgraphics = get_ndgraphics + + # starting index for all dims + self._indices = [refr[0] for refr in self.ref_ranges] @property - def index(self) -> tuple[Any]: - # TODO: clamp index to given range here + def indices(self) -> tuple[Any]: + # TODO: clamp index to given ref range here # graphics will clamp according to their own array sizes? - pass + return tuple(self._indices) + + @indices.setter + def indices(self, new_indices: tuple[Any]): + self._indices[:] = new_indices + self._render_indices() + + def _render_indices(self): + for g in self._get_ndgraphics(): + g.indices = self.indices @property def dims(self) -> tuple[str]: @@ -65,16 +86,16 @@ def dims(self) -> tuple[str]: @property def ref_ranges(self) -> tuple[ReferenceRangeContinuous]: - pass + return tuple(self._ref_ranges) def __getitem__(self, item): if isinstance(item, int): - # integer index in the ordered dict - return self.ref_ranges[item] + # integer index in the list + return self._indices[item] - for rr in self.ref_ranges: + for i, rr in enumerate(self.ref_ranges): if rr.unit == item: - return rr + return self._indices[i] raise KeyError @@ -88,13 +109,30 @@ def __setitem__(self, key, value): else: raise KeyError - index = list(self.index) - # set index for given dim - index[key] = value + self._indices[key] = value + self._render_indices() + + def pop_dim(self): + pass + + def push_dim(self, ref_range: ReferenceRangeContinuous): + # TODO: implement pushing and popping dims + pass + + def __iter__(self): + for index in self.indices: + yield index + + def __len__(self): + return len(self._indices) + + def __eq__(self, other): + return self._indices == other def __repr__(self): - return "\n".join([f"{d}: {i}" for d, i in zip(self.dims, self.index)]) + named = ", ".join([f"{d}: {i}" for d, i in zip(self.dims, self.index)]) + return f"Indices: {named}" class SelectionVector: diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/core.py b/fastplotlib/widgets/nd_widget/_nd_positions/core.py index 77871e71b..59a240687 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/core.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/core.py @@ -18,7 +18,8 @@ ) from ....graphics.utils import pause_events from ....graphics.selectors import LinearSelector -from ..base import NDProcessor, NDGraphic, WindowFuncCallable +from ..base import NDProcessor, NDGraphic, WindowFuncCallable, block_reentrance, block_indices +from .._index import GlobalIndexVector # TODO: Maybe get rid of n_display_dims in NDProcessor, @@ -339,9 +340,10 @@ def get(self, indices: tuple[Any, ...]): ] -class NDPositions: +class NDPositions(NDGraphic): def __init__( self, + global_index: GlobalIndexVector, data: Any, *args, graphic: Type[ @@ -359,8 +361,8 @@ def __init__( window_sizes: tuple[int | None] | None = None, index_mappings: tuple[Callable[[Any], int] | None] | None = None, max_display_datapoints: int = 1_000, - x_range_mode: Literal[None, "fixed-window", "view-range"] = None, linear_selector: bool = False, + name: str = None, graphic_kwargs: dict = None, processor_kwargs: dict = None, ): @@ -390,15 +392,19 @@ def __init__( self._x_range_mode = None self._last_x_range = np.array([0.0, 0.0], dtype=np.float32) - self._block_update_indices = False if linear_selector: self._linear_selector = LinearSelector(0, limits=(-np.inf, np.inf), edge_color="cyan") + self._linear_selector.add_event_handler(self._linear_selector_handler, "selection") else: self._linear_selector = None self._pause = False + self._global_index = global_index + + super().__init__(name) + @property def processor(self) -> NDPositionsProcessor: return self._processor @@ -433,13 +439,8 @@ def indices(self) -> tuple: return self._indices @indices.setter + @block_reentrance def indices(self, indices): - if self._block_update_indices: - return - - # this update must be non-reentrant - self._block_update_indices = True - data_slice = self.processor.get(indices) if isinstance(self.graphic, (LineGraphic, ScatterGraphic)): @@ -465,7 +466,7 @@ def indices(self, indices): if self._x_range_mode is not None: self.graphic._plot_area.x_range = xr - self._last_x_range[:] = xr # if the update_from_view is polling, prevents it + self._last_x_range[:] = xr # if the update_from_view is polling, prevents it from being called if self._linear_selector is not None: with pause_events(self._linear_selector): @@ -475,11 +476,10 @@ def indices(self, indices): self._indices = indices - self._block_update_indices = False - - # def _set_linear_selector(self, x_mid, limits): - # self._linear_selector.selection = x_mid - # self._linear_selector.limits = limits + def _linear_selector_handler(self, ev): + with block_indices(self): + # linear selector always acts on the `p` dim + self._global_index[-1] = ev.info["value"] def _tooltip_handler(self, graphic, pick_info): if isinstance(self.graphic, (LineCollection, ScatterCollection)): @@ -598,11 +598,14 @@ def _update_from_view_range(self): self.display_window = xr[1] - xr[0] new_index = (xr[0] + xr[1]) / 2 - indices = list(self.indices) - if indices[-1] == new_index: - return - - indices[-1] = new_index - - self.indices = indices + # set the `p` dim on the global index vector + self._global_index[-1] = new_index + # indices = list(self.indices) + # if indices[-1] == new_index: + # return + # + # indices[-1] = new_index + # + # self.indices = indices + # diff --git a/fastplotlib/widgets/nd_widget/_ndw_subplot.py b/fastplotlib/widgets/nd_widget/_ndw_subplot.py index f28c88a50..39975741f 100644 --- a/fastplotlib/widgets/nd_widget/_ndw_subplot.py +++ b/fastplotlib/widgets/nd_widget/_ndw_subplot.py @@ -16,8 +16,8 @@ def __init__(self, ndw, subplot: Subplot): self._nd_graphics = list() @property - def nd_graphics(self) -> list[NDGraphic]: - return self._nd_graphics + def nd_graphics(self) -> tuple[NDGraphic]: + return tuple(self._nd_graphics) def __getitem__(self, key): if isinstance(key, (int, np.integer)): @@ -37,7 +37,9 @@ def add_nd_image(self, *args, **kwargs): return nd def add_nd_scatter(self, *args, **kwargs): - nd = NDPositions(*args, graphic=ScatterCollection, multi=True, **kwargs) + nd = NDPositions( + self.ndw.indices, *args, graphic=ScatterCollection, multi=True, **kwargs + ) self._nd_graphics.append(nd) self._subplot.add_graphic(nd.graphic) @@ -51,19 +53,20 @@ def add_nd_timeseries( **kwargs, ): nd = NDPositions( + self.ndw.indices, *args, graphic=graphic, multi=True, - x_range_mode=x_range_mode, + # x_range_mode=x_range_mode, linear_selector=True, **kwargs, ) self._nd_graphics.append(nd) self._subplot.add_graphic(nd.graphic) self._subplot.add_graphic(nd._linear_selector) - nd._linear_selector.add_event_handler( - partial(self._set_indices_from_selector, nd), "selection" - ) + # nd._linear_selector.add_event_handler( + # partial(self._set_indices_from_selector, nd), "selection" + # ) nd.x_range_mode = x_range_mode @@ -75,19 +78,6 @@ def add_nd_lines(self, *args, **kwargs): self._subplot.add_graphic(nd.graphic) return nd - def _set_indices_from_selector(self, skip_graphic: NDGraphic, ev): - # skip the NDPosition object which has the linear selector that triggered this event - skip_graphic._block_update_indices = True - - x = ev.info["value"] - indices_new = list(self.ndw.indices) - # linear selector for NDPositions always acts on the `p` dim - indices_new[-1] = x - self.ndw.indices = tuple(indices_new) - - # restore - skip_graphic._block_update_indices = False - # def __repr__(self): # return "NDWidget Subplot" # diff --git a/fastplotlib/widgets/nd_widget/_ui.py b/fastplotlib/widgets/nd_widget/_ui.py index a9777e87f..a2198d6c9 100644 --- a/fastplotlib/widgets/nd_widget/_ui.py +++ b/fastplotlib/widgets/nd_widget/_ui.py @@ -1,7 +1,13 @@ import numpy as np from imgui_bundle import imgui -from ...graphics import ScatterCollection, LineCollection, LineStack, ImageGraphic, ImageVolumeGraphic +from ...graphics import ( + ScatterCollection, + LineCollection, + LineStack, + ImageGraphic, + ImageVolumeGraphic, +) from ...layouts import Subplot from ...ui import EdgeWindow from . import NDPositions @@ -15,9 +21,13 @@ class NDWidgetUI(EdgeWindow): def __init__(self, figure, size, ndwidget): super().__init__( - figure=figure, size=size, title="NDWidget controls", location="bottom", + figure=figure, + size=size, + title="NDWidget controls", + location="bottom", window_flags=imgui.WindowFlags_.no_collapse - | imgui.WindowFlags_.no_resize | imgui.WindowFlags_.no_title_bar + | imgui.WindowFlags_.no_resize + | imgui.WindowFlags_.no_title_bar, ) self._ndwidget = ndwidget @@ -125,7 +135,9 @@ def _draw_nd_pos_ui(self, subplot: Subplot, nd_graphic: NDPositions): if i < len(position_graphics) - 1: imgui.same_line() - changed, val = imgui.checkbox("use display window", nd_graphic.display_window is not None) + changed, val = imgui.checkbox( + "use display window", nd_graphic.display_window is not None + ) if changed: if not val: nd_graphic.display_window = None @@ -134,9 +146,7 @@ def _draw_nd_pos_ui(self, subplot: Subplot, nd_graphic: NDPositions): nd_graphic.display_window = self._ndwidget.ref_ranges[0].range * 0.1 if nd_graphic.display_window is not None: - if isinstance( - nd_graphic.display_window, (int, np.integer) - ): + if isinstance(nd_graphic.display_window, (int, np.integer)): slider = imgui.slider_int input_ = imgui.input_int type_ = int @@ -156,6 +166,10 @@ def _draw_nd_pos_ui(self, subplot: Subplot, nd_graphic: NDPositions): nd_graphic.display_window = new options = [None, "fixed-window", "view-range"] - changed, option = imgui.combo("x-range mode", options.index(nd_graphic.x_range_mode), [str(o) for o in options]) + changed, option = imgui.combo( + "x-range mode", + options.index(nd_graphic.x_range_mode), + [str(o) for o in options], + ) if changed: nd_graphic.x_range_mode = options[option] diff --git a/fastplotlib/widgets/nd_widget/base.py b/fastplotlib/widgets/nd_widget/base.py index e46386e93..b78dbcbbb 100644 --- a/fastplotlib/widgets/nd_widget/base.py +++ b/fastplotlib/widgets/nd_widget/base.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager import inspect from typing import Literal, Callable, Any from warnings import warn @@ -252,9 +253,40 @@ def __getitem__(self, item: tuple[Any, ...]) -> ArrayProtocol: pass +def block_reentrance(setter): + # decorator to block re-entrant indices setter + def set_indices_wrapper(self: NDGraphic, new_indices): + """ + wraps NDGraphic.indices + + self: NDGraphic instance + + new_indices: new indices to set + """ + # set_value is already in the middle of an execution, block re-entrance + if self._block_indices: + return + try: + # block re-execution of set_value until it has *fully* finished executing + self._block_indices = True + setter(self, new_indices) + except Exception as exc: + # raise original exception + raise exc # set_value has raised. The line above and the lines 2+ steps below are probably more relevant! + finally: + # set_value has finished executing, now allow future executions + self._block_indices = False + + return set_indices_wrapper + + class NDGraphic: + def __init__(self, name: str | None): + self._name = name + self._block_indices = False + @property - def name(self) -> str: + def name(self) -> str | None: return self._name @property @@ -272,3 +304,33 @@ def indices(self) -> tuple[Any]: @indices.setter def indices(self, new: tuple): raise NotImplementedError + + +@contextmanager +def block_indices(ndgraphic: NDGraphic): + """ + Context manager for pausing Graphic events. + + Optionally pass in only specific event handlers which are blocked. Other events for the graphic will not be blocked. + + Examples + -------- + + .. code-block:: + + # pass in any number of graphics + with fpl.pause_events(graphic1, graphic2, graphic3): + # enter context manager + # all events are blocked from graphic1, graphic2, graphic3 + + # context manager exited, event states restored. + + """ + ndgraphic._block_indices = True + + try: + yield + except Exception as e: + raise e from None # indices setter has raised, the line above and the lines below are probably more relevant! + finally: + ndgraphic._block_indices = False diff --git a/fastplotlib/widgets/nd_widget/ndwidget.py b/fastplotlib/widgets/nd_widget/ndwidget.py index 427abba4e..0caf9b9c0 100644 --- a/fastplotlib/widgets/nd_widget/ndwidget.py +++ b/fastplotlib/widgets/nd_widget/ndwidget.py @@ -1,6 +1,6 @@ from typing import Any -from ._index import ReferenceRangeContinuous, ReferenceRangeDiscrete +from ._index import ReferenceRangeContinuous, ReferenceRangeDiscrete, GlobalIndexVector from ._ndw_subplot import NDWSubplot from ._ui import NDWidgetUI from ...layouts import ImguiFigure, Subplot @@ -8,29 +8,13 @@ class NDWidget: def __init__(self, ref_ranges: list[tuple], **kwargs): - # TODO: this should maybe be an ordered dict?? - self._ref_ranges = list() - - for r in ref_ranges: - if len(r) == 4: - # assume start, stop, step, unit - refr = ReferenceRangeContinuous(*r) - elif len(r) == 2: - refr = ReferenceRangeDiscrete(*r) - else: - raise ValueError - - self._ref_ranges.append(refr) - + self._indices = GlobalIndexVector(ref_ranges, self._get_ndgraphics) self._figure = ImguiFigure(**kwargs) self._subplots_nd: dict[Subplot, NDWSubplot] = dict() for subplot in self.figure: self._subplots_nd[subplot] = NDWSubplot(self, subplot) - # starting index for all dims - self._indices = tuple(refr[0] for refr in self.ref_ranges) - # hard code the expected height so that the first render looks right in tests, docs etc. ui_size = 57 + (50 * len(self.indices)) @@ -42,25 +26,28 @@ def figure(self) -> ImguiFigure: return self._figure @property - def ref_ranges(self) -> tuple[ReferenceRangeContinuous | ReferenceRangeDiscrete]: - return tuple(self._ref_ranges) - - @property - def indices(self) -> tuple: + def indices(self) -> GlobalIndexVector: return self._indices @indices.setter def indices(self, new_indices: tuple[Any]): - for subplot in self._subplots_nd.values(): - for ndg in subplot.nd_graphics: - ndg.indices = new_indices + self._indices.indices = new_indices - self._indices = new_indices + @property + def ref_ranges(self) -> tuple[ReferenceRangeContinuous | ReferenceRangeDiscrete]: + return tuple(self._indices.ref_ranges) def __getitem__(self, key: str | tuple[int, int] | Subplot): if not isinstance(key, Subplot): key = self.figure[key] return self._subplots_nd[key] + def _get_ndgraphics(self): + gs = list() + for subplot in self._subplots_nd.values(): + gs.extend(subplot.nd_graphics) + + return tuple(gs) + def show(self, **kwargs): return self.figure.show(**kwargs) From fb42ae04185981ba79a4d10ad08b22b33e803d5e Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 24 Feb 2026 04:23:06 -0500 Subject: [PATCH 041/101] GlobalIndexVector working with ndpostions and ndimage --- fastplotlib/layouts/_plot_area.py | 2 +- fastplotlib/widgets/nd_widget/_index.py | 14 ++++---- fastplotlib/widgets/nd_widget/_nd_image.py | 11 +++---- .../widgets/nd_widget/_nd_positions/core.py | 33 ++++++++----------- fastplotlib/widgets/nd_widget/_ndw_subplot.py | 4 +-- fastplotlib/widgets/nd_widget/ndwidget.py | 6 ++-- 6 files changed, 31 insertions(+), 39 deletions(-) diff --git a/fastplotlib/layouts/_plot_area.py b/fastplotlib/layouts/_plot_area.py index 27ec75eef..513a7ad47 100644 --- a/fastplotlib/layouts/_plot_area.py +++ b/fastplotlib/layouts/_plot_area.py @@ -874,7 +874,7 @@ def x_range(self) -> tuple[float, float]: @x_range.setter def x_range(self, xr: tuple[float, float]): width = xr[1] - xr[0] - x_mid = xr[0] + (width / 2) + x_mid = (xr[0] + xr[1]) / 2 self.camera.width = width self.camera.local.x = x_mid diff --git a/fastplotlib/widgets/nd_widget/_index.py b/fastplotlib/widgets/nd_widget/_index.py index 4cf1e0bd6..ff2edd6f6 100644 --- a/fastplotlib/widgets/nd_widget/_index.py +++ b/fastplotlib/widgets/nd_widget/_index.py @@ -63,16 +63,16 @@ def __init__(self, ref_ranges: list, get_ndgraphics: Callable): self._get_ndgraphics = get_ndgraphics # starting index for all dims - self._indices = [refr[0] for refr in self.ref_ranges] + self._indices: list[int | float | Any] = [refr[0] for refr in self.ref_ranges] @property - def indices(self) -> tuple[Any]: + def indices(self) -> tuple[int | float | Any, ...]: # TODO: clamp index to given ref range here # graphics will clamp according to their own array sizes? return tuple(self._indices) @indices.setter - def indices(self, new_indices: tuple[Any]): + def indices(self, new_indices: tuple[int | float | Any, ...]): self._indices[:] = new_indices self._render_indices() @@ -81,11 +81,11 @@ def _render_indices(self): g.indices = self.indices @property - def dims(self) -> tuple[str]: - return tuple(ref.unit for ref in self.ref_ranges) + def dims(self) -> tuple[str, ...]: + return tuple([ref.unit for ref in self.ref_ranges]) @property - def ref_ranges(self) -> tuple[ReferenceRangeContinuous]: + def ref_ranges(self) -> tuple[ReferenceRangeContinuous, ...]: return tuple(self._ref_ranges) def __getitem__(self, item): @@ -131,7 +131,7 @@ def __eq__(self, other): return self._indices == other def __repr__(self): - named = ", ".join([f"{d}: {i}" for d, i in zip(self.dims, self.index)]) + named = ", ".join([f"{d}: {i}" for d, i in zip(self.dims, self._indices)]) return f"Indices: {named}" diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py index 3e54814b2..398e48dee 100644 --- a/fastplotlib/widgets/nd_widget/_nd_image.py +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -526,6 +526,7 @@ def _recompute_histogram(self): class NDImage(NDGraphic): def __init__( self, + global_index, data: Any, *args, graphic: type[ImageGraphic, ImageVolumeGraphic] = None, @@ -540,6 +541,8 @@ def __init__( if processor_kwargs is None: processor_kwargs = dict() + self._global_index = global_index + self._processor = processor( data, *args, @@ -549,8 +552,6 @@ def __init__( **processor_kwargs, ) - self._indices = tuple([0] * self._processor.n_slider_dims) - self._graphic = None self._create_graphic() @@ -582,7 +583,7 @@ def _create_graphic(self): case 3: cls = ImageVolumeGraphic - data_slice = self.processor.get(self.indices) + data_slice = self.processor.get(self._global_index.indices) old_graphic = self._graphic new_graphic = cls(data_slice) @@ -607,7 +608,7 @@ def n_display_dims(self, n: Literal[2 , 3]): @property def indices(self) -> tuple: - return self._indices + return self._global_index.indices @indices.setter def indices(self, indices): @@ -615,8 +616,6 @@ def indices(self, indices): self.graphic.data = data_slice - self._indices = indices - def _tooltip_handler(self, graphic, pick_info): # get graphic within the collection n_index = np.argwhere(self.graphic.graphics == graphic).item() diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/core.py b/fastplotlib/widgets/nd_widget/_nd_positions/core.py index 59a240687..e9be48368 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/core.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/core.py @@ -366,6 +366,8 @@ def __init__( graphic_kwargs: dict = None, processor_kwargs: dict = None, ): + self._global_index = global_index + if issubclass(graphic, LineCollection): multi = True @@ -386,8 +388,6 @@ def __init__( self._processor.p_max = 1_000 - self._indices = tuple([0] * self._processor.n_slider_dims) - self._create_graphic(graphic) self._x_range_mode = None @@ -401,8 +401,6 @@ def __init__( self._pause = False - self._global_index = global_index - super().__init__(name) @property @@ -436,7 +434,7 @@ def graphic(self, graphic_type): @property def indices(self) -> tuple: - return self._indices + return self._global_index.indices @indices.setter @block_reentrance @@ -466,15 +464,14 @@ def indices(self, indices): if self._x_range_mode is not None: self.graphic._plot_area.x_range = xr - self._last_x_range[:] = xr # if the update_from_view is polling, prevents it from being called + # if the update_from_view is polling, this prevents it from being called by setting the new last xrange + # in theory, but this doesn't seem to fully work yet, not a big deal right now can check later + self._last_x_range[:] = self.graphic._plot_area.x_range if self._linear_selector is not None: with pause_events(self._linear_selector): self._linear_selector.limits = xr self._linear_selector.selection = indices[-1] - # self._set_linear_selector(x_mid, limits=xr) - - self._indices = indices def _linear_selector_handler(self, ev): with block_indices(self): @@ -566,6 +563,8 @@ def display_window(self) -> int | float | None: @display_window.setter def display_window(self, dw: int | float | None): self.processor.display_window = dw + + # force re-render self.indices = self.indices @property @@ -593,19 +592,15 @@ def _update_from_view_range(self): if np.allclose(xr, self._last_x_range, atol=1e-14): return + last_width = abs(self._last_x_range[1] - self._last_x_range[0]) self._last_x_range[:] = xr - self.display_window = xr[1] - xr[0] + new_width = abs(xr[1] - xr[0]) new_index = (xr[0] + xr[1]) / 2 + if (new_index == self._global_index[-1]) and (last_width == new_width): + return + + self.processor.display_window = new_width # set the `p` dim on the global index vector self._global_index[-1] = new_index - - # indices = list(self.indices) - # if indices[-1] == new_index: - # return - # - # indices[-1] = new_index - # - # self.indices = indices - # diff --git a/fastplotlib/widgets/nd_widget/_ndw_subplot.py b/fastplotlib/widgets/nd_widget/_ndw_subplot.py index 39975741f..92ec69d74 100644 --- a/fastplotlib/widgets/nd_widget/_ndw_subplot.py +++ b/fastplotlib/widgets/nd_widget/_ndw_subplot.py @@ -1,5 +1,3 @@ -from functools import partial - import numpy as np from ... import ScatterCollection, LineCollection, LineStack, ImageGraphic @@ -31,7 +29,7 @@ def __getitem__(self, key): raise KeyError(f"NDGraphc with given key not found: {key}") def add_nd_image(self, *args, **kwargs): - nd = NDImage(*args, **kwargs) + nd = NDImage(self.ndw.indices, *args, **kwargs) self._nd_graphics.append(nd) self._subplot.add_graphic(nd.graphic) return nd diff --git a/fastplotlib/widgets/nd_widget/ndwidget.py b/fastplotlib/widgets/nd_widget/ndwidget.py index 0caf9b9c0..c4bd8fb1e 100644 --- a/fastplotlib/widgets/nd_widget/ndwidget.py +++ b/fastplotlib/widgets/nd_widget/ndwidget.py @@ -30,12 +30,12 @@ def indices(self) -> GlobalIndexVector: return self._indices @indices.setter - def indices(self, new_indices: tuple[Any]): + def indices(self, new_indices: tuple[int | float | Any, ...]): self._indices.indices = new_indices @property - def ref_ranges(self) -> tuple[ReferenceRangeContinuous | ReferenceRangeDiscrete]: - return tuple(self._indices.ref_ranges) + def ref_ranges(self) -> tuple[ReferenceRangeContinuous | ReferenceRangeDiscrete, ...]: + return self._indices.ref_ranges def __getitem__(self, key: str | tuple[int, int] | Subplot): if not isinstance(key, Subplot): From fabb65a6678235088c3ede8fa1983c8bf2c2300f Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 24 Feb 2026 05:33:57 -0500 Subject: [PATCH 042/101] examples --- examples/ndwidget/ndimage.py | 25 +++++++++++++++++++++ examples/ndwidget/timeseries.py | 39 +++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 examples/ndwidget/ndimage.py create mode 100644 examples/ndwidget/timeseries.py diff --git a/examples/ndwidget/ndimage.py b/examples/ndwidget/ndimage.py new file mode 100644 index 000000000..9fb6dd422 --- /dev/null +++ b/examples/ndwidget/ndimage.py @@ -0,0 +1,25 @@ +""" +NDWidget image +============== + +NDWidget image example +""" + +# test_example = true +# sphinx_gallery_pygfx_docs = 'screenshot' + +import numpy as np +import fastplotlib as fpl + + +a = np.random.rand(30, 1000, 64, 64) + + +ndw = fpl.NDWidget(ref_ranges=[(0, 30, 1, "um"), (0, 1000, 1, "t")], size=(800, 800)) +ndw.show() + +ndi = ndw[0, 0].add_nd_image(a, index_mappings=(int, int)) +# TODO: need to think about how to "auto ignore" reference range for a dim when switching between 2 & 3 dim images +# ndi.n_display_dims = 3 + +fpl.loop.run() diff --git a/examples/ndwidget/timeseries.py b/examples/ndwidget/timeseries.py new file mode 100644 index 000000000..1dac31326 --- /dev/null +++ b/examples/ndwidget/timeseries.py @@ -0,0 +1,39 @@ +""" +NDWidget Timeseries +=================== + +NDWidget timeseries example +""" + +# test_example = true +# sphinx_gallery_pygfx_docs = 'screenshot' + +import numpy as np +import fastplotlib as fpl + +# generate some toy timeseries data +n_datapoints = 50_000 # number of datapoints per line +xs = np.linspace(0, 1000 * np.pi, n_datapoints) + +lines = list() +for i in range(1, 11): + l = np.column_stack( + [ + xs, + np.sin(xs * i) + ] + ) + lines.append(l) + +# timeseries data of shape [n_lines, n_datapoint, 2] +data = np.stack(lines) + +# must define a reference range, this would often be your time dimension and corresponds to your x-dimension +ref = [(0, xs[-1], 0.1, "angle")] + +ndw = fpl.NDWidget(ref_ranges=ref, size=(700, 560)) + +ndw[0, 0].add_nd_timeseries(data, index_mappings=(lambda xval: xs.searchsorted(xval),), x_range_mode="view-range") + +ndw.show(maintain_aspect=False) +fpl.loop.run() From 1642978aa6b1acea8e0b2a709f8e570eca29555c Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 24 Feb 2026 20:28:04 -0500 Subject: [PATCH 043/101] progress --- examples/ndwidget/ndimage.py | 6 +++--- fastplotlib/widgets/nd_widget/_index.py | 2 +- fastplotlib/widgets/nd_widget/_nd_image.py | 12 ++++++------ fastplotlib/widgets/nd_widget/_nd_positions/core.py | 4 +++- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/examples/ndwidget/ndimage.py b/examples/ndwidget/ndimage.py index 9fb6dd422..0e44cd1b5 100644 --- a/examples/ndwidget/ndimage.py +++ b/examples/ndwidget/ndimage.py @@ -12,14 +12,14 @@ import fastplotlib as fpl -a = np.random.rand(30, 1000, 64, 64) +a = np.random.rand(1000, 30, 64, 64) -ndw = fpl.NDWidget(ref_ranges=[(0, 30, 1, "um"), (0, 1000, 1, "t")], size=(800, 800)) +ndw = fpl.NDWidget(ref_ranges=[(0, 1000, 1, "t"), (0, 30, 1, "um")], size=(800, 800)) ndw.show() ndi = ndw[0, 0].add_nd_image(a, index_mappings=(int, int)) # TODO: need to think about how to "auto ignore" reference range for a dim when switching between 2 & 3 dim images -# ndi.n_display_dims = 3 +ndi.n_display_dims = 3 fpl.loop.run() diff --git a/fastplotlib/widgets/nd_widget/_index.py b/fastplotlib/widgets/nd_widget/_index.py index ff2edd6f6..f30f93b94 100644 --- a/fastplotlib/widgets/nd_widget/_index.py +++ b/fastplotlib/widgets/nd_widget/_index.py @@ -46,7 +46,7 @@ def __len__(self): class GlobalIndexVector: - def __init__(self, ref_ranges: list, get_ndgraphics: Callable): + def __init__(self, ref_ranges: list, get_ndgraphics: Callable[[], tuple[NDGraphic]]): self._ref_ranges = list() for r in ref_ranges: diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py index 398e48dee..535927a03 100644 --- a/fastplotlib/widgets/nd_widget/_nd_image.py +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -583,16 +583,15 @@ def _create_graphic(self): case 3: cls = ImageVolumeGraphic - data_slice = self.processor.get(self._global_index.indices) + data_slice = self.processor.get(self.indices) old_graphic = self._graphic new_graphic = cls(data_slice) if old_graphic is not None: - g = self._graphic - plot_area = g._plot_area - self._graphic._plot_area.delete_graphic(g) - plot_area.add_graphic(self._graphic) + plot_area = old_graphic._plot_area + plot_area.delete_graphic(old_graphic) + plot_area.add_graphic(new_graphic) self._graphic = new_graphic @@ -608,10 +607,11 @@ def n_display_dims(self, n: Literal[2 , 3]): @property def indices(self) -> tuple: - return self._global_index.indices + return self._global_index.indices[-self.processor.n_slider_dims:] @indices.setter def indices(self, indices): + indices = indices[-self.processor.n_slider_dims:] data_slice = self.processor.get(indices) self.graphic.data = data_slice diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/core.py b/fastplotlib/widgets/nd_widget/_nd_positions/core.py index e9be48368..d772e11e9 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/core.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/core.py @@ -434,11 +434,13 @@ def graphic(self, graphic_type): @property def indices(self) -> tuple: - return self._global_index.indices + return self._global_index.indices[-self.processor.n_slider_dims:] @indices.setter @block_reentrance def indices(self, indices): + # upto the number of slider dims in this data + indices = indices[-self.processor.n_slider_dims:] data_slice = self.processor.get(indices) if isinstance(self.graphic, (LineGraphic, ScatterGraphic)): From 8df9fb39e7c01fb42aa64e166571400fed31f7d2 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Thu, 26 Feb 2026 17:04:55 -0500 Subject: [PATCH 044/101] do not reset vmin vmax when replacing Image buffer --- fastplotlib/graphics/image.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fastplotlib/graphics/image.py b/fastplotlib/graphics/image.py index 6dfb52238..8e11f4751 100644 --- a/fastplotlib/graphics/image.py +++ b/fastplotlib/graphics/image.py @@ -259,8 +259,6 @@ def data(self, data): wrap="clamp-to-edge", ) - self._material.clim = quick_min_max(self.data.value) - # remove tiles from the WorldObject -> Graphic map self._remove_group_graphic_map(self.world_object) From 7e832629e859a5dc01b3edc2b877fe136cba4aa5 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 28 Feb 2026 03:13:46 -0500 Subject: [PATCH 045/101] WIP migrate to xarray --- fastplotlib/widgets/nd_widget/base.py | 289 ++++++++++++-------------- 1 file changed, 136 insertions(+), 153 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/base.py b/fastplotlib/widgets/nd_widget/base.py index b78dbcbbb..bd32023c4 100644 --- a/fastplotlib/widgets/nd_widget/base.py +++ b/fastplotlib/widgets/nd_widget/base.py @@ -1,8 +1,11 @@ +from collections.abc import Callable, Hashable, Sequence from contextlib import contextmanager import inspect -from typing import Literal, Callable, Any +from numbers import Real +from typing import Literal, Any from warnings import warn +import xarray as xr import numpy as np from numpy.typing import ArrayLike @@ -17,45 +20,107 @@ def identity(index: int) -> int: return index +class BaseNDProcessor: + @property + def data(self) -> Any: + pass + + @property + def shape(self) -> dict[Hashable, int]: + pass + + @property + def ndim(self): + pass + + @property + def spatial_dims(self) -> tuple[Hashable, ...]: + pass + + @property + def slider_dims(self): + pass + + @property + def window_funcs( + self, + ) -> dict[Hashable, tuple[WindowFuncCallable | None, int | float | None]]: + # {dim: (func, size)} + pass + + @property + def window_funcs_order(self) -> tuple[Hashable]: + pass + + @property + def index_mappings(self) -> dict[Hashable, Callable[[Any], int] | ArrayLike]: + pass + + def get(self, **indices): + raise NotImplementedError + + class NDProcessor: def __init__( self, data, - n_display_dims: Literal[2, 3] = 2, - index_mappings: tuple[Callable[[Any], int] | None, ...] | None = None, - window_funcs: tuple[WindowFuncCallable | None] | None = None, - window_sizes: tuple[int | None] | None = None, - window_order: tuple[int, ...] = None, + dims: Sequence[Hashable], + spatial_dims: Sequence[Hashable] | None, + index_mappings: dict[Hashable, Callable[[Any], int] | ArrayLike] = None, + window_funcs: dict[ + Hashable, tuple[WindowFuncCallable | None, int | float | None] + ] = None, + window_funcs_order: tuple[Hashable, ...] = None, spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, ): - self._data = self._validate_data(data) + self._data = self._validate_data(data, dims) + self.spatial_dims = spatial_dims + self._index_mappings = tuple(self._validate_index_mappings(index_mappings)) self.window_funcs = window_funcs - self.window_sizes = window_sizes - self.window_order = window_order + self.window_order = window_funcs_order @property - def data(self) -> ArrayProtocol: + def data(self) -> xr.DataArray: return self._data @data.setter def data(self, data: ArrayProtocol): - self._data = self._validate_data(data) + self._data = self._validate_data(data, self.dims) @property - def shape(self) -> tuple[int, ...]: - return self.data.shape + def shape(self) -> dict[Hashable, int]: + """interpreted shape of the data""" + return {d: n for d, n in zip(self.dims, self.data.shape)} @property def ndim(self) -> int: - return len(self.shape) + """number of dims""" + return self.data.ndim - def _validate_data(self, data: ArrayProtocol): + @property + def dims(self) -> tuple[Hashable, ...]: + """dim names""" + return self.data.dims + + @property + def spatial_dims(self) -> tuple[Hashable, ...]: + return self._spatial_dims + + @spatial_dims.setter + def spatial_dims(self, sdims: Sequence[Hashable]): + for dim in tuple(sdims): + if dim not in self.dims: + raise KeyError + + self._spatial_dims = tuple(sdims) + + def _validate_data(self, data: ArrayProtocol, dims): if not isinstance(data, ArrayProtocol): raise TypeError("`data` must implement the ArrayProtocol") - return data + return xr.DataArray(data, dims=dims) @property def tooltip(self) -> bool: @@ -72,146 +137,74 @@ def tooltip_format(self, *args) -> str | None: @property def slider_dims(self): - raise NotImplementedError + return set(self.dims) - set(self.spatial_dims) @property def n_slider_dims(self): - raise NotImplementedError + return len(self.slider_dims) @property def window_funcs( self, - ) -> tuple[WindowFuncCallable | None, ...] | None: + ) -> dict[Hashable, tuple[WindowFuncCallable | None, int | float | None]]: """get or set window functions, see docstring for details""" return self._window_funcs @window_funcs.setter def window_funcs( self, - window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable | None, + window_funcs: ( + dict[Hashable, tuple[WindowFuncCallable | None, int | float | None]] | None + ), ): if window_funcs is None: - self._window_funcs = tuple([None] * self.n_slider_dims) - return - - if callable(window_funcs): - window_funcs = (window_funcs,) - - # if all are None - # if all([f is None for f in window_funcs]): - # self._window_funcs = tuple(window_funcs) - # return - - self._validate_window_func(window_funcs) - - self._window_funcs = tuple(window_funcs) - # self._recompute_histogram() - - def _validate_window_func(self, funcs): - if isinstance(funcs, (tuple, list)): - for f in funcs: - if f is None: - pass - elif callable(f): - sig = inspect.signature(f) - - if "axis" not in sig.parameters or "keepdims" not in sig.parameters: - raise TypeError( - f"Each window function must take an `axis` and `keepdims` argument, " - f"you passed: {f} with the following function signature: {sig}" - ) - else: - raise TypeError( - f"`window_funcs` must be of type: tuple[Callable | None, ...], you have passed: {funcs}" - ) - - if not (len(funcs) == self.n_slider_dims or self.n_slider_dims == 0): - raise IndexError( - f"number of `window_funcs` must be the same as the number of slider dims: {self.n_slider_dims}, " - f"and you passed {len(funcs)} `window_funcs`: {funcs}" - ) - - @property - def window_sizes(self) -> tuple[int | None, ...] | None: - """get or set window sizes used for the corresponding window functions, see docstring for details""" - return self._window_sizes - - @window_sizes.setter - def window_sizes(self, window_sizes: tuple[int | None, ...] | int | None): - if window_sizes is None: - self._window_sizes = tuple([None] * self.n_slider_dims) + self._window_funcs = {d: None for d in self.data.dims} return - if isinstance(window_sizes, int): - window_sizes = (window_sizes,) + for k in window_funcs.keys(): + if k not in self.dims: + raise KeyError + if k in self.spatial_dims: + raise KeyError - # if all are None - if all([w is None for w in window_sizes]): - self._window_sizes = None - return + func = window_funcs[k][0] + size = window_funcs[k][1] - if not all([isinstance(w, (int)) or w is None for w in window_sizes]): - raise TypeError( - f"`window_sizes` must be of type: tuple[int | None, ...] | int | None, you have passed: {window_sizes}" - ) - - # if not (len(window_sizes) == self.n_slider_dims or self.n_slider_dims == 0): - # raise IndexError( - # f"number of `window_sizes` must be the same as the number of slider dims, " - # f"i.e. `data.ndim` - n_display_dims, your data array has {self.ndim} dimensions " - # f"and you passed {len(window_sizes)} `window_sizes`: {window_sizes}" - # ) - - # make all window sizes are valid numbers - _window_sizes = list() - for i, w in enumerate(window_sizes): - if w is None: - _window_sizes.append(None) - continue - - if w < 0: - raise ValueError( - f"negative window size passed, all `window_sizes` must be positive " - f"integers or `None`, you passed: {_window_sizes}" - ) + if func is None: + pass + elif callable(func): + sig = inspect.signature(func) - if w == 0 or w == 1: - # this is not a real window, set as None - w = None - - elif w % 2 == 0: - # odd window sizes makes most sense - warn( - f"provided even window size: {w} in dim: {i}, adding `1` to make it odd" + if "axis" not in sig.parameters or "keepdims" not in sig.parameters: + raise TypeError( + f"Each window function must take an `axis` and `keepdims` argument, " + f"you passed: {func} with the following function signature: {sig}" + ) + else: + raise TypeError( + f"`window_funcs` must be a dict mapping dim names to a tuple of the window function callable and " + f"window size, {'name': (func, size), ...}.\nYou have passed: {window_funcs}" ) - w += 1 - _window_sizes.append(w) + if not isinstance(size, Real): + raise TypeError + elif size < 0: + raise ValueError - self._window_sizes = tuple(_window_sizes) + self._window_funcs = window_funcs @property - def window_order(self) -> tuple[int, ...] | None: + def window_order(self) -> tuple[Hashable, ...] | None: """get or set dimension order in which window functions are applied""" return self._window_order @window_order.setter - def window_order(self, order: tuple[int] | None): - if order is None: - self._window_order = None - return - - if order is not None: - if not all([d <= self.n_slider_dims for d in order]): - raise IndexError( - f"all `window_order` entries must be <= n_slider_dims\n" - f"`n_slider_dims` is: {self.n_slider_dims}, you have passed `window_order`: {order}" - ) - - if not all([d >= 0 for d in order]): - raise IndexError( - f"all `window_order` entires must be >= 0, you have passed: {order}" - ) + def window_order(self, order: tuple[Hashable] | None): + for d in order: + if d not in self.dims: + raise KeyError + if d in self.spatial_dims: + raise KeyError self._window_order = tuple(order) @@ -219,37 +212,27 @@ def window_order(self, order: tuple[int] | None): def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None: pass - # @property - # def slider_dims(self) -> tuple[int, ...] | None: - # pass - @property def index_mappings(self) -> tuple[Callable[[Any], int]]: return self._index_mappings @index_mappings.setter - def index_mappings(self, maps: tuple[Callable[[Any], int] | None] | None): - self._index_mappings = tuple(self._validate_index_mappings(maps)) - - def _validate_index_mappings(self, maps): - if maps is None: - return tuple([identity] * self.n_slider_dims) - - if len(maps) != self.n_slider_dims: - raise IndexError - - _maps = list() - for m in maps: - if m is None: - _maps.append(identity) - elif callable(m): - _maps.append(identity) - else: - raise TypeError - - return tuple(maps) - - def __getitem__(self, item: tuple[Any, ...]) -> ArrayProtocol: + def index_mappings(self, maps: dict[Hashable, Callable[[Any], int] | ArrayLike]): + for d in maps.keys(): + if d not in self.dims: + raise KeyError + if d in self.spatial_dims: + raise KeyError + if isinstance(maps[d], ArrayProtocol): + # create a searchsorted mapping function automatically + maps[d] = maps[d].searchsorted + elif maps[d] is None: + # assign identity mapping + maps[d] = identity + + self._index_mappings = maps + + def get(self, indices: dict[Hashable, Any]): pass From b220f94a5a6942708062e1311ec51f8a238a2dd0 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 28 Feb 2026 03:53:49 -0500 Subject: [PATCH 046/101] window funcs in NDProcessor class using xarray, WIP, not tested --- fastplotlib/widgets/nd_widget/base.py | 85 +++++++++++++++++++++++---- 1 file changed, 73 insertions(+), 12 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/base.py b/fastplotlib/widgets/nd_widget/base.py index bd32023c4..468d57d7d 100644 --- a/fastplotlib/widgets/nd_widget/base.py +++ b/fastplotlib/widgets/nd_widget/base.py @@ -76,7 +76,7 @@ def __init__( self._data = self._validate_data(data, dims) self.spatial_dims = spatial_dims - self._index_mappings = tuple(self._validate_index_mappings(index_mappings)) + self.index_mappings = index_mappings self.window_funcs = window_funcs self.window_order = window_funcs_order @@ -136,7 +136,7 @@ def tooltip_format(self, *args) -> str | None: return None @property - def slider_dims(self): + def slider_dims(self) -> set[Hashable]: return set(self.dims) - set(self.spatial_dims) @property @@ -146,7 +146,7 @@ def n_slider_dims(self): @property def window_funcs( self, - ) -> dict[Hashable, tuple[WindowFuncCallable | None, int | float | None]]: + ) -> dict[Hashable, tuple[WindowFuncCallable | None, int | float | None] | None]: """get or set window functions, see docstring for details""" return self._window_funcs @@ -154,7 +154,7 @@ def window_funcs( def window_funcs( self, window_funcs: ( - dict[Hashable, tuple[WindowFuncCallable | None, int | float | None]] | None + dict[Hashable, tuple[WindowFuncCallable | None, int | float | None] | None] | None ), ): if window_funcs is None: @@ -186,11 +186,20 @@ def window_funcs( f"window size, {'name': (func, size), ...}.\nYou have passed: {window_funcs}" ) - if not isinstance(size, Real): + if size is None: + pass + + elif not isinstance(size, Real): raise TypeError + elif size < 0: raise ValueError + # fill in rest with None + for d in self.slider_dims: + if d not in window_funcs.keys(): + window_funcs[d] = None + self._window_funcs = window_funcs @property @@ -200,11 +209,8 @@ def window_order(self) -> tuple[Hashable, ...] | None: @window_order.setter def window_order(self, order: tuple[Hashable] | None): - for d in order: - if d not in self.dims: - raise KeyError - if d in self.spatial_dims: - raise KeyError + if set(order) != self.slider_dims: + raise ValueError("Order must specify all dims") self._window_order = tuple(order) @@ -213,7 +219,7 @@ def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None: pass @property - def index_mappings(self) -> tuple[Callable[[Any], int]]: + def index_mappings(self) -> dict[Hashable, Callable[[Any], int]]: return self._index_mappings @index_mappings.setter @@ -232,8 +238,63 @@ def index_mappings(self, maps: dict[Hashable, Callable[[Any], int] | ArrayLike]) self._index_mappings = maps + def _apply_window_functions(self, indices: dict[Hashable, Any]) -> np.ndarray: + if set(indices.keys()) != set(self.slider_dims): + raise IndexError( + f"Must provide an index for all slider dims: {self.slider_dims}, you have provided: {indices.keys()}" + ) + + indexer = dict() + # go through each slider dim and accumulate slice objects + for dim in self.slider_dims: + # index for this dim in reference space + index_ref = indices[dim] + + # if a window function exists for this dim + if self.window_funcs[dim] is not None: + # window size in reference units + w = self.window_funcs[dim][1] + + # half window in reference units + hw = w / 2 + + # start in reference units + start_ref = index_ref - hw + # stop in ref units + stop_ref = index_ref + hw + + # map start and stop ref to array indices + start = self.index_mappings[dim](start_ref) + stop = self.index_mappings[dim](stop_ref) + + # cmap within array bounds + start = max(min(self.shape[dim] - 1, start), 0) + stop = max(min(self.shape[dim] - 1, stop), 0) + indexer[dim] = slice(start, stop, 1) + else: + # no window func for this dim, direct indexing + # index mapped to array index + index = self.index_mappings[dim](index_ref) + + # clamp within the bounds + start = max(min(self.shape[dim] - 1, index), 0) + + # stop index is just the start index + 1 + indexer[dim] = slice(start, start + 1, 1) + + # apply indexer with any specified windows, return the underlying numpy array + data_sliced = self.data.isel(indexer).values + + # apply window funcs in the specified order + for dim in self.window_order: + func, _ = self.window_funcs[dim] + + data_sliced = func(data_sliced, axis=self.dims.index(dim), keepdims=True) + + return data_sliced + def get(self, indices: dict[Hashable, Any]): - pass + window_output = self._apply_window_functions(indices) def block_reentrance(setter): From 6f09b5d58f30dda463ea9938c706ac013770ed5b Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 28 Feb 2026 04:03:21 -0500 Subject: [PATCH 047/101] typo --- fastplotlib/widgets/nd_widget/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastplotlib/widgets/nd_widget/base.py b/fastplotlib/widgets/nd_widget/base.py index 468d57d7d..3b64de487 100644 --- a/fastplotlib/widgets/nd_widget/base.py +++ b/fastplotlib/widgets/nd_widget/base.py @@ -267,7 +267,7 @@ def _apply_window_functions(self, indices: dict[Hashable, Any]) -> np.ndarray: start = self.index_mappings[dim](start_ref) stop = self.index_mappings[dim](stop_ref) - # cmap within array bounds + # clamp within array bounds start = max(min(self.shape[dim] - 1, start), 0) stop = max(min(self.shape[dim] - 1, stop), 0) indexer[dim] = slice(start, stop, 1) From be437fa1d13d55d211d63adfd5041f9e19a0e405 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sun, 1 Mar 2026 21:55:43 -0500 Subject: [PATCH 048/101] basic single index slicing working with xarray --- fastplotlib/widgets/nd_widget/base.py | 70 ++++++++++++++++++--------- 1 file changed, 47 insertions(+), 23 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/base.py b/fastplotlib/widgets/nd_widget/base.py index 3b64de487..973b18331 100644 --- a/fastplotlib/widgets/nd_widget/base.py +++ b/fastplotlib/widgets/nd_widget/base.py @@ -17,7 +17,7 @@ def identity(index: int) -> int: - return index + return round(index) class BaseNDProcessor: @@ -89,6 +89,12 @@ def data(self) -> xr.DataArray: def data(self, data: ArrayProtocol): self._data = self._validate_data(data, self.dims) + def _validate_data(self, data: ArrayProtocol, dims): + if not isinstance(data, ArrayProtocol): + raise TypeError("`data` must implement the ArrayProtocol") + + return xr.DataArray(data, dims=dims) + @property def shape(self) -> dict[Hashable, int]: """interpreted shape of the data""" @@ -110,18 +116,12 @@ def spatial_dims(self) -> tuple[Hashable, ...]: @spatial_dims.setter def spatial_dims(self, sdims: Sequence[Hashable]): - for dim in tuple(sdims): + for dim in sdims: if dim not in self.dims: raise KeyError self._spatial_dims = tuple(sdims) - def _validate_data(self, data: ArrayProtocol, dims): - if not isinstance(data, ArrayProtocol): - raise TypeError("`data` must implement the ArrayProtocol") - - return xr.DataArray(data, dims=dims) - @property def tooltip(self) -> bool: """ @@ -146,7 +146,7 @@ def n_slider_dims(self): @property def window_funcs( self, - ) -> dict[Hashable, tuple[WindowFuncCallable | None, int | float | None] | None]: + ) -> dict[Hashable, tuple[WindowFuncCallable | None, int | float | None]]: """get or set window functions, see docstring for details""" return self._window_funcs @@ -154,11 +154,13 @@ def window_funcs( def window_funcs( self, window_funcs: ( - dict[Hashable, tuple[WindowFuncCallable | None, int | float | None] | None] | None + dict[Hashable, tuple[WindowFuncCallable | None, int | float | None] | None] + | None ), ): if window_funcs is None: - self._window_funcs = {d: None for d in self.data.dims} + # tuple of (None, None) makes the checks easier in _apply_window_funcs + self._window_funcs = {d: (None, None) for d in self.data.dims} return for k in window_funcs.keys(): @@ -198,19 +200,26 @@ def window_funcs( # fill in rest with None for d in self.slider_dims: if d not in window_funcs.keys(): - window_funcs[d] = None + window_funcs[d] = (None, None) self._window_funcs = window_funcs @property - def window_order(self) -> tuple[Hashable, ...] | None: + def window_order(self) -> tuple[Hashable, ...]: """get or set dimension order in which window functions are applied""" return self._window_order @window_order.setter def window_order(self, order: tuple[Hashable] | None): - if set(order) != self.slider_dims: - raise ValueError("Order must specify all dims") + if order is None: + self._window_order = tuple() + return + + if not set(order).issubset(self.slider_dims): + raise ValueError( + f"each dimension in `window_order` must be a slider dim. You passed order: {order} " + f"and the slider dims are: {self.slider_dims}" + ) self._window_order = tuple(order) @@ -223,19 +232,31 @@ def index_mappings(self) -> dict[Hashable, Callable[[Any], int]]: return self._index_mappings @index_mappings.setter - def index_mappings(self, maps: dict[Hashable, Callable[[Any], int] | ArrayLike]): + def index_mappings(self, maps: dict[Hashable, Callable[[Any], int] | ArrayLike | None] | None): + if maps is None: + self._index_mappings = {d: identity for d in self.dims} + return + for d in maps.keys(): if d not in self.dims: raise KeyError + if d in self.spatial_dims: - raise KeyError + raise KeyError("index mappings only apply to slider dims, not spatial dims") + if isinstance(maps[d], ArrayProtocol): # create a searchsorted mapping function automatically maps[d] = maps[d].searchsorted + elif maps[d] is None: # assign identity mapping maps[d] = identity + for d in self.dims: + # fill in any unspecified maps with identity + if d not in maps.keys(): + maps[d] = identity + self._index_mappings = maps def _apply_window_functions(self, indices: dict[Hashable, Any]) -> np.ndarray: @@ -250,13 +271,13 @@ def _apply_window_functions(self, indices: dict[Hashable, Any]) -> np.ndarray: # index for this dim in reference space index_ref = indices[dim] - # if a window function exists for this dim - if self.window_funcs[dim] is not None: - # window size in reference units - w = self.window_funcs[dim][1] + # get window func and size in reference units + wf, ws = self.window_funcs[dim] + # if a window function exists for this dim, and it's specified in the window order + if (wf is not None) and (ws is not None) and (dim in self.window_order): # half window in reference units - hw = w / 2 + hw = ws / 2 # start in reference units start_ref = index_ref - hw @@ -287,6 +308,9 @@ def _apply_window_functions(self, indices: dict[Hashable, Any]) -> np.ndarray: # apply window funcs in the specified order for dim in self.window_order: + if self.window_funcs[dim] is None: + continue + func, _ = self.window_funcs[dim] data_sliced = func(data_sliced, axis=self.dims.index(dim), keepdims=True) @@ -294,7 +318,7 @@ def _apply_window_functions(self, indices: dict[Hashable, Any]) -> np.ndarray: return data_sliced def get(self, indices: dict[Hashable, Any]): - window_output = self._apply_window_functions(indices) + raise NotImplementedError def block_reentrance(setter): From 404bb7bc192d3181cdfe140674cef52e6f64801c Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 3 Mar 2026 04:49:28 -0500 Subject: [PATCH 049/101] window funcs working for NDPositions and NDPP_Pands --- fastplotlib/widgets/nd_widget/_nd_image.py | 17 +- .../nd_widget/_nd_positions/_pandas.py | 56 +-- .../widgets/nd_widget/_nd_positions/core.py | 406 ++++++++---------- fastplotlib/widgets/nd_widget/_ndw_subplot.py | 5 +- fastplotlib/widgets/nd_widget/base.py | 43 +- 5 files changed, 234 insertions(+), 293 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py index 535927a03..262693004 100644 --- a/fastplotlib/widgets/nd_widget/_nd_image.py +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -1,3 +1,4 @@ +from collections.abc import Hashable import inspect from typing import Literal, Callable, Type, Any from warnings import warn @@ -143,20 +144,12 @@ def rgb(self, rgb: bool): self._rgb = rgb @property - def n_slider_dims(self) -> int: - """number of slider dimensions""" - if self._data is None: - return 0 - - return self.ndim - self.n_display_dims - int(self.rgb) + def slider_dims(self) -> set[Hashable]: + return set(self.dims) - set(self.spatial_dims) @property - def slider_dims(self) -> tuple[int, ...] | None: - """tuple indicating the slider dimension indices""" - if self.n_slider_dims == 0: - return None - - return tuple(range(self.n_slider_dims)) + def n_slider_dims(self): + return len(self.slider_dims) @property def slider_dims_shape(self) -> tuple[int, ...] | None: diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py index 3e03b9c2d..296787d56 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py @@ -1,3 +1,5 @@ +from typing import Any + import numpy as np import pandas as pd @@ -8,13 +10,11 @@ class NDPP_Pandas(NDPositionsProcessor): def __init__( self, data: pd.DataFrame, + spatial_dims: tuple[str, str, str], # [l, p, d] dims in order columns: list[tuple[str, str] | tuple[str, str, str]], tooltip_columns: list[str] = None, - max_display_datapoints: int = 1_000, **kwargs, ): - data = data - self._columns = columns if tooltip_columns is not None: @@ -26,17 +26,22 @@ def __init__( self._tooltip_columns = None self._tooltip = False + self._dims = spatial_dims + super().__init__( data=data, - max_display_datapoints=max_display_datapoints, + dims=spatial_dims, + spatial_dims=spatial_dims, **kwargs, ) + self._dw_slice = None + @property def data(self) -> pd.DataFrame: return self._data - def _validate_data(self, data: pd.DataFrame): + def _validate_data(self, data: pd.DataFrame, dims): if not isinstance(data, pd.DataFrame): raise TypeError @@ -47,56 +52,41 @@ def columns(self) -> list[tuple[str, str] | tuple[str, str, str]]: return self._columns @property - def multi(self) -> bool: - return True - - @multi.setter - def multi(self, v): - pass + def dims(self) -> tuple[str, str, str]: + return self._dims @property - def shape(self) -> tuple[int, ...]: + def shape(self) -> dict[str, int]: # n_graphical_elements, n_timepoints, 2 - return len(self.columns), self.data.index.size, 2 + return {self.dims[0]: len(self.columns), self.dims[1]: self.data.index.size, self.dims[2]: 2} @property def ndim(self) -> int: return len(self.shape) - @property - def n_slider_dims(self) -> int: - return 1 - @property def tooltip(self) -> bool: return self._tooltip def tooltip_format(self, n: int, p: int): # datapoint index w.r.t. full data - p += self._slices[-1].start + p += self._dw_slice.start return str(self.data[self._tooltip_columns[n]][p]) - def get(self, indices: tuple[float | int, ...]) -> np.ndarray: - if not isinstance(indices, tuple): - raise TypeError(".get() must receive a tuple of float | int indices") - + def get(self, indices: dict[str, Any]) -> np.ndarray: # TODO: LOD by using a step size according to max_p # TODO: Also what to do if display_window is None and data # hasn't changed when indices keeps getting set, cache? - # assume no additional slider dims, only time slider dim - if self.display_window is not None: - self._slices = self._get_dw_slices(indices) - gdata_shape = len(self.columns), self._slices[-1].stop - self._slices[-1].start, 3 - else: - gdata_shape = len(self.columns), self.data.shape[0], 3 - self._slices = (slice(None),) + # assume no additional slider dims + self._dw_slice = self._get_dw_slice(indices) + gdata_shape = len(self.columns), self._dw_slice.stop - self._dw_slice.start, 3 - gdata = np.zeros(shape=gdata_shape, dtype=np.float32) + graphic_data = np.zeros(shape=gdata_shape, dtype=np.float32) for i, col in enumerate(self.columns): - gdata[i, :, :len(col)] = np.column_stack( - [self.data[c][self._slices[-1]] for c in col] + graphic_data[i, :, :len(col)] = np.column_stack( + [self.data[c][self._dw_slice] for c in col] ) - return gdata + return self._apply_dw_window_func(graphic_data) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/core.py b/fastplotlib/widgets/nd_widget/_nd_positions/core.py index d772e11e9..42577e387 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/core.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/core.py @@ -1,9 +1,12 @@ +from collections.abc import Callable, Hashable, Sequence, Iterable from functools import partial -from typing import Literal, Callable, Any, Type +from typing import Literal, Any, Type from warnings import warn import numpy as np from numpy.lib.stride_tricks import sliding_window_view +from numpy.typing import ArrayLike +import xarray as xr from ....utils import subsample_array, ArrayProtocol @@ -18,7 +21,13 @@ ) from ....graphics.utils import pause_events from ....graphics.selectors import LinearSelector -from ..base import NDProcessor, NDGraphic, WindowFuncCallable, block_reentrance, block_indices +from ..base import ( + NDProcessor, + NDGraphic, + WindowFuncCallable, + block_reentrance, + block_indices, +) from .._index import GlobalIndexVector @@ -29,28 +38,62 @@ class NDPositionsProcessor(NDProcessor): def __init__( self, data: Any, - multi: bool = False, # TODO: interpret [n - 2] dimension as n_lines or n_points + dims: Sequence[str], + # TODO: allow stack_dim to be None and auto-add new dim of size 1 in get logic + spatial_dims: tuple[ + str | None, str, str + ], # [stack_dim, n_datapoints, spatial_dim], IN ORDER!! + index_mappings: dict[str, Callable[[Any], int] | ArrayLike] = None, display_window: int | float | None = 100, # window for n_datapoints dim only max_display_datapoints: int = 1_000, - datapoints_window_func: Callable | None = None, - datapoints_window_size: int | None = None, + datapoints_window_func: tuple[Callable, str, int | float] | None = None, **kwargs, ): + """ + + Parameters + ---------- + data + dims + spatial_dims + index_mappings + display_window + max_display_datapoints + datapoints_window_func: + Important note: if used, display_window is approximate and not exact due to padding from the window size + kwargs + """ self._display_window = display_window self._max_display_datapoints = max_display_datapoints - # TOOD: this does data validation twice and is a bit messy, cleanup - self._data = self._validate_data(data) - self.multi = multi - - super().__init__(data=data, **kwargs) + super().__init__( + data=data, + dims=dims, + spatial_dims=spatial_dims, + index_mappings=index_mappings, + **kwargs, + ) self._datapoints_window_func = datapoints_window_func - self._datapoints_window_size = datapoints_window_size - def _validate_data(self, data: ArrayProtocol): - # TODO: determine right validation shape etc. - return data + @property + def spatial_dims(self) -> tuple[str, str, str]: + return self._spatial_dims + + @spatial_dims.setter + def spatial_dims(self, sdims: tuple[str, str, str]): + if len(sdims) != 3: + raise IndexError + + if not all([d in self.dims for d in sdims]): + raise KeyError + + self._spatial_dims = tuple(sdims) + + @property + def slider_dims(self) -> set[Hashable]: + # append `p` dim to slider dims + return tuple([*super().slider_dims, self.spatial_dims[1]]) @property def display_window(self) -> int | float | None: @@ -80,264 +123,169 @@ def max_display_datapoints(self, n: int): self._max_display_datapoints = n - @property - def multi(self) -> bool: - return self._multi - - @multi.setter - def multi(self, m: bool): - if m and self.data.ndim < 3: - # p is p-datapoints, n is how many lines to show simultaneously (for line collection/stack) - raise ValueError( - "ndim must be >= 3 for multi, shape must be [s1..., sn, n, p, 2 | 3]" - ) - - self._multi = m - - @property - def slider_dims(self) -> tuple[int, ...]: - """slider dimensions""" - return tuple(range(self.ndim - 2 - int(self.multi))) + (self.ndim - 2,) - - @property - def n_slider_dims(self) -> int: - return self.ndim - 1 - int(self.multi) - # TODO: validation for datapoints_window_func and size @property - def datapoints_window_func(self) -> tuple[Callable, str] | None: + def datapoints_window_func(self) -> tuple[Callable, str, int | float] | None: """ Callable and str indicating which dims to apply window function along: 'all', 'x', 'y', 'z', 'xyz', 'xy', 'xz', 'yz' '""" return self._datapoints_window_func - @property - def datapoints_window_size(self) -> Callable | None: - return self._datapoints_window_size - - def _apply_window_functions(self, indices: tuple[int, ...]): - """applies the window functions for each dimension specified""" - # window size for each dim - winds = self._window_sizes - # window function for each dim - funcs = self._window_funcs - - # TODO: use tuple of None for window funcs and sizes to indicate all None, instead of just None - # print(winds) - # print(funcs) - # - # if winds is None or funcs is None: - # # no window funcs or window sizes, just slice data and return - # # clamp to max bounds - # indexer = list() - # print(indices) - # print(self.shape) - # for dim, i in enumerate(indices): - # i = min(self.shape[dim] - 1, i) - # indexer.append(i) - # - # return self.data[tuple(indexer)] - - # order in which window funcs are applied - order = self._window_order - - if order is not None: - # remove any entries in `window_order` where the specified dim - # has a window function or window size specified as `None` - # example: - # window_sizes = (3, 2) - # window_funcs = (np.mean, None) - # order = (0, 1) - # `1` is removed from the order since that window_func is `None` - order = tuple( - d for d in order if winds[d] is not None and funcs[d] is not None - ) - else: - # sequential order - order = list() - for d in range(self.n_slider_dims): - if winds[d] is not None and funcs[d] is not None: - order.append(d) - - # the final indexer which will be used on the data array - indexer = list() - - for dim_index, (i, w, f) in enumerate(zip(indices, winds, funcs)): - # clamp i within the max bounds - i = min(self.shape[dim_index] - 1, i) - - if (w is not None) and (f is not None): - # specify slice window if both window size and function for this dim are not None - hw = int((w - 1) / 2) # half window + def _get_dw_slice(self, indices: dict[str, Any]) -> slice: + # given indices, return slice required to obtain display window - # start index cannot be less than 0 - start = max(0, i - hw) + # n_datapoints dim name + # display_window acts on this dim + p_dim = self.spatial_dims[1] - # stop index cannot exceed the bounds of this dimension - stop = min(self.shape[dim_index], i + hw) - - s = slice(start, stop, 1) - else: - s = slice(i, i + 1, 1) - - indexer.append(s) - - # apply indexer to slice data with the specified windows - data_sliced = self.data[tuple(indexer)] - - # finally apply the window functions in the specified order - for dim in order: - f = funcs[dim] - - data_sliced = f(data_sliced, axis=dim, keepdims=True) - - return data_sliced - - def _get_dw_slices(self, indices) -> tuple[slice] | tuple[slice, slice]: - # given indices, return slice using display window - - # display window is interpreted using the index mapping for the `p` dim - dw = self.display_window - - if dw is None: + if self.display_window is None: # just return everything - return (slice(None),) + return slice(0, self.shape[p_dim] - 1) - if dw == 0: + if self.display_window == 0: # just map p dimension at this index and return - index_p = self.index_mappings[-1](indices[-1]) - return (slice(index_p, index_p + 1),) + index = self._ref_index_to_array_index(p_dim, indices[p_dim]) + return slice(index, index + 1) + + # half window size, in reference units + hw = self.display_window / 2 + + if self.datapoints_window_func is not None: + # add half datapoints_window_func size here, assumes the reference space is somewhat continuous + # and the display_window and datapoints window size map to their actual size values + hw += self._ref_index_to_array_index(p_dim, self.datapoints_window_func[2] / 2) # display window is in reference units, apply display window and then map to array indices - # clamp w.r.t. 0 and processor shape `p` dim - hw = dw / 2 - index_p_start = max(self.index_mappings[-1](indices[-1] - hw), 0) - index_p_stop = min(self.index_mappings[-1](indices[-1] + hw), self.shape[-2]) - if index_p_start >= index_p_stop: - index_p_stop = index_p_start + 1 + # start in reference units + start_ref = indices[p_dim] - hw + # stop in reference units + stop_ref = indices[p_dim] + hw - # round to the nearest integer since to use as arra indices - slices = [slice(round(index_p_start), round(index_p_stop))] + # map to array indices + start = self._ref_index_to_array_index(p_dim, start_ref) + stop = self._ref_index_to_array_index(p_dim, stop_ref) - if self.multi: - slices.insert(0, slice(None)) + if start >= stop: + stop = start + 1 - return tuple(slices) + return slice(start, stop) - def get(self, indices: tuple[Any, ...]): + def _apply_dw_window_func(self, array: np.ndarray) -> np.ndarray: """ - slices through all slider dims and outputs an array that can be used to set graphic data + Takes array where display window has already been applied and applies window functions on the `p` dim. - Note that we do not use __getitem__ here since the index is a tuple specifying a single integer - index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. + Parameters + ---------- + array: np.ndarray + array of shape: [l, display_window, 2 | 3] + + Returns + ------- + np.ndarray + array with window functions applied along `p` dim """ - # apply any slider index mappings - array_indices = tuple([m(i) for m, i in zip(self.index_mappings, indices)]) - - if len(array_indices) > 1: - # there are dims in addition to the n_datapoints dim - # apply window funcs - # window_output array should be of shape [n_datapoints, 2 | 3] - window_output = self._apply_window_functions(array_indices[:-1]).squeeze() - else: - window_output = self.data + if self.display_window == 0: + # can't apply window func when there is only 1 datapoint + return array - if self.display_window is not None: - # display_window is in reference units - slices = self._get_dw_slices(indices) - - # if self.display_window is not None: - # # display window is interpreted using the index mapping for the `p` dim - # dw = self.index_mappings[-1](self.display_window) - # - # if dw == 1: - # slices = [slice(indices[-1], indices[-1] + 1)] - # - # else: - # # half window size - # hw = dw // 2 - # - # # for now assume just a single index provided that indicates x axis value - # start = max(indices[-1] - hw, 0) - # stop = start + dw - # # also add window size of `p` dim so window_func output has the same number of datapoints - # if ( - # self.datapoints_window_func is not None - # and self.datapoints_window_size is not None - # ): - # stop += self.datapoints_window_size - 1 - # # TODO: pad with constant if we're using a window func and the index is near the end - # - # # TODO: uncomment this once we have resizeable buffers!! - # # stop = min(indices[-1] + hw, self.shape[-2]) - # - # slices = [slice(start, stop)] - # - # if self.multi: - # # n - 2 dim is n_lines or n_scatters - # slices.insert(0, slice(None)) + p_dim = self.spatial_dims[1] - # data that will be used for the graphical representation - # a copy is made, if there were no window functions then this is a view of the original data - graphic_data = window_output[tuple(slices)] + # display window in array index space + dw = self.index_mappings[p_dim](self.display_window) - dw = self.index_mappings[-1](self.display_window) + # step size based on max number of datapoints to render + step = max(1, dw // self.max_display_datapoints) # apply window function on the `p` n_datapoints dim if ( self.datapoints_window_func is not None - and self.datapoints_window_size is not None - # if there are too many points to efficiently compute the window func + # if there are too many points to efficiently compute the window func, skip # applying a window func also requires making a copy so that's a further performance hit and (dw < self.max_display_datapoints * 2) ): # get windows - # graphic_data will be of shape: [n, p + (ws - 1), 2 | 3] + # graphic_data will be of shape: [n, p, 2 | 3] # where: # n - number of lines, scatters, heatmap rows # p - number of datapoints/samples - wf = self.datapoints_window_func[0] - apply_dims = self.datapoints_window_func[1] - ws = self.datapoints_window_size + # ws is in ref units + wf, apply_dims, ws = self.datapoints_window_func + + # map ws in ref units to array index + ws = self._ref_index_to_array_index(p_dim, ws) + + if ws % 2 == 0: + # odd size windows are easier to handle + ws += 1 + + hw = ws // 2 + start, stop = hw, array.shape[1] - hw # apply user's window func # result will be of shape [n, p, 2 | 3] if apply_dims == "all": # windows will be of shape [n, p, 1 | 2 | 3, ws] - windows = sliding_window_view(graphic_data, ws, axis=-2) - return wf(windows, axis=-1) + windows = sliding_window_view(array, ws, axis=-2) + return wf(windows, axis=-1)[:, ::step] # map user dims str to tuple of numerical dims dims = tuple(map({"x": 0, "y": 1, "z": 2}.get, apply_dims)) - # windows will be of shape [n, p, 1 | 2 | 3, ws] + # windows will be of shape [n, (p - ws + 1), 1 | 2 | 3, ws] windows = sliding_window_view( - graphic_data[..., dims], ws, axis=-2 + array[..., dims], ws, axis=-2 ).squeeze() # make a copy because we need to modify it - graphic_data = graphic_data.copy() + array = array[:, start:stop].copy() # this reshape is required to reshape wf outputs of shape [n, p] -> [n, p, 1] only when necessary - # we need to slice upto dw since we add the `datapoints_window_size` above - graphic_data[..., :dw, dims] = wf(windows, axis=-1).reshape( - graphic_data.shape[0], dw, len(dims) + array[..., dims] = wf(windows, axis=-1).reshape( + *array.shape[:-1], len(dims) ) - return graphic_data[ - ..., : dw : max(1, dw // self.max_display_datapoints), : - ] + return array[:, ::step] + + return array[:, ::step] + + def get(self, indices: dict[str, Any]): + """ + slices through all slider dims and outputs an array that can be used to set graphic data + + Note that we do not use __getitem__ here since the index is a tuple specifying a single integer + index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. + """ + # # map slider dim indices to array indices + # array_indices = tuple([m(i) for m, i in zip(self.index_mappings, indices)]) + + if len(self.slider_dims) > 0: + # there are dims in addition to the spatial dims + window_output = self._apply_window_functions(indices).squeeze() + else: + # no slider dims, use all the data + window_output = self.data + + # verify window output only has the spatial dims + if not set(window_output.dims) == set(self.spatial_dims): + raise IndexError - return graphic_data[ - ..., - : graphic_data.shape[-2] : max( - 1, graphic_data.shape[-2] // self.max_display_datapoints - ), - :, - ] + # get slice obj for display window + dw_slice = self._get_dw_slice(indices) + + # data that will be used for the graphical representation + # a copy is made, if there were no window functions then this is a view of the original data + p_dim = self.spatial_dims[1] + + # slice the datapoints to be displayed in the graphic using the display window slice + # transpose to match spatial dims order, get numpy array, this is a view + graphic_data = ( + window_output.isel({p_dim: dw_slice}).transpose(*self.spatial_dims).values + ) + + return self._apply_dw_window_func(graphic_data) class NDPositions(NDGraphic): @@ -355,7 +303,6 @@ def __init__( | ImageGraphic ], processor: type[NDPositionsProcessor] = NDPositionsProcessor, - multi: bool = False, display_window: int = 10, window_funcs: tuple[WindowFuncCallable | None] | None = None, window_sizes: tuple[int | None] | None = None, @@ -368,16 +315,12 @@ def __init__( ): self._global_index = global_index - if issubclass(graphic, LineCollection): - multi = True - if processor_kwargs is None: processor_kwargs = dict() self._processor = processor( data, *args, - multi=multi, display_window=display_window, max_display_datapoints=max_display_datapoints, window_funcs=window_funcs, @@ -394,8 +337,12 @@ def __init__( self._last_x_range = np.array([0.0, 0.0], dtype=np.float32) if linear_selector: - self._linear_selector = LinearSelector(0, limits=(-np.inf, np.inf), edge_color="cyan") - self._linear_selector.add_event_handler(self._linear_selector_handler, "selection") + self._linear_selector = LinearSelector( + 0, limits=(-np.inf, np.inf), edge_color="cyan" + ) + self._linear_selector.add_event_handler( + self._linear_selector_handler, "selection" + ) else: self._linear_selector = None @@ -434,15 +381,17 @@ def graphic(self, graphic_type): @property def indices(self) -> tuple: - return self._global_index.indices[-self.processor.n_slider_dims:] + return self._global_index.indices[-self.processor.n_slider_dims :] @indices.setter @block_reentrance def indices(self, indices): # upto the number of slider dims in this data - indices = indices[-self.processor.n_slider_dims:] + indices = indices[-self.processor.n_slider_dims :] data_slice = self.processor.get(indices) + # TODO: set other graphic features, colors, sizes, markers, etc. + if isinstance(self.graphic, (LineGraphic, ScatterGraphic)): self.graphic.data[:, : data_slice.shape[-1]] = data_slice @@ -504,9 +453,6 @@ def _create_graphic( data_slice = self.processor.get(self.indices) if issubclass(graphic_cls, ImageGraphic): - if not self.processor.multi: - raise ValueError - if self.processor.shape[-1] != 2: raise ValueError @@ -578,9 +524,7 @@ def x_range_mode(self) -> Literal[None, "fixed-window", "view-range"]: def x_range_mode(self, mode: Literal[None, "fixed-window", "view-range"]): if self._x_range_mode == "view-range": # old mode was view-range - self.graphic._plot_area.remove_animation( - self._update_from_view_range - ) + self.graphic._plot_area.remove_animation(self._update_from_view_range) if mode == "view-range": self.graphic._plot_area.add_animations(self._update_from_view_range) diff --git a/fastplotlib/widgets/nd_widget/_ndw_subplot.py b/fastplotlib/widgets/nd_widget/_ndw_subplot.py index 92ec69d74..4503c59ae 100644 --- a/fastplotlib/widgets/nd_widget/_ndw_subplot.py +++ b/fastplotlib/widgets/nd_widget/_ndw_subplot.py @@ -36,7 +36,7 @@ def add_nd_image(self, *args, **kwargs): def add_nd_scatter(self, *args, **kwargs): nd = NDPositions( - self.ndw.indices, *args, graphic=ScatterCollection, multi=True, **kwargs + self.ndw.indices, *args, graphic=ScatterCollection, **kwargs ) self._nd_graphics.append(nd) self._subplot.add_graphic(nd.graphic) @@ -54,7 +54,6 @@ def add_nd_timeseries( self.ndw.indices, *args, graphic=graphic, - multi=True, # x_range_mode=x_range_mode, linear_selector=True, **kwargs, @@ -71,7 +70,7 @@ def add_nd_timeseries( return nd def add_nd_lines(self, *args, **kwargs): - nd = NDPositions(*args, graphic=LineCollection, multi=True, **kwargs) + nd = NDPositions(*args, graphic=LineCollection, **kwargs) self._nd_graphics.append(nd) self._subplot.add_graphic(nd.graphic) return nd diff --git a/fastplotlib/widgets/nd_widget/base.py b/fastplotlib/widgets/nd_widget/base.py index 973b18331..c81053a62 100644 --- a/fastplotlib/widgets/nd_widget/base.py +++ b/fastplotlib/widgets/nd_widget/base.py @@ -93,6 +93,9 @@ def _validate_data(self, data: ArrayProtocol, dims): if not isinstance(data, ArrayProtocol): raise TypeError("`data` must implement the ArrayProtocol") + if data.ndim != len(dims): + raise IndexError("must specify a dim for every dimension in the data array") + return xr.DataArray(data, dims=dims) @property @@ -160,13 +163,11 @@ def window_funcs( ): if window_funcs is None: # tuple of (None, None) makes the checks easier in _apply_window_funcs - self._window_funcs = {d: (None, None) for d in self.data.dims} + self._window_funcs = {d: (None, None) for d in self.slider_dims} return for k in window_funcs.keys(): - if k not in self.dims: - raise KeyError - if k in self.spatial_dims: + if k not in self.slider_dims: raise KeyError func = window_funcs[k][0] @@ -238,12 +239,9 @@ def index_mappings(self, maps: dict[Hashable, Callable[[Any], int] | ArrayLike | return for d in maps.keys(): - if d not in self.dims: + if d not in self.slider_dims: raise KeyError - if d in self.spatial_dims: - raise KeyError("index mappings only apply to slider dims, not spatial dims") - if isinstance(maps[d], ArrayProtocol): # create a searchsorted mapping function automatically maps[d] = maps[d].searchsorted @@ -259,15 +257,24 @@ def index_mappings(self, maps: dict[Hashable, Callable[[Any], int] | ArrayLike | self._index_mappings = maps - def _apply_window_functions(self, indices: dict[Hashable, Any]) -> np.ndarray: + def _ref_index_to_array_index(self, dim: str, ref_index: Any) -> int: + # wraps index mappings, clamps between 0 and max array index for this dimension + index = self.index_mappings[dim](ref_index) + + return max(min(index, self.shape[dim] - 1), 0) + + def _get_slider_dims_indexer(self, indices) -> dict: if set(indices.keys()) != set(self.slider_dims): raise IndexError( f"Must provide an index for all slider dims: {self.slider_dims}, you have provided: {indices.keys()}" ) indexer = dict() + # get only slider dims which are not also spatial dims (example: p dim for positional data) + # since that is dealt with separately + slider_dims = set(self.slider_dims) - set(self.spatial_dims) # go through each slider dim and accumulate slice objects - for dim in self.slider_dims: + for dim in slider_dims: # index for this dim in reference space index_ref = indices[dim] @@ -303,8 +310,16 @@ def _apply_window_functions(self, indices: dict[Hashable, Any]) -> np.ndarray: # stop index is just the start index + 1 indexer[dim] = slice(start, start + 1, 1) - # apply indexer with any specified windows, return the underlying numpy array - data_sliced = self.data.isel(indexer).values + return indexer + + def _apply_window_functions(self, indices) -> xr.DataArray: + """slice with windows at given indices and apply window functions""" + indexer = self._get_slider_dims_indexer(indices) + + # there is significant overhead with passing xarray objects to numpy for things like np.mean() + # so convert to numpy, apply window functions, then convert back to xarray + # creating an xarray object from a numpy array has very little overhead, ~10 microseconds + array = self.data.isel(indexer).values # apply window funcs in the specified order for dim in self.window_order: @@ -313,9 +328,9 @@ def _apply_window_functions(self, indices: dict[Hashable, Any]) -> np.ndarray: func, _ = self.window_funcs[dim] - data_sliced = func(data_sliced, axis=self.dims.index(dim), keepdims=True) + array = func(array, axis=self.dims.index(dim), keepdims=True) - return data_sliced + return xr.DataArray(array, dims=self.dims) def get(self, indices: dict[Hashable, Any]): raise NotImplementedError From 42eb3f9df9e92d9813ba327e950487e9c2950ca3 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 3 Mar 2026 06:05:52 -0500 Subject: [PATCH 050/101] display_window window funcs working for NDPositions --- fastplotlib/widgets/nd_widget/_index.py | 91 ++++++++----------- .../widgets/nd_widget/_nd_positions/core.py | 33 ++++--- fastplotlib/widgets/nd_widget/_ndw_subplot.py | 1 + fastplotlib/widgets/nd_widget/ndwidget.py | 14 +-- 4 files changed, 63 insertions(+), 76 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_index.py b/fastplotlib/widgets/nd_widget/_index.py index f30f93b94..9d96626a7 100644 --- a/fastplotlib/widgets/nd_widget/_index.py +++ b/fastplotlib/widgets/nd_widget/_index.py @@ -6,10 +6,11 @@ @dataclass class ReferenceRangeContinuous: + name: str + unit: str start: int | float stop: int | float step: int | float - unit: str def __getitem__(self, index: int): """return the value at the index w.r.t. the step size""" @@ -32,8 +33,9 @@ def range(self) -> int | float: @dataclass class ReferenceRangeDiscrete: - options: Sequence[Any] + name: str unit: str + options: Sequence[Any] def __getitem__(self, index: int): if index > len(self.options): @@ -45,72 +47,51 @@ def __len__(self): return len(self.options) -class GlobalIndexVector: - def __init__(self, ref_ranges: list, get_ndgraphics: Callable[[], tuple[NDGraphic]]): - self._ref_ranges = list() +class GlobalIndex: + def __init__(self, ref_ranges: dict[str, tuple], get_ndgraphics: Callable[[], tuple[NDGraphic]]): + self._ref_ranges = dict() - for r in ref_ranges: - if len(r) == 4: - # assume start, stop, step, unit - refr = ReferenceRangeContinuous(*r) - elif len(r) == 2: - refr = ReferenceRangeDiscrete(*r) + for r in ref_ranges.values(): + if len(r) == 5: + # assume name, unit, start, stop, step + rr = ReferenceRangeContinuous(*r) + elif len(r) == 3: + rr = ReferenceRangeDiscrete(*r) else: raise ValueError - self._ref_ranges.append(refr) + self._ref_ranges[rr.name] = rr self._get_ndgraphics = get_ndgraphics # starting index for all dims - self._indices: list[int | float | Any] = [refr[0] for refr in self.ref_ranges] + self._indices: dict[str, int | float | Any] = {rr.name: rr.start for rr in self._ref_ranges.values()} + + def set(self, indices: dict[str, Any]): + for k in self._indices: + self._indices[k] = indices[k] - @property - def indices(self) -> tuple[int | float | Any, ...]: - # TODO: clamp index to given ref range here - # graphics will clamp according to their own array sizes? - return tuple(self._indices) - - @indices.setter - def indices(self, new_indices: tuple[int | float | Any, ...]): - self._indices[:] = new_indices self._render_indices() def _render_indices(self): for g in self._get_ndgraphics(): - g.indices = self.indices + g.indices = {d: self._indices[d] for d in g.processor.slider_dims} @property - def dims(self) -> tuple[str, ...]: - return tuple([ref.unit for ref in self.ref_ranges]) + def ref_ranges(self) -> dict[str, ReferenceRangeContinuous | ReferenceRangeDiscrete]: + return self._ref_ranges - @property - def ref_ranges(self) -> tuple[ReferenceRangeContinuous, ...]: - return tuple(self._ref_ranges) - - def __getitem__(self, item): - if isinstance(item, int): - # integer index in the list - return self._indices[item] - - for i, rr in enumerate(self.ref_ranges): - if rr.unit == item: - return self._indices[i] - - raise KeyError - - def __setitem__(self, key, value): - # TODO: set the index for the given dimension only - if isinstance(key, str): - for i, rr in enumerate(self.ref_ranges): - if rr.unit == key: - key = i - break - else: - raise KeyError + def __getitem__(self, dim): + return self._indices[dim] + + def __setitem__(self, dim, value): + # set index for given dim and render - # set index for given dim - self._indices[key] = value + # clamp within reference range + if isinstance(self.ref_ranges[dim], ReferenceRangeContinuous): + value = max(min(value, self.ref_ranges[dim].stop - self.ref_ranges[dim].step), self.ref_ranges[dim].start) + + self._indices[dim] = value self._render_indices() def pop_dim(self): @@ -121,7 +102,7 @@ def push_dim(self, ref_range: ReferenceRangeContinuous): pass def __iter__(self): - for index in self.indices: + for index in self._indices: yield index def __len__(self): @@ -131,8 +112,10 @@ def __eq__(self, other): return self._indices == other def __repr__(self): - named = ", ".join([f"{d}: {i}" for d, i in zip(self.dims, self._indices)]) - return f"Indices: {named}" + return f"Global Index: {self._indices}" + + def __str__(self): + return str(self._indices) class SelectionVector: diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/core.py b/fastplotlib/widgets/nd_widget/_nd_positions/core.py index 42577e387..35183b031 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/core.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/core.py @@ -28,7 +28,7 @@ block_reentrance, block_indices, ) -from .._index import GlobalIndexVector +from .._index import GlobalIndex # TODO: Maybe get rid of n_display_dims in NDProcessor, @@ -154,7 +154,7 @@ def _get_dw_slice(self, indices: dict[str, Any]) -> slice: if self.datapoints_window_func is not None: # add half datapoints_window_func size here, assumes the reference space is somewhat continuous # and the display_window and datapoints window size map to their actual size values - hw += self._ref_index_to_array_index(p_dim, self.datapoints_window_func[2] / 2) + hw += self.datapoints_window_func[2] / 2 # display window is in reference units, apply display window and then map to array indices # start in reference units @@ -215,7 +215,8 @@ def _apply_dw_window_func(self, array: np.ndarray) -> np.ndarray: wf, apply_dims, ws = self.datapoints_window_func # map ws in ref units to array index - ws = self._ref_index_to_array_index(p_dim, ws) + # min window size is 3 + ws = max(self._ref_index_to_array_index(p_dim, ws), 3) if ws % 2 == 0: # odd size windows are easier to handle @@ -291,8 +292,10 @@ def get(self, indices: dict[str, Any]): class NDPositions(NDGraphic): def __init__( self, - global_index: GlobalIndexVector, + global_index: GlobalIndex, data: Any, + dims: Sequence[str], + spatial_dims: tuple[str, str, str], *args, graphic: Type[ LineGraphic @@ -305,7 +308,6 @@ def __init__( processor: type[NDPositionsProcessor] = NDPositionsProcessor, display_window: int = 10, window_funcs: tuple[WindowFuncCallable | None] | None = None, - window_sizes: tuple[int | None] | None = None, index_mappings: tuple[Callable[[Any], int] | None] | None = None, max_display_datapoints: int = 1_000, linear_selector: bool = False, @@ -320,11 +322,12 @@ def __init__( self._processor = processor( data, + dims, + spatial_dims, *args, display_window=display_window, max_display_datapoints=max_display_datapoints, window_funcs=window_funcs, - window_sizes=window_sizes, index_mappings=index_mappings, **processor_kwargs, ) @@ -380,14 +383,12 @@ def graphic(self, graphic_type): plot_area.add_graphic(self._graphic) @property - def indices(self) -> tuple: - return self._global_index.indices[-self.processor.n_slider_dims :] + def indices(self) -> dict[Hashable, Any]: + return {d: self._global_index[d] for d in self.processor.slider_dims} @indices.setter @block_reentrance def indices(self, indices): - # upto the number of slider dims in this data - indices = indices[-self.processor.n_slider_dims :] data_slice = self.processor.get(indices) # TODO: set other graphic features, colors, sizes, markers, etc. @@ -422,12 +423,13 @@ def indices(self, indices): if self._linear_selector is not None: with pause_events(self._linear_selector): self._linear_selector.limits = xr - self._linear_selector.selection = indices[-1] + # linear selector acts on `p` dim + self._linear_selector.selection = indices[self.processor.spatial_dims[1]] def _linear_selector_handler(self, ev): with block_indices(self): # linear selector always acts on the `p` dim - self._global_index[-1] = ev.info["value"] + self._global_index[self.processor.spatial_dims[1]] = ev.info["value"] def _tooltip_handler(self, graphic, pick_info): if isinstance(self.graphic, (LineCollection, ScatterCollection)): @@ -453,7 +455,8 @@ def _create_graphic( data_slice = self.processor.get(self.indices) if issubclass(graphic_cls, ImageGraphic): - if self.processor.shape[-1] != 2: + # `d` dim must only have xy data to be interpreted as a heatmap, xyz can't become a timeseries heatmap + if self.processor.shape[self.processor.spatial_dims[-1]] != 2: raise ValueError image_data, x0, x_scale = self._create_heatmap_data(data_slice) @@ -544,9 +547,9 @@ def _update_from_view_range(self): new_width = abs(xr[1] - xr[0]) new_index = (xr[0] + xr[1]) / 2 - if (new_index == self._global_index[-1]) and (last_width == new_width): + if (new_index == self._global_index[self.processor.spatial_dims[1]]) and (last_width == new_width): return self.processor.display_window = new_width # set the `p` dim on the global index vector - self._global_index[-1] = new_index + self._global_index[self.processor.spatial_dims[1]] = new_index diff --git a/fastplotlib/widgets/nd_widget/_ndw_subplot.py b/fastplotlib/widgets/nd_widget/_ndw_subplot.py index 4503c59ae..677661b9b 100644 --- a/fastplotlib/widgets/nd_widget/_ndw_subplot.py +++ b/fastplotlib/widgets/nd_widget/_ndw_subplot.py @@ -35,6 +35,7 @@ def add_nd_image(self, *args, **kwargs): return nd def add_nd_scatter(self, *args, **kwargs): + # TODO: better func signature here, send all kwargs to processor_kwargs nd = NDPositions( self.ndw.indices, *args, graphic=ScatterCollection, **kwargs ) diff --git a/fastplotlib/widgets/nd_widget/ndwidget.py b/fastplotlib/widgets/nd_widget/ndwidget.py index c4bd8fb1e..b755296ee 100644 --- a/fastplotlib/widgets/nd_widget/ndwidget.py +++ b/fastplotlib/widgets/nd_widget/ndwidget.py @@ -1,14 +1,14 @@ from typing import Any -from ._index import ReferenceRangeContinuous, ReferenceRangeDiscrete, GlobalIndexVector +from ._index import ReferenceRangeContinuous, ReferenceRangeDiscrete, GlobalIndex from ._ndw_subplot import NDWSubplot from ._ui import NDWidgetUI from ...layouts import ImguiFigure, Subplot class NDWidget: - def __init__(self, ref_ranges: list[tuple], **kwargs): - self._indices = GlobalIndexVector(ref_ranges, self._get_ndgraphics) + def __init__(self, ref_ranges: dict[str, tuple], **kwargs): + self._indices = GlobalIndex(ref_ranges, self._get_ndgraphics) self._figure = ImguiFigure(**kwargs) self._subplots_nd: dict[Subplot, NDWSubplot] = dict() @@ -26,15 +26,15 @@ def figure(self) -> ImguiFigure: return self._figure @property - def indices(self) -> GlobalIndexVector: + def indices(self) -> GlobalIndex: return self._indices @indices.setter - def indices(self, new_indices: tuple[int | float | Any, ...]): - self._indices.indices = new_indices + def indices(self, new_indices: dict[str, int | float | Any]): + self._indices.set = new_indices @property - def ref_ranges(self) -> tuple[ReferenceRangeContinuous | ReferenceRangeDiscrete, ...]: + def ref_ranges(self) -> dict[str, ReferenceRangeContinuous | ReferenceRangeDiscrete]: return self._indices.ref_ranges def __getitem__(self, key: str | tuple[int, int] | Subplot): From c251a6d438ac74f614f504b8dbfd3fefa8dcaad3 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 3 Mar 2026 06:24:53 -0500 Subject: [PATCH 051/101] imgui stuff --- fastplotlib/widgets/nd_widget/_index.py | 2 +- fastplotlib/widgets/nd_widget/_ui.py | 32 ++++++----------------- fastplotlib/widgets/nd_widget/ndwidget.py | 2 +- 3 files changed, 10 insertions(+), 26 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_index.py b/fastplotlib/widgets/nd_widget/_index.py index 9d96626a7..20d9abc7d 100644 --- a/fastplotlib/widgets/nd_widget/_index.py +++ b/fastplotlib/widgets/nd_widget/_index.py @@ -102,7 +102,7 @@ def push_dim(self, ref_range: ReferenceRangeContinuous): pass def __iter__(self): - for index in self._indices: + for index in self._indices.items(): yield index def __len__(self): diff --git a/fastplotlib/widgets/nd_widget/_ui.py b/fastplotlib/widgets/nd_widget/_ui.py index a2198d6c9..17f908384 100644 --- a/fastplotlib/widgets/nd_widget/_ui.py +++ b/fastplotlib/widgets/nd_widget/_ui.py @@ -62,46 +62,30 @@ def __init__(self, figure, size, ndwidget): self._max_display_windows: dict[NDGraphic, float | int] = dict() def update(self): - indices_changed = False - if imgui.begin_tab_bar("NDWidget Controls"): if imgui.begin_tab_item("Indices")[0]: - for dim_index, (current_index, refr) in enumerate( - zip(self._ndwidget.indices, self._ndwidget.ref_ranges) - ): + for dim, current_index in self._ndwidget.indices: + refr = self._ndwidget.ref_ranges[dim] + if isinstance(refr, ReferenceRangeContinuous): changed, new_index = imgui.slider_float( v=current_index, v_min=refr.start, v_max=refr.stop, - label=refr.unit, + label=dim, ) # TODO: refactor all this stuff, make fully fledged UI if changed: - new_indices = list(self._ndwidget.indices) - new_indices[dim_index] = new_index - - indices_changed = True + self._ndwidget.indices[dim] = new_index elif imgui.is_item_hovered(): if imgui.is_key_pressed(imgui.Key.right_arrow): - new_index = current_index + refr.step - new_indices = list(self._ndwidget.indices) - new_indices[dim_index] = new_index - - indices_changed = True - - if imgui.is_key_pressed(imgui.Key.left_arrow): - new_index = current_index - refr.step - new_indices = list(self._ndwidget.indices) - new_indices[dim_index] = new_index - - indices_changed = True + self._ndwidget.indices[dim] = current_index + refr.step - if indices_changed: - self._ndwidget.indices = tuple(new_indices) + elif imgui.is_key_pressed(imgui.Key.left_arrow): + self._ndwidget.indices[dim] = current_index - refr.step imgui.end_tab_item() diff --git a/fastplotlib/widgets/nd_widget/ndwidget.py b/fastplotlib/widgets/nd_widget/ndwidget.py index b755296ee..534c1a922 100644 --- a/fastplotlib/widgets/nd_widget/ndwidget.py +++ b/fastplotlib/widgets/nd_widget/ndwidget.py @@ -31,7 +31,7 @@ def indices(self) -> GlobalIndex: @indices.setter def indices(self, new_indices: dict[str, int | float | Any]): - self._indices.set = new_indices + self._indices.set(new_indices) @property def ref_ranges(self) -> dict[str, ReferenceRangeContinuous | ReferenceRangeDiscrete]: From 4ba98078b688eea2ce986fb2cfe448a9b9e29e99 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 3 Mar 2026 08:04:52 -0500 Subject: [PATCH 052/101] finish migrate NDImage to xarray, basics work --- fastplotlib/widgets/nd_widget/_index.py | 1 + fastplotlib/widgets/nd_widget/_nd_image.py | 494 ++++-------------- .../nd_widget/_nd_positions/_pandas.py | 2 + .../widgets/nd_widget/_nd_positions/core.py | 9 +- fastplotlib/widgets/nd_widget/_ndw_subplot.py | 7 +- fastplotlib/widgets/nd_widget/base.py | 24 +- 6 files changed, 144 insertions(+), 393 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_index.py b/fastplotlib/widgets/nd_widget/_index.py index 20d9abc7d..d7f60ba7e 100644 --- a/fastplotlib/widgets/nd_widget/_index.py +++ b/fastplotlib/widgets/nd_widget/_index.py @@ -75,6 +75,7 @@ def set(self, indices: dict[str, Any]): def _render_indices(self): for g in self._get_ndgraphics(): + # only provide slider indices to the graphic g.indices = {d: self._indices[d] for d in g.processor.slider_dims} @property diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py index 262693004..16a4686ec 100644 --- a/fastplotlib/widgets/nd_widget/_nd_image.py +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -1,10 +1,11 @@ -from collections.abc import Hashable +from collections.abc import Hashable, Sequence import inspect from typing import Literal, Callable, Type, Any from warnings import warn import numpy as np from numpy.typing import ArrayLike +import xarray as xr from ...utils import subsample_array, ArrayProtocol, ARRAY_LIKE_ATTRS from ...graphics import ImageGraphic, ImageVolumeGraphic @@ -15,14 +16,16 @@ class NDImageProcessor(NDProcessor): def __init__( self, data: ArrayLike | None, - n_display_dims: Literal[2, 3] = 2, - rgb: bool = False, + dims: Sequence[Hashable], + spatial_dims: ( + tuple[str, str] | tuple[str, str, str] + ), # must be in order! [rows, cols] | [z, rows, cols] + rgb_dim: str | None = None, window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, - window_sizes: tuple[int | None, ...] | int = None, window_order: tuple[int, ...] = None, spatial_func: Callable[[ArrayLike], ArrayLike] = None, compute_histogram: bool = True, - index_mappings = None, + index_mappings=None, ): """ An ND image that supports computing window functions, and functions over spatial dimensions. @@ -71,33 +74,32 @@ def __init__( # set as False until data, window funcs stuff and spatial func is all set self._compute_histogram = False - self.data = data - self.n_display_dims = n_display_dims - self.rgb = rgb - - self.window_funcs = window_funcs - self.window_sizes = window_sizes - self.window_order = window_order - - self._spatial_func = spatial_func + super().__init__( + data=data, + dims=dims, + spatial_dims=spatial_dims, + index_mappings=index_mappings, + window_funcs=window_funcs, + window_order=window_order, + spatial_func=spatial_func, + ) + self.rgb_dim = rgb_dim self._compute_histogram = compute_histogram self._recompute_histogram() - self._index_mappings = self._validate_index_mappings(index_mappings) - @property - def data(self) -> ArrayLike | None: + def data(self) -> xr.DataArray | None: """get or set the data array""" return self._data @data.setter def data(self, data: ArrayLike): # check that all array-like attributes are present - if data is None: - self._data = None - return + self._data = self._validate_data(data, self.dims) + self._recompute_histogram() + def _validate_data(self, data: ArrayProtocol, dims): if not isinstance(data, ArrayProtocol): raise TypeError( f"`data` arrays must have all of the following attributes to be sufficiently array-like:\n" @@ -109,235 +111,21 @@ def data(self, data: ArrayLike): f"Image data must have a minimum of 2 dimensions, you have passed an array of shape: {data.shape}" ) - self._data = data - self._recompute_histogram() + return xr.DataArray(data, dims=dims) @property - def ndim(self) -> int: - if self.data is None: - return 0 - - return self.data.ndim - - @property - def shape(self) -> tuple[int, ...]: - if self._data is None: - return tuple() - - return self.data.shape - - @property - def rgb(self) -> bool: - """whether or not the data is rgb(a)""" + def rgb_dim(self) -> str | None: + """indicates the rgb dim if one exists""" return self._rgb - @rgb.setter - def rgb(self, rgb: bool): - if not isinstance(rgb, bool): - raise TypeError - - if rgb and self.ndim < 3: - raise IndexError( - f"require 3 or more dims for RGB, you have: {self.ndim} dims" - ) + @rgb_dim.setter + def rgb_dim(self, rgb: str | None): + if rgb is not None: + if rgb not in self.dims: + raise KeyError self._rgb = rgb - @property - def slider_dims(self) -> set[Hashable]: - return set(self.dims) - set(self.spatial_dims) - - @property - def n_slider_dims(self): - return len(self.slider_dims) - - @property - def slider_dims_shape(self) -> tuple[int, ...] | None: - if self.n_slider_dims == 0: - return None - - return tuple(self.shape[i] for i in self.slider_dims) - - @property - def n_display_dims(self) -> Literal[2, 3]: - """get or set the number of display dimensions, `2` for 2D image and `3` for volume images""" - return self._n_display_dims - - # TODO: make n_display_dims settable, requires thinking about inserting and poping indices in ImageWidget - @n_display_dims.setter - def n_display_dims(self, n: Literal[2, 3]): - if not (n == 2 or n == 3): - raise ValueError( - f"`n_display_dims` must be an with a value of 2 or 3, you have passed: {n}" - ) - self._n_display_dims = n - self._recompute_histogram() - - @property - def max_n_display_dims(self) -> int: - """maximum number of possible display dims""" - # min 2, max 3, accounts for if data is None and ndim is 0 - return max(2, min(3, self.ndim - int(self.rgb))) - - @property - def display_dims(self) -> tuple[int, int] | tuple[int, int, int]: - """tuple indicating the display dimension indices""" - return tuple(range(self.data.ndim))[self.n_slider_dims :] - - @property - def window_funcs( - self, - ) -> tuple[WindowFuncCallable | None, ...] | None: - """get or set window functions, see docstring for details""" - return self._window_funcs - - @window_funcs.setter - def window_funcs( - self, - window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable | None, - ): - if window_funcs is None: - self._window_funcs = None - return - - if callable(window_funcs): - window_funcs = (window_funcs,) - - # if all are None - if all([f is None for f in window_funcs]): - self._window_funcs = None - return - - self._validate_window_func(window_funcs) - - self._window_funcs = tuple(window_funcs) - self._recompute_histogram() - - def _validate_window_func(self, funcs): - if isinstance(funcs, (tuple, list)): - for f in funcs: - if f is None: - pass - elif callable(f): - sig = inspect.signature(f) - - if "axis" not in sig.parameters or "keepdims" not in sig.parameters: - raise TypeError( - f"Each window function must take an `axis` and `keepdims` argument, " - f"you passed: {f} with the following function signature: {sig}" - ) - else: - raise TypeError( - f"`window_funcs` must be of type: tuple[Callable | None, ...], you have passed: {funcs}" - ) - - if not (len(funcs) == self.n_slider_dims or self.n_slider_dims == 0): - raise IndexError( - f"number of `window_funcs` must be the same as the number of slider dims: {self.n_slider_dims}, " - f"and you passed {len(funcs)} `window_funcs`: {funcs}" - ) - - @property - def window_sizes(self) -> tuple[int | None, ...] | None: - """get or set window sizes used for the corresponding window functions, see docstring for details""" - return self._window_sizes - - @window_sizes.setter - def window_sizes(self, window_sizes: tuple[int | None, ...] | int | None): - if window_sizes is None: - self._window_sizes = None - return - - if isinstance(window_sizes, int): - window_sizes = (window_sizes,) - - # if all are None - if all([w is None for w in window_sizes]): - self._window_sizes = None - return - - if not all([isinstance(w, (int)) or w is None for w in window_sizes]): - raise TypeError( - f"`window_sizes` must be of type: tuple[int | None, ...] | int | None, you have passed: {window_sizes}" - ) - - if not (len(window_sizes) == self.n_slider_dims or self.n_slider_dims == 0): - raise IndexError( - f"number of `window_sizes` must be the same as the number of slider dims, " - f"i.e. `data.ndim` - n_display_dims, your data array has {self.ndim} dimensions " - f"and you passed {len(window_sizes)} `window_sizes`: {window_sizes}" - ) - - # make all window sizes are valid numbers - _window_sizes = list() - for i, w in enumerate(window_sizes): - if w is None: - _window_sizes.append(None) - continue - - if w < 0: - raise ValueError( - f"negative window size passed, all `window_sizes` must be positive " - f"integers or `None`, you passed: {_window_sizes}" - ) - - if w == 0 or w == 1: - # this is not a real window, set as None - w = None - - elif w % 2 == 0: - # odd window sizes makes most sense - warn( - f"provided even window size: {w} in dim: {i}, adding `1` to make it odd" - ) - w += 1 - - _window_sizes.append(w) - - self._window_sizes = tuple(_window_sizes) - self._recompute_histogram() - - @property - def window_order(self) -> tuple[int, ...] | None: - """get or set dimension order in which window functions are applied""" - return self._window_order - - @window_order.setter - def window_order(self, order: tuple[int] | None): - if order is None: - self._window_order = None - return - - if order is not None: - if not all([d <= self.n_slider_dims for d in order]): - raise IndexError( - f"all `window_order` entries must be <= n_slider_dims\n" - f"`n_slider_dims` is: {self.n_slider_dims}, you have passed `window_order`: {order}" - ) - - if not all([d >= 0 for d in order]): - raise IndexError( - f"all `window_order` entires must be >= 0, you have passed: {order}" - ) - - self._window_order = tuple(order) - self._recompute_histogram() - - @property - def spatial_func(self) -> Callable[[ArrayLike], ArrayLike] | None: - """get or set a spatial_func function, see docstring for details""" - return self._spatial_func - - @spatial_func.setter - def spatial_func(self, func: Callable[[ArrayLike], ArrayLike] | None): - if not (callable(func) or func is not None): - raise TypeError( - f"`spatial_func` must be a callable or `None`, you have passed: {func}" - ) - - self._spatial_func = func - self._recompute_histogram() - @property def compute_histogram(self) -> bool: return self._compute_histogram @@ -345,7 +133,7 @@ def compute_histogram(self) -> bool: @compute_histogram.setter def compute_histogram(self, compute: bool): if compute: - if self._compute_histogram is False: + if not self._compute_histogram: # compute a histogram self._recompute_histogram() self._compute_histogram = True @@ -362,79 +150,7 @@ def histogram(self) -> tuple[np.ndarray, np.ndarray] | None: """ return self._histogram - def _apply_window_function(self, indices: tuple[int, ...]) -> ArrayLike: - """applies the window functions for each dimension specified""" - # window size for each dim - winds = self._window_sizes - # window function for each dim - funcs = self._window_funcs - - if winds is None or funcs is None: - # no window funcs or window sizes, just slice data and return - # clamp to max bounds - indexer = list() - for dim, i in enumerate(indices): - i = min(self.shape[dim] - 1, i) - indexer.append(i) - - return self.data[tuple(indexer)] - - # order in which window funcs are applied - order = self._window_order - - if order is not None: - # remove any entries in `window_order` where the specified dim - # has a window function or window size specified as `None` - # example: - # window_sizes = (3, 2) - # window_funcs = (np.mean, None) - # order = (0, 1) - # `1` is removed from the order since that window_func is `None` - order = tuple( - d for d in order if winds[d] is not None and funcs[d] is not None - ) - else: - # sequential order - order = list() - for d in range(self.n_slider_dims): - if winds[d] is not None and funcs[d] is not None: - order.append(d) - - # the final indexer which will be used on the data array - indexer = list() - - for dim_index, (i, w, f) in enumerate(zip(indices, winds, funcs)): - # clamp i within the max bounds - i = min(self.shape[dim_index] - 1, i) - - if (w is not None) and (f is not None): - # specify slice window if both window size and function for this dim are not None - hw = int((w - 1) / 2) # half window - - # start index cannot be less than 0 - start = max(0, i - hw) - - # stop index cannot exceed the bounds of this dimension - stop = min(self.shape[dim_index] - 1, i + hw) - - s = slice(start, stop, 1) - else: - s = slice(i, i + 1, 1) - - indexer.append(s) - - # apply indexer to slice data with the specified windows - data_sliced = self.data[tuple(indexer)] - - # finally apply the window functions in the specified order - for dim in order: - f = funcs[dim] - - data_sliced = f(data_sliced, axis=dim, keepdims=True) - - return data_sliced - - def get(self, indices: tuple[int, ...]) -> ArrayLike | None: + def get(self, indices: dict[str, Any]) -> ArrayLike | None: """ Get the data at the given index, process data through the window functions. @@ -448,46 +164,25 @@ def get(self, indices: tuple[int, ...]) -> ArrayLike | None: Example: get((100, 5)) """ - if self.data is None: - return None - - # apply any slider index mappings - indices = tuple([m(i) for m, i in zip(self.index_mappings, indices)]) - - if self.n_slider_dims != 0: - if len(indices) != self.n_slider_dims: - raise IndexError( - f"Must specify index for every slider dim, you have specified an index: {indices}\n" - f"But there are: {self.n_slider_dims} slider dims." - ) - # get output after processing through all window funcs - # squeeze to remove all dims of size 1 - window_output = self._apply_window_function(indices).squeeze() + if len(self.slider_dims) > 0: + # there are dims in addition to the spatial dims + window_output = self._apply_window_functions(indices).squeeze() else: - # data is a static image or volume + # no slider dims, use all the data window_output = self.data + if window_output.ndim != len(self.spatial_dims): + raise ValueError + # apply spatial_func if self.spatial_func is not None: - final_output = self.spatial_func(window_output) - if final_output.ndim != (self.n_display_dims + int(self.rgb)): - raise IndexError( - f"Final output after of the `spatial_func` must match the number of display dims." - f"Output after `spatial_func` returned an array with {final_output.ndim} dims and " - f"of shape: {final_output.shape}, expected {self.n_display_dims} dims" - ) - else: - # check that output ndim after window functions matches display dims - final_output = window_output - if final_output.ndim != (self.n_display_dims + int(self.rgb)): - raise IndexError( - f"Final output after of the `window_funcs` must match the number of display dims." - f"Output after `window_funcs` returned an array with {window_output.ndim} dims and " - f"of shape: {window_output.shape}{' with rgb(a) channels' if self.rgb else ''}, " - f"expected {self.n_display_dims} dims" - ) - - return final_output + spatial_out = self._spatial_func(window_output) + if spatial_out.ndim != len(self.spatial_dims): + raise ValueError + + return spatial_out.transpose(*self.spatial_dims).values + + return window_output.transpose(*self.spatial_dims).values def _recompute_histogram(self): """ @@ -506,11 +201,11 @@ def _recompute_histogram(self): # spatial functions often operate on the spatial dims, ex: a gaussian kernel # so their results require the full spatial resolution, the histogram of a # spatially subsampled image will be very different - ignore_dims = self.display_dims + ignore_dims = [self.dims.index(dim) for dim in self.spatial_dims] else: ignore_dims = None - sub = subsample_array(self.data, ignore_dims=ignore_dims) + sub = subsample_array(self.data.values, ignore_dims=ignore_dims) sub_real = sub[~(np.isnan(sub) | np.isinf(sub))] self._histogram = np.histogram(sub_real, bins=100) @@ -520,36 +215,38 @@ class NDImage(NDGraphic): def __init__( self, global_index, - data: Any, - *args, - graphic: type[ImageGraphic, ImageVolumeGraphic] = None, - processor: type[NDImageProcessor] = NDImageProcessor, - window_funcs: tuple[WindowFuncCallable | None] | None = None, - window_sizes: tuple[int | None] | None = None, - index_mappings: tuple[Callable[[Any], int] | None] | None = None, - graphic_kwargs: dict = None, - processor_kwargs: dict = None, + data: ArrayLike | None, + dims: Sequence[Hashable], + spatial_dims: ( + tuple[str, str] | tuple[str, str, str] + ), # must be in order! [rows, cols] | [z, rows, cols] + rgb_dim: str | None = None, + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, + window_order: tuple[int, ...] = None, + spatial_func: Callable[[ArrayLike], ArrayLike] = None, + compute_histogram: bool = True, + index_mappings=None, name: str = None, ): - if processor_kwargs is None: - processor_kwargs = dict() self._global_index = global_index - self._processor = processor( + self._processor = NDImageProcessor( data, - *args, + dims=dims, + spatial_dims=spatial_dims, + rgb_dim=rgb_dim, window_funcs=window_funcs, - window_sizes=window_sizes, + window_order=window_order, + spatial_func=spatial_func, + compute_histogram=compute_histogram, index_mappings=index_mappings, - **processor_kwargs, ) self._graphic = None self._create_graphic() - - self._name = name + super().__init__(name) @property def processor(self) -> NDImageProcessor: @@ -558,9 +255,7 @@ def processor(self) -> NDImageProcessor: @property def graphic( self, - ) -> ( - ImageGraphic | ImageVolumeGraphic - ): + ) -> ImageGraphic | ImageVolumeGraphic: """LineStack or ImageGraphic for heatmaps""" return self._graphic @@ -570,7 +265,7 @@ def graphic(self, graphic_type): pass def _create_graphic(self): - match self.processor.n_display_dims: + match len(self.processor.spatial_dims): case 2: cls = ImageGraphic case 3: @@ -587,24 +282,59 @@ def _create_graphic(self): plot_area.add_graphic(new_graphic) self._graphic = new_graphic + if self._graphic._plot_area is not None: + self._reset_camera() + + def _reset_camera(self): + plot_area = self._graphic._plot_area + + # set camera to a nice position for 2D or 3D + if isinstance(self._graphic, ImageGraphic): + # set camera orthogonal to the xy plane, flip y axis + plot_area.camera.set_state( + { + "position": [0, 0, -1], + "rotation": [0, 0, 0, 1], + "scale": [1, -1, 1], + "reference_up": [0, 1, 0], + "fov": 0, + "depth_range": None, + } + ) + + plot_area.controller = "panzoom" + plot_area.axes.intersection = None + plot_area.auto_scale() + + else: + plot_area.camera.fov = 50 + plot_area.controller = "orbit" + + # make sure all 3D dimension camera scales are positive + # MIP rendering doesn't work with negative camera scales + for dim in ["x", "y", "z"]: + if getattr(plot_area.camera.local, f"scale_{dim}") < 0: + setattr(plot_area.camera.local, f"scale_{dim}", 1) + + plot_area.auto_scale() @property - def n_display_dims(self) -> Literal[2, 3]: - return self.processor.n_display_dims + def spatial_dims(self) -> tuple[str, str] | tuple[str, str, str]: + return self.processor.spatial_dims - @n_display_dims.setter - def n_display_dims(self, n: Literal[2 , 3]): - self.processor.n_display_dims = n + @spatial_dims.setter + def spatial_dims(self, dims: tuple[str, str] | tuple[str, str, str]): + self.processor.spatial_dims = dims + # shape has probably changed, recreate graphic self._create_graphic() @property - def indices(self) -> tuple: - return self._global_index.indices[-self.processor.n_slider_dims:] + def indices(self) -> dict[Hashable, Any]: + return {d: self._global_index[d] for d in self.processor.slider_dims} @indices.setter def indices(self, indices): - indices = indices[-self.processor.n_slider_dims:] data_slice = self.processor.get(indices) self.graphic.data = data_slice diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py index 296787d56..acfc84630 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py @@ -89,4 +89,6 @@ def get(self, indices: dict[str, Any]) -> np.ndarray: [self.data[c][self._dw_slice] for c in col] ) + fin + return self._apply_dw_window_func(graphic_data) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/core.py b/fastplotlib/widgets/nd_widget/_nd_positions/core.py index 35183b031..fe3068757 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/core.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/core.py @@ -252,6 +252,13 @@ def _apply_dw_window_func(self, array: np.ndarray) -> np.ndarray: return array[:, ::step] + def _apply_spatial_func(self, array: np.ndarray): + if self.spatial_func is not None: + return self.spatial_func(array) + + def _finalize_(self, array): + return self._apply_spatial_func(self._apply_dw_window_func(array)) + def get(self, indices: dict[str, Any]): """ slices through all slider dims and outputs an array that can be used to set graphic data @@ -286,7 +293,7 @@ def get(self, indices: dict[str, Any]): window_output.isel({p_dim: dw_slice}).transpose(*self.spatial_dims).values ) - return self._apply_dw_window_func(graphic_data) + return self._finalize_(graphic_data) class NDPositions(NDGraphic): diff --git a/fastplotlib/widgets/nd_widget/_ndw_subplot.py b/fastplotlib/widgets/nd_widget/_ndw_subplot.py index 677661b9b..5e625cc99 100644 --- a/fastplotlib/widgets/nd_widget/_ndw_subplot.py +++ b/fastplotlib/widgets/nd_widget/_ndw_subplot.py @@ -32,13 +32,12 @@ def add_nd_image(self, *args, **kwargs): nd = NDImage(self.ndw.indices, *args, **kwargs) self._nd_graphics.append(nd) self._subplot.add_graphic(nd.graphic) + nd._reset_camera() return nd def add_nd_scatter(self, *args, **kwargs): # TODO: better func signature here, send all kwargs to processor_kwargs - nd = NDPositions( - self.ndw.indices, *args, graphic=ScatterCollection, **kwargs - ) + nd = NDPositions(self.ndw.indices, *args, graphic=ScatterCollection, **kwargs) self._nd_graphics.append(nd) self._subplot.add_graphic(nd.graphic) @@ -71,7 +70,7 @@ def add_nd_timeseries( return nd def add_nd_lines(self, *args, **kwargs): - nd = NDPositions(*args, graphic=LineCollection, **kwargs) + nd = NDPositions(self.ndw.indices, *args, graphic=LineCollection, **kwargs) self._nd_graphics.append(nd) self._subplot.add_graphic(nd.graphic) return nd diff --git a/fastplotlib/widgets/nd_widget/base.py b/fastplotlib/widgets/nd_widget/base.py index c81053a62..4d55a3514 100644 --- a/fastplotlib/widgets/nd_widget/base.py +++ b/fastplotlib/widgets/nd_widget/base.py @@ -70,7 +70,7 @@ def __init__( window_funcs: dict[ Hashable, tuple[WindowFuncCallable | None, int | float | None] ] = None, - window_funcs_order: tuple[Hashable, ...] = None, + window_order: tuple[Hashable, ...] = None, spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, ): self._data = self._validate_data(data, dims) @@ -79,7 +79,8 @@ def __init__( self.index_mappings = index_mappings self.window_funcs = window_funcs - self.window_order = window_funcs_order + self.window_order = window_order + self.spatial_func = spatial_func @property def data(self) -> xr.DataArray: @@ -115,6 +116,7 @@ def dims(self) -> tuple[Hashable, ...]: @property def spatial_dims(self) -> tuple[Hashable, ...]: + """Spatial dims, **in order**)""" return self._spatial_dims @spatial_dims.setter @@ -225,8 +227,15 @@ def window_order(self, order: tuple[Hashable] | None): self._window_order = tuple(order) @property - def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None: - pass + def spatial_func(self) -> Callable[[xr.DataArray], xr.DataArray] | None: + return self._spatial_func + + @spatial_func.setter + def spatial_func(self, func: Callable[[xr.DataArray], xr.DataArray]) -> Callable | None: + if not callable(func) and func is not None: + raise TypeError + + self._spatial_func = func @property def index_mappings(self) -> dict[Hashable, Callable[[Any], int]]: @@ -278,8 +287,11 @@ def _get_slider_dims_indexer(self, indices) -> dict: # index for this dim in reference space index_ref = indices[dim] - # get window func and size in reference units - wf, ws = self.window_funcs[dim] + if dim not in self.window_funcs.keys(): + wf, ws = None, None + else: + # get window func and size in reference units + wf, ws = self.window_funcs[dim] # if a window function exists for this dim, and it's specified in the window order if (wf is not None) and (ws is not None) and (dim in self.window_order): From 4f005ed5fd04097c3dc17aa951ad722b07759f4d Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 3 Mar 2026 09:19:46 -0500 Subject: [PATCH 053/101] NDImage working mostly, behavior viz is back --- fastplotlib/widgets/nd_widget/_nd_image.py | 3 ++- .../widgets/nd_widget/_nd_positions/_pandas.py | 2 -- .../widgets/nd_widget/_nd_positions/core.py | 14 +++++++------- fastplotlib/widgets/nd_widget/_ui.py | 7 +++++-- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py index 16a4686ec..152f59379 100644 --- a/fastplotlib/widgets/nd_widget/_nd_image.py +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -265,7 +265,7 @@ def graphic(self, graphic_type): pass def _create_graphic(self): - match len(self.processor.spatial_dims): + match len(self.processor.spatial_dims) - int(bool(self.processor.rgb_dim)): case 2: cls = ImageGraphic case 3: @@ -282,6 +282,7 @@ def _create_graphic(self): plot_area.add_graphic(new_graphic) self._graphic = new_graphic + if self._graphic._plot_area is not None: self._reset_camera() diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py index acfc84630..296787d56 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py @@ -89,6 +89,4 @@ def get(self, indices: dict[str, Any]) -> np.ndarray: [self.data[c][self._dw_slice] for c in col] ) - fin - return self._apply_dw_window_func(graphic_data) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/core.py b/fastplotlib/widgets/nd_widget/_nd_positions/core.py index fe3068757..fd2914079 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/core.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/core.py @@ -252,10 +252,12 @@ def _apply_dw_window_func(self, array: np.ndarray) -> np.ndarray: return array[:, ::step] - def _apply_spatial_func(self, array: np.ndarray): + def _apply_spatial_func(self, array: xr.DataArray) -> xr.DataArray: if self.spatial_func is not None: return self.spatial_func(array) + return array + def _finalize_(self, array): return self._apply_spatial_func(self._apply_dw_window_func(array)) @@ -266,11 +268,9 @@ def get(self, indices: dict[str, Any]): Note that we do not use __getitem__ here since the index is a tuple specifying a single integer index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. """ - # # map slider dim indices to array indices - # array_indices = tuple([m(i) for m, i in zip(self.index_mappings, indices)]) - if len(self.slider_dims) > 0: - # there are dims in addition to the spatial dims + if len(self.slider_dims) > 1: + # there are slider dims in addition to the datapoints_dim window_output = self._apply_window_functions(indices).squeeze() else: # no slider dims, use all the data @@ -290,10 +290,10 @@ def get(self, indices: dict[str, Any]): # slice the datapoints to be displayed in the graphic using the display window slice # transpose to match spatial dims order, get numpy array, this is a view graphic_data = ( - window_output.isel({p_dim: dw_slice}).transpose(*self.spatial_dims).values + window_output.isel({p_dim: dw_slice}).transpose(*self.spatial_dims) ) - return self._finalize_(graphic_data) + return self._finalize_(graphic_data).values class NDPositions(NDGraphic): diff --git a/fastplotlib/widgets/nd_widget/_ui.py b/fastplotlib/widgets/nd_widget/_ui.py index 17f908384..147202e69 100644 --- a/fastplotlib/widgets/nd_widget/_ui.py +++ b/fastplotlib/widgets/nd_widget/_ui.py @@ -122,12 +122,15 @@ def _draw_nd_pos_ui(self, subplot: Subplot, nd_graphic: NDPositions): changed, val = imgui.checkbox( "use display window", nd_graphic.display_window is not None ) + + p_dim = nd_graphic.processor.spatial_dims[1] + if changed: if not val: nd_graphic.display_window = None else: # pick a value 10% of the reference range - nd_graphic.display_window = self._ndwidget.ref_ranges[0].range * 0.1 + nd_graphic.display_window = self._ndwidget.ref_ranges[p_dim].range * 0.1 if nd_graphic.display_window is not None: if isinstance(nd_graphic.display_window, (int, np.integer)): @@ -143,7 +146,7 @@ def _draw_nd_pos_ui(self, subplot: Subplot, nd_graphic: NDPositions): "display window", v=nd_graphic.display_window, v_min=type_(0), - v_max=type_(self._ndwidget.ref_ranges[0].stop * 0.25), + v_max=type_(self._ndwidget.ref_ranges[p_dim].stop * 0.25), ) if changed: From 597c48b1c0060179d4749224c4099e9414ec2cdc Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 3 Mar 2026 09:32:00 -0500 Subject: [PATCH 054/101] better flipping logic --- fastplotlib/layouts/_figure.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fastplotlib/layouts/_figure.py b/fastplotlib/layouts/_figure.py index 00b915b1f..2b22cbd23 100644 --- a/fastplotlib/layouts/_figure.py +++ b/fastplotlib/layouts/_figure.py @@ -609,7 +609,9 @@ def show( for subplot in self._subplots.ravel(): for g in subplot.graphics: if isinstance(g, ImageGraphic): - subplot.camera.local.scale_y *= -1 + if subplot.camera.local.scale_y == 1: + # if it's 1 it's likely not been touched manually before show was called + subplot.camera.local.scale_y = -1 break if autoscale: From 68571b75e53f9533bf96a8bc12038d8a64219ed7 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 3 Mar 2026 10:13:35 -0500 Subject: [PATCH 055/101] update examples --- examples/ndwidget/README.rst | 2 ++ examples/ndwidget/ndimage.py | 21 +++++++++---- examples/ndwidget/timeseries.py | 53 ++++++++++++++++++++++++--------- 3 files changed, 57 insertions(+), 19 deletions(-) create mode 100644 examples/ndwidget/README.rst diff --git a/examples/ndwidget/README.rst b/examples/ndwidget/README.rst new file mode 100644 index 000000000..28ed4d752 --- /dev/null +++ b/examples/ndwidget/README.rst @@ -0,0 +1,2 @@ +NDWidget Examples +================= diff --git a/examples/ndwidget/ndimage.py b/examples/ndwidget/ndimage.py index 0e44cd1b5..7400f12e3 100644 --- a/examples/ndwidget/ndimage.py +++ b/examples/ndwidget/ndimage.py @@ -12,14 +12,25 @@ import fastplotlib as fpl -a = np.random.rand(1000, 30, 64, 64) +data = np.random.rand(1000, 30, 64, 64) +# must define a reference range for each dim +ref = { + "time": ("time", "s", 0, 1000, 1), + "depth": ("depth", "um", 0, 30, 1), +} -ndw = fpl.NDWidget(ref_ranges=[(0, 1000, 1, "t"), (0, 30, 1, "um")], size=(800, 800)) + +ndw = fpl.NDWidget(ref_ranges=ref, size=(700, 560)) ndw.show() -ndi = ndw[0, 0].add_nd_image(a, index_mappings=(int, int)) -# TODO: need to think about how to "auto ignore" reference range for a dim when switching between 2 & 3 dim images -ndi.n_display_dims = 3 +ndi = ndw[0, 0].add_nd_image( + data, + ("time", "depth", "m", "n"), # specify all dim names + ("m", "n"), # specify spatial dims IN ORDER, rest are auto slider dims +) + +# change spatial dims on the fly +ndi.spatial_dims = ("depth", "m", "n") fpl.loop.run() diff --git a/examples/ndwidget/timeseries.py b/examples/ndwidget/timeseries.py index 1dac31326..fefc385df 100644 --- a/examples/ndwidget/timeseries.py +++ b/examples/ndwidget/timeseries.py @@ -12,28 +12,53 @@ import fastplotlib as fpl # generate some toy timeseries data -n_datapoints = 50_000 # number of datapoints per line +n_datapoints = 100_000 # number of datapoints per line +n_freqs = 20 # number of frequencies +n_ampls = 15 # number of amplitudes +n_lines = 8 + xs = np.linspace(0, 1000 * np.pi, n_datapoints) -lines = list() -for i in range(1, 11): - l = np.column_stack( - [ - xs, - np.sin(xs * i) - ] - ) - lines.append(l) +data = np.zeros(shape=(n_freqs, n_ampls, n_lines, n_datapoints, 2), dtype=np.float32) + +for freq in range(data.shape[0]): + for ampl in range(data.shape[1]): + ys = np.sin(xs * (freq + 1)) * (ampl + 1) + np.random.normal(0, 0.1, size=n_datapoints) + line = np.column_stack([xs, ys]) + data[freq, ampl] = np.stack([line] * n_lines) -# timeseries data of shape [n_lines, n_datapoint, 2] -data = np.stack(lines) # must define a reference range, this would often be your time dimension and corresponds to your x-dimension -ref = [(0, xs[-1], 0.1, "angle")] +ref = { + "freq": ("freq", "Hz", 1, n_freqs + 1, 1), + "ampl": ("ampl", "arbitrary", 1, n_ampls + 1, 1), + "angle": ( + "angle", + "rad", + 0, + xs[-1], + 0.1, + ), +} ndw = fpl.NDWidget(ref_ranges=ref, size=(700, 560)) -ndw[0, 0].add_nd_timeseries(data, index_mappings=(lambda xval: xs.searchsorted(xval),), x_range_mode="view-range") +nd_lines = ndw[0, 0].add_nd_timeseries( + data, + ("freq", "ampl", "n_lines", "angle", "d"), + ("n_lines", "angle", "d"), + index_mappings={ + "angle": xs, + "ampl": lambda x: int(x + 1), + "freq": lambda x: int(x + 1), + }, + x_range_mode="view-range", +) + +nd_lines.graphic.cmap = "tab10" + +subplot = ndw.figure[0, 0] +subplot.controller.add_camera(subplot.camera, include_state={"x", "width"}) ndw.show(maintain_aspect=False) fpl.loop.run() From cad17c844fe524a140c863fca3dc3a1279678f41 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 4 Mar 2026 00:08:13 -0500 Subject: [PATCH 056/101] more progress --- fastplotlib/widgets/nd_widget/__init__.py | 4 +- .../widgets/nd_widget/{base.py => _base.py} | 40 ------- fastplotlib/widgets/nd_widget/_index.py | 67 ++++++----- fastplotlib/widgets/nd_widget/_nd_image.py | 5 +- .../nd_widget/_nd_positions/__init__.py | 2 +- .../{core.py => _nd_positions.py} | 112 +++++++++--------- .../nd_widget/_nd_positions/_pandas.py | 2 +- fastplotlib/widgets/nd_widget/_ndw_subplot.py | 2 +- .../nd_widget/{ndwidget.py => _ndwidget.py} | 4 +- fastplotlib/widgets/nd_widget/_ui.py | 9 +- 10 files changed, 109 insertions(+), 138 deletions(-) rename fastplotlib/widgets/nd_widget/{base.py => _base.py} (94%) rename fastplotlib/widgets/nd_widget/_nd_positions/{core.py => _nd_positions.py} (85%) rename fastplotlib/widgets/nd_widget/{ndwidget.py => _ndwidget.py} (89%) diff --git a/fastplotlib/widgets/nd_widget/__init__.py b/fastplotlib/widgets/nd_widget/__init__.py index 7855327d9..0617a729d 100644 --- a/fastplotlib/widgets/nd_widget/__init__.py +++ b/fastplotlib/widgets/nd_widget/__init__.py @@ -1,10 +1,10 @@ from ...layouts import IMGUI if IMGUI: - from .base import NDProcessor + from ._base import NDProcessor from ._nd_positions import NDPositions, NDPositionsProcessor, ndp_extras from ._nd_image import NDImageProcessor, NDImage - from .ndwidget import NDWidget + from ._ndwidget import NDWidget else: class NDWidget: def __init__(self, *args, **kwargs): diff --git a/fastplotlib/widgets/nd_widget/base.py b/fastplotlib/widgets/nd_widget/_base.py similarity index 94% rename from fastplotlib/widgets/nd_widget/base.py rename to fastplotlib/widgets/nd_widget/_base.py index 4d55a3514..ea4844fdb 100644 --- a/fastplotlib/widgets/nd_widget/base.py +++ b/fastplotlib/widgets/nd_widget/_base.py @@ -20,46 +20,6 @@ def identity(index: int) -> int: return round(index) -class BaseNDProcessor: - @property - def data(self) -> Any: - pass - - @property - def shape(self) -> dict[Hashable, int]: - pass - - @property - def ndim(self): - pass - - @property - def spatial_dims(self) -> tuple[Hashable, ...]: - pass - - @property - def slider_dims(self): - pass - - @property - def window_funcs( - self, - ) -> dict[Hashable, tuple[WindowFuncCallable | None, int | float | None]]: - # {dim: (func, size)} - pass - - @property - def window_funcs_order(self) -> tuple[Hashable]: - pass - - @property - def index_mappings(self) -> dict[Hashable, Callable[[Any], int] | ArrayLike]: - pass - - def get(self, **indices): - raise NotImplementedError - - class NDProcessor: def __init__( self, diff --git a/fastplotlib/widgets/nd_widget/_index.py b/fastplotlib/widgets/nd_widget/_index.py index d7f60ba7e..31d026beb 100644 --- a/fastplotlib/widgets/nd_widget/_index.py +++ b/fastplotlib/widgets/nd_widget/_index.py @@ -1,17 +1,21 @@ from dataclasses import dataclass from typing import Sequence, Any, Callable -from .base import NDGraphic +from ._base import NDGraphic @dataclass -class ReferenceRangeContinuous: - name: str - unit: str +class RangeContinuous: start: int | float stop: int | float step: int | float + def __post_init__(self): + if self.start >= self.stop: + raise IndexError( + f"start must be less than stop, {self.start} !< {self.stop}" + ) + def __getitem__(self, index: int): """return the value at the index w.r.t. the step size""" # if index is negative, turn to positive index @@ -32,9 +36,7 @@ def range(self) -> int | float: @dataclass -class ReferenceRangeDiscrete: - name: str - unit: str +class RangeDiscrete: options: Sequence[Any] def __getitem__(self, index: int): @@ -48,38 +50,54 @@ def __len__(self): class GlobalIndex: - def __init__(self, ref_ranges: dict[str, tuple], get_ndgraphics: Callable[[], tuple[NDGraphic]]): + def __init__( + self, + ref_ranges: dict[str, tuple], + get_ndgraphics: Callable[[], tuple[NDGraphic]], + ): self._ref_ranges = dict() - for r in ref_ranges.values(): - if len(r) == 5: - # assume name, unit, start, stop, step - rr = ReferenceRangeContinuous(*r) - elif len(r) == 3: - rr = ReferenceRangeDiscrete(*r) + for name, r in ref_ranges.items(): + if len(r) == 3: + # assume start, stop, step + self._ref_ranges[name] = RangeContinuous(*r) + + elif len(r) == 1: + # assume just options + self._ref_ranges[name] = RangeDiscrete(*r) + else: raise ValueError - self._ref_ranges[rr.name] = rr - self._get_ndgraphics = get_ndgraphics # starting index for all dims - self._indices: dict[str, int | float | Any] = {rr.name: rr.start for rr in self._ref_ranges.values()} + self._indices: dict[str, int | float | Any] = { + name: rr.start for name, rr in self._ref_ranges.items() + } def set(self, indices: dict[str, Any]): - for k in self._indices: - self._indices[k] = indices[k] + for dim, value in indices.items(): + self._indices[dim] = self._clamp(value) self._render_indices() + def _clamp(self, dim, value): + if isinstance(self.ref_ranges[dim], RangeContinuous): + return max( + min(value, self.ref_ranges[dim].stop - self.ref_ranges[dim].step), + self.ref_ranges[dim].start, + ) + + return value + def _render_indices(self): for g in self._get_ndgraphics(): # only provide slider indices to the graphic g.indices = {d: self._indices[d] for d in g.processor.slider_dims} @property - def ref_ranges(self) -> dict[str, ReferenceRangeContinuous | ReferenceRangeDiscrete]: + def ref_ranges(self) -> dict[str, RangeContinuous | RangeDiscrete]: return self._ref_ranges def __getitem__(self, dim): @@ -87,18 +105,13 @@ def __getitem__(self, dim): def __setitem__(self, dim, value): # set index for given dim and render - - # clamp within reference range - if isinstance(self.ref_ranges[dim], ReferenceRangeContinuous): - value = max(min(value, self.ref_ranges[dim].stop - self.ref_ranges[dim].step), self.ref_ranges[dim].start) - - self._indices[dim] = value + self._indices[dim] = self._clamp(dim, value) self._render_indices() def pop_dim(self): pass - def push_dim(self, ref_range: ReferenceRangeContinuous): + def push_dim(self, ref_range: RangeContinuous): # TODO: implement pushing and popping dims pass diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py index 152f59379..f6a41cd4f 100644 --- a/fastplotlib/widgets/nd_widget/_nd_image.py +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -1,7 +1,6 @@ from collections.abc import Hashable, Sequence import inspect -from typing import Literal, Callable, Type, Any -from warnings import warn +from typing import Callable, Any import numpy as np from numpy.typing import ArrayLike @@ -9,7 +8,7 @@ from ...utils import subsample_array, ArrayProtocol, ARRAY_LIKE_ATTRS from ...graphics import ImageGraphic, ImageVolumeGraphic -from .base import NDProcessor, NDGraphic, WindowFuncCallable +from ._base import NDProcessor, NDGraphic, WindowFuncCallable class NDImageProcessor(NDProcessor): diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/__init__.py b/fastplotlib/widgets/nd_widget/_nd_positions/__init__.py index 03bb0e8f7..60703f8c2 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/__init__.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/__init__.py @@ -1,6 +1,6 @@ import importlib -from .core import NDPositions, NDPositionsProcessor +from ._nd_positions import NDPositions, NDPositionsProcessor class Extras: pass diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/core.py b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py similarity index 85% rename from fastplotlib/widgets/nd_widget/_nd_positions/core.py rename to fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py index fd2914079..08b5406ba 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/core.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py @@ -1,15 +1,12 @@ -from collections.abc import Callable, Hashable, Sequence, Iterable +from collections.abc import Callable, Hashable, Sequence from functools import partial from typing import Literal, Any, Type -from warnings import warn import numpy as np from numpy.lib.stride_tricks import sliding_window_view from numpy.typing import ArrayLike import xarray as xr -from ....utils import subsample_array, ArrayProtocol - from ....graphics import ( Graphic, ImageGraphic, @@ -21,7 +18,7 @@ ) from ....graphics.utils import pause_events from ....graphics.selectors import LinearSelector -from ..base import ( +from .._base import ( NDProcessor, NDGraphic, WindowFuncCallable, @@ -38,10 +35,10 @@ class NDPositionsProcessor(NDProcessor): def __init__( self, data: Any, - dims: Sequence[str], + dims: Sequence[Hashable], # TODO: allow stack_dim to be None and auto-add new dim of size 1 in get logic spatial_dims: tuple[ - str | None, str, str + Hashable | None, Hashable, Hashable ], # [stack_dim, n_datapoints, spatial_dim], IN ORDER!! index_mappings: dict[str, Callable[[Any], int] | ArrayLike] = None, display_window: int | float | None = 100, # window for n_datapoints dim only @@ -192,63 +189,64 @@ def _apply_dw_window_func(self, array: np.ndarray) -> np.ndarray: p_dim = self.spatial_dims[1] # display window in array index space - dw = self.index_mappings[p_dim](self.display_window) + if self.display_window is not None: + dw = self.index_mappings[p_dim](self.display_window) - # step size based on max number of datapoints to render - step = max(1, dw // self.max_display_datapoints) + # step size based on max number of datapoints to render + step = max(1, dw // self.max_display_datapoints) - # apply window function on the `p` n_datapoints dim - if ( - self.datapoints_window_func is not None - # if there are too many points to efficiently compute the window func, skip - # applying a window func also requires making a copy so that's a further performance hit - and (dw < self.max_display_datapoints * 2) - ): - # get windows + # apply window function on the `p` n_datapoints dim + if ( + self.datapoints_window_func is not None + # if there are too many points to efficiently compute the window func, skip + # applying a window func also requires making a copy so that's a further performance hit + and (dw < self.max_display_datapoints * 2) + ): + # get windows - # graphic_data will be of shape: [n, p, 2 | 3] - # where: - # n - number of lines, scatters, heatmap rows - # p - number of datapoints/samples + # graphic_data will be of shape: [n, p, 2 | 3] + # where: + # n - number of lines, scatters, heatmap rows + # p - number of datapoints/samples - # ws is in ref units - wf, apply_dims, ws = self.datapoints_window_func + # ws is in ref units + wf, apply_dims, ws = self.datapoints_window_func - # map ws in ref units to array index - # min window size is 3 - ws = max(self._ref_index_to_array_index(p_dim, ws), 3) + # map ws in ref units to array index + # min window size is 3 + ws = max(self._ref_index_to_array_index(p_dim, ws), 3) - if ws % 2 == 0: - # odd size windows are easier to handle - ws += 1 + if ws % 2 == 0: + # odd size windows are easier to handle + ws += 1 - hw = ws // 2 - start, stop = hw, array.shape[1] - hw + hw = ws // 2 + start, stop = hw, array.shape[1] - hw - # apply user's window func - # result will be of shape [n, p, 2 | 3] - if apply_dims == "all": - # windows will be of shape [n, p, 1 | 2 | 3, ws] - windows = sliding_window_view(array, ws, axis=-2) - return wf(windows, axis=-1)[:, ::step] + # apply user's window func + # result will be of shape [n, p, 2 | 3] + if apply_dims == "all": + # windows will be of shape [n, p, 1 | 2 | 3, ws] + windows = sliding_window_view(array, ws, axis=-2) + return wf(windows, axis=-1)[:, ::step] - # map user dims str to tuple of numerical dims - dims = tuple(map({"x": 0, "y": 1, "z": 2}.get, apply_dims)) + # map user dims str to tuple of numerical dims + dims = tuple(map({"x": 0, "y": 1, "z": 2}.get, apply_dims)) - # windows will be of shape [n, (p - ws + 1), 1 | 2 | 3, ws] - windows = sliding_window_view( - array[..., dims], ws, axis=-2 - ).squeeze() + # windows will be of shape [n, (p - ws + 1), 1 | 2 | 3, ws] + windows = sliding_window_view(array[..., dims], ws, axis=-2).squeeze() - # make a copy because we need to modify it - array = array[:, start:stop].copy() + # make a copy because we need to modify it + array = array[:, start:stop].copy() - # this reshape is required to reshape wf outputs of shape [n, p] -> [n, p, 1] only when necessary - array[..., dims] = wf(windows, axis=-1).reshape( - *array.shape[:-1], len(dims) - ) + # this reshape is required to reshape wf outputs of shape [n, p] -> [n, p, 1] only when necessary + array[..., dims] = wf(windows, axis=-1).reshape( + *array.shape[:-1], len(dims) + ) + + return array[:, ::step] - return array[:, ::step] + step = max(1, array.shape[1] // self.max_display_datapoints) return array[:, ::step] @@ -289,8 +287,8 @@ def get(self, indices: dict[str, Any]): # slice the datapoints to be displayed in the graphic using the display window slice # transpose to match spatial dims order, get numpy array, this is a view - graphic_data = ( - window_output.isel({p_dim: dw_slice}).transpose(*self.spatial_dims) + graphic_data = window_output.isel({p_dim: dw_slice}).transpose( + *self.spatial_dims ) return self._finalize_(graphic_data).values @@ -431,7 +429,9 @@ def indices(self, indices): with pause_events(self._linear_selector): self._linear_selector.limits = xr # linear selector acts on `p` dim - self._linear_selector.selection = indices[self.processor.spatial_dims[1]] + self._linear_selector.selection = indices[ + self.processor.spatial_dims[1] + ] def _linear_selector_handler(self, ev): with block_indices(self): @@ -554,7 +554,9 @@ def _update_from_view_range(self): new_width = abs(xr[1] - xr[0]) new_index = (xr[0] + xr[1]) / 2 - if (new_index == self._global_index[self.processor.spatial_dims[1]]) and (last_width == new_width): + if (new_index == self._global_index[self.processor.spatial_dims[1]]) and ( + last_width == new_width + ): return self.processor.display_window = new_width diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py index 296787d56..26acfd73d 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py @@ -3,7 +3,7 @@ import numpy as np import pandas as pd -from .core import NDPositionsProcessor +from ._nd_positions import NDPositionsProcessor class NDPP_Pandas(NDPositionsProcessor): diff --git a/fastplotlib/widgets/nd_widget/_ndw_subplot.py b/fastplotlib/widgets/nd_widget/_ndw_subplot.py index 5e625cc99..ef42e65bb 100644 --- a/fastplotlib/widgets/nd_widget/_ndw_subplot.py +++ b/fastplotlib/widgets/nd_widget/_ndw_subplot.py @@ -3,7 +3,7 @@ from ... import ScatterCollection, LineCollection, LineStack, ImageGraphic from ...layouts import Subplot from . import NDImage, NDPositions -from .base import NDGraphic +from ._base import NDGraphic class NDWSubplot: diff --git a/fastplotlib/widgets/nd_widget/ndwidget.py b/fastplotlib/widgets/nd_widget/_ndwidget.py similarity index 89% rename from fastplotlib/widgets/nd_widget/ndwidget.py rename to fastplotlib/widgets/nd_widget/_ndwidget.py index 534c1a922..20f09ba55 100644 --- a/fastplotlib/widgets/nd_widget/ndwidget.py +++ b/fastplotlib/widgets/nd_widget/_ndwidget.py @@ -1,6 +1,6 @@ from typing import Any -from ._index import ReferenceRangeContinuous, ReferenceRangeDiscrete, GlobalIndex +from ._index import RangeContinuous, RangeDiscrete, GlobalIndex from ._ndw_subplot import NDWSubplot from ._ui import NDWidgetUI from ...layouts import ImguiFigure, Subplot @@ -34,7 +34,7 @@ def indices(self, new_indices: dict[str, int | float | Any]): self._indices.set(new_indices) @property - def ref_ranges(self) -> dict[str, ReferenceRangeContinuous | ReferenceRangeDiscrete]: + def ref_ranges(self) -> dict[str, RangeContinuous | RangeDiscrete]: return self._indices.ref_ranges def __getitem__(self, key: str | tuple[int, int] | Subplot): diff --git a/fastplotlib/widgets/nd_widget/_ui.py b/fastplotlib/widgets/nd_widget/_ui.py index 147202e69..3223fe595 100644 --- a/fastplotlib/widgets/nd_widget/_ui.py +++ b/fastplotlib/widgets/nd_widget/_ui.py @@ -11,8 +11,8 @@ from ...layouts import Subplot from ...ui import EdgeWindow from . import NDPositions -from ._index import ReferenceRangeContinuous -from .base import NDGraphic +from ._index import RangeContinuous +from ._base import NDGraphic position_graphics = [ScatterCollection, LineCollection, LineStack, ImageGraphic] image_graphics = [ImageGraphic, ImageVolumeGraphic] @@ -56,9 +56,6 @@ def __init__(self, figure, size, ndwidget): # # self.pause = False - self._selected_subplot = self._ndwidget.figure[0, 0].name - self._selected_nd_graphic = 0 - self._max_display_windows: dict[NDGraphic, float | int] = dict() def update(self): @@ -68,7 +65,7 @@ def update(self): for dim, current_index in self._ndwidget.indices: refr = self._ndwidget.ref_ranges[dim] - if isinstance(refr, ReferenceRangeContinuous): + if isinstance(refr, RangeContinuous): changed, new_index = imgui.slider_float( v=current_index, v_min=refr.start, From 4c902ba0234f7a09ba7674e248efddc377423328 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 4 Mar 2026 00:08:36 -0500 Subject: [PATCH 057/101] update example --- examples/ndwidget/ndimage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ndwidget/ndimage.py b/examples/ndwidget/ndimage.py index 7400f12e3..4212f46b6 100644 --- a/examples/ndwidget/ndimage.py +++ b/examples/ndwidget/ndimage.py @@ -16,8 +16,8 @@ # must define a reference range for each dim ref = { - "time": ("time", "s", 0, 1000, 1), - "depth": ("depth", "um", 0, 30, 1), + "time": (0, 1000, 1), + "depth": (0, 30, 1), } From 1248e8e4acc2cf211619981aad0f9f269693d608 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 4 Mar 2026 00:58:44 -0500 Subject: [PATCH 058/101] histogram working for images --- fastplotlib/widgets/nd_widget/_base.py | 49 +++++++++++-- fastplotlib/widgets/nd_widget/_nd_image.py | 69 ++++++++++++++++++- fastplotlib/widgets/nd_widget/_ndw_subplot.py | 12 ++-- 3 files changed, 120 insertions(+), 10 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_base.py b/fastplotlib/widgets/nd_widget/_base.py index ea4844fdb..421b43360 100644 --- a/fastplotlib/widgets/nd_widget/_base.py +++ b/fastplotlib/widgets/nd_widget/_base.py @@ -191,7 +191,9 @@ def spatial_func(self) -> Callable[[xr.DataArray], xr.DataArray] | None: return self._spatial_func @spatial_func.setter - def spatial_func(self, func: Callable[[xr.DataArray], xr.DataArray]) -> Callable | None: + def spatial_func( + self, func: Callable[[xr.DataArray], xr.DataArray] + ) -> Callable | None: if not callable(func) and func is not None: raise TypeError @@ -202,7 +204,9 @@ def index_mappings(self) -> dict[Hashable, Callable[[Any], int]]: return self._index_mappings @index_mappings.setter - def index_mappings(self, maps: dict[Hashable, Callable[[Any], int] | ArrayLike | None] | None): + def index_mappings( + self, maps: dict[Hashable, Callable[[Any], int] | ArrayLike | None] | None + ): if maps is None: self._index_mappings = {d: identity for d in self.dims} return @@ -353,13 +357,50 @@ def graphic(self) -> Graphic: raise NotImplementedError @property - def indices(self) -> tuple[Any]: + def indices(self) -> dict[Hashable, Any]: raise NotImplementedError @indices.setter - def indices(self, new: tuple): + def indices(self, new: dict[Hashable, Any]): raise NotImplementedError + # aliases for easier access to processor properties + @property + def window_funcs( + self, + ) -> dict[Hashable, tuple[WindowFuncCallable | None, int | float | None]]: + """get or set window functions, see docstring for details""" + return self.processor.window_funcs + + @window_funcs.setter + def window_funcs( + self, + window_funcs: ( + dict[Hashable, tuple[WindowFuncCallable | None, int | float | None] | None] + | None + ), + ): + self.processor.window_funcs = window_funcs + + @property + def window_order(self) -> tuple[Hashable, ...]: + """get or set dimension order in which window functions are applied""" + return self.processor.window_order + + @window_order.setter + def window_order(self, order: tuple[Hashable] | None): + self.processor.window_order = order + + @property + def spatial_func(self) -> Callable[[xr.DataArray], xr.DataArray] | None: + return self.processor.spatial_func + + @spatial_func.setter + def spatial_func( + self, func: Callable[[xr.DataArray], xr.DataArray] + ) -> Callable | None: + self.processor.spatial_func = func + @contextmanager def block_indices(ndgraphic: NDGraphic): diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py index f6a41cd4f..12a2b791e 100644 --- a/fastplotlib/widgets/nd_widget/_nd_image.py +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -8,6 +8,7 @@ from ...utils import subsample_array, ArrayProtocol, ARRAY_LIKE_ATTRS from ...graphics import ImageGraphic, ImageVolumeGraphic +from ...tools import HistogramLUTTool from ._base import NDProcessor, NDGraphic, WindowFuncCallable @@ -204,7 +205,9 @@ def _recompute_histogram(self): else: ignore_dims = None - sub = subsample_array(self.data.values, ignore_dims=ignore_dims) + # TODO: account for window funcs + + sub = subsample_array(self.data, ignore_dims=ignore_dims) sub_real = sub[~(np.isnan(sub) | np.isinf(sub))] self._histogram = np.histogram(sub_real, bins=100) @@ -242,7 +245,8 @@ def __init__( index_mappings=index_mappings, ) - self._graphic = None + self._graphic: ImageGraphic | None = None + self._histogram_widget: HistogramLUTTool | None = None self._create_graphic() super().__init__(name) @@ -276,15 +280,55 @@ def _create_graphic(self): new_graphic = cls(data_slice) if old_graphic is not None: + # carry over some attributes from old graphic + attrs = dict.fromkeys(["cmap", "interpolation", "cmap_interpolation"]) + for k in attrs: + attrs[k] = getattr(old_graphic, k) + plot_area = old_graphic._plot_area plot_area.delete_graphic(old_graphic) plot_area.add_graphic(new_graphic) + # set cmap and interpolation + for attr, val in attrs.keys(): + setattr(new_graphic, attr, val) + self._graphic = new_graphic if self._graphic._plot_area is not None: self._reset_camera() + self._reset_histogram() + + def _reset_histogram(self): + # reset histogram + if self._graphic._plot_area is None: + return + + if not self.processor.compute_histogram: + # hide right dock if histogram not desired + self._graphic._plot_area.docks["right"].size = 0 + return + + if self.processor.histogram: + if self._histogram_widget: + # histogram widget exists, update it + self._histogram_widget.histogram = self.processor.histogram + self._histogram_widget.images = self.graphic + if self.graphic._plot_area.docks["right"].size < 1: + self.graphic._plot_area.docks["right"].size = 80 + else: + # make hist tool + self._histogram_widget = HistogramLUTTool( + histogram=self.processor.histogram, + images=self.graphic, + name=f"hist-{hex(id(self.graphic))}", + ) + self.graphic._plot_area.docks["right"].add_graphic(self._histogram_widget) + self.graphic._plot_area.docks["right"].size = 80 + + self.graphic.reset_vmin_vmax() + def _reset_camera(self): plot_area = self._graphic._plot_area @@ -339,6 +383,27 @@ def indices(self, indices): self.graphic.data = data_slice + @property + def compute_histogram(self) -> bool: + return self.processor.compute_histogram + + @compute_histogram.setter + def compute_histogram(self, v: bool): + self.processor.compute_histogram = v + self._reset_histogram() + + @property + def spatial_func(self) -> Callable[[xr.DataArray], xr.DataArray] | None: + return self.processor.spatial_func + + @spatial_func.setter + def spatial_func( + self, func: Callable[[xr.DataArray], xr.DataArray] + ) -> Callable | None: + self.processor.spatial_func = func + self.processor._recompute_histogram() + self._reset_histogram() + def _tooltip_handler(self, graphic, pick_info): # get graphic within the collection n_index = np.argwhere(self.graphic.graphics == graphic).item() diff --git a/fastplotlib/widgets/nd_widget/_ndw_subplot.py b/fastplotlib/widgets/nd_widget/_ndw_subplot.py index ef42e65bb..5a0b00da2 100644 --- a/fastplotlib/widgets/nd_widget/_ndw_subplot.py +++ b/fastplotlib/widgets/nd_widget/_ndw_subplot.py @@ -33,6 +33,10 @@ def add_nd_image(self, *args, **kwargs): self._nd_graphics.append(nd) self._subplot.add_graphic(nd.graphic) nd._reset_camera() + + # graphic._plot_area must exist before this is called + nd._reset_histogram() + return nd def add_nd_scatter(self, *args, **kwargs): @@ -54,19 +58,19 @@ def add_nd_timeseries( self.ndw.indices, *args, graphic=graphic, - # x_range_mode=x_range_mode, linear_selector=True, **kwargs, ) self._nd_graphics.append(nd) self._subplot.add_graphic(nd.graphic) self._subplot.add_graphic(nd._linear_selector) - # nd._linear_selector.add_event_handler( - # partial(self._set_indices_from_selector, nd), "selection" - # ) + # need plot_area to exist before these this can be called nd.x_range_mode = x_range_mode + # probably don't want to maintain aspect + self._subplot.camera.maintain_aspect = False + return nd def add_nd_lines(self, *args, **kwargs): From 2ca3fbfb7063a8fca6eeb819430dde1918b09c94 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 4 Mar 2026 01:13:59 -0500 Subject: [PATCH 059/101] NDProcessor property aliases --- fastplotlib/widgets/nd_widget/_base.py | 43 +++++++++++++++++ fastplotlib/widgets/nd_widget/_index.py | 46 +++++++++++++++++++ .../nd_widget/_nd_positions/_nd_positions.py | 10 ++++ fastplotlib/widgets/nd_widget/_ndwidget.py | 3 ++ 4 files changed, 102 insertions(+) diff --git a/fastplotlib/widgets/nd_widget/_base.py b/fastplotlib/widgets/nd_widget/_base.py index 421b43360..6a997b206 100644 --- a/fastplotlib/widgets/nd_widget/_base.py +++ b/fastplotlib/widgets/nd_widget/_base.py @@ -365,6 +365,43 @@ def indices(self, new: dict[Hashable, Any]): raise NotImplementedError # aliases for easier access to processor properties + @property + def data(self) -> Any: + return self.processor.data + + @data.setter + def data(self, data: Any): + self.processor.data = data + # force a re-render + self.indices = self.indices + + @property + def shape(self) -> dict[Hashable, int]: + """interpreted shape of the data""" + self.processor.shape + + @property + def ndim(self) -> int: + """number of dims""" + return self.processor.ndim + + @property + def dims(self) -> tuple[Hashable, ...]: + """dim names""" + return self.processor.dims + + @property + def index_mappings(self) -> dict[Hashable, Callable[[Any], int]]: + return self.processor.index_mappings + + @index_mappings.setter + def index_mappings( + self, maps: dict[Hashable, Callable[[Any], int] | ArrayLike | None] | None + ): + self.processor.index_mappings = maps + # force a re-render + self.indices = self.indices + @property def window_funcs( self, @@ -381,6 +418,8 @@ def window_funcs( ), ): self.processor.window_funcs = window_funcs + # force a re-render + self.indices = self.indices @property def window_order(self) -> tuple[Hashable, ...]: @@ -390,6 +429,8 @@ def window_order(self) -> tuple[Hashable, ...]: @window_order.setter def window_order(self, order: tuple[Hashable] | None): self.processor.window_order = order + # force a re-render + self.indices = self.indices @property def spatial_func(self) -> Callable[[xr.DataArray], xr.DataArray] | None: @@ -400,6 +441,8 @@ def spatial_func( self, func: Callable[[xr.DataArray], xr.DataArray] ) -> Callable | None: self.processor.spatial_func = func + # force a re-render + self.indices = self.indices @contextmanager diff --git a/fastplotlib/widgets/nd_widget/_index.py b/fastplotlib/widgets/nd_widget/_index.py index 31d026beb..b4cca34fe 100644 --- a/fastplotlib/widgets/nd_widget/_index.py +++ b/fastplotlib/widgets/nd_widget/_index.py @@ -76,6 +76,8 @@ def __init__( name: rr.start for name, rr in self._ref_ranges.items() } + self._indices_changed_handlers + def set(self, indices: dict[str, Any]): for dim, value in indices.items(): self._indices[dim] = self._clamp(value) @@ -115,6 +117,50 @@ def push_dim(self, ref_range: RangeContinuous): # TODO: implement pushing and popping dims pass + def add_event_handler(self, handler: callable, event: str = "indices"): + """ + Register an event handler. + + Currently the only event that ImageWidget supports is "indices". This event is + emitted whenever the indices of the ImageWidget changes. + + Parameters + ---------- + handler: callable + callback function, must take a tuple of int as the only argument. This tuple will be the `indices` + + event: str, "indices" + the only supported event is "indices" + + Example + ------- + + .. code-block:: py + + def my_handler(indices): + print(indices) + # example prints: {"t": 100, "z": 15} if the index has 2 slider dimensions "t" and "z" + + # create an NDWidget + ndw = NDWidget(...) + + # add event handler + ndw.indices.add_event_handler(my_handler) + + """ + if event != "indices": + raise ValueError("`indices` is the only event supported by `GlobalIndex`") + + self._indices_changed_handlers.add(handler) + + def remove_event_handler(self, handler: callable): + """Remove a registered event handler""" + self._indices_changed_handlers.remove(handler) + + def clear_event_handlers(self): + """Clear all registered event handlers""" + self._indices_changed_handlers.clear() + def __iter__(self): for index in self._indices.items(): yield index diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py index 08b5406ba..843953d67 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py @@ -387,6 +387,16 @@ def graphic(self, graphic_type): self._create_graphic(graphic_type) plot_area.add_graphic(self._graphic) + @property + def spatial_dims(self) -> tuple[str, str, str]: + return self.processor.spatial_dims + + @spatial_dims.setter + def spatial_dims(self, dims: tuple[str, str, str]): + self.processor.spatial_dims = dims + # force re-render + self.indices = self.indices + @property def indices(self) -> dict[Hashable, Any]: return {d: self._global_index[d] for d in self.processor.slider_dims} diff --git a/fastplotlib/widgets/nd_widget/_ndwidget.py b/fastplotlib/widgets/nd_widget/_ndwidget.py index 20f09ba55..a67c9d18d 100644 --- a/fastplotlib/widgets/nd_widget/_ndwidget.py +++ b/fastplotlib/widgets/nd_widget/_ndwidget.py @@ -51,3 +51,6 @@ def _get_ndgraphics(self): def show(self, **kwargs): return self.figure.show(**kwargs) + + def close(self): + self.figure.close() From 782951f529f1da0c9fe13a22ecc863132069cbe9 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 4 Mar 2026 01:25:54 -0500 Subject: [PATCH 060/101] more aliasing --- fastplotlib/widgets/nd_widget/_base.py | 10 ++++++++ fastplotlib/widgets/nd_widget/_nd_image.py | 2 +- .../nd_widget/_nd_positions/_nd_positions.py | 25 ++++++++++++++++--- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_base.py b/fastplotlib/widgets/nd_widget/_base.py index 6a997b206..de9826030 100644 --- a/fastplotlib/widgets/nd_widget/_base.py +++ b/fastplotlib/widgets/nd_widget/_base.py @@ -390,6 +390,16 @@ def dims(self) -> tuple[Hashable, ...]: """dim names""" return self.processor.dims + @property + def spatial_dims(self) -> tuple[str, ...]: + # number of spatial dims for positional data is always 3 + # for image is 2 or 3, so it must be implemented in subclass + raise NotImplementedError + + @property + def slider_dims(self) -> set[Hashable]: + return self.processor.slider_dims + @property def index_mappings(self) -> dict[Hashable, Callable[[Any], int]]: return self.processor.index_mappings diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py index 12a2b791e..3b363027d 100644 --- a/fastplotlib/widgets/nd_widget/_nd_image.py +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -265,7 +265,7 @@ def graphic( @graphic.setter def graphic(self, graphic_type): # TODO implement if graphic type changes to custom user subclass - pass + raise NotImplementedError def _create_graphic(self): match len(self.processor.spatial_dims) - int(bool(self.processor.rgb_dim)): diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py index 843953d67..20fec1fbc 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py @@ -124,11 +124,18 @@ def max_display_datapoints(self, n: int): @property def datapoints_window_func(self) -> tuple[Callable, str, int | float] | None: """ - Callable and str indicating which dims to apply window function along: + Callable, str indicating which dims to apply window function along, window_size in reference space: 'all', 'x', 'y', 'z', 'xyz', 'xy', 'xz', 'yz' '""" return self._datapoints_window_func + @datapoints_window_func.setter + def datapoints_window_func(self, funcs: tuple[Callable, str, int | float]): + if len(funcs) != 3: + raise TypeError + + self._datapoints_window_func = tuple(funcs) + def _get_dw_slice(self, indices: dict[str, Any]) -> slice: # given indices, return slice required to obtain display window @@ -168,7 +175,7 @@ def _get_dw_slice(self, indices: dict[str, Any]) -> slice: return slice(start, stop) - def _apply_dw_window_func(self, array: np.ndarray) -> np.ndarray: + def _apply_dw_window_func(self, array: xr.DataArray) -> xr.DataArray: """ Takes array where display window has already been applied and applies window functions on the `p` dim. @@ -256,7 +263,7 @@ def _apply_spatial_func(self, array: xr.DataArray) -> xr.DataArray: return array - def _finalize_(self, array): + def _finalize_(self, array: xr.DataArray) -> xr.DataArray: return self._apply_spatial_func(self._apply_dw_window_func(array)) def get(self, indices: dict[str, Any]): @@ -535,6 +542,18 @@ def display_window(self, dw: int | float | None): # force re-render self.indices = self.indices + @property + def datapoints_window_func(self) -> tuple[Callable, str, int | float] | None: + """ + Callable, str indicating which dims to apply window function along, window_size in reference space: + 'all', 'x', 'y', 'z', 'xyz', 'xy', 'xz', 'yz' + '""" + return self.processor.datapoints_window_func + + @datapoints_window_func.setter + def datapoints_window_func(self, funcs: tuple[Callable, str, int | float]): + self.processor.datapoints_window_func = funcs + @property def x_range_mode(self) -> Literal[None, "fixed-window", "view-range"]: """x-range using a fixed window from the display window, or by polling the camera (view-range)""" From 5c6c360a72c53a32e5de1df685de417b8b96d539 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 4 Mar 2026 01:58:13 -0500 Subject: [PATCH 061/101] fix --- fastplotlib/widgets/nd_widget/_index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastplotlib/widgets/nd_widget/_index.py b/fastplotlib/widgets/nd_widget/_index.py index b4cca34fe..9ba9d03eb 100644 --- a/fastplotlib/widgets/nd_widget/_index.py +++ b/fastplotlib/widgets/nd_widget/_index.py @@ -76,7 +76,7 @@ def __init__( name: rr.start for name, rr in self._ref_ranges.items() } - self._indices_changed_handlers + self._indices_changed_handlers = set() def set(self, indices: dict[str, Any]): for dim, value in indices.items(): From ca44b94f7bec3bcd29c1841dc0d088a87ab85d9e Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 4 Mar 2026 02:07:27 -0500 Subject: [PATCH 062/101] update example --- examples/ndwidget/timeseries.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/examples/ndwidget/timeseries.py b/examples/ndwidget/timeseries.py index fefc385df..a0a3074ff 100644 --- a/examples/ndwidget/timeseries.py +++ b/examples/ndwidget/timeseries.py @@ -23,22 +23,18 @@ for freq in range(data.shape[0]): for ampl in range(data.shape[1]): - ys = np.sin(xs * (freq + 1)) * (ampl + 1) + np.random.normal(0, 0.1, size=n_datapoints) + ys = np.sin(xs * (freq + 1)) * (ampl + 1) + np.random.normal( + 0, 0.1, size=n_datapoints + ) line = np.column_stack([xs, ys]) data[freq, ampl] = np.stack([line] * n_lines) # must define a reference range, this would often be your time dimension and corresponds to your x-dimension ref = { - "freq": ("freq", "Hz", 1, n_freqs + 1, 1), - "ampl": ("ampl", "arbitrary", 1, n_ampls + 1, 1), - "angle": ( - "angle", - "rad", - 0, - xs[-1], - 0.1, - ), + "freq": (1, n_freqs + 1, 1), + "ampl": (1, n_ampls + 1, 1), + "angle": (0, xs[-1], 0.1), } ndw = fpl.NDWidget(ref_ranges=ref, size=(700, 560)) From 2f680077d545690704c34d2462c03f9c19b3b56c Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 4 Mar 2026 02:17:13 -0500 Subject: [PATCH 063/101] fix --- fastplotlib/widgets/nd_widget/_nd_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py index 3b363027d..038e7d82f 100644 --- a/fastplotlib/widgets/nd_widget/_nd_image.py +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -290,7 +290,7 @@ def _create_graphic(self): plot_area.add_graphic(new_graphic) # set cmap and interpolation - for attr, val in attrs.keys(): + for attr, val in attrs.items(): setattr(new_graphic, attr, val) self._graphic = new_graphic From 7f2bcad312f900625a6b1ddd80cc5b4f4c0fb511 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 4 Mar 2026 02:36:59 -0500 Subject: [PATCH 064/101] ui --- fastplotlib/widgets/nd_widget/_ui.py | 175 +++++++++++++++++++-------- 1 file changed, 123 insertions(+), 52 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_ui.py b/fastplotlib/widgets/nd_widget/_ui.py index 3223fe595..be0999fe6 100644 --- a/fastplotlib/widgets/nd_widget/_ui.py +++ b/fastplotlib/widgets/nd_widget/_ui.py @@ -1,5 +1,8 @@ +import os +from time import perf_counter + import numpy as np -from imgui_bundle import imgui +from imgui_bundle import imgui, icons_fontawesome_6 as fa from ...graphics import ( ScatterCollection, @@ -31,69 +34,137 @@ def __init__(self, figure, size, ndwidget): ) self._ndwidget = ndwidget - # n_sliders = self._image_widget.n_sliders - # - # # whether or not a dimension is in play mode - # self._playing: list[bool] = [False] * n_sliders - # - # # approximate framerate for playing - # self._fps: list[int] = [20] * n_sliders - # - # # framerate converted to frame time - # self._frame_time: list[float] = [1 / 20] * n_sliders - # - # # last timepoint that a frame was displayed from a given dimension - # self._last_frame_time: list[float] = [perf_counter()] * n_sliders - # - # # loop playback - # self._loop = False - # - # # auto-plays the ImageWidget's left-most dimension in docs galleries - # if "DOCS_BUILD" in os.environ.keys(): - # if os.environ["DOCS_BUILD"] == "1": - # self._playing[0] = True - # self._loop = True - # - # self.pause = False + ref_ranges = self._ndwidget.ref_ranges - self._max_display_windows: dict[NDGraphic, float | int] = dict() + # whether or not a dimension is in play mode + self._playing = {dim: False for dim in ref_ranges.keys()} - def update(self): - if imgui.begin_tab_bar("NDWidget Controls"): + # approximate framerate for playing + self._fps = {dim: 20 for dim in ref_ranges.keys()} - if imgui.begin_tab_item("Indices")[0]: - for dim, current_index in self._ndwidget.indices: - refr = self._ndwidget.ref_ranges[dim] + # framerate converted to frame time + self._frame_time = {dim: 1 / 20 for dim in ref_ranges.keys()} - if isinstance(refr, RangeContinuous): - changed, new_index = imgui.slider_float( - v=current_index, - v_min=refr.start, - v_max=refr.stop, - label=dim, - ) + # last timepoint that a frame was displayed from a given dimension + self._last_frame_time = {dim: perf_counter() for dim in ref_ranges.keys()} - # TODO: refactor all this stuff, make fully fledged UI - if changed: - self._ndwidget.indices[dim] = new_index + # loop playback + self._loop ={dim: False for dim in ref_ranges.keys()} - elif imgui.is_item_hovered(): - if imgui.is_key_pressed(imgui.Key.right_arrow): - self._ndwidget.indices[dim] = current_index + refr.step + # auto-plays the ImageWidget's left-most dimension in docs galleries + if "DOCS_BUILD" in os.environ.keys(): + if os.environ["DOCS_BUILD"] == "1": + self._playing[0] = True + self._loop = True - elif imgui.is_key_pressed(imgui.Key.left_arrow): - self._ndwidget.indices[dim] = current_index - refr.step + self._max_display_windows: dict[NDGraphic, float | int] = dict() - imgui.end_tab_item() + def _set_index(self, dim, index): + if index >= self._ndwidget.ref_ranges[dim].stop: + if self._loop[dim]: + index = self._ndwidget.ref_ranges[dim].start + else: + index = self._ndwidget.ref_ranges[dim].stop + self._playing[dim] = False - if imgui.begin_tab_item("NDGraphic properties")[0]: - imgui.text("Subplots:") + self._ndwidget.indices[dim] = index - self._draw_nd_graphics_props_tab() + def update(self): + now = perf_counter() - imgui.end_tab_item() + for dim, current_index in self._ndwidget.indices: + # push id since we have the same buttons for each dim + imgui.push_id(f"{self._id_counter}_{dim}") - imgui.end_tab_bar() + rr = self._ndwidget.ref_ranges[dim] + + if self._playing[dim]: + # show pause button if playing + if imgui.button(label=fa.ICON_FA_PAUSE): + # if pause button clicked, then set playing to false + self._playing[dim] = False + + # if in play mode and enough time has elapsed w.r.t. the desired framerate, increment the index + if now - self._last_frame_time[dim] >= self._frame_time[dim]: + self._set_index(dim, current_index + rr.step) + self._last_frame_time[dim] = now + + else: + # we are not playing, so display play button + if imgui.button(label=fa.ICON_FA_PLAY): + # if play button is clicked, set last frame time to 0 so that index increments on next render + self._last_frame_time[dim] = 0 + # set playing to True since play button was clicked + self._playing[dim] = True + + imgui.same_line() + # step back one frame button + if imgui.button(label=fa.ICON_FA_BACKWARD_STEP) and not self._playing[dim]: + self._set_index(dim, current_index - rr.step) + + imgui.same_line() + # step forward one frame button + if imgui.button(label=fa.ICON_FA_FORWARD_STEP) and not self._playing[dim]: + self._set_index(dim, current_index + rr.step) + + imgui.same_line() + # stop button + if imgui.button(label=fa.ICON_FA_STOP): + self._playing[dim] = False + self._last_frame_time[dim] = 0 + self._ndwidget.indices[dim] = rr.start + + imgui.same_line() + # loop checkbox + _, self._loop[dim] = imgui.checkbox(label=fa.ICON_FA_ROTATE, v=self._loop[dim]) + if imgui.is_item_hovered(0): + imgui.set_tooltip("loop playback") + + imgui.same_line() + imgui.text("framerate :") + imgui.same_line() + imgui.set_next_item_width(100) + # framerate int entry + fps_changed, value = imgui.input_int( + label="fps", v=self._fps[dim], step_fast=5 + ) + if imgui.is_item_hovered(0): + imgui.set_tooltip( + "framerate is approximate and less reliable as it approaches your monitor refresh rate" + ) + if fps_changed: + if value < 1: + value = 1 + if value > 50: + value = 50 + self._fps[dim] = value + self._frame_time[dim] = 1 / value + + imgui.text(str(dim)) + imgui.same_line() + # so that slider occupies full width + imgui.set_next_item_width(self.width * 0.85) + + if isinstance(rr, RangeContinuous): + changed, new_index = imgui.slider_float( + v=current_index, + v_min=rr.start, + v_max=rr.stop - rr.step, + label=f"##{dim}", + ) + + # TODO: refactor all this stuff, make fully fledged UI + if changed: + self._ndwidget.indices[dim] = new_index + + elif imgui.is_item_hovered(): + if imgui.is_key_pressed(imgui.Key.right_arrow): + self._set_index(dim, current_index + rr.step) + + elif imgui.is_key_pressed(imgui.Key.left_arrow): + self._set_index(dim, current_index - rr.step) + + imgui.pop_id() def _draw_nd_graphics_props_tab(self): for subplot in self._ndwidget.figure: From c79c2835aecaea52ddeff20ae50ca3ff697818ad Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 4 Mar 2026 02:56:10 -0500 Subject: [PATCH 065/101] cleanup old iw-array, imports, add deprecation warning on old iw --- fastplotlib/widgets/__init__.py | 10 +- fastplotlib/widgets/image_widget/__init__.py | 1 - .../widgets/image_widget/_nd_iw_backup.py | 1007 ----------------- .../widgets/image_widget/_processor.py | 519 --------- .../widgets/image_widget/_properties.py | 139 --- fastplotlib/widgets/image_widget/_sliders.py | 91 +- fastplotlib/widgets/image_widget/_widget.py | 5 + fastplotlib/widgets/nd_widget/__init__.py | 16 +- 8 files changed, 65 insertions(+), 1723 deletions(-) delete mode 100644 fastplotlib/widgets/image_widget/_nd_iw_backup.py delete mode 100644 fastplotlib/widgets/image_widget/_processor.py delete mode 100644 fastplotlib/widgets/image_widget/_properties.py diff --git a/fastplotlib/widgets/__init__.py b/fastplotlib/widgets/__init__.py index 04102dbdf..4347f6c80 100644 --- a/fastplotlib/widgets/__init__.py +++ b/fastplotlib/widgets/__init__.py @@ -1,4 +1,12 @@ -from .nd_widget import NDWidget +from .nd_widget import ( + NDWidget, + NDProcessor, + NDGraphic, + NDPositionsProcessor, + NDPositions, + NDImageProcessor, + NDImage, +) from .image_widget import ImageWidget __all__ = ["NDWidget", "ImageWidget"] diff --git a/fastplotlib/widgets/image_widget/__init__.py b/fastplotlib/widgets/image_widget/__init__.py index dc5daea55..70a1aa8ae 100644 --- a/fastplotlib/widgets/image_widget/__init__.py +++ b/fastplotlib/widgets/image_widget/__init__.py @@ -2,7 +2,6 @@ if IMGUI: from ._widget import ImageWidget - from ._processor import NDImageProcessor else: diff --git a/fastplotlib/widgets/image_widget/_nd_iw_backup.py b/fastplotlib/widgets/image_widget/_nd_iw_backup.py deleted file mode 100644 index 7db265c0c..000000000 --- a/fastplotlib/widgets/image_widget/_nd_iw_backup.py +++ /dev/null @@ -1,1007 +0,0 @@ -from typing import Callable, Sequence, Literal -from warnings import warn - -import numpy as np - -from rendercanvas import BaseRenderCanvas - -from ...layouts import ImguiFigure as Figure -from ...graphics import ImageGraphic, ImageVolumeGraphic -from ...utils import calculate_figure_shape, quick_min_max, ArrayProtocol -from ...tools import HistogramLUTTool -from ._sliders import ImageWidgetSliders -from ._processor import NDImageProcessor, WindowFuncCallable -from ._properties import ImageWidgetProperty, Indices - - -IMGUI_SLIDER_HEIGHT = 49 - - -class ImageWidget: - def __init__( - self, - data: ArrayProtocol | Sequence[ArrayProtocol | None] | None, - processors: NDImageProcessor | Sequence[NDImageProcessor] = NDImageProcessor, - n_display_dims: Literal[2, 3] | Sequence[Literal[2, 3]] = 2, - slider_dim_names: Sequence[str] | None = None, # dim names left -> right - rgb: bool | Sequence[bool] = False, - cmap: str | Sequence[str] = "plasma", - window_funcs: ( - tuple[WindowFuncCallable | None, ...] - | WindowFuncCallable - | None - | Sequence[ - tuple[WindowFuncCallable | None, ...] | WindowFuncCallable | None - ] - ) = None, - window_sizes: ( - tuple[int | None, ...] | Sequence[tuple[int | None, ...] | None] - ) = None, - window_order: tuple[int, ...] | Sequence[tuple[int, ...] | None] = None, - spatial_func: ( - Callable[[ArrayProtocol], ArrayProtocol] - | Sequence[Callable[[ArrayProtocol], ArrayProtocol]] - | None - ) = None, - sliders_dim_order: Literal["right", "left"] = "right", - figure_shape: tuple[int, int] = None, - names: Sequence[str] = None, - figure_kwargs: dict = None, - histogram_widget: bool = True, - histogram_init_quantile: int = (0, 100), - graphic_kwargs: dict | Sequence[dict] = None, - ): - """ - This widget facilitates high-level navigation through image stacks, which are arrays containing one or more - images. It includes sliders for key dimensions such as "t" (time) and "z", enabling users to smoothly navigate - through one or multiple image stacks simultaneously. - - Allowed dimensions orders for each image stack: Note that each has a an optional (c) channel which refers to - RGB(A) a channel. So this channel should be either 3 or 4. - - Parameters - ---------- - data: ArrayProtocol | Sequence[ArrayProtocol | None] | None - array-like or a list of array-like, each array must have a minimum of 2 dimensions - - processors: NDImageProcessor | Sequence[NDImageProcessor], default NDImageProcessor - The image processors used for each n-dimensional data array - - n_display_dims: Literal[2, 3] | Sequence[Literal[2, 3]], default 2 - number of display dimensions - - slider_dim_names: Sequence[str], optional - optional list/tuple of names for each slider dim - - rgb: bool | Sequence[bool], default - whether or not each data array represents RGB(A) images - - figure_shape: Optional[Tuple[int, int]] - manually provide the shape for the Figure, otherwise the number of rows and columns is estimated - - figure_kwargs: dict, optional - passed to ``Figure`` - - names: Optional[str] - gives names to the subplots - - histogram_widget: bool, default False - make histogram LUT widget for each subplot - - rgb: bool | list[bool], default None - bool or list of bool for each input data array in the ImageWidget, indicating whether the corresponding - data arrays are grayscale or RGB(A). - - graphic_kwargs: Any - passed to each ImageGraphic in the ImageWidget figure subplots - - """ - - if figure_kwargs is None: - figure_kwargs = dict() - - if isinstance(data, ArrayProtocol) or (data is None): - data = [data] - - elif isinstance(data, (list, tuple)): - # verify that it's a list of np.ndarray - if not all([isinstance(d, ArrayProtocol) or d is None for d in data]): - raise TypeError( - f"`data` must be an array-like type or a list/tuple of array-like or None. " - f"You have passed the following type {type(data)}" - ) - - else: - raise TypeError( - f"`data` must be an array-like type or a list/tuple of array-like or None. " - f"You have passed the following type {type(data)}" - ) - - if issubclass(processors, NDImageProcessor): - processors = [processors] * len(data) - - elif isinstance(processors, (tuple, list)): - if not all([issubclass(p, NDImageProcessor) for p in processors]): - raise TypeError( - f"`processors` must be a `NDImageProcess` class, a subclass of `NDImageProcessor`, or a " - f"list/tuple of `NDImageProcess` subclasses. You have passed: {processors}" - ) - - else: - raise TypeError( - f"`processors` must be a `NDImageProcess` class, a subclass of `NDImageProcessor`, or a " - f"list/tuple of `NDImageProcess` subclasses. You have passed: {processors}" - ) - - # subplot layout - if figure_shape is None: - if "shape" in figure_kwargs: - figure_shape = figure_kwargs["shape"] - else: - figure_shape = calculate_figure_shape(len(data)) - - # Regardless of how figure_shape is computed, below code - # verifies that figure shape is large enough for the number of image arrays passed - if figure_shape[0] * figure_shape[1] < len(data): - original_shape = (figure_shape[0], figure_shape[1]) - figure_shape = calculate_figure_shape(len(data)) - warn( - f"Original `figure_shape` was: {original_shape} " - f" but data length is {len(data)}" - f" Resetting figure shape to: {figure_shape}" - ) - - elif isinstance(rgb, bool): - rgb = [rgb] * len(data) - - if not all([isinstance(v, bool) for v in rgb]): - raise TypeError( - f"`rgb` parameter must be a bool or a Sequence of bool, you have passed: {rgb}" - ) - - if not len(rgb) == len(data): - raise ValueError( - f"len(rgb) != len(data), {len(rgb)} != {len(data)}. These must be equal" - ) - - if names is not None: - if not all([isinstance(n, str) for n in names]): - raise TypeError("optional argument `names` must be a Sequence of str") - - if len(names) != len(data): - raise ValueError( - "number of `names` for subplots must be same as the number of data arrays" - ) - - # verify window funcs - if window_funcs is None: - win_funcs = [None] * len(data) - - elif callable(window_funcs) or all( - [callable(f) or f is None for f in window_funcs] - ): - # across all data arrays - # one window function defined for all dims, or window functions defined per-dim - win_funcs = [window_funcs] * len(data) - - # if the above two clauses didn't trigger, then window_funcs defined per-dim, per data array - elif len(window_funcs) != len(data): - raise IndexError - else: - win_funcs = window_funcs - - # verify window sizes - if window_sizes is None: - win_sizes = [window_sizes] * len(data) - - elif isinstance(window_sizes, int): - win_sizes = [window_sizes] * len(data) - - elif all([isinstance(size, int) or size is None for size in window_sizes]): - # window sizes defined per-dim across all data arrays - win_sizes = [window_sizes] * len(data) - - elif len(window_sizes) != len(data): - # window sizes defined per-dim, per data array - raise IndexError - else: - win_sizes = window_sizes - - # verify window orders - if window_order is None: - win_order = [None] * len(data) - - elif all([isinstance(o, int) for o in order]): - # window order defined per-dim across all data arrays - win_order = [window_order] * len(data) - - elif len(window_order) != len(data): - raise IndexError - - else: - win_order = window_order - - # verify spatial_func - if spatial_func is None: - spatial_func = [None] * len(data) - - elif callable(spatial_func): - # same spatial_func for all data arrays - spatial_func = [spatial_func] * len(data) - - elif len(spatial_func) != len(data): - raise IndexError - - else: - spatial_func = spatial_func - - # verify number of display dims - if isinstance(n_display_dims, (int, np.integer)): - n_display_dims = [n_display_dims] * len(data) - - elif isinstance(n_display_dims, (tuple, list)): - if not all([isinstance(n, (int, np.integer)) for n in n_display_dims]): - raise TypeError - - if len(n_display_dims) != len(data): - raise IndexError - else: - raise TypeError - - n_display_dims = tuple(n_display_dims) - - if sliders_dim_order not in ("right",): - raise ValueError( - f"Only 'right' slider dims order is currently supported, you passed: {sliders_dim_order}" - ) - self._sliders_dim_order = sliders_dim_order - - self._slider_dim_names = None - self.slider_dim_names = slider_dim_names - - self._histogram_widget = histogram_widget - - # make NDImageArrays - self._image_processors: list[NDImageProcessor] = list() - for i in range(len(data)): - cls = processors[i] - image_processor = cls( - data=data[i], - rgb=rgb[i], - n_display_dims=n_display_dims[i], - window_funcs=win_funcs[i], - window_sizes=win_sizes[i], - window_order=win_order[i], - spatial_func=spatial_func[i], - compute_histogram=self._histogram_widget, - ) - - self._image_processors.append(image_processor) - - self._data = ImageWidgetProperty(self, "data") - self._rgb = ImageWidgetProperty(self, "rgb") - self._n_display_dims = ImageWidgetProperty(self, "n_display_dims") - self._window_funcs = ImageWidgetProperty(self, "window_funcs") - self._window_sizes = ImageWidgetProperty(self, "window_sizes") - self._window_order = ImageWidgetProperty(self, "window_order") - self._spatial_func = ImageWidgetProperty(self, "spatial_func") - - if len(set(n_display_dims)) > 1: - # assume user wants one controller for 2D images and another for 3D image volumes - n_subplots = np.prod(figure_shape) - controller_ids = [0] * n_subplots - controller_types = ["panzoom"] * n_subplots - - for i in range(len(data)): - if n_display_dims[i] == 2: - controller_ids[i] = 1 - else: - controller_ids[i] = 2 - controller_types[i] = "orbit" - - # needs to be a list of list - controller_ids = [controller_ids] - - else: - controller_ids = "sync" - controller_types = None - - figure_kwargs_default = { - "controller_ids": controller_ids, - "controller_types": controller_types, - "names": names, - } - - # update the default kwargs with any user-specified kwargs - # user specified kwargs will overwrite the defaults - figure_kwargs_default.update(figure_kwargs) - figure_kwargs_default["shape"] = figure_shape - - if graphic_kwargs is None: - graphic_kwargs = [dict()] * len(data) - - elif isinstance(graphic_kwargs, dict): - graphic_kwargs = [graphic_kwargs] * len(data) - - elif len(graphic_kwargs) != len(data): - raise IndexError - - if cmap is None: - cmap = [None] * len(data) - - elif isinstance(cmap, str): - cmap = [cmap] * len(data) - - elif not all([isinstance(c, str) for c in cmap]): - raise TypeError(f"`cmap` must be a or a list/tuple of ") - - self._figure: Figure = Figure(**figure_kwargs_default) - - self._indices = Indices(list(0 for i in range(self.n_sliders)), self) - - for i, subplot in zip(range(len(self._image_processors)), self.figure): - image_data = self._get_image( - self._image_processors[i], tuple(self._indices) - ) - - if image_data is None: - # this subplot/data array is blank, skip - continue - - # next 20 lines are just vmin, vmax parsing - vmin_specified, vmax_specified = None, None - if "vmin" in graphic_kwargs[i].keys(): - vmin_specified = graphic_kwargs[i].pop("vmin") - if "vmax" in graphic_kwargs[i].keys(): - vmax_specified = graphic_kwargs[i].pop("vmax") - - if (vmin_specified is None) or (vmax_specified is None): - # if either vmin or vmax are not specified, calculate an estimate by subsampling - vmin_estimate, vmax_estimate = quick_min_max( - self._image_processors[i].data - ) - - # decide vmin, vmax passed to ImageGraphic constructor based on whether it's user specified or now - if vmin_specified is None: - # user hasn't specified vmin, use estimated value - vmin = vmin_estimate - else: - # user has provided a specific value, use that - vmin = vmin_specified - - if vmax_specified is None: - vmax = vmax_estimate - else: - vmax = vmax_specified - else: - # both vmin and vmax are specified - vmin, vmax = vmin_specified, vmax_specified - - graphic_kwargs[i]["cmap"] = cmap[i] - - if self._image_processors[i].n_display_dims == 2: - # create an Image - graphic = ImageGraphic( - data=image_data, - name="image_widget_managed", - vmin=vmin, - vmax=vmax, - **graphic_kwargs[i], - ) - elif self._image_processors[i].n_display_dims == 3: - # create an ImageVolume - graphic = ImageVolumeGraphic( - data=image_data, - name="image_widget_managed", - vmin=vmin, - vmax=vmax, - **graphic_kwargs[i], - ) - subplot.camera.fov = 50 - - subplot.add_graphic(graphic) - - self._reset_histogram(subplot, self._image_processors[i]) - - self._sliders_ui = ImageWidgetSliders( - figure=self.figure, - size=57 + (IMGUI_SLIDER_HEIGHT * self.n_sliders), - location="bottom", - title="ImageWidget Controls", - image_widget=self, - ) - - self.figure.add_gui(self._sliders_ui) - - self._indices_changed_handlers = set() - - self._reentrant_block = False - - @property - def data(self) -> ImageWidgetProperty[ArrayProtocol | None]: - """get or set the nd-image data arrays""" - return self._data - - @data.setter - def data(self, new_data: Sequence[ArrayProtocol | None]): - if isinstance(new_data, ArrayProtocol) or new_data is None: - new_data = [new_data] * len(self._image_processors) - - if len(new_data) != len(self._image_processors): - raise IndexError - - # if the data array hasn't been changed - # graphics will not be reset for this data index - skip_indices = list() - - for i, (new_data, image_processor) in enumerate( - zip(new_data, self._image_processors) - ): - if new_data is image_processor.data: - skip_indices.append(i) - continue - - image_processor.data = new_data - - self._reset(skip_indices) - - @property - def rgb(self) -> ImageWidgetProperty[bool]: - """get or set the rgb toggle for each data array""" - return self._rgb - - @rgb.setter - def rgb(self, rgb: Sequence[bool]): - if isinstance(rgb, bool): - rgb = [rgb] * len(self._image_processors) - - if len(rgb) != len(self._image_processors): - raise IndexError - - # if the rgb option hasn't been changed - # graphics will not be reset for this data index - skip_indices = list() - - for i, (new, image_processor) in enumerate(zip(rgb, self._image_processors)): - if image_processor.rgb == new: - skip_indices.append(i) - continue - - image_processor.rgb = new - - self._reset(skip_indices) - - @property - def n_display_dims(self) -> ImageWidgetProperty[Literal[2, 3]]: - """Get or set the number of display dimensions for each data array, 2 is a 2D image, 3 is a 3D volume image""" - return self._n_display_dims - - @n_display_dims.setter - def n_display_dims(self, new_ndd: Sequence[Literal[2, 3]] | Literal[2, 3]): - if isinstance(new_ndd, (int, np.integer)): - if new_ndd == 2 or new_ndd == 3: - new_ndd = [new_ndd] * len(self._image_processors) - else: - raise ValueError - - if len(new_ndd) != len(self._image_processors): - raise IndexError - - if not all([(n == 2) or (n == 3) for n in new_ndd]): - raise ValueError - - # if the n_display_dims hasn't been changed for this data array - # graphics will not be reset for this data array index - skip_indices = list() - - # first update image arrays - for i, (image_processor, new) in enumerate( - zip(self._image_processors, new_ndd) - ): - if new > image_processor.max_n_display_dims: - raise IndexError( - f"number of display dims exceeds maximum number of possible " - f"display dimensions: {image_processor.max_n_display_dims}, for array at index: " - f"{i} with shape: {image_processor.shape}, and rgb set to: {image_processor.rgb}" - ) - - if image_processor.n_display_dims == new: - skip_indices.append(i) - else: - image_processor.n_display_dims = new - - self._reset(skip_indices) - - @property - def window_funcs(self) -> ImageWidgetProperty[tuple[WindowFuncCallable | None] | None]: - """get or set the window functions""" - return self._window_funcs - - @window_funcs.setter - def window_funcs(self, new_funcs: Sequence[WindowFuncCallable | None] | None): - if callable(new_funcs) or new_funcs is None: - new_funcs = [new_funcs] * len(self._image_processors) - - if len(new_funcs) != len(self._image_processors): - raise IndexError - - self._set_image_processor_funcs("window_funcs", new_funcs) - - @property - def window_sizes(self) -> ImageWidgetProperty[tuple[int | None, ...] | None]: - """get or set the window sizes""" - return self._window_sizes - - @window_sizes.setter - def window_sizes( - self, new_sizes: Sequence[tuple[int | None, ...] | int | None] | int | None - ): - if isinstance(new_sizes, int) or new_sizes is None: - # same window for all data arrays - new_sizes = [new_sizes] * len(self._image_processors) - - if len(new_sizes) != len(self._image_processors): - raise IndexError - - self._set_image_processor_funcs("window_sizes", new_sizes) - - @property - def window_order(self) -> ImageWidgetProperty[tuple[int, ...] | None]: - """get or set order in which window functions are applied over dimensions""" - return self._window_order - - @window_order.setter - def window_order(self, new_order: Sequence[tuple[int, ...]]): - if new_order is None: - new_order = [new_order] * len(self._image_processors) - - if all([isinstance(order, (int, np.integer))] for order in new_order): - # same order specified across all data arrays - new_order = [new_order] * len(self._image_processors) - - if len(new_order) != len(self._image_processors): - raise IndexError - - self._set_image_processor_funcs("window_order", new_order) - - @property - def spatial_func(self) -> ImageWidgetProperty[Callable | None]: - """Get or set a spatial_func that operates on the spatial dimensions of the 2D or 3D image""" - return self._spatial_func - - @spatial_func.setter - def spatial_func(self, funcs: Callable | Sequence[Callable] | None): - if callable(funcs) or funcs is None: - funcs = [funcs] * len(self._image_processors) - - if len(funcs) != len(self._image_processors): - raise IndexError - - self._set_image_processor_funcs("spatial_func", funcs) - - def _set_image_processor_funcs(self, attr, new_values): - """sets window_funcs, window_sizes, window_order, or spatial_func and updates displayed data and histograms""" - for new, image_processor, subplot in zip( - new_values, self._image_processors, self.figure - ): - if getattr(image_processor, attr) == new: - continue - - setattr(image_processor, attr, new) - - # window functions and spatial functions will only change the histogram - # they do not change the collections of dimensions, so we don't need to call _reset_dimensions - # they also do not change the image graphic, so we do not need to call _reset_image_graphics - self._reset_histogram(subplot, image_processor) - - # update the displayed image data in the graphics - self.indices = self.indices - - @property - def indices(self) -> ImageWidgetProperty[int]: - """ - Get or set the current indices. - - Returns - ------- - indices: ImageWidgetProperty[int] - integer index for each slider dimension - - """ - return self._indices - - @indices.setter - def indices(self, new_indices: Sequence[int]): - if self._reentrant_block: - return - - try: - self._reentrant_block = True # block re-execution until new_indices has *fully* completed execution - - if len(new_indices) != self.n_sliders: - raise IndexError( - f"len(new_indices) != ImageWidget.n_sliders, {len(new_indices)} != {self.n_sliders}. " - f"The length of the new_indices must be the same as the number of sliders" - ) - - if any([i < 0 for i in new_indices]): - raise IndexError( - f"only positive index values are supported, you have passed: {new_indices}" - ) - - for image_processor, graphic in zip(self._image_processors, self.graphics): - new_data = self._get_image(image_processor, indices=new_indices) - if new_data is None: - continue - - graphic.data = new_data - - self._indices._fpl_set(new_indices) - - # call any event handlers - for handler in self._indices_changed_handlers: - handler(tuple(self.indices)) - - except Exception as exc: - # raise original exception - raise exc # indices setter has raised. The lines above below are probably more relevant! - finally: - # set_value has finished executing, now allow future executions - self._reentrant_block = False - - @property - def histogram_widget(self) -> bool: - """show or hide the histograms""" - return self._histogram_widget - - @histogram_widget.setter - def histogram_widget(self, show_histogram: bool): - if not isinstance(show_histogram, bool): - raise TypeError( - f"`histogram_widget` can be set with a bool, you have passed: {show_histogram}" - ) - - for subplot, image_processor in zip(self.figure, self._image_processors): - image_processor.compute_histogram = show_histogram - self._reset_histogram(subplot, image_processor) - - @property - def n_sliders(self) -> int: - """number of sliders""" - return max([a.n_slider_dims for a in self._image_processors]) - - @property - def bounds(self) -> tuple[int, ...]: - """The max bound across all dimensions across all data arrays""" - # initialize with 0 - bounds = [0] * self.n_sliders - - # TODO: implement left -> right slider dims ordering, right now it's only right -> left - # in reverse because dims go left <- right - for i, dim in enumerate(range(-1, -self.n_sliders - 1, -1)): - # across each dim - for array in self._image_processors: - if i > array.n_slider_dims - 1: - continue - # across each data array - # dims go left <- right - bounds[dim] = max(array.slider_dims_shape[dim], bounds[dim]) - - return bounds - - @property - def slider_dim_names(self) -> tuple[str, ...]: - return self._slider_dim_names - - @slider_dim_names.setter - def slider_dim_names(self, names: Sequence[str]): - if names is None: - self._slider_dim_names = None - return - - if not all([isinstance(n, str) for n in names]): - raise TypeError(f"`slider_dim_names` must be set with a list/tuple of , you passed: {names}") - - if len(set(names)) != len(names): - raise ValueError( - f"`slider_dim_names` must be unique, you passed: {names}" - ) - - self._slider_dim_names = tuple(names) - - def _get_image( - self, image_processor: NDImageProcessor, indices: Sequence[int] - ) -> ArrayProtocol: - """Get a processed 2d or 3d image from the NDImage at the given indices""" - n = image_processor.n_slider_dims - - if self._sliders_dim_order == "right": - return image_processor.get(indices[-n:]) - - elif self._sliders_dim_order == "left": - # TODO: left -> right is not fully implemented yet in ImageWidget - return image_processor.get(indices[:n]) - - def _reset_dimensions(self): - """reset the dimensions w.r.t. current collection of NDImageProcessors""" - # TODO: implement left -> right slider dims ordering, right now it's only right -> left - # add or remove dims from indices - # trim any excess dimensions - while len(self._indices) > self.n_sliders: - # remove outer most dims first - self._indices.pop_dim() - self._sliders_ui.pop_dim() - - # add any new dimensions that aren't present - while len(self.indices) < self.n_sliders: - # insert right -> left - self._indices.push_dim() - self._sliders_ui.push_dim() - - self._sliders_ui.size = 57 + (IMGUI_SLIDER_HEIGHT * self.n_sliders) - - def _reset_image_graphics(self, subplot, image_processor): - """delete and create a new image graphic if necessary""" - new_image = self._get_image(image_processor, indices=tuple(self.indices)) - if new_image is None: - if "image_widget_managed" in subplot: - # delete graphic from this subplot if present - subplot.delete_graphic(subplot["image_widget_managed"]) - # skip this subplot - return - - # check if a graphic exists - if "image_widget_managed" in subplot: - # create a new graphic only if the Texture buffer shape doesn't match - if subplot["image_widget_managed"].data.value.shape == new_image.shape: - return - - # keep cmap - cmap = subplot["image_widget_managed"].cmap - if cmap is None: - # ex: going from rgb -> grayscale - cmap = "plasma" - # delete graphic since it will be replaced - subplot.delete_graphic(subplot["image_widget_managed"]) - else: - # default cmap - cmap = "plasma" - - if image_processor.n_display_dims == 2: - g = subplot.add_image( - data=new_image, cmap=cmap, name="image_widget_managed" - ) - - # set camera orthogonal to the xy plane, flip y axis - subplot.camera.set_state( - { - "position": [0, 0, -1], - "rotation": [0, 0, 0, 1], - "scale": [1, -1, 1], - "reference_up": [0, 1, 0], - "fov": 0, - "depth_range": None, - } - ) - - subplot.controller = "panzoom" - subplot.axes.intersection = None - subplot.auto_scale() - - elif image_processor.n_display_dims == 3: - g = subplot.add_image_volume( - data=new_image, cmap=cmap, name="image_widget_managed" - ) - subplot.camera.fov = 50 - subplot.controller = "orbit" - - # make sure all 3D dimension camera scales are positive - # MIP rendering doesn't work with negative camera scales - for dim in ["x", "y", "z"]: - if getattr(subplot.camera.local, f"scale_{dim}") < 0: - setattr(subplot.camera.local, f"scale_{dim}", 1) - - subplot.auto_scale() - - def _reset_histogram(self, subplot, image_processor): - """reset the histogram""" - if not self._histogram_widget: - subplot.docks["right"].size = 0 - return - - if image_processor.histogram is None: - # no histogram available for this processor - # either there is no data array in this subplot, - # or a histogram routine does not exist for this processor - subplot.docks["right"].size = 0 - return - - if "image_widget_managed" not in subplot: - # no image in this subplot - subplot.docks["right"].size = 0 - return - - image = subplot["image_widget_managed"] - - if "histogram_lut" in subplot.docks["right"]: - hlut: HistogramLUTTool = subplot.docks["right"]["histogram_lut"] - hlut.histogram = image_processor.histogram - hlut.images = image - if subplot.docks["right"].size < 1: - subplot.docks["right"].size = 80 - - else: - # need to make one - hlut = HistogramLUTTool( - histogram=image_processor.histogram, - images=image, - name="histogram_lut", - ) - - subplot.docks["right"].add_graphic(hlut) - subplot.docks["right"].size = 80 - - self.reset_vmin_vmax() - - def _reset(self, skip_data_indices: tuple[int, ...] = None): - if skip_data_indices is None: - skip_data_indices = tuple() - - # reset the slider indices according to the new collection of dimensions - self._reset_dimensions() - # update graphics where display dims have changed accordings to indices - for i, (subplot, image_processor) in enumerate( - zip(self.figure, self._image_processors) - ): - if i in skip_data_indices: - continue - - self._reset_image_graphics(subplot, image_processor) - self._reset_histogram(subplot, image_processor) - - # force an update - self.indices = self.indices - - @property - def figure(self) -> Figure: - """ - ``Figure`` used by `ImageWidget`. - """ - return self._figure - - @property - def graphics(self) -> list[ImageGraphic]: - """List of ``ImageWidget`` managed graphics.""" - iw_managed = list() - for subplot in self.figure: - if "image_widget_managed" in subplot: - iw_managed.append(subplot["image_widget_managed"]) - else: - iw_managed.append(None) - return tuple(iw_managed) - - @property - def cmap(self) -> tuple[str | None, ...]: - """get the cmaps, or set the cmap across all images""" - return tuple(g.cmap for g in self.graphics) - - @cmap.setter - def cmap(self, name: str): - for g in self.graphics: - if g is None: - # no data at this index - continue - - if g.cmap is None: - # if rgb - continue - - g.cmap = name - - def add_event_handler(self, handler: callable, event: str = "indices"): - """ - Register an event handler. - - Currently the only event that ImageWidget supports is "indices". This event is - emitted whenever the indices of the ImageWidget changes. - - Parameters - ---------- - handler: callable - callback function, must take a tuple of int as the only argument. This tuple will be the `indices` - - event: str, "indices" - the only supported event is "indices" - - Example - ------- - - .. code-block:: py - - def my_handler(indices): - print(indices) - # example prints: (100, 15) if the data has 2 slider dimensions with sliders at positions 100, 15 - - # create an image widget - iw = ImageWidget(...) - - # add event handler - iw.add_event_handler(my_handler) - - """ - if event != "indices": - raise ValueError("`indices` is the only event supported by `ImageWidget`") - - self._indices_changed_handlers.add(handler) - - def remove_event_handler(self, handler: callable): - """Remove a registered event handler""" - self._indices_changed_handlers.remove(handler) - - def clear_event_handlers(self): - """Clear all registered event handlers""" - self._indices_changed_handlers.clear() - - def reset_vmin_vmax(self): - """ - Reset the vmin and vmax w.r.t. the full data - """ - for image_processor, subplot in zip(self._image_processors, self.figure): - if "histogram_lut" not in subplot.docks["right"]: - continue - - if image_processor.histogram is None: - continue - - hlut = subplot.docks["right"]["histogram_lut"] - hlut.histogram = image_processor.histogram - - edges = image_processor.histogram[1] - - hlut.vmin, hlut.vmax = edges[0], edges[-1] - - def reset_vmin_vmax_frame(self): - """ - Resets the vmin vmax and HistogramLUT widgets w.r.t. the current data shown in the - ImageGraphic instead of the data in the full data array. For example, if a post-processing - function is used, the range of values in the ImageGraphic can be very different from the - range of values in the full data array. - """ - - for subplot, image_processor in zip(self.figure, self._image_processors): - if "histogram_lut" not in subplot.docks["right"]: - continue - - if image_processor.histogram is None: - continue - - hlut = subplot.docks["right"]["histogram_lut"] - # set the data using the current image graphic data - image = subplot["image_widget_managed"] - freqs, edges = np.histogram(image.data.value, bins=100) - hlut.histogram = (freqs, edges) - hlut.vmin, hlut.vmax = edges[0], edges[-1] - - def show(self, **kwargs): - """ - Show the widget. - - Parameters - ---------- - - kwargs: Any - passed to `Figure.show()`t - - Returns - ------- - BaseRenderCanvas - In Qt or GLFW, the canvas window containing the Figure will be shown. - In jupyter, it will display the plot in the output cell or sidecar. - - """ - - return self.figure.show(**kwargs) - - def close(self): - """Close Widget""" - self.figure.close() diff --git a/fastplotlib/widgets/image_widget/_processor.py b/fastplotlib/widgets/image_widget/_processor.py deleted file mode 100644 index 0dce84a5e..000000000 --- a/fastplotlib/widgets/image_widget/_processor.py +++ /dev/null @@ -1,519 +0,0 @@ -import inspect -from typing import Literal, Callable -from warnings import warn - -import numpy as np -from numpy.typing import ArrayLike - -from ...utils import subsample_array, ArrayProtocol, ARRAY_LIKE_ATTRS - - -# must take arguments: array-like, `axis`: int, `keepdims`: bool -WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] - - -class NDImageProcessor: - def __init__( - self, - data: ArrayLike | None, - n_display_dims: Literal[2, 3] = 2, - rgb: bool = False, - window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, - window_sizes: tuple[int | None, ...] | int = None, - window_order: tuple[int, ...] = None, - spatial_func: Callable[[ArrayLike], ArrayLike] = None, - compute_histogram: bool = True, - ): - """ - An ND image that supports computing window functions, and functions over spatial dimensions. - - Parameters - ---------- - data: ArrayLike - array-like data, must have 2 or more dimensions - - n_display_dims: int, 2 or 3, default 2 - number of display dimensions - - rgb: bool, default False - whether the image data is RGB(A) or not - - window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable, optional - A function or a ``tuple`` of functions that are applied to a rolling window of the data. - - You can provide unique window functions for each dimension. If you want to apply a window function - only to a subset of the dimensions, put ``None`` to indicate no window function for a given dimension. - - A "window function" must take ``axis`` argument, which is an ``int`` that specifies the axis along which - the window function is applied. It must also take a ``keepdims`` argument which is a ``bool``. The window - function **must** return an array that has the same number of dimensions as the original ``data`` array, - therefore the size of the dimension along which the window was applied will reduce to ``1``. - - The output array-like type from a window function **must** support a ``.squeeze()`` method, but the - function itself should NOT squeeze the output array. - - window_sizes: tuple[int | None, ...], optional - ``tuple`` of ``int`` that specifies the window size for each dimension. - - window_order: tuple[int, ...] | None, optional - order in which to apply the window functions, by default just applies it from the left-most dim to the - right-most slider dim. - - spatial_func: Callable[[ArrayLike], ArrayLike] | None, optional - A function that is applied on the _spatial_ dimensions of the data array, i.e. the last 2 or 3 dimensions. - This function is applied after the window functions (if present). - - compute_histogram: bool, default True - Compute a histogram of the data, auto re-computes if window function propties or spatial_func changes. - Disable if slow. - - """ - # set as False until data, window funcs stuff and spatial func is all set - self._compute_histogram = False - - self.data = data - self.n_display_dims = n_display_dims - self.rgb = rgb - - self.window_funcs = window_funcs - self.window_sizes = window_sizes - self.window_order = window_order - - self._spatial_func = spatial_func - - self._compute_histogram = compute_histogram - self._recompute_histogram() - - @property - def data(self) -> ArrayLike | None: - """get or set the data array""" - return self._data - - @data.setter - def data(self, data: ArrayLike): - # check that all array-like attributes are present - if data is None: - self._data = None - return - - if not isinstance(data, ArrayProtocol): - raise TypeError( - f"`data` arrays must have all of the following attributes to be sufficiently array-like:\n" - f"{ARRAY_LIKE_ATTRS}, or they must be `None`" - ) - - if data.ndim < 2: - raise IndexError( - f"Image data must have a minimum of 2 dimensions, you have passed an array of shape: {data.shape}" - ) - - self._data = data - self._recompute_histogram() - - @property - def ndim(self) -> int: - if self.data is None: - return 0 - - return self.data.ndim - - @property - def shape(self) -> tuple[int, ...]: - if self._data is None: - return tuple() - - return self.data.shape - - @property - def rgb(self) -> bool: - """whether or not the data is rgb(a)""" - return self._rgb - - @rgb.setter - def rgb(self, rgb: bool): - if not isinstance(rgb, bool): - raise TypeError - - if rgb and self.ndim < 3: - raise IndexError( - f"require 3 or more dims for RGB, you have: {self.ndim} dims" - ) - - self._rgb = rgb - - @property - def n_slider_dims(self) -> int: - """number of slider dimensions""" - if self._data is None: - return 0 - - return self.ndim - self.n_display_dims - int(self.rgb) - - @property - def slider_dims(self) -> tuple[int, ...] | None: - """tuple indicating the slider dimension indices""" - if self.n_slider_dims == 0: - return None - - return tuple(range(self.n_slider_dims)) - - @property - def slider_dims_shape(self) -> tuple[int, ...] | None: - if self.n_slider_dims == 0: - return None - - return tuple(self.shape[i] for i in self.slider_dims) - - @property - def n_display_dims(self) -> Literal[2, 3]: - """get or set the number of display dimensions, `2` for 2D image and `3` for volume images""" - return self._n_display_dims - - # TODO: make n_display_dims settable, requires thinking about inserting and poping indices in ImageWidget - @n_display_dims.setter - def n_display_dims(self, n: Literal[2, 3]): - if not (n == 2 or n == 3): - raise ValueError( - f"`n_display_dims` must be an with a value of 2 or 3, you have passed: {n}" - ) - self._n_display_dims = n - self._recompute_histogram() - - @property - def max_n_display_dims(self) -> int: - """maximum number of possible display dims""" - # min 2, max 3, accounts for if data is None and ndim is 0 - return max(2, min(3, self.ndim - int(self.rgb))) - - @property - def display_dims(self) -> tuple[int, int] | tuple[int, int, int]: - """tuple indicating the display dimension indices""" - return tuple(range(self.data.ndim))[self.n_slider_dims :] - - @property - def window_funcs( - self, - ) -> tuple[WindowFuncCallable | None, ...] | None: - """get or set window functions, see docstring for details""" - return self._window_funcs - - @window_funcs.setter - def window_funcs( - self, - window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable | None, - ): - if window_funcs is None: - self._window_funcs = None - return - - if callable(window_funcs): - window_funcs = (window_funcs,) - - # if all are None - if all([f is None for f in window_funcs]): - self._window_funcs = None - return - - self._validate_window_func(window_funcs) - - self._window_funcs = tuple(window_funcs) - self._recompute_histogram() - - def _validate_window_func(self, funcs): - if isinstance(funcs, (tuple, list)): - for f in funcs: - if f is None: - pass - elif callable(f): - sig = inspect.signature(f) - - if "axis" not in sig.parameters or "keepdims" not in sig.parameters: - raise TypeError( - f"Each window function must take an `axis` and `keepdims` argument, " - f"you passed: {f} with the following function signature: {sig}" - ) - else: - raise TypeError( - f"`window_funcs` must be of type: tuple[Callable | None, ...], you have passed: {funcs}" - ) - - if not (len(funcs) == self.n_slider_dims or self.n_slider_dims == 0): - raise IndexError( - f"number of `window_funcs` must be the same as the number of slider dims: {self.n_slider_dims}, " - f"and you passed {len(funcs)} `window_funcs`: {funcs}" - ) - - @property - def window_sizes(self) -> tuple[int | None, ...] | None: - """get or set window sizes used for the corresponding window functions, see docstring for details""" - return self._window_sizes - - @window_sizes.setter - def window_sizes(self, window_sizes: tuple[int | None, ...] | int | None): - if window_sizes is None: - self._window_sizes = None - return - - if isinstance(window_sizes, int): - window_sizes = (window_sizes,) - - # if all are None - if all([w is None for w in window_sizes]): - self._window_sizes = None - return - - if not all([isinstance(w, (int)) or w is None for w in window_sizes]): - raise TypeError( - f"`window_sizes` must be of type: tuple[int | None, ...] | int | None, you have passed: {window_sizes}" - ) - - if not (len(window_sizes) == self.n_slider_dims or self.n_slider_dims == 0): - raise IndexError( - f"number of `window_sizes` must be the same as the number of slider dims, " - f"i.e. `data.ndim` - n_display_dims, your data array has {self.ndim} dimensions " - f"and you passed {len(window_sizes)} `window_sizes`: {window_sizes}" - ) - - # make all window sizes are valid numbers - _window_sizes = list() - for i, w in enumerate(window_sizes): - if w is None: - _window_sizes.append(None) - continue - - if w < 0: - raise ValueError( - f"negative window size passed, all `window_sizes` must be positive " - f"integers or `None`, you passed: {_window_sizes}" - ) - - if w == 0 or w == 1: - # this is not a real window, set as None - w = None - - elif w % 2 == 0: - # odd window sizes makes most sense - warn( - f"provided even window size: {w} in dim: {i}, adding `1` to make it odd" - ) - w += 1 - - _window_sizes.append(w) - - self._window_sizes = tuple(_window_sizes) - self._recompute_histogram() - - @property - def window_order(self) -> tuple[int, ...] | None: - """get or set dimension order in which window functions are applied""" - return self._window_order - - @window_order.setter - def window_order(self, order: tuple[int] | None): - if order is None: - self._window_order = None - return - - if order is not None: - if not all([d <= self.n_slider_dims for d in order]): - raise IndexError( - f"all `window_order` entries must be <= n_slider_dims\n" - f"`n_slider_dims` is: {self.n_slider_dims}, you have passed `window_order`: {order}" - ) - - if not all([d >= 0 for d in order]): - raise IndexError( - f"all `window_order` entires must be >= 0, you have passed: {order}" - ) - - self._window_order = tuple(order) - self._recompute_histogram() - - @property - def spatial_func(self) -> Callable[[ArrayLike], ArrayLike] | None: - """get or set a spatial_func function, see docstring for details""" - return self._spatial_func - - @spatial_func.setter - def spatial_func(self, func: Callable[[ArrayLike], ArrayLike] | None): - if not (callable(func) or func is not None): - raise TypeError( - f"`spatial_func` must be a callable or `None`, you have passed: {func}" - ) - - self._spatial_func = func - self._recompute_histogram() - - @property - def compute_histogram(self) -> bool: - return self._compute_histogram - - @compute_histogram.setter - def compute_histogram(self, compute: bool): - if compute: - if self._compute_histogram is False: - # compute a histogram - self._recompute_histogram() - self._compute_histogram = True - else: - self._compute_histogram = False - self._histogram = None - - @property - def histogram(self) -> tuple[np.ndarray, np.ndarray] | None: - """ - an estimate of the histogram of the data, (histogram_values, bin_edges). - - returns `None` if `compute_histogram` is `False` - """ - return self._histogram - - def _apply_window_function(self, indices: tuple[int, ...]) -> ArrayLike: - """applies the window functions for each dimension specified""" - # window size for each dim - winds = self._window_sizes - # window function for each dim - funcs = self._window_funcs - - if winds is None or funcs is None: - # no window funcs or window sizes, just slice data and return - # clamp to max bounds - indexer = list() - for dim, i in enumerate(indices): - i = min(self.shape[dim] - 1, i) - indexer.append(i) - - return self.data[tuple(indexer)] - - # order in which window funcs are applied - order = self._window_order - - if order is not None: - # remove any entries in `window_order` where the specified dim - # has a window function or window size specified as `None` - # example: - # window_sizes = (3, 2) - # window_funcs = (np.mean, None) - # order = (0, 1) - # `1` is removed from the order since that window_func is `None` - order = tuple( - d for d in order if winds[d] is not None and funcs[d] is not None - ) - else: - # sequential order - order = list() - for d in range(self.n_slider_dims): - if winds[d] is not None and funcs[d] is not None: - order.append(d) - - # the final indexer which will be used on the data array - indexer = list() - - for dim_index, (i, w, f) in enumerate(zip(indices, winds, funcs)): - # clamp i within the max bounds - i = min(self.shape[dim_index] - 1, i) - - if (w is not None) and (f is not None): - # specify slice window if both window size and function for this dim are not None - hw = int((w - 1) / 2) # half window - - # start index cannot be less than 0 - start = max(0, i - hw) - - # stop index cannot exceed the bounds of this dimension - stop = min(self.shape[dim_index] - 1, i + hw) - - s = slice(start, stop, 1) - else: - s = slice(i, i + 1, 1) - - indexer.append(s) - - # apply indexer to slice data with the specified windows - data_sliced = self.data[tuple(indexer)] - - # finally apply the window functions in the specified order - for dim in order: - f = funcs[dim] - - data_sliced = f(data_sliced, axis=dim, keepdims=True) - - return data_sliced - - def get(self, indices: tuple[int, ...]) -> ArrayLike | None: - """ - Get the data at the given index, process data through the window functions. - - Note that we do not use __getitem__ here since the index is a tuple specifying a single integer - index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. - - Parameters - ---------- - indices: tuple[int, ...] - Get the processed data at this index. Must provide a value for each dimension. - Example: get((100, 5)) - - """ - if self.data is None: - return None - - if self.n_slider_dims != 0: - if len(indices) != self.n_slider_dims: - raise IndexError( - f"Must specify index for every slider dim, you have specified an index: {indices}\n" - f"But there are: {self.n_slider_dims} slider dims." - ) - # get output after processing through all window funcs - # squeeze to remove all dims of size 1 - window_output = self._apply_window_function(indices).squeeze() - else: - # data is a static image or volume - window_output = self.data - - # apply spatial_func - if self.spatial_func is not None: - final_output = self.spatial_func(window_output) - if final_output.ndim != (self.n_display_dims + int(self.rgb)): - raise IndexError( - f"Final output after of the `spatial_func` must match the number of display dims." - f"Output after `spatial_func` returned an array with {final_output.ndim} dims and " - f"of shape: {final_output.shape}, expected {self.n_display_dims} dims" - ) - else: - # check that output ndim after window functions matches display dims - final_output = window_output - if final_output.ndim != (self.n_display_dims + int(self.rgb)): - raise IndexError( - f"Final output after of the `window_funcs` must match the number of display dims." - f"Output after `window_funcs` returned an array with {window_output.ndim} dims and " - f"of shape: {window_output.shape}{' with rgb(a) channels' if self.rgb else ''}, " - f"expected {self.n_display_dims} dims" - ) - - return final_output - - def _recompute_histogram(self): - """ - - Returns - ------- - (histogram_values, bin_edges) - - """ - if not self._compute_histogram or self.data is None: - self._histogram = None - return - - if self.spatial_func is not None: - # don't subsample spatial dims if a spatial function is used - # spatial functions often operate on the spatial dims, ex: a gaussian kernel - # so their results require the full spatial resolution, the histogram of a - # spatially subsampled image will be very different - ignore_dims = self.display_dims - else: - ignore_dims = None - - sub = subsample_array(self.data, ignore_dims=ignore_dims) - sub_real = sub[~(np.isnan(sub) | np.isinf(sub))] - - self._histogram = np.histogram(sub_real, bins=100) diff --git a/fastplotlib/widgets/image_widget/_properties.py b/fastplotlib/widgets/image_widget/_properties.py deleted file mode 100644 index 060314439..000000000 --- a/fastplotlib/widgets/image_widget/_properties.py +++ /dev/null @@ -1,139 +0,0 @@ -from pprint import pformat -from typing import Iterable - -import numpy as np - -from ._processor import NDImageProcessor - - -class ImageWidgetProperty: - __class_getitem__ = classmethod(type(list[int])) - - def __init__( - self, - image_widget, - attribute: str, - ): - self._image_widget = image_widget - self._image_processors: list[NDImageProcessor] = image_widget._image_processors - self._attribute = attribute - - def _get_key(self, key: slice | int | np.integer | str) -> int | slice: - if not isinstance(key, (slice | int, np.integer, str)): - raise TypeError( - f"can index `{self._attribute}` only with a , , or a indicating the subplot name." - f"You tried to index with: {key}" - ) - - if isinstance(key, str): - for i, subplot in enumerate(self._image_widget.figure): - if subplot.name == key: - key = i - break - else: - raise IndexError(f"No subplot with given name: {key}") - - return key - - def __getitem__(self, key): - key = self._get_key(key) - # return image processor attribute at this index - if isinstance(key, (int, np.integer)): - return getattr(self._image_processors[key], self._attribute) - - # if it's a slice - processors = self._image_processors[key] - - return tuple(getattr(p, self._attribute) for p in processors) - - def __setitem__(self, key, value): - key = self._get_key(key) - - # get the values from the ImageWidget property - new_values = list(getattr(p, self._attribute) for p in self._image_processors) - - # set the new value at this slice - new_values[key] = value - - # call the setter - setattr(self._image_widget, self._attribute, new_values) - - def __iter__(self): - for image_processor in self._image_processors: - yield getattr(image_processor, self._attribute) - - def __repr__(self): - return f"{self._attribute}: {pformat(self[:])}" - - def __eq__(self, other): - return self[:] == other - - -class Indices: - def __init__( - self, - indices: list[int], - image_widget, - ): - self._data = indices - - self._image_widget = image_widget - - def __iter__(self): - for i in self._data: - yield i - - def _parse_key(self, key: int | np.integer | str) -> int: - if not isinstance(key, (int, np.integer, str)): - raise TypeError( - f"indices can only be indexed with or types, you have used: {key}" - ) - - if isinstance(key, str): - # get integer index from user's names - names = self._image_widget._slider_dim_names - if key not in names: - raise KeyError( - f"dim with name: {key} not found in slider_dim_names, current names are: {names}" - ) - - key = names.index(key) - - return key - - def __getitem__(self, key: int | np.integer | str) -> int | tuple[int]: - if isinstance(key, str): - key = self._parse_key(key) - - return self._data[key] - - def __setitem__(self, key, value): - key = self._parse_key(key) - - if not isinstance(value, (int, np.integer)): - raise TypeError( - f"indices values can only be set with integers, you have tried to set the value: {value}" - ) - - new_indices = list(self._data) - new_indices[key] = value - - self._image_widget.indices = new_indices - - def _fpl_set(self, values): - self._data[:] = values - - def pop_dim(self): - self._data.pop(0) - - def push_dim(self): - self._data.insert(0, 0) - - def __len__(self): - return len(self._data) - - def __eq__(self, other): - return self._data == other - - def __repr__(self): - return f"indices: {self._data}" diff --git a/fastplotlib/widgets/image_widget/_sliders.py b/fastplotlib/widgets/image_widget/_sliders.py index 1945b8cfb..393b13273 100644 --- a/fastplotlib/widgets/image_widget/_sliders.py +++ b/fastplotlib/widgets/image_widget/_sliders.py @@ -11,66 +11,50 @@ def __init__(self, figure, size, location, title, image_widget): super().__init__(figure=figure, size=size, location=location, title=title) self._image_widget = image_widget - n_sliders = self._image_widget.n_sliders - # whether or not a dimension is in play mode - self._playing: list[bool] = [False] * n_sliders + self._playing: dict[str, bool] = {"t": False, "z": False} # approximate framerate for playing - self._fps: list[int] = [20] * n_sliders - + self._fps: dict[str, int] = {"t": 20, "z": 20} # framerate converted to frame time - self._frame_time: list[float] = [1 / 20] * n_sliders + self._frame_time: dict[str, float] = {"t": 1 / 20, "z": 1 / 20} # last timepoint that a frame was displayed from a given dimension - self._last_frame_time: list[float] = [perf_counter()] * n_sliders + self._last_frame_time: dict[str, float] = {"t": 0, "z": 0} - # loop playback self._loop = False - # auto-plays the ImageWidget's left-most dimension in docs galleries - if "DOCS_BUILD" in os.environ.keys(): - if os.environ["DOCS_BUILD"] == "1": - self._playing[0] = True + if "RTD_BUILD" in os.environ.keys(): + if os.environ["RTD_BUILD"] == "1": + self._playing["t"] = True self._loop = True - self.pause = False - - def pop_dim(self): - """pop right most dim""" - i = 0 # len(self._image_widget.indices) - 1 - for l in [self._playing, self._fps, self._frame_time, self._last_frame_time]: - l.pop(i) - - def push_dim(self): - """push a new dim""" - self._playing.insert(0, False) - self._fps.insert(0, 20) - self._frame_time.insert(0, 1 / 20) - self._last_frame_time.insert(0, perf_counter()) - - def set_index(self, dim: int, new_index: int): - """set the index of the ImageWidget""" + def set_index(self, dim: str, index: int): + """set the current_index of the ImageWidget""" # make sure the max index for this dim is not exceeded - max_index = self._image_widget.bounds[dim] - 1 - if new_index > max_index: + max_index = self._image_widget._dims_max_bounds[dim] - 1 + if index > max_index: if self._loop: # loop back to index zero if looping is enabled - new_index = 0 + index = 0 else: # if looping not enabled, stop playing this dimension self._playing[dim] = False return - # set new index - new_indices = list(self._image_widget.indices) - new_indices[dim] = new_index - self._image_widget.indices = new_indices + # set current_index + self._image_widget.current_index = {dim: min(index, max_index)} def update(self): """called on every render cycle to update the GUI elements""" + # store the new index of the image widget ("t" and "z") + new_index = dict() + + # flag if the index changed + flag_index_changed = False + # reset vmin-vmax using full orig data if imgui.button(label=fa.ICON_FA_CIRCLE_HALF_STROKE + fa.ICON_FA_FILM): self._image_widget.reset_vmin_vmax() @@ -88,7 +72,7 @@ def update(self): now = perf_counter() # buttons and slider UI elements for each dim - for dim in range(self._image_widget.n_sliders): + for dim in self._image_widget.slider_dims: imgui.push_id(f"{self._id_counter}_{dim}") if self._playing[dim]: @@ -99,7 +83,7 @@ def update(self): # if in play mode and enough time has elapsed w.r.t. the desired framerate, increment the index if now - self._last_frame_time[dim] >= self._frame_time[dim]: - self.set_index(dim, self._image_widget.indices[dim] + 1) + self.set_index(dim, self._image_widget.current_index[dim] + 1) self._last_frame_time[dim] = now else: @@ -113,12 +97,12 @@ def update(self): imgui.same_line() # step back one frame button if imgui.button(label=fa.ICON_FA_BACKWARD_STEP) and not self._playing[dim]: - self.set_index(dim, self._image_widget.indices[dim] - 1) + self.set_index(dim, self._image_widget.current_index[dim] - 1) imgui.same_line() # step forward one frame button if imgui.button(label=fa.ICON_FA_FORWARD_STEP) and not self._playing[dim]: - self.set_index(dim, self._image_widget.indices[dim] + 1) + self.set_index(dim, self._image_widget.current_index[dim] + 1) imgui.same_line() # stop button @@ -153,15 +137,10 @@ def update(self): self._fps[dim] = value self._frame_time[dim] = 1 / value - val = self._image_widget.indices[dim] - vmax = self._image_widget.bounds[dim] - 1 - - dim_name = dim - if self._image_widget._slider_dim_names is not None: - if dim < len(self._image_widget._slider_dim_names): - dim_name = self._image_widget._slider_dim_names[dim] + val = self._image_widget.current_index[dim] + vmax = self._image_widget._dims_max_bounds[dim] - 1 - imgui.text(f"dim '{dim_name}:' ") + imgui.text(f"{dim}: ") imgui.same_line() # so that slider occupies full width imgui.set_next_item_width(self.width * 0.85) @@ -175,12 +154,18 @@ def update(self): # slider for this dimension changed, index = imgui.slider_int( - f"d: {dim}", v=val, v_min=0, v_max=vmax, flags=flags + f"{dim}", v=val, v_min=0, v_max=vmax, flags=flags ) - if changed: - new_indices = list(self._image_widget.indices) - new_indices[dim] = index - self._image_widget.indices = new_indices + new_index[dim] = index + + # if the slider value changed for this dimension + flag_index_changed |= changed imgui.pop_id() + + if flag_index_changed: + # if any slider dim changed set the new index of the image widget + self._image_widget.current_index = new_index + + self.size = int(imgui.get_window_height()) diff --git a/fastplotlib/widgets/image_widget/_widget.py b/fastplotlib/widgets/image_widget/_widget.py index 86a01b083..0b0f25164 100644 --- a/fastplotlib/widgets/image_widget/_widget.py +++ b/fastplotlib/widgets/image_widget/_widget.py @@ -358,6 +358,11 @@ def __init__( passed to each ImageGraphic in the ImageWidget figure subplots """ + warnings.warn( + "`ImageWidget` is deprecated and will be removed in a" + " future release, please migrate to NDWidget", + DeprecationWarning + ) self._initialized = False if figure_kwargs is None: diff --git a/fastplotlib/widgets/nd_widget/__init__.py b/fastplotlib/widgets/nd_widget/__init__.py index 0617a729d..378f7dfcd 100644 --- a/fastplotlib/widgets/nd_widget/__init__.py +++ b/fastplotlib/widgets/nd_widget/__init__.py @@ -1,14 +1,24 @@ from ...layouts import IMGUI -if IMGUI: - from ._base import NDProcessor +try: + import imgui_bundle +except ImportError: + HAS_XARRAY = False +else: + HAS_XARRAY = True + + +if IMGUI and HAS_XARRAY: + from ._base import NDProcessor, NDGraphic from ._nd_positions import NDPositions, NDPositionsProcessor, ndp_extras from ._nd_image import NDImageProcessor, NDImage from ._ndwidget import NDWidget + else: + class NDWidget: def __init__(self, *args, **kwargs): raise ModuleNotFoundError( - "NDWidget requires `imgui-bundle` to be installed.\n" + "NDWidget requires `imgui-bundle` and `xarray` to be installed.\n" "pip install imgui-bundle" ) From 611f8979e114e7231c56c4b7311464106b03d759 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 4 Mar 2026 02:58:25 -0500 Subject: [PATCH 066/101] add ndwidget section to deps with xarray --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 73dfd7ee3..30d194e79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,8 @@ tests = [ "ome-zarr", ] imgui = ["wgpu[imgui]"] -dev = ["fastplotlib[docs,notebook,tests,imgui]"] +ndwidget = ["wgpu[imgui]", "xarray"] +dev = ["fastplotlib[docs,notebook,tests,imgui,ndwidget]"] [project.urls] Homepage = "https://www.fastplotlib.org/" From 05907ed83d7bcaf11fb323702e0b20e3a167f238 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 4 Mar 2026 21:11:24 -0500 Subject: [PATCH 067/101] nice repr for NDProcessor --- fastplotlib/widgets/nd_widget/_base.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/fastplotlib/widgets/nd_widget/_base.py b/fastplotlib/widgets/nd_widget/_base.py index de9826030..00c190418 100644 --- a/fastplotlib/widgets/nd_widget/_base.py +++ b/fastplotlib/widgets/nd_widget/_base.py @@ -2,6 +2,8 @@ from contextlib import contextmanager import inspect from numbers import Real +from pprint import pformat +import textwrap from typing import Literal, Any from warnings import warn @@ -311,6 +313,20 @@ def _apply_window_functions(self, indices) -> xr.DataArray: def get(self, indices: dict[Hashable, Any]): raise NotImplementedError + def __repr__(self): + tab = "\t" + return ( + f"{self.__class__.__name__}\n" + f"shape:\n\t{self.shape}\n" + f"dims:\n\t{self.dims}\n" + f"spatial_dims:\n\t{self.spatial_dims}\n" + f"slider_dims:\n\t{self.slider_dims}\n" + f"index_mappings:\n{textwrap.indent(pformat(self.index_mappings, width=120), prefix=tab)}\n" + f"window_funcs:\n{textwrap.indent(pformat(self.window_funcs, width=120), prefix=tab)}\n" + f"window_order:\n\t{self.window_order}\n" + f"spatial_func:\n\t{self.spatial_func}\n" + ) + def block_reentrance(setter): # decorator to block re-entrant indices setter @@ -454,6 +470,12 @@ def spatial_func( # force a re-render self.indices = self.indices + def __repr__(self): + return ( + f"graphic: {self.graphic}\n" + f"processor:\n{self.processor}" + ) + @contextmanager def block_indices(ndgraphic: NDGraphic): From 04718f8794bf155cb05e07af8adc204920cb2944 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Thu, 5 Mar 2026 01:15:22 -0500 Subject: [PATCH 068/101] imgui right click menu for ndgraphics --- examples/ndwidget/ndimage.py | 3 +- examples/ndwidget/timeseries.py | 1 + fastplotlib/layouts/_imgui_figure.py | 9 ++- .../ui/right_click_menus/_standard_menu.py | 6 ++ fastplotlib/widgets/nd_widget/_ndwidget.py | 5 +- fastplotlib/widgets/nd_widget/_ui.py | 72 ++++++++++++++----- 6 files changed, 72 insertions(+), 24 deletions(-) diff --git a/examples/ndwidget/ndimage.py b/examples/ndwidget/ndimage.py index 4212f46b6..80c010ea1 100644 --- a/examples/ndwidget/ndimage.py +++ b/examples/ndwidget/ndimage.py @@ -28,9 +28,10 @@ data, ("time", "depth", "m", "n"), # specify all dim names ("m", "n"), # specify spatial dims IN ORDER, rest are auto slider dims + name="4d-image", ) # change spatial dims on the fly -ndi.spatial_dims = ("depth", "m", "n") +# ndi.spatial_dims = ("depth", "m", "n") fpl.loop.run() diff --git a/examples/ndwidget/timeseries.py b/examples/ndwidget/timeseries.py index a0a3074ff..e506182e3 100644 --- a/examples/ndwidget/timeseries.py +++ b/examples/ndwidget/timeseries.py @@ -49,6 +49,7 @@ "freq": lambda x: int(x + 1), }, x_range_mode="view-range", + name="nd-sine" ) nd_lines.graphic.cmap = "tab10" diff --git a/fastplotlib/layouts/_imgui_figure.py b/fastplotlib/layouts/_imgui_figure.py index 33cc6d925..15b3d7c45 100644 --- a/fastplotlib/layouts/_imgui_figure.py +++ b/fastplotlib/layouts/_imgui_figure.py @@ -44,6 +44,7 @@ def __init__( canvas_kwargs: dict = None, size: tuple[int, int] = (500, 300), names: list | np.ndarray = None, + std_right_click_menu: type[Popup] = StandardRightClickMenu, ): self._guis: dict[str, EdgeWindow] = {k: None for k in GUI_EDGES} @@ -105,7 +106,7 @@ def __init__( toolbar = SubplotToolbar(subplot=subplot) self._subplot_toolbars[i] = toolbar - self._right_click_menu = StandardRightClickMenu(figure=self) + self._std_right_click_menu = std_right_click_menu(figure=self) self._popups: dict[str, Popup] = {} @@ -118,6 +119,10 @@ def __init__( def default_imgui_font(self) -> imgui.ImFont: return self._default_imgui_font + @property + def std_right_click_menu(self) -> Popup: + return self._std_right_click_menu + @property def guis(self) -> dict[str, EdgeWindow]: """GUI windows added to the Figure""" @@ -158,7 +163,7 @@ def _draw_imgui(self) -> imgui.ImDrawData: for popup in self._popups.values(): popup.update() - self._right_click_menu.update() + self._std_right_click_menu.update() # imgui.end_frame() diff --git a/fastplotlib/ui/right_click_menus/_standard_menu.py b/fastplotlib/ui/right_click_menus/_standard_menu.py index bb9e5bdef..78b5f4c9f 100644 --- a/fastplotlib/ui/right_click_menus/_standard_menu.py +++ b/fastplotlib/ui/right_click_menus/_standard_menu.py @@ -47,6 +47,10 @@ def cleanup(self): """called when the popup disappears""" self.is_open = False + def _extra_menu(self): + # extra menu items, optional, implement in subclass + pass + def update(self): if imgui.is_mouse_down(1) and not self._mouse_down: # mouse button was pressed down, store this position @@ -182,4 +186,6 @@ def update(self): imgui.end_menu() + self._extra_menu() + imgui.end_popup() diff --git a/fastplotlib/widgets/nd_widget/_ndwidget.py b/fastplotlib/widgets/nd_widget/_ndwidget.py index a67c9d18d..8449a2c70 100644 --- a/fastplotlib/widgets/nd_widget/_ndwidget.py +++ b/fastplotlib/widgets/nd_widget/_ndwidget.py @@ -2,14 +2,15 @@ from ._index import RangeContinuous, RangeDiscrete, GlobalIndex from ._ndw_subplot import NDWSubplot -from ._ui import NDWidgetUI +from ._ui import NDWidgetUI, RightClickMenu from ...layouts import ImguiFigure, Subplot class NDWidget: def __init__(self, ref_ranges: dict[str, tuple], **kwargs): self._indices = GlobalIndex(ref_ranges, self._get_ndgraphics) - self._figure = ImguiFigure(**kwargs) + self._figure = ImguiFigure(std_right_click_menu=RightClickMenu, **kwargs) + self._figure.std_right_click_menu.set_nd_widget(self) self._subplots_nd: dict[Subplot, NDWSubplot] = dict() for subplot in self.figure: diff --git a/fastplotlib/widgets/nd_widget/_ui.py b/fastplotlib/widgets/nd_widget/_ui.py index be0999fe6..eba5a97d3 100644 --- a/fastplotlib/widgets/nd_widget/_ui.py +++ b/fastplotlib/widgets/nd_widget/_ui.py @@ -11,14 +11,15 @@ ImageGraphic, ImageVolumeGraphic, ) +from ...utils import quick_min_max from ...layouts import Subplot -from ...ui import EdgeWindow -from . import NDPositions +from ...ui import EdgeWindow, StandardRightClickMenu from ._index import RangeContinuous from ._base import NDGraphic +from ._nd_positions import NDPositions +from ._nd_image import NDImage position_graphics = [ScatterCollection, LineCollection, LineStack, ImageGraphic] -image_graphics = [ImageGraphic, ImageVolumeGraphic] class NDWidgetUI(EdgeWindow): @@ -49,7 +50,7 @@ def __init__(self, figure, size, ndwidget): self._last_frame_time = {dim: perf_counter() for dim in ref_ranges.keys()} # loop playback - self._loop ={dim: False for dim in ref_ranges.keys()} + self._loop = {dim: False for dim in ref_ranges.keys()} # auto-plays the ImageWidget's left-most dimension in docs galleries if "DOCS_BUILD" in os.environ.keys(): @@ -116,7 +117,9 @@ def update(self): imgui.same_line() # loop checkbox - _, self._loop[dim] = imgui.checkbox(label=fa.ICON_FA_ROTATE, v=self._loop[dim]) + _, self._loop[dim] = imgui.checkbox( + label=fa.ICON_FA_ROTATE, v=self._loop[dim] + ) if imgui.is_item_hovered(0): imgui.set_tooltip("loop playback") @@ -166,26 +169,57 @@ def update(self): imgui.pop_id() - def _draw_nd_graphics_props_tab(self): - for subplot in self._ndwidget.figure: - if imgui.tree_node(subplot.name): - self._draw_ndgraphics_node(subplot) - imgui.tree_pop() - def _draw_ndgraphics_node(self, subplot: Subplot): - for ng in self._ndwidget[subplot].nd_graphics: - if imgui.tree_node(str(ng)): - if isinstance(ng, NDPositions): - self._draw_nd_pos_ui(subplot, ng) - imgui.tree_pop() +class RightClickMenu(StandardRightClickMenu): + def __init__(self, figure): + self._ndwidget = None + super().__init__(figure=figure) + + def set_nd_widget(self, ndw): + self._ndwidget = ndw + + def _extra_menu(self): + if self._ndwidget is None: + return + + if imgui.begin_menu("ND Graphics"): + subplot = self.get_subplot() + for ndg in self._ndwidget[subplot].nd_graphics: + if imgui.begin_menu( + f"{ndg.name if ndg.name is not None else hex(id(ndg))}" + ): + if isinstance(ndg, NDPositions): + self._draw_nd_pos_ui(subplot, ndg) + elif isinstance(ndg, NDImage): + self._draw_nd_image_ui(subplot, ndg) + imgui.end_menu() + imgui.end_menu() + + def _draw_nd_image_ui(self, subplot, nd_image: NDImage): + _min, _max = quick_min_max(nd_image.graphic.data.value) + changed, vmin = imgui.slider_float( + "vmin", nd_image.graphic.vmin, v_min=_min, v_max=_max + ) + if changed: + nd_image.graphic.vmin = vmin + + changed, vmax = imgui.slider_float( + "vmax", nd_image.graphic.vmax, v_min=_min, v_max=_max + ) + if changed: + nd_image.graphic.vmax = vmax + + changed, new_gamma = imgui.slider_float( + "gamma", nd_image.graphic._material.gamma, 0.01, 5 + ) + if changed: + nd_image.graphic._material.gamma = new_gamma def _draw_nd_pos_ui(self, subplot: Subplot, nd_graphic: NDPositions): for i, cls in enumerate(position_graphics): if imgui.radio_button(cls.__name__, type(nd_graphic.graphic) is cls): nd_graphic.graphic = cls subplot.auto_scale() - if i < len(position_graphics) - 1: - imgui.same_line() changed, val = imgui.checkbox( "use display window", nd_graphic.display_window is not None @@ -214,7 +248,7 @@ def _draw_nd_pos_ui(self, subplot: Subplot, nd_graphic: NDPositions): "display window", v=nd_graphic.display_window, v_min=type_(0), - v_max=type_(self._ndwidget.ref_ranges[p_dim].stop * 0.25), + v_max=type_(self._ndwidget.ref_ranges[p_dim].stop * 0.1), ) if changed: From 0f03bd2d60432dd23e40bcb6a04e3bca66c49796 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Thu, 5 Mar 2026 01:47:45 -0500 Subject: [PATCH 069/101] better --- fastplotlib/widgets/nd_widget/_ui.py | 38 +++++++++++++++++++++------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_ui.py b/fastplotlib/widgets/nd_widget/_ui.py index eba5a97d3..843d93961 100644 --- a/fastplotlib/widgets/nd_widget/_ui.py +++ b/fastplotlib/widgets/nd_widget/_ui.py @@ -2,7 +2,7 @@ from time import perf_counter import numpy as np -from imgui_bundle import imgui, icons_fontawesome_6 as fa +from imgui_bundle import imgui, imgui_ctx, icons_fontawesome_6 as fa from ...graphics import ( ScatterCollection, @@ -173,6 +173,8 @@ def update(self): class RightClickMenu(StandardRightClickMenu): def __init__(self, figure): self._ndwidget = None + self._ndgraphic_windows = set() + super().__init__(figure=figure) def set_nd_widget(self, ndw): @@ -185,16 +187,34 @@ def _extra_menu(self): if imgui.begin_menu("ND Graphics"): subplot = self.get_subplot() for ndg in self._ndwidget[subplot].nd_graphics: - if imgui.begin_menu( - f"{ndg.name if ndg.name is not None else hex(id(ndg))}" - ): - if isinstance(ndg, NDPositions): - self._draw_nd_pos_ui(subplot, ndg) - elif isinstance(ndg, NDImage): - self._draw_nd_image_ui(subplot, ndg) - imgui.end_menu() + name = ndg.name if ndg.name is not None else hex(id(ndg)) + if imgui.menu_item( + f"{name}", "", False + )[0]: + self._ndgraphic_windows.add(ndg) + imgui.end_menu() + def update(self): + super().update() + subplot = self.get_subplot() + + for ndg in list(self._ndgraphic_windows): # set -> list so we can change size during iteration + name = ndg.name if ndg.name is not None else hex(id(ndg)) + imgui.set_next_window_size((0, 0)) + _, open = imgui.begin(name, True) + + if isinstance(ndg, NDPositions): + self._draw_nd_pos_ui(subplot, ndg) + + elif isinstance(ndg, NDImage): + self._draw_nd_image_ui(subplot, ndg) + + if not open: + self._ndgraphic_windows.remove(ndg) + + imgui.end() + def _draw_nd_image_ui(self, subplot, nd_image: NDImage): _min, _max = quick_min_max(nd_image.graphic.data.value) changed, vmin = imgui.slider_float( From 057308d8ef290bd0ab61bc822b33cfdb4c29d899 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Thu, 5 Mar 2026 01:59:28 -0500 Subject: [PATCH 070/101] controller options separate window --- .../ui/right_click_menus/_standard_menu.py | 80 +++++++++++-------- 1 file changed, 48 insertions(+), 32 deletions(-) diff --git a/fastplotlib/ui/right_click_menus/_standard_menu.py b/fastplotlib/ui/right_click_menus/_standard_menu.py index 78b5f4c9f..9c659f4a7 100644 --- a/fastplotlib/ui/right_click_menus/_standard_menu.py +++ b/fastplotlib/ui/right_click_menus/_standard_menu.py @@ -31,6 +31,8 @@ def __init__(self, figure): # whether the right click menu is currently open or not self.is_open: bool = False + self._controller_window_open: bool | PlotArea = False + def get_subplot(self) -> PlotArea | bool | None: """get the subplot that a click occurred in""" if self._last_right_click_pos is None: @@ -151,41 +153,55 @@ def update(self): imgui.separator() # controller options - if imgui.begin_menu("Controller"): - _, enabled = imgui.menu_item( - "Enabled", "", self.get_subplot().controller.enabled - ) - - self.get_subplot().controller.enabled = enabled - - changed, damping = imgui.slider_float( - "Damping", - v=self.get_subplot().controller.damping, - v_min=0.0, - v_max=10.0, - ) - - if changed: - self.get_subplot().controller.damping = damping + if imgui.menu_item("Controller Options", "", False)[0]: + self._controller_window_open = self.get_subplot() - imgui.separator() - imgui.text("Controller type:") - # switching between different controllers - for name, controller_type_iter in controller_types.items(): - current_type = type(self.get_subplot().controller) + self._extra_menu() - clicked, _ = imgui.menu_item( - label=name, - shortcut="", - p_selected=current_type is controller_type_iter, - ) + imgui.end_popup() - if clicked and (current_type is not controller_type_iter): - # menu item was clicked and the desired controller isn't the current one - self.get_subplot().controller = name + if self._controller_window_open: + self._draw_controller_window() + + def _draw_controller_window(self): + subplot = self._controller_window_open + + imgui.set_next_window_size((0, 0)) + _, keep_open = imgui.begin(f"Controller", True) + imgui.text(f"subplot: {subplot.name}") + _, enabled = imgui.menu_item( + "Enabled", "", subplot.controller.enabled + ) + + subplot.controller.enabled = enabled + + changed, damping = imgui.slider_float( + "Damping", + v=subplot.controller.damping, + v_min=0.0, + v_max=10.0, + ) + + if changed: + subplot.controller.damping = damping + + imgui.separator() + imgui.text("Controller type:") + # switching between different controllers + for name, controller_type_iter in controller_types.items(): + current_type = type(subplot.controller) + + clicked, _ = imgui.menu_item( + label=name, + shortcut="", + p_selected=current_type is controller_type_iter, + ) - imgui.end_menu() + if clicked and (current_type is not controller_type_iter): + # menu item was clicked and the desired controller isn't the current one + subplot.controller = name - self._extra_menu() + if not keep_open: + self._controller_window_open = False - imgui.end_popup() + imgui.end() From 0335af787ebd7768c341a92aa8e3f4bc77208ee1 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Thu, 5 Mar 2026 02:35:04 -0500 Subject: [PATCH 071/101] update imgui --- fastplotlib/widgets/nd_widget/_base.py | 59 ++++++++++++-------------- fastplotlib/widgets/nd_widget/_ui.py | 4 +- 2 files changed, 30 insertions(+), 33 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_base.py b/fastplotlib/widgets/nd_widget/_base.py index 00c190418..4541640a6 100644 --- a/fastplotlib/widgets/nd_widget/_base.py +++ b/fastplotlib/widgets/nd_widget/_base.py @@ -328,33 +328,6 @@ def __repr__(self): ) -def block_reentrance(setter): - # decorator to block re-entrant indices setter - def set_indices_wrapper(self: NDGraphic, new_indices): - """ - wraps NDGraphic.indices - - self: NDGraphic instance - - new_indices: new indices to set - """ - # set_value is already in the middle of an execution, block re-entrance - if self._block_indices: - return - try: - # block re-execution of set_value until it has *fully* finished executing - self._block_indices = True - setter(self, new_indices) - except Exception as exc: - # raise original exception - raise exc # set_value has raised. The line above and the lines 2+ steps below are probably more relevant! - finally: - # set_value has finished executing, now allow future executions - self._block_indices = False - - return set_indices_wrapper - - class NDGraphic: def __init__(self, name: str | None): self._name = name @@ -471,10 +444,7 @@ def spatial_func( self.indices = self.indices def __repr__(self): - return ( - f"graphic: {self.graphic}\n" - f"processor:\n{self.processor}" - ) + return f"graphic: {self.graphic}\n" f"processor:\n{self.processor}" @contextmanager @@ -505,3 +475,30 @@ def block_indices(ndgraphic: NDGraphic): raise e from None # indices setter has raised, the line above and the lines below are probably more relevant! finally: ndgraphic._block_indices = False + + +def block_reentrance(setter): + # decorator to block re-entrant indices setter + def set_indices_wrapper(self: NDGraphic, new_indices): + """ + wraps NDGraphic.indices + + self: NDGraphic instance + + new_indices: new indices to set + """ + # set_value is already in the middle of an execution, block re-entrance + if self._block_indices: + return + try: + # block re-execution of set_value until it has *fully* finished executing + self._block_indices = True + setter(self, new_indices) + except Exception as exc: + # raise original exception + raise exc # set_value has raised. The line above and the lines 2+ steps below are probably more relevant! + finally: + # set_value has finished executing, now allow future executions + self._block_indices = False + + return set_indices_wrapper diff --git a/fastplotlib/widgets/nd_widget/_ui.py b/fastplotlib/widgets/nd_widget/_ui.py index 843d93961..a75d99e00 100644 --- a/fastplotlib/widgets/nd_widget/_ui.py +++ b/fastplotlib/widgets/nd_widget/_ui.py @@ -197,12 +197,12 @@ def _extra_menu(self): def update(self): super().update() - subplot = self.get_subplot() for ndg in list(self._ndgraphic_windows): # set -> list so we can change size during iteration name = ndg.name if ndg.name is not None else hex(id(ndg)) + subplot = ndg.graphic._plot_area imgui.set_next_window_size((0, 0)) - _, open = imgui.begin(name, True) + _, open = imgui.begin(f"subplot: {subplot.name}, {name}", True) if isinstance(ndg, NDPositions): self._draw_nd_pos_ui(subplot, ndg) From 5ee11eca9224c3b7b0338d31f4be144f384229cf Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Thu, 5 Mar 2026 20:47:20 -0500 Subject: [PATCH 072/101] fix --- fastplotlib/widgets/image_widget/_widget.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastplotlib/widgets/image_widget/_widget.py b/fastplotlib/widgets/image_widget/_widget.py index 0b0f25164..6d262678d 100644 --- a/fastplotlib/widgets/image_widget/_widget.py +++ b/fastplotlib/widgets/image_widget/_widget.py @@ -358,7 +358,7 @@ def __init__( passed to each ImageGraphic in the ImageWidget figure subplots """ - warnings.warn( + warn( "`ImageWidget` is deprecated and will be removed in a" " future release, please migrate to NDWidget", DeprecationWarning From eb918461cb7598ea34098e0775b46bbbbc0f7247 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 7 Mar 2026 19:41:37 -0500 Subject: [PATCH 073/101] fix compute histogram --- fastplotlib/widgets/nd_widget/_nd_image.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py index 038e7d82f..5589cd221 100644 --- a/fastplotlib/widgets/nd_widget/_nd_image.py +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -208,6 +208,11 @@ def _recompute_histogram(self): # TODO: account for window funcs sub = subsample_array(self.data, ignore_dims=ignore_dims) + + if isinstance(sub, xr.DataArray): + # can't do the isnan and isinf boolean indexing below on xarray + sub = sub.values + sub_real = sub[~(np.isnan(sub) | np.isinf(sub))] self._histogram = np.histogram(sub_real, bins=100) From 64d61508ef6c80466080e978fd4e3e0a0686f7dc Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 14 Mar 2026 02:24:16 -0400 Subject: [PATCH 074/101] other features WIP --- .../nd_widget/_nd_positions/_nd_positions.py | 233 +++++++++++++++++- .../nd_widget/_nd_positions/_pandas.py | 10 +- 2 files changed, 232 insertions(+), 11 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py index 20fec1fbc..476f22920 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py @@ -16,6 +16,7 @@ ScatterGraphic, ScatterCollection, ) +from ....graphics.features.utils import parse_colors from ....graphics.utils import pause_events from ....graphics.selectors import LinearSelector from .._base import ( @@ -32,6 +33,8 @@ # we will know the display dims automatically here from the last dim # so maybe we only need it for images? class NDPositionsProcessor(NDProcessor): + _other_features = ["colors", "markers", "cmaps_transforms", "alphas", "sizes"] + def __init__( self, data: Any, @@ -44,6 +47,11 @@ def __init__( display_window: int | float | None = 100, # window for n_datapoints dim only max_display_datapoints: int = 1_000, datapoints_window_func: tuple[Callable, str, int | float] | None = None, + colors: Sequence[str] | np.ndarray = None, + markers: Sequence[str] | np.ndarray = None, + cmaps_transforms: np.ndarray = None, + alpha: np.ndarray = None, + sizes: Sequence[float] = None, **kwargs, ): """ @@ -73,6 +81,100 @@ def __init__( self._datapoints_window_func = datapoints_window_func + self.colors = colors + self.markers = markers + self.cmaps_transforms = cmaps_transforms + self.alphas = alpha + self.sizes = sizes + + def _check_get_datapoints_dim_size(self, check_prop: str, check_shape: int) -> tuple[int, int]: + # this function exists because it's used repeatedly for colors, markers, etc. + # shape for [l, p] dims must match, or l must be 1 + shape = tuple([self.shape[dim] for dim in self.spatial_dims[:2]]) + + if check_shape[0] != 1 and check_shape != shape: + raise IndexError( + f"Number of {check_prop} must match the size of the datapoints dim in the data" + ) + + return shape + + @property + def colors(self) -> np.ndarray | None | Callable: + return self._colors + + @colors.setter + def colors(self, new): + if callable(new): + self._colors = new + return + + if new is None: + self._colors = None + return + + n = self._check_get_datapoints_dim_size("colors", new.shape) + self._colors = parse_colors(new, n_colors=n) + + @property + def markers(self) -> np.ndarray | None: + return self._markers + + @markers.setter + def markers(self, new: Sequence[str] | None): + if new is None: + self._markers = None + return + + self._check_get_datapoints_dim_size("markers", len(new)) + self._markers = np.asarray(new) + + @property + def cmaps_transforms(self) -> Sequence[str] | None: + return self._cmaps_transforms + + @cmaps_transforms.setter + def cmaps_transforms(self, new: Sequence[str] | None): + if new is None: + self._cmaps_transforms = None + return + + self._check_get_datapoints_dim_size("markers", len(new)) + self._cmap_transforms = np.asarray(new) + + @property + def alphas(self) -> np.ndarray | None: + return self._alphas + + @alphas.setter + def alphas(self, new: Sequence[float] | None): + if new is None: + self._alphas = None + return + + self._check_get_datapoints_dim_size("alphas", len(new)) + alphas = np.asarray(new) + + self._alphas = alphas + + @property + def sizes(self) -> np.ndarray | None: + return self._sizes + + @sizes.setter + def sizes(self, new: Sequence[float] | None): + if new is None: + self._sizes = None + return + + self._check_get_datapoints_dim_size("alphas", len(new)) + new = np.array(new) + + if new.ndim != 1: + raise ValueError + + self._sizes = new + @property def spatial_dims(self) -> tuple[str, str, str]: return self._spatial_dims @@ -173,9 +275,14 @@ def _get_dw_slice(self, indices: dict[str, Any]) -> slice: if start >= stop: stop = start + 1 - return slice(start, stop) + w = stop - start + + # get step size + step = max(1, w // self.max_display_datapoints) - def _apply_dw_window_func(self, array: xr.DataArray) -> xr.DataArray: + return slice(start, stop, step) + + def _apply_dw_window_func(self, array: xr.DataArray | np.ndarray) -> xr.DataArray | np.ndarray: """ Takes array where display window has already been applied and applies window functions on the `p` dim. @@ -257,16 +364,35 @@ def _apply_dw_window_func(self, array: xr.DataArray) -> xr.DataArray: return array[:, ::step] - def _apply_spatial_func(self, array: xr.DataArray) -> xr.DataArray: + def _apply_spatial_func(self, array: xr.DataArray | np.ndarray) -> xr.DataArray | np.ndarray: if self.spatial_func is not None: return self.spatial_func(array) return array - def _finalize_(self, array: xr.DataArray) -> xr.DataArray: + def _finalize_(self, array: xr.DataArray | np.ndarray) -> xr.DataArray | np.ndarray: return self._apply_spatial_func(self._apply_dw_window_func(array)) - def get(self, indices: dict[str, Any]): + def _get_other_features(self, data_slice: np.ndarray, dw_slice: slice) -> dict[str, np.ndarray]: + other = dict.fromkeys(self._other_features) + for attr in self._other_features: + val = getattr(self, attr) + + if callable(val): + # if it's a callable, give it the data and display window slice, it must return the appropriate + # type of array for that graphic feature + val = val(data_slice, dw_slice) + + match val: + case None: + other[attr] = None + + case _: + other[attr] = val[dw_slice] + + return other + + def get(self, indices: dict[str, Any]) -> dict[str, np.ndarray]: """ slices through all slider dims and outputs an array that can be used to set graphic data @@ -298,7 +424,13 @@ def get(self, indices: dict[str, Any]): *self.spatial_dims ) - return self._finalize_(graphic_data).values + data = self._finalize_(graphic_data).values + other = self._get_other_features(data, dw_slice) + + return { + "data": data, + **other, + } class NDPositions(NDGraphic): @@ -323,6 +455,15 @@ def __init__( index_mappings: tuple[Callable[[Any], int] | None] | None = None, max_display_datapoints: int = 1_000, linear_selector: bool = False, + colors: Sequence[str] | np.ndarray | Callable[[slice, np.ndarray], np.ndarray] = None, + # TODO: cleanup how this cmap stuff works, require a cmap to be set per-graphic + # before allowing cmaps_transform, validate that stuff makes sense etc. + cmap: str = None, # across the line/scatter collection + cmaps: Sequence[str] = None, # for each individual line/scatter + cmaps_transforms: np.ndarray = None, # for each individual line/scatter + markers: Sequence[str] = None, + sizes: Sequence[float] = None, + alpha: Sequence[float] = None, name: str = None, graphic_kwargs: dict = None, processor_kwargs: dict = None, @@ -332,6 +473,9 @@ def __init__( if processor_kwargs is None: processor_kwargs = dict() + if graphic_kwargs is None: + self._graphic_kwargs = dict() + self._processor = processor( data, dims, @@ -341,6 +485,11 @@ def __init__( max_display_datapoints=max_display_datapoints, window_funcs=window_funcs, index_mappings=index_mappings, + colors=colors, + markers=markers, + cmaps_transforms=cmaps_transforms, + alpha=alpha, + sizes=sizes, **processor_kwargs, ) @@ -394,6 +543,21 @@ def graphic(self, graphic_type): self._create_graphic(graphic_type) plot_area.add_graphic(self._graphic) + @property + def cmap(self) -> str | None: + # across all lines/scatters, or heatmap cmap + pass + + @property + def cmaps(self) -> np.ndarray[str] | None: + # per-line/scatter + pass + + @property + def cmaps_transforms(self) -> np.ndarray | None: + # PER line/scatter, only allowed after `cmaps` is set. + pass + @property def spatial_dims(self) -> tuple[str, str, str]: return self.processor.spatial_dims @@ -411,7 +575,8 @@ def indices(self) -> dict[Hashable, Any]: @indices.setter @block_reentrance def indices(self, indices): - data_slice = self.processor.get(indices) + new_features = self.processor.get(indices) + data_slice = new_features["data"] # TODO: set other graphic features, colors, sizes, markers, etc. @@ -419,7 +584,8 @@ def indices(self, indices): self.graphic.data[:, : data_slice.shape[-1]] = data_slice elif isinstance(self.graphic, (LineCollection, ScatterCollection)): - for g, new_data in zip(self.graphic.graphics, data_slice): + for l, g in enumerate(self.graphic.graphics): + new_data = data_slice[l] if g.data.value.shape[0] != new_data.shape[0]: # will replace buffer internally g.data = new_data @@ -427,6 +593,29 @@ def indices(self, indices): # if data are only xy, set only xy g.data[:, : new_data.shape[1]] = new_data + for feature in ["colors", "sizes", "markers"]: + value = new_features[feature] + + match value: + case None: + pass + case _: + if feature == "colors": + g.color_mode = "vertex" + + setattr(g, feature, value[l]) + + if self.cmaps is not None: + match new_features["cmaps_transforms"]: + case None: + pass + case _: + setattr( + getattr(g, "cmap"), # indv_graphic.cmap + "transform", + new_features["cmaps_transforms"], + ) + elif isinstance(self.graphic, ImageGraphic): image_data, x0, x_scale = self._create_heatmap_data(data_slice) self.graphic.data = image_data @@ -476,7 +665,8 @@ def _create_graphic( if not issubclass(graphic_cls, Graphic): raise TypeError - data_slice = self.processor.get(self.indices) + new_features = self.processor.get(self.indices) + data_slice = new_features["data"] if issubclass(graphic_cls, ImageGraphic): # `d` dim must only have xy data to be interpreted as a heatmap, xyz can't become a timeseries heatmap @@ -495,6 +685,31 @@ def _create_graphic( kwargs = dict() self._graphic = graphic_cls(data_slice, **kwargs) + if isinstance(self._graphic, (LineCollection, ScatterCollection)): + for l, g in enumerate(self.graphic.graphics): + for feature in ["colors", "sizes", "markers"]: + value = new_features[feature] + + match value: + case None: + pass + case _: + if feature == "colors": + g.color_mode = "vertex" + + setattr(g, feature, value[l]) + + if self.cmaps is not None: + match new_features["cmaps_transforms"]: + case None: + pass + case _: + setattr( + getattr(g, "cmap"), # indv_graphic.cmap + "transform", + new_features["cmaps_transforms"], + ) + if self.processor.tooltip: if isinstance(self._graphic, (LineCollection, ScatterCollection)): for g in self._graphic.graphics: diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py index 26acfd73d..740dfe21e 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py @@ -73,7 +73,7 @@ def tooltip_format(self, n: int, p: int): p += self._dw_slice.start return str(self.data[self._tooltip_columns[n]][p]) - def get(self, indices: dict[str, Any]) -> np.ndarray: + def get(self, indices: dict[str, Any]) -> dict[str, np.ndarray]: # TODO: LOD by using a step size according to max_p # TODO: Also what to do if display_window is None and data # hasn't changed when indices keeps getting set, cache? @@ -89,4 +89,10 @@ def get(self, indices: dict[str, Any]) -> np.ndarray: [self.data[c][self._dw_slice] for c in col] ) - return self._apply_dw_window_func(graphic_data) + data = self._finalize_(graphic_data) + other = self._get_other_features(data, self._dw_slice) + + return { + "data": data, + **other, + } From 8b8626d7dc582e88cb387899402528458600312d Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sun, 15 Mar 2026 04:25:23 -0400 Subject: [PATCH 075/101] basics of other features works with ScatterStack for colors, markers, sizes, need to keep testing --- fastplotlib/graphics/__init__.py | 3 +- fastplotlib/graphics/features/_scatter.py | 6 +- fastplotlib/graphics/scatter_collection.py | 56 ++--- .../nd_widget/_nd_positions/_nd_positions.py | 225 ++++++++++++------ fastplotlib/widgets/nd_widget/_ndw_subplot.py | 2 +- 5 files changed, 192 insertions(+), 100 deletions(-) diff --git a/fastplotlib/graphics/__init__.py b/fastplotlib/graphics/__init__.py index 8734a5e72..cca2afc21 100644 --- a/fastplotlib/graphics/__init__.py +++ b/fastplotlib/graphics/__init__.py @@ -7,7 +7,7 @@ from .mesh import MeshGraphic, SurfaceGraphic, PolygonGraphic from .text import TextGraphic from .line_collection import LineCollection, LineStack -from .scatter_collection import ScatterCollection +from .scatter_collection import ScatterCollection, ScatterStack __all__ = [ "Graphic", @@ -23,4 +23,5 @@ "LineCollection", "LineStack", "ScatterCollection", + "ScatterStack", ] diff --git a/fastplotlib/graphics/features/_scatter.py b/fastplotlib/graphics/features/_scatter.py index 36c8527be..685bbe6ec 100644 --- a/fastplotlib/graphics/features/_scatter.py +++ b/fastplotlib/graphics/features/_scatter.py @@ -100,7 +100,7 @@ def searchsorted_markers_to_int_array(markers_str_array: np.ndarray[str]): return marker_int_searchsorted_vals[indices] -def parse_markers_init(markers: str | Sequence[str] | np.ndarray, n_datapoints: int): +def parse_markers(markers: str | Sequence[str] | np.ndarray, n_datapoints: int): # first validate then allocate buffers if isinstance(markers, str): @@ -155,7 +155,7 @@ def __init__( Manages the markers buffer for the scatter points. Supports fancy indexing. """ - markers_int_array, self._markers_readable_array = parse_markers_init( + markers_int_array, self._markers_readable_array = parse_markers( markers, n_datapoints ) @@ -205,7 +205,7 @@ def set_value(self, graphic, value): if isinstance(value, (np.ndarray, list, tuple)): if self.buffer.data.shape[0] != len(value): # need to create a new buffer - markers_int_array, self._markers_readable_array = parse_markers_init( + markers_int_array, self._markers_readable_array = parse_markers( value, len(value) ) diff --git a/fastplotlib/graphics/scatter_collection.py b/fastplotlib/graphics/scatter_collection.py index b8e7556ad..8762a9fb3 100644 --- a/fastplotlib/graphics/scatter_collection.py +++ b/fastplotlib/graphics/scatter_collection.py @@ -102,7 +102,6 @@ def cmap(self, args): class ScatterCollectionIndexer(CollectionIndexer, _ScatterCollectionProperties): """Indexer for scatter collections""" - pass @@ -117,11 +116,14 @@ def __init__( cmap: Sequence[str] | str = None, cmap_transform: np.ndarray | List = None, sizes: float | Sequence[float] = 5.0, + uniform_size: bool = True, + markers: np.ndarray | Sequence[str] = None, + uniform_marker: bool = True, + edge_width: float = 1.0, name: str = None, names: list[str] = None, metadata: Any = None, metadatas: Sequence[Any] | np.ndarray = None, - kwargs_lines: list[dict] = None, **kwargs, ): """ @@ -186,13 +188,6 @@ def __init__( f"len(metadata) != len(data)\n{len(metadatas)} != {len(data)}" ) - if kwargs_lines is not None: - if len(kwargs_lines) != len(data): - raise ValueError( - f"len(kwargs_lines) != len(data)\n" - f"{len(kwargs_lines)} != {len(data)}" - ) - self._cmap_transform = cmap_transform self._cmap_str = cmap @@ -259,9 +254,6 @@ def __init__( "or must be a tuple/list of colors represented by a string with the same length as the data" ) - if kwargs_lines is None: - kwargs_lines = dict() - self._set_world_object(pygfx.Group()) for i, d in enumerate(data): @@ -286,14 +278,34 @@ def __init__( else: _name = None + if markers is not None: + if isinstance(markers, (tuple, list, np.ndarray)): + markers_ = markers[i] + else: + markers_ = markers + else: + markers_ = "o" + + if sizes is not None: + if isinstance(sizes, (tuple, list, np.ndarray)): + sizes_ = sizes[i] + else: + sizes_ = sizes + else: + sizes_ = 5 + lg = ScatterGraphic( data=d, colors=_c, - sizes=sizes, + sizes=sizes_, + markers=markers_, cmap=_cmap, name=_name, metadata=_m, - **kwargs_lines, + uniform_marker=uniform_marker, + uniform_size=uniform_size, + edge_width=edge_width, + **kwargs, ) self.add_graphic(lg) @@ -519,19 +531,16 @@ def _get_linear_selector_init_args(self, axis, padding): class ScatterStack(ScatterCollection): def __init__( self, - data: List[np.ndarray], - thickness: float | Iterable[float] = 2.0, - colors: str | Iterable[str] | np.ndarray | Iterable[np.ndarray] = "w", - cmap: Iterable[str] | str = None, + data: np.ndarray | List[np.ndarray], + colors: str | Sequence[str] | np.ndarray | Sequence[np.ndarray] = "w", + cmap: Sequence[str] | str = None, cmap_transform: np.ndarray | List = None, name: str = None, names: list[str] = None, metadata: Any = None, metadatas: Sequence[Any] | np.ndarray = None, - isolated_buffer: bool = True, separation: float = 0.0, separation_axis: str = "y", - kwargs_lines: list[dict] = None, **kwargs, ): """ @@ -584,17 +593,12 @@ def __init__( separation_axis: str, default "y" axis in which the line graphics in the stack should be separated - - kwargs_lines: list[dict], optional - list of kwargs passed to the individual lines, ``len(kwargs_lines)`` must equal ``len(data)`` - kwargs_collection kwargs for the collection, passed to GraphicCollection """ super().__init__( data=data, - thickness=thickness, colors=colors, cmap=cmap, cmap_transform=cmap_transform, @@ -602,8 +606,6 @@ def __init__( names=names, metadata=metadata, metadatas=metadatas, - isolated_buffer=isolated_buffer, - kwargs_lines=kwargs_lines, **kwargs, ) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py index 476f22920..a0636c31f 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py @@ -15,6 +15,7 @@ LineCollection, ScatterGraphic, ScatterCollection, + ScatterStack, ) from ....graphics.features.utils import parse_colors from ....graphics.utils import pause_events @@ -28,12 +29,15 @@ ) from .._index import GlobalIndex +# types for the other features +FeatureCallable = Callable[[np.ndarray, slice], np.ndarray] +ColorsType = np.ndarray | FeatureCallable | None +MarkersType = Sequence[str] | np.ndarray | FeatureCallable | None +SizesType = Sequence[float] | np.ndarray | FeatureCallable | None + -# TODO: Maybe get rid of n_display_dims in NDProcessor, -# we will know the display dims automatically here from the last dim -# so maybe we only need it for images? class NDPositionsProcessor(NDProcessor): - _other_features = ["colors", "markers", "cmaps_transforms", "alphas", "sizes"] + _other_features = ["colors", "markers", "cmap_transform_each", "sizes"] def __init__( self, @@ -47,11 +51,10 @@ def __init__( display_window: int | float | None = 100, # window for n_datapoints dim only max_display_datapoints: int = 1_000, datapoints_window_func: tuple[Callable, str, int | float] | None = None, - colors: Sequence[str] | np.ndarray = None, - markers: Sequence[str] | np.ndarray = None, - cmaps_transforms: np.ndarray = None, - alpha: np.ndarray = None, - sizes: Sequence[float] = None, + colors: ColorsType = None, + markers: MarkersType = None, + cmap_transform_each: np.ndarray = None, + sizes: SizesType = None, **kwargs, ): """ @@ -63,7 +66,8 @@ def __init__( spatial_dims index_mappings display_window - max_display_datapoints + max_display_datapoints: int, default 1_000 + this is approximate since floor division is used to determine the step size of the current display window slice datapoints_window_func: Important note: if used, display_window is approximate and not exact due to padding from the window size kwargs @@ -83,29 +87,48 @@ def __init__( self.colors = colors self.markers = markers - self.cmaps_transforms = cmaps_transforms - self.alphas = alpha + self.cmap_transform_each = cmap_transform_each self.sizes = sizes - def _check_get_datapoints_dim_size(self, check_prop: str, check_shape: int) -> tuple[int, int]: + def _check_shape_feature( + self, prop: str, check_shape: tuple[int, int] + ) -> tuple[int, int]: # this function exists because it's used repeatedly for colors, markers, etc. # shape for [l, p] dims must match, or l must be 1 shape = tuple([self.shape[dim] for dim in self.spatial_dims[:2]]) - if check_shape[0] != 1 and check_shape != shape: + if check_shape[1] != shape[1]: + raise IndexError( + f"shape of first two dims of {prop} must must be [l, p] or [1, p].\n" + f"required `p` dim shape is: {shape[1]}, {check_shape[1]} was provided" + ) + + if check_shape[0] != 1 and check_shape[0] != shape[0]: raise IndexError( - f"Number of {check_prop} must match the size of the datapoints dim in the data" + f"shape of first two dims of {prop} must must be [l, p] or [1, p]\n" + f"required `l` dim shape is {shape[0]} | 1, {check_shape[0]} was provided" ) return shape @property - def colors(self) -> np.ndarray | None | Callable: + def colors(self) -> ColorsType: + """ + A callable that dynamically creates colors for the current display window, or array of colors per-datapoint. + + Array must be of shape [l, p, 4] for unique colors per line/scatter, or [1, p, 4] for identical colors per + line/scatter. + + Callable must return an array of shape [l, pw, 4] or [1, pw, 4], where pw is the number of currently displayed + datapoints given the current display window. The callable receives the current data slice array, as well as the + slice object that corresponds to the current display window. + """ return self._colors @colors.setter def colors(self, new): if callable(new): + # custom callable that creates the colors self._colors = new return @@ -113,65 +136,106 @@ def colors(self, new): self._colors = None return - n = self._check_get_datapoints_dim_size("colors", new.shape) - self._colors = parse_colors(new, n_colors=n) + # as array so we can check shape + new = np.asarray(new) + if new.ndim == 2: + # only [p, 4] provided, broadcast to [1, p, 4] + new = new[None] + + shape = self._check_shape_feature("colors", new.shape[:2]) + + if new.shape[0] == 1: + # same colors across all graphical elements + self._colors = parse_colors(new[0], n_colors=shape[1])[None] + + else: + # colors specified for each individual line/scatter + new_ = np.zeros(shape=(*self.data.shape[:2], 4), dtype=np.float32) + for i in range(shape[0]): + new_[i] = parse_colors(new[i], n_colors=shape[1]) + + self._colors = new_ @property - def markers(self) -> np.ndarray | None: + def markers(self) -> MarkersType: + """ + A callable that dynamically creates markers for the current display window, or array of markers per-datapoint. + + Array must be of shape [l, p] for unique markers per line/scatter, or [p,] or [1, p] for identical markers per + line/scatter. + + Callable must return an array of shape [l, pw], [1, pw], or [pw,] where pw is the number of currently displayed + datapoints given the current display window. The callable receives the current data slice array, as well as the + slice object that corresponds to the current display window. + """ return self._markers @markers.setter - def markers(self, new: Sequence[str] | None): + def markers(self, new: MarkersType): + if callable(new): + # custom callable that creates the markers dynamically + self._markers = new + return + if new is None: self._markers = None return - self._check_get_datapoints_dim_size("markers", len(new)) - self._markers = np.asarray(new) + # as array so we can check shape + new = np.asarray(new) - @property - def cmaps_transforms(self) -> Sequence[str] | None: - return self._cmaps_transforms + # if 1-dim, assume it's specifying markers over `p` dim, so set `l` dim to 1 + if new.ndim == 1: + new = new[None] - @cmaps_transforms.setter - def cmaps_transforms(self, new: Sequence[str] | None): - if new is None: - self._cmaps_transforms = None - return + self._check_shape_feature("markers", new.shape[:2]) - self._check_get_datapoints_dim_size("markers", len(new)) - self._cmap_transforms = np.asarray(new) + self._markers = np.asarray(new) @property - def alphas(self) -> np.ndarray | None: - return self._alphas + def cmap_transform_each(self) -> Sequence[str] | None: + return self._cmap_transform_each - @alphas.setter - def alphas(self, new: Sequence[float] | None): + @cmap_transform_each.setter + def cmap_transform_each(self, new: Sequence[str] | None): if new is None: - self._alphas = None + self._cmap_transform_each = None return - self._check_get_datapoints_dim_size("alphas", len(new)) - alphas = np.asarray(new) - - self._alphas = alphas + self._check_shape_feature("markers", len(new)) + self._cmap_transforms = np.asarray(new) @property - def sizes(self) -> np.ndarray | None: + def sizes(self) -> SizesType: return self._sizes @sizes.setter - def sizes(self, new: Sequence[float] | None): + def sizes(self, new: SizesType): + """ + A callable that dynamically creates sizes for the current display window, or array of sizes per-datapoint. + + Array must be of shape [l, p] for unique sizes per line/scatter, or [p,] or [1, p] for identical markers per + line/scatter. + + Callable must return an array of shape [l, pw], [1, pw], or [pw,] where pw is the number of currently displayed + datapoints given the current display window. The callable receives the current data slice array, as well as the + slice object that corresponds to the current display window. + """ + if callable(new): + # custom callable + self._sizes = new + return + if new is None: self._sizes = None return - self._check_get_datapoints_dim_size("alphas", len(new)) new = np.array(new) + # if 1-dim, assume it's specifying sizes over `p` dim, set `l` dim to 1 + if new.ndim == 1: + new = new[None] - if new.ndim != 1: - raise ValueError + self._check_shape_feature("sizes", new.shape) self._sizes = new @@ -247,7 +311,7 @@ def _get_dw_slice(self, indices: dict[str, Any]) -> slice: if self.display_window is None: # just return everything - return slice(0, self.shape[p_dim] - 1) + return slice(0, self.shape[p_dim]) if self.display_window == 0: # just map p dimension at this index and return @@ -282,7 +346,9 @@ def _get_dw_slice(self, indices: dict[str, Any]) -> slice: return slice(start, stop, step) - def _apply_dw_window_func(self, array: xr.DataArray | np.ndarray) -> xr.DataArray | np.ndarray: + def _apply_dw_window_func( + self, array: xr.DataArray | np.ndarray + ) -> xr.DataArray | np.ndarray: """ Takes array where display window has already been applied and applies window functions on the `p` dim. @@ -364,7 +430,9 @@ def _apply_dw_window_func(self, array: xr.DataArray | np.ndarray) -> xr.DataArra return array[:, ::step] - def _apply_spatial_func(self, array: xr.DataArray | np.ndarray) -> xr.DataArray | np.ndarray: + def _apply_spatial_func( + self, array: xr.DataArray | np.ndarray + ) -> xr.DataArray | np.ndarray: if self.spatial_func is not None: return self.spatial_func(array) @@ -373,22 +441,38 @@ def _apply_spatial_func(self, array: xr.DataArray | np.ndarray) -> xr.DataArray def _finalize_(self, array: xr.DataArray | np.ndarray) -> xr.DataArray | np.ndarray: return self._apply_spatial_func(self._apply_dw_window_func(array)) - def _get_other_features(self, data_slice: np.ndarray, dw_slice: slice) -> dict[str, np.ndarray]: + def _get_other_features( + self, data_slice: np.ndarray, dw_slice: slice + ) -> dict[str, np.ndarray]: other = dict.fromkeys(self._other_features) for attr in self._other_features: val = getattr(self, attr) + if val is None: + other[attr] = None + continue + if callable(val): # if it's a callable, give it the data and display window slice, it must return the appropriate # type of array for that graphic feature - val = val(data_slice, dw_slice) + val_sliced = val(data_slice, dw_slice) + + else: + # if no l dim, broadcast to [1, p] + if val.ndim == 1: + val = val[None] - match val: - case None: - other[attr] = None + # apply current display window slice + val_sliced = val[:, dw_slice] - case _: - other[attr] = val[dw_slice] + # check if l dim size is 1 + if val_sliced.shape[0] == 1: + # broadcast across all graphical elements + n_graphics = self.shape[self.spatial_dims[0]] + print(val_sliced.shape, n_graphics) + val_sliced = np.broadcast_to(val_sliced, shape=(n_graphics, *val_sliced.shape[1:])) + + other[attr] = val_sliced return other @@ -447,6 +531,7 @@ def __init__( | LineStack | ScatterGraphic | ScatterCollection + | ScatterStack | ImageGraphic ], processor: type[NDPositionsProcessor] = NDPositionsProcessor, @@ -455,15 +540,16 @@ def __init__( index_mappings: tuple[Callable[[Any], int] | None] | None = None, max_display_datapoints: int = 1_000, linear_selector: bool = False, - colors: Sequence[str] | np.ndarray | Callable[[slice, np.ndarray], np.ndarray] = None, + colors: ( + Sequence[str] | np.ndarray | Callable[[slice, np.ndarray], np.ndarray] + ) = None, # TODO: cleanup how this cmap stuff works, require a cmap to be set per-graphic # before allowing cmaps_transform, validate that stuff makes sense etc. cmap: str = None, # across the line/scatter collection - cmaps: Sequence[str] = None, # for each individual line/scatter - cmaps_transforms: np.ndarray = None, # for each individual line/scatter + cmap_each: Sequence[str] = None, # for each individual line/scatter + cmap_transform_each: np.ndarray = None, # for each individual line/scatter markers: Sequence[str] = None, sizes: Sequence[float] = None, - alpha: Sequence[float] = None, name: str = None, graphic_kwargs: dict = None, processor_kwargs: dict = None, @@ -475,6 +561,8 @@ def __init__( if graphic_kwargs is None: self._graphic_kwargs = dict() + else: + self._graphic_kwargs = graphic_kwargs self._processor = processor( data, @@ -487,8 +575,7 @@ def __init__( index_mappings=index_mappings, colors=colors, markers=markers, - cmaps_transforms=cmaps_transforms, - alpha=alpha, + cmap_transform_each=cmap_transform_each, sizes=sizes, **processor_kwargs, ) @@ -527,6 +614,7 @@ def graphic( | LineStack | ScatterGraphic | ScatterCollection + | ScatterStack | ImageGraphic ): """LineStack or ImageGraphic for heatmaps""" @@ -606,14 +694,14 @@ def indices(self, indices): setattr(g, feature, value[l]) if self.cmaps is not None: - match new_features["cmaps_transforms"]: + match new_features["cmap_transform_each"]: case None: pass case _: setattr( - getattr(g, "cmap"), # indv_graphic.cmap + getattr(g, "cmap"), # ind_graphic.cmap "transform", - new_features["cmaps_transforms"], + new_features["cmap_transform_each"], ) elif isinstance(self.graphic, ImageGraphic): @@ -659,6 +747,7 @@ def _create_graphic( | LineStack | ScatterGraphic | ScatterCollection + | ScatterStack | ImageGraphic ], ): @@ -679,10 +768,10 @@ def _create_graphic( ) else: - if issubclass(graphic_cls, LineStack): - kwargs = {"separation": 0.0} + if issubclass(graphic_cls, (LineStack, ScatterStack)): + kwargs = {"separation": 0.0, **self._graphic_kwargs} else: - kwargs = dict() + kwargs = self._graphic_kwargs self._graphic = graphic_cls(data_slice, **kwargs) if isinstance(self._graphic, (LineCollection, ScatterCollection)): diff --git a/fastplotlib/widgets/nd_widget/_ndw_subplot.py b/fastplotlib/widgets/nd_widget/_ndw_subplot.py index 5a0b00da2..0783379ec 100644 --- a/fastplotlib/widgets/nd_widget/_ndw_subplot.py +++ b/fastplotlib/widgets/nd_widget/_ndw_subplot.py @@ -1,6 +1,6 @@ import numpy as np -from ... import ScatterCollection, LineCollection, LineStack, ImageGraphic +from ... import ScatterCollection, ScatterStack, LineCollection, LineStack, ImageGraphic from ...layouts import Subplot from . import NDImage, NDPositions from ._base import NDGraphic From 6727f45f8a0e13fa8420db4fb9367d550e981991 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sun, 15 Mar 2026 04:40:24 -0400 Subject: [PATCH 076/101] require min pygfx v0.16.0 due to gc hash fix necessary for NDWidget --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 30d194e79..0a9371891 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ keywords = [ requires-python = ">= 3.10" dependencies = [ "numpy>=1.23.0", - "pygfx==0.15.3", + "pygfx==0.16.0", "wgpu", # Let pygfx constrain the wgpu version "cmap>=0.1.3", # (this comment keeps this list multiline in VSCode) From f8b1ea41ef31f970ee3f9e092ad11ef9941f8b78 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sun, 15 Mar 2026 05:16:38 -0400 Subject: [PATCH 077/101] fix PlotArea.y_range --- fastplotlib/layouts/_plot_area.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fastplotlib/layouts/_plot_area.py b/fastplotlib/layouts/_plot_area.py index 513a7ad47..974b6f653 100644 --- a/fastplotlib/layouts/_plot_area.py +++ b/fastplotlib/layouts/_plot_area.py @@ -885,15 +885,15 @@ def y_range(self) -> tuple[float, float]: Only valid for orthographic projections of the xy plane. Use camera.set_state() to set the camera position for arbitrary projections. """ - hh = self.camera.width / 2 + hh = self.camera.height / 2 y = self.camera.local.y return y - hh, y + hh @y_range.setter def y_range(self, yr: tuple[float, float]): - width = yr[1] - yr[0] - y_mid = yr[0] + (width / 2) - self.camera.width = width + height = yr[1] - yr[0] + y_mid = yr[0] + (height / 2) + self.camera.height = height self.camera.local.y = y_mid def remove_graphic(self, graphic: Graphic): From 9374820645481f57d182c65987cf9a1090904099 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sun, 15 Mar 2026 05:17:09 -0400 Subject: [PATCH 078/101] fix to create isolated buffer for colors when buffer replaced --- fastplotlib/graphics/features/_positions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fastplotlib/graphics/features/_positions.py b/fastplotlib/graphics/features/_positions.py index 7b67e6bd7..71767e3ec 100644 --- a/fastplotlib/graphics/features/_positions.py +++ b/fastplotlib/graphics/features/_positions.py @@ -89,7 +89,9 @@ def set_value( new_colors = parse_colors(value, len(value)) # create the new buffer, old buffer should get dereferenced - self._fpl_buffer = pygfx.Buffer(new_colors) + # make sure new buffer is isolated (i.e. allocate a buffer, then set the values) + self._fpl_buffer = pygfx.Buffer(np.zeros(new_colors.shape, dtype=np.float32)) + self._fpl_buffer.data[:] = new_colors graphic.world_object.geometry.colors = self._fpl_buffer if len(self._event_handlers) < 1: From ef7c29cbee5eac1baa0c3abc9b4ee29da82535de Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sun, 15 Mar 2026 05:44:31 -0400 Subject: [PATCH 079/101] np.empty --- fastplotlib/graphics/features/_positions.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fastplotlib/graphics/features/_positions.py b/fastplotlib/graphics/features/_positions.py index 71767e3ec..507fc1ee0 100644 --- a/fastplotlib/graphics/features/_positions.py +++ b/fastplotlib/graphics/features/_positions.py @@ -90,8 +90,9 @@ def set_value( # create the new buffer, old buffer should get dereferenced # make sure new buffer is isolated (i.e. allocate a buffer, then set the values) - self._fpl_buffer = pygfx.Buffer(np.zeros(new_colors.shape, dtype=np.float32)) - self._fpl_buffer.data[:] = new_colors + buff = np.empty(new_colors.shape, dtype=np.float32) + buff[:] = new_colors + self._fpl_buffer = pygfx.Buffer(buff) graphic.world_object.geometry.colors = self._fpl_buffer if len(self._event_handlers) < 1: From c591d5cbe9b100fb54bc8a45ebdd4cdc33f7c1de Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sun, 15 Mar 2026 07:24:54 -0400 Subject: [PATCH 080/101] cmap_transform_each WIP --- fastplotlib/widgets/nd_widget/_base.py | 2 +- fastplotlib/widgets/nd_widget/_nd_image.py | 3 +- .../nd_widget/_nd_positions/_nd_positions.py | 112 +++++++++++++++--- 3 files changed, 98 insertions(+), 19 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_base.py b/fastplotlib/widgets/nd_widget/_base.py index 4541640a6..1cfa2c42c 100644 --- a/fastplotlib/widgets/nd_widget/_base.py +++ b/fastplotlib/widgets/nd_widget/_base.py @@ -367,7 +367,7 @@ def data(self, data: Any): @property def shape(self) -> dict[Hashable, int]: """interpreted shape of the data""" - self.processor.shape + return self.processor.shape @property def ndim(self) -> int: diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py index 5589cd221..f78bf7ce9 100644 --- a/fastplotlib/widgets/nd_widget/_nd_image.py +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -236,6 +236,8 @@ def __init__( name: str = None, ): + super().__init__(name) + self._global_index = global_index self._processor = NDImageProcessor( @@ -254,7 +256,6 @@ def __init__( self._histogram_widget: HistogramLUTTool | None = None self._create_graphic() - super().__init__(name) @property def processor(self) -> NDImageProcessor: diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py index a0636c31f..6d1054746 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py @@ -1,6 +1,7 @@ from collections.abc import Callable, Hashable, Sequence from functools import partial from typing import Literal, Any, Type +from warnings import warn import numpy as np from numpy.lib.stride_tricks import sliding_window_view @@ -36,6 +37,18 @@ SizesType = Sequence[float] | np.ndarray | FeatureCallable | None +def default_cmap_transform_each(p: int, data_slice: np.ndarray, s: slice): + # create a cmap transform based on the `p` dim size + n_displayed = data_slice.shape[1] + + # linspace that's just normalized 0 - 1 within `p` dim size + return np.linspace( + start=s.start / p, + stop=s.stop / p, + num=n_displayed, + endpoint=False # since we use a slice object for the displayed data, the last point isn't included + ) + class NDPositionsProcessor(NDProcessor): _other_features = ["colors", "markers", "cmap_transform_each", "sizes"] @@ -193,17 +206,40 @@ def markers(self, new: MarkersType): self._markers = np.asarray(new) @property - def cmap_transform_each(self) -> Sequence[str] | None: + def cmap_transform_each(self) -> np.ndarray | FeatureCallable | None: return self._cmap_transform_each @cmap_transform_each.setter - def cmap_transform_each(self, new: Sequence[str] | None): + def cmap_transform_each(self, new: np.ndarray | FeatureCallable | None): + """ + A callable that dynamically creates cmap transforms for the current display window, or array + of transforms per-datapoint. + + Array must be of shape [l, p] for unique transforms per line/scatter, or [p,] or [1, p] for identical markers + per line/scatter. + + Callable must return an array of shape [l, pw], [1, pw], or [pw,] where pw is the number of currently displayed + datapoints given the current display window. The callable receives the current data slice array, as well as the + slice object that corresponds to the current display window. + """ + if callable(new): + self._cmap_transform_each = new + return + if new is None: + # default transform is just a transform based on the `p` dim size self._cmap_transform_each = None return - self._check_shape_feature("markers", len(new)) - self._cmap_transforms = np.asarray(new) + new = np.asarray(new) + + # if 1-dim, assume it's specifying sizes over `p` dim, set `l` dim to 1 + if new.ndim == 1: + new = new[None] + + self._check_shape_feature("cmap_transform_each", new.shape) + + self._cmap_transform_each = new @property def sizes(self) -> SizesType: @@ -470,7 +506,9 @@ def _get_other_features( # broadcast across all graphical elements n_graphics = self.shape[self.spatial_dims[0]] print(val_sliced.shape, n_graphics) - val_sliced = np.broadcast_to(val_sliced, shape=(n_graphics, *val_sliced.shape[1:])) + val_sliced = np.broadcast_to( + val_sliced, shape=(n_graphics, *val_sliced.shape[1:]) + ) other[attr] = val_sliced @@ -554,6 +592,8 @@ def __init__( graphic_kwargs: dict = None, processor_kwargs: dict = None, ): + super().__init__(name) + self._global_index = global_index if processor_kwargs is None: @@ -580,7 +620,9 @@ def __init__( **processor_kwargs, ) - self._processor.p_max = 1_000 + self.cmap = cmap + self.cmap_each = cmap_each + self.cmap_transform_each = cmap_transform_each self._create_graphic(graphic) @@ -599,7 +641,6 @@ def __init__( self._pause = False - super().__init__(name) @property def processor(self) -> NDPositionsProcessor: @@ -633,18 +674,53 @@ def graphic(self, graphic_type): @property def cmap(self) -> str | None: - # across all lines/scatters, or heatmap cmap - pass + return self._cmap + + @cmap.setter + def cmap(self, new: str | None): + self._cmap = new @property - def cmaps(self) -> np.ndarray[str] | None: + def cmap_each(self) -> np.ndarray[str] | None: # per-line/scatter - pass + return self._cmap_each + + @cmap_each.setter + def cmap_each(self, new: Sequence[str] | None): + if isinstance(new, str): + new = [new] + if new is None: + self._cmap_each = None + + new = np.asarray(new) + + if new.ndim != 1: + raise ValueError + + l_dim_size = self.processor.shape[self.processor.spatial_dims[0]] + # same cmap for all if size == 1, or specific cmap for each in `l` dim + if new.size != 1 and new.size != l_dim_size: + raise ValueError + + self._cmap_each = np.broadcast_to(new, shape=(l_dim_size,)) @property - def cmaps_transforms(self) -> np.ndarray | None: + def cmap_transform_each(self) -> np.ndarray | None: # PER line/scatter, only allowed after `cmaps` is set. - pass + return self.processor.cmap_transform_each + + @cmap_transform_each.setter + def cmap_transform_each(self, new: np.ndarray | FeatureCallable | None): + if self.cmap_each is None: + self.processor.cmap_transform_each = None + warn("must set `cmap_each` before `cmap_transform_each`") + if new is None and self.cmap_each is not None: + # default transform is just a transform based on the `p` dim size + new = partial( + default_cmap_transform_each, self.shape[self.spatial_dims[1]] + ) + + self.processor.cmap_transform_each = new @property def spatial_dims(self) -> tuple[str, str, str]: @@ -693,7 +769,7 @@ def indices(self, indices): setattr(g, feature, value[l]) - if self.cmaps is not None: + if self.cmap_each is not None: match new_features["cmap_transform_each"]: case None: pass @@ -788,15 +864,17 @@ def _create_graphic( setattr(g, feature, value[l]) - if self.cmaps is not None: - match new_features["cmaps_transforms"]: + if self.cmap_each is not None: + g.color_mode = "vertex" + g.cmap = self.cmap_each[l] + match new_features["cmap_transform_each"]: case None: pass case _: setattr( getattr(g, "cmap"), # indv_graphic.cmap "transform", - new_features["cmaps_transforms"], + new_features["cmap_transform_each"], ) if self.processor.tooltip: From 7f5e5e523bdcfb73e3549a5e6044251e8a439ce3 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sun, 15 Mar 2026 08:12:14 -0400 Subject: [PATCH 081/101] progress --- .../widgets/nd_widget/_nd_positions/_nd_positions.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py index 6d1054746..f1e170745 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py @@ -687,10 +687,12 @@ def cmap_each(self) -> np.ndarray[str] | None: @cmap_each.setter def cmap_each(self, new: Sequence[str] | None): - if isinstance(new, str): - new = [new] if new is None: self._cmap_each = None + return + + if isinstance(new, str): + new = [new] new = np.asarray(new) @@ -714,6 +716,8 @@ def cmap_transform_each(self, new: np.ndarray | FeatureCallable | None): if self.cmap_each is None: self.processor.cmap_transform_each = None warn("must set `cmap_each` before `cmap_transform_each`") + return + if new is None and self.cmap_each is not None: # default transform is just a transform based on the `p` dim size new = partial( From 62ed9390c6b0932edbd78d75f41d3f6767584cc9 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sun, 15 Mar 2026 19:11:53 -0400 Subject: [PATCH 082/101] fix --- fastplotlib/graphics/features/_scatter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fastplotlib/graphics/features/_scatter.py b/fastplotlib/graphics/features/_scatter.py index 685bbe6ec..e41115ae3 100644 --- a/fastplotlib/graphics/features/_scatter.py +++ b/fastplotlib/graphics/features/_scatter.py @@ -569,6 +569,7 @@ def set_value(self, graphic, value): # create new buffer value = self._fix_sizes(value, len(value)) data = np.empty(shape=(len(value),), dtype=np.float32) + data[:] = value # create the new buffer, old buffer should get dereferenced self._fpl_buffer = pygfx.Buffer(data) From 148c6c38f06509138df59240463a8f2b90d3a4d3 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 17 Mar 2026 05:05:07 -0400 Subject: [PATCH 083/101] replace graphic when data changed, tweak index_mappings --- fastplotlib/widgets/nd_widget/_base.py | 18 +++++++++++++++--- .../nd_widget/_nd_positions/_nd_positions.py | 3 +++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_base.py b/fastplotlib/widgets/nd_widget/_base.py index 1cfa2c42c..707480e58 100644 --- a/fastplotlib/widgets/nd_widget/_base.py +++ b/fastplotlib/widgets/nd_widget/_base.py @@ -35,7 +35,7 @@ def __init__( window_order: tuple[Hashable, ...] = None, spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, ): - self._data = self._validate_data(data, dims) + self._data = self._validate_data(data, tuple(dims)) self.spatial_dims = spatial_dims self.index_mappings = index_mappings @@ -214,8 +214,10 @@ def index_mappings( return for d in maps.keys(): - if d not in self.slider_dims: - raise KeyError + if d not in self.dims: + raise KeyError( + f"`index_mapping` provided for non-existent dimension: {d}, existing dims are: {self.dims}" + ) if isinstance(maps[d], ArrayProtocol): # create a searchsorted mapping function automatically @@ -333,6 +335,9 @@ def __init__(self, name: str | None): self._name = name self._block_indices = False + def _create_graphic(self, graphic_cls: type): + raise NotImplementedError + @property def name(self) -> str | None: return self._name @@ -361,6 +366,13 @@ def data(self) -> Any: @data.setter def data(self, data: Any): self.processor.data = data + # create a new graphic when data has changed + plot_area = self._graphic._plot_area + plot_area.delete_graphic(self._graphic) + + self._create_graphic(self.graphic.__class__) + plot_area.add_graphic(self._graphic) + # force a re-render self.indices = self.indices diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py index f1e170745..e5800d538 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py @@ -957,6 +957,9 @@ def x_range_mode(self, mode: Literal[None, "fixed-window", "view-range"]): self._x_range_mode = mode def _update_from_view_range(self): + if self._graphic is None: + return + xr = self.graphic._plot_area.x_range # the floating point error near zero gets nasty here From db492190dc18f716e8444097be816e524c398fb2 Mon Sep 17 00:00:00 2001 From: Kushal Kolar Date: Tue, 17 Mar 2026 07:23:13 -0400 Subject: [PATCH 084/101] Update installation docs (#1013) * add simplejpeg to notebook deps * Update guide.rst * Update guide.rst * Update README.md --- README.md | 33 ++++++++++++++---------------- docs/source/user_guide/guide.rst | 35 +++++++++++++++++++------------- pyproject.toml | 1 + 3 files changed, 37 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index da5ed64f8..c8e64e65e 100644 --- a/README.md +++ b/README.md @@ -63,31 +63,28 @@ Questions, issues, ideas? You are welcome to post an [issue](https://github.com/ To install use pip: -```bash -# with imgui and jupyterlab -pip install -U "fastplotlib[notebook,imgui]" +### With imgui support (recommended) -# minimal install, install glfw, pyqt6 or pyside6 separately -pip install -U fastplotlib +Without jupyterlab support, install desired GUI framework such as glfw, PyQt6, or PySide6 separately. -# with imgui -pip install -U "fastplotlib[imgui]" + pip install -U "fastplotlib[imgui]" -# to use in jupyterlab without imgui -pip install -U "fastplotlib[notebook]" -``` +With jupyterlab support. -We strongly recommend installing ``simplejpeg`` for use in notebooks, you must first install [libjpeg-turbo](https://libjpeg-turbo.org/) + pip install -U "fastplotlib[notebook,imgui]" -- If you use ``conda``, you can get ``libjpeg-turbo`` through conda. -- If you are on linux, you can get it through your distro's package manager. -- For Windows and Mac compiled binaries are available on their release page: https://github.com/libjpeg-turbo/libjpeg-turbo/releases +### Without imgui -Once you have ``libjpeg-turbo``: +Minimal, install desired GUI library such as PyQt6, PySide6, or glfw separately. + + pip install fastplotlib + +With jupyterlab support only. + + pip install -U "fastplotlib[notebook]" + +Fastplotlib is also available on conda-forge. For imgui support you will need to separately install `imgui-bundle`, and for jupyterlab you will need to install `jupyter-rfb` and `simplejpeg` which are all available on conda-forge. -```bash -pip install simplejpeg -``` > **Note:** > `fastplotlib` and `pygfx` are fast evolving projects, the version available through pip might be outdated, you will need to follow the "For developers" instructions below if you want the latest features. You can find the release history here: https://github.com/fastplotlib/fastplotlib/releases diff --git a/docs/source/user_guide/guide.rst b/docs/source/user_guide/guide.rst index bd0352aa7..c3487de2e 100644 --- a/docs/source/user_guide/guide.rst +++ b/docs/source/user_guide/guide.rst @@ -6,31 +6,38 @@ Installation To install use pip: -.. code-block:: +With imgui support (recommended) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - # with imgui and jupyterlab - pip install -U "fastplotlib[notebook,imgui]" +Without jupyterlab support, install desired GUI framework such as glfw, PyQt6, or PySide6 separately. - # minimal install, install glfw, pyqt6 or pyside6 separately - pip install -U fastplotlib +.. code-block:: - # with imgui pip install -U "fastplotlib[imgui]" - # to use in jupyterlab, no imgui - pip install -U "fastplotlib[notebook]" +With jupyterlab support. -We strongly recommend installing ``simplejpeg`` for use in notebooks, you must first install `libjpeg-turbo `_. +.. code-block:: + + pip install -U "fastplotlib[notebook,imgui]" -- If you use ``conda``, you can get ``libjpeg-turbo`` through conda. -- If you are on linux you can get it through your distro's package manager. -- For Windows and Mac compiled binaries are available on their release page: https://github.com/libjpeg-turbo/libjpeg-turbo/releases +Without imgui +^^^^^^^^^^^^^ -Once you have ``libjpeg-turbo``: +Minimal, install desired GUI library such as PyQt6, PySide6, or glfw separately. .. code-block:: - pip install simplejpeg + pip install fastplotlib + +With jupyterlab support only. + +.. code-block:: + + pip install -U "fastplotlib[notebook]" + +Fastplotlib is also available on conda-forge. For imgui support you will need to separately install ``imgui-bundle``, and for jupyterlab you will need to install ``jupyter-rfb`` and ``simplejpeg`` which are all available on conda-forge. + What is ``fastplotlib``? ------------------------ diff --git a/pyproject.toml b/pyproject.toml index 0a9371891..b91b168c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ notebook = [ "jupyter-rfb>=0.5.1", "ipywidgets>=8.0.0,<9", "sidecar", + "simplejpeg", ] tests = [ "pytest", From f6f81412c52db29b2607e857f005af4d77e7f91b Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 17 Mar 2026 19:07:32 -0400 Subject: [PATCH 085/101] cmap lib handles image colormaps now --- fastplotlib/graphics/features/_image.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/fastplotlib/graphics/features/_image.py b/fastplotlib/graphics/features/_image.py index cb66bb1ef..af0783c71 100644 --- a/fastplotlib/graphics/features/_image.py +++ b/fastplotlib/graphics/features/_image.py @@ -1,14 +1,13 @@ from itertools import product - from math import ceil +import cmap as cmap_lib import numpy as np import pygfx from ._base import GraphicFeature, GraphicFeatureEvent, block_reentrance from ...utils import ( - make_colors, get_cmap_texture, ) @@ -239,8 +238,8 @@ def value(self) -> str: @block_reentrance def set_value(self, graphic, value: str): - new_colors = make_colors(256, value) - graphic._material.map.texture.data[:] = new_colors + colormap = pygfx.cm.create_colormap(cmap_lib.Colormap(value).lut()) + graphic._material.map = colormap graphic._material.map.texture.update_range((0, 0, 0), size=(256, 1, 1)) self._value = value From c2dff8691d1985404d904d30b768431ebb8f8d6d Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 17 Mar 2026 19:10:29 -0400 Subject: [PATCH 086/101] multi-windows ndwidget, maintain features like cmap when switching graphics --- fastplotlib/graphics/line_collection.py | 9 +- fastplotlib/graphics/scatter_collection.py | 36 ++- fastplotlib/layouts/_graphic_methods_mixin.py | 139 +++++++++-- fastplotlib/widgets/nd_widget/_index.py | 39 ++- .../nd_widget/_nd_positions/_nd_positions.py | 232 ++++++++++++------ fastplotlib/widgets/nd_widget/_ndwidget.py | 33 ++- fastplotlib/widgets/nd_widget/_ui.py | 17 +- 7 files changed, 376 insertions(+), 129 deletions(-) diff --git a/fastplotlib/graphics/line_collection.py b/fastplotlib/graphics/line_collection.py index 5ec56777e..351f3368e 100644 --- a/fastplotlib/graphics/line_collection.py +++ b/fastplotlib/graphics/line_collection.py @@ -1,3 +1,5 @@ +from itertools import repeat +from numbers import Number from typing import * import numpy as np @@ -105,8 +107,11 @@ def thickness(self) -> np.ndarray: return np.asarray([g.thickness for g in self]) @thickness.setter - def thickness(self, values: np.ndarray | list[float]): - if not len(values) == len(self): + def thickness(self, values: float | Sequence[float]): + if isinstance(values, Number): + values = repeat(values, len(self)) + + elif not len(values) == len(self): raise IndexError for g, v in zip(self, values): diff --git a/fastplotlib/graphics/scatter_collection.py b/fastplotlib/graphics/scatter_collection.py index 8762a9fb3..f0993dd46 100644 --- a/fastplotlib/graphics/scatter_collection.py +++ b/fastplotlib/graphics/scatter_collection.py @@ -1,3 +1,5 @@ +from itertools import repeat +from numbers import Number from typing import * import numpy as np @@ -59,7 +61,7 @@ def colors(self, values: str | np.ndarray | tuple[float] | list[float] | list[st @property def data(self) -> CollectionFeature: - """get or set data of lines in the collection""" + """get or set data of scatters in the collection""" return CollectionFeature(self.graphics, "data") @data.setter @@ -99,6 +101,38 @@ def cmap(self, args): n_colors=len(self), cmap_name=name, transform=transform ) + @property + def markers(self) -> CollectionFeature: + """get or set markers of scatters in the collection""" + return CollectionFeature(self.graphics, "markers") + + @markers.setter + def markers(self, values: str | Sequence[str]): + if isinstance(values, str): + values = repeat(values, len(self)) + + elif len(values) != len(self): + raise IndexError("len(markers) must be the same as the number of ScatterGraphics in the collection") + + for g, v in zip(self, values): + g.markers = v + + @property + def sizes(self) -> CollectionFeature: + """get or set sizes of scatter points in the collection""" + return CollectionFeature(self.graphics, "sizes") + + @sizes.setter + def sizes(self, values): + if isinstance(values, Number): + values = repeat(values, len(self)) + + elif len(values) != len(self): + raise IndexError("len(sizes) must be the same as the number of ScatterGraphics in the collection") + + for g, v in zip(self, values): + g.sizes = v + class ScatterCollectionIndexer(CollectionIndexer, _ScatterCollectionProperties): """Indexer for scatter collections""" diff --git a/fastplotlib/layouts/_graphic_methods_mixin.py b/fastplotlib/layouts/_graphic_methods_mixin.py index bd01855bd..1fbf337e2 100644 --- a/fastplotlib/layouts/_graphic_methods_mixin.py +++ b/fastplotlib/layouts/_graphic_methods_mixin.py @@ -33,7 +33,7 @@ def add_image( cmap: str = "plasma", interpolation: str = "nearest", cmap_interpolation: str = "linear", - **kwargs, + **kwargs ) -> ImageGraphic: """ @@ -74,7 +74,7 @@ def add_image( cmap, interpolation, cmap_interpolation, - **kwargs, + **kwargs ) def add_image_volume( @@ -92,7 +92,7 @@ def add_image_volume( substep_size: float = 0.1, emissive: str | tuple | numpy.ndarray = (0, 0, 0), shininess: int = 30, - **kwargs, + **kwargs ) -> ImageVolumeGraphic: """ @@ -169,7 +169,7 @@ def add_image_volume( substep_size, emissive, shininess, - **kwargs, + **kwargs ) def add_line_collection( @@ -185,7 +185,7 @@ def add_line_collection( metadata: Any = None, metadatas: Union[Sequence[Any], numpy.ndarray] = None, kwargs_lines: list[dict] = None, - **kwargs, + **kwargs ) -> LineCollection: """ @@ -256,7 +256,7 @@ def add_line_collection( metadata, metadatas, kwargs_lines, - **kwargs, + **kwargs ) def add_line( @@ -268,7 +268,7 @@ def add_line( cmap_transform: Union[numpy.ndarray, Sequence] = None, color_mode: Literal["auto", "uniform", "vertex"] = "auto", size_space: str = "screen", - **kwargs, + **kwargs ) -> LineGraphic: """ @@ -322,7 +322,7 @@ def add_line( cmap_transform, color_mode, size_space, - **kwargs, + **kwargs ) def add_line_stack( @@ -339,7 +339,7 @@ def add_line_stack( separation: float = 10.0, separation_axis: str = "y", kwargs_lines: list[dict] = None, - **kwargs, + **kwargs ) -> LineStack: """ @@ -415,7 +415,7 @@ def add_line_stack( separation, separation_axis, kwargs_lines, - **kwargs, + **kwargs ) def add_mesh( @@ -434,7 +434,7 @@ def add_mesh( | numpy.ndarray ) = None, clim: tuple[float, float] = None, - **kwargs, + **kwargs ) -> MeshGraphic: """ @@ -488,7 +488,7 @@ def add_mesh( mapcoords, cmap, clim, - **kwargs, + **kwargs ) def add_polygon( @@ -505,7 +505,7 @@ def add_polygon( | numpy.ndarray ) = None, clim: tuple[float, float] | None = None, - **kwargs, + **kwargs ) -> PolygonGraphic: """ @@ -555,12 +555,15 @@ def add_scatter_collection( cmap: Union[Sequence[str], str] = None, cmap_transform: Union[numpy.ndarray, List] = None, sizes: Union[float, Sequence[float]] = 5.0, + uniform_size: bool = True, + markers: Union[numpy.ndarray, Sequence[str]] = None, + uniform_marker: bool = True, + edge_width: float = 1.0, name: str = None, names: list[str] = None, metadata: Any = None, metadatas: Union[Sequence[Any], numpy.ndarray] = None, - kwargs_lines: list[dict] = None, - **kwargs, + **kwargs ) -> ScatterCollection: """ @@ -618,12 +621,15 @@ def add_scatter_collection( cmap, cmap_transform, sizes, + uniform_size, + markers, + uniform_marker, + edge_width, name, names, metadata, metadatas, - kwargs_lines, - **kwargs, + **kwargs ) def add_scatter( @@ -648,7 +654,7 @@ def add_scatter( sizes: Union[float, numpy.ndarray, Sequence[float]] = 5, uniform_size: bool = True, size_space: str = "screen", - **kwargs, + **kwargs ) -> ScatterGraphic: """ @@ -778,7 +784,92 @@ def add_scatter( sizes, uniform_size, size_space, - **kwargs, + **kwargs + ) + + def add_scatter_stack( + self, + data: Union[numpy.ndarray, List[numpy.ndarray]], + colors: Union[str, Sequence[str], numpy.ndarray, Sequence[numpy.ndarray]] = "w", + cmap: Union[Sequence[str], str] = None, + cmap_transform: Union[numpy.ndarray, List] = None, + name: str = None, + names: list[str] = None, + metadata: Any = None, + metadatas: Union[Sequence[Any], numpy.ndarray] = None, + separation: float = 0.0, + separation_axis: str = "y", + **kwargs + ) -> ScatterStack: + """ + + Create a stack of :class:`.LineGraphic` that are separated along the "x" or "y" axis. + + Parameters + ---------- + data: list of array-like + List or array-like of multiple line data to plot + + | if ``list`` each item in the list must be a 1D, 2D, or 3D numpy array + | if array-like, must be of shape [n_lines, n_points_line, y | xy | xyz] + + thickness: float or Iterable of float, default 2.0 + | if ``float``, single thickness will be used for all lines + | if ``list`` of ``float``, each value will apply to the individual lines + + colors: str, RGBA array, Iterable of RGBA array, or Iterable of str, default "w" + | if single ``str`` such as "w", "r", "b", etc, represents a single color for all lines + | if single ``RGBA array`` (tuple or list of size 4), represents a single color for all lines + | if ``list`` of ``str``, represents color for each individual line, example ["w", "b", "r",...] + | if ``RGBA array`` of shape [data_size, 4], represents a single RGBA array for each line + + cmap: Iterable of str or str, optional + | if ``str``, single cmap will be used for all lines + | if ``list`` of ``str``, each cmap will apply to the individual lines + + .. note:: + ``cmap`` overrides any arguments passed to ``colors`` + + cmap_transform: 1D array-like of numerical values, optional + if provided, these values are used to map the colors from the cmap + + name: str, optional + name of the line collection as a whole + + names: list[str], optional + names of the individual lines in the collection, ``len(names)`` must equal ``len(data)`` + + metadata: Any + metadata associated with the collection as a whole + + metadatas: Iterable or array + metadata for each individual line associated with this collection, this is for the user to manage. + ``len(metadata)`` must be same as ``len(data)`` + + separation: float, default 0.0 + space in between each line graphic in the stack + + separation_axis: str, default "y" + axis in which the line graphics in the stack should be separated + + kwargs_collection + kwargs for the collection, passed to GraphicCollection + + + """ + return self._create_graphic( + ScatterStack, + data, + colors, + cmap, + cmap_transform, + name, + names, + metadata, + metadatas, + separation, + separation_axis, + **kwargs ) def add_surface( @@ -795,7 +886,7 @@ def add_surface( | numpy.ndarray ) = None, clim: tuple[float, float] | None = None, - **kwargs, + **kwargs ) -> SurfaceGraphic: """ @@ -849,7 +940,7 @@ def add_text( screen_space: bool = True, offset: tuple[float] = (0, 0, 0), anchor: str = "middle-center", - **kwargs, + **kwargs ) -> TextGraphic: """ @@ -900,7 +991,7 @@ def add_text( screen_space, offset, anchor, - **kwargs, + **kwargs ) def add_vectors( @@ -910,7 +1001,7 @@ def add_vectors( color: Union[str, Sequence[float], numpy.ndarray] = "w", size: float = None, vector_shape_options: dict = None, - **kwargs, + **kwargs ) -> VectorsGraphic: """ @@ -955,5 +1046,5 @@ def add_vectors( color, size, vector_shape_options, - **kwargs, + **kwargs ) diff --git a/fastplotlib/widgets/nd_widget/_index.py b/fastplotlib/widgets/nd_widget/_index.py index 9ba9d03eb..3cb8a71f9 100644 --- a/fastplotlib/widgets/nd_widget/_index.py +++ b/fastplotlib/widgets/nd_widget/_index.py @@ -1,7 +1,11 @@ +from __future__ import annotations + from dataclasses import dataclass from typing import Sequence, Any, Callable -from ._base import NDGraphic +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ._ndwidget import NDWidget @dataclass @@ -49,11 +53,10 @@ def __len__(self): return len(self.options) -class GlobalIndex: +class ReferenceIndex: def __init__( self, ref_ranges: dict[str, tuple], - get_ndgraphics: Callable[[], tuple[NDGraphic]], ): self._ref_ranges = dict() @@ -69,8 +72,6 @@ def __init__( else: raise ValueError - self._get_ndgraphics = get_ndgraphics - # starting index for all dims self._indices: dict[str, int | float | Any] = { name: rr.start for name, rr in self._ref_ranges.items() @@ -78,9 +79,18 @@ def __init__( self._indices_changed_handlers = set() + self._ndwidgets: list[NDWidget] = list() + + def _add_ndwidget_(self, ndw: NDWidget): + from ._ndwidget import NDWidget + if not isinstance(ndw, NDWidget): + raise TypeError + + self._ndwidgets.append(ndw) + def set(self, indices: dict[str, Any]): for dim, value in indices.items(): - self._indices[dim] = self._clamp(value) + self._indices[dim] = self._clamp(dim, value) self._render_indices() @@ -94,9 +104,13 @@ def _clamp(self, dim, value): return value def _render_indices(self): - for g in self._get_ndgraphics(): - # only provide slider indices to the graphic - g.indices = {d: self._indices[d] for d in g.processor.slider_dims} + for ndw in self._ndwidgets: + for g in ndw.ndgraphics: + if g.data is None: + continue + # only provide slider indices to the graphic + g.indices = {d: self._indices[d] for d in g.processor.slider_dims} + print(g) @property def ref_ranges(self) -> dict[str, RangeContinuous | RangeDiscrete]: @@ -117,7 +131,7 @@ def push_dim(self, ref_range: RangeContinuous): # TODO: implement pushing and popping dims pass - def add_event_handler(self, handler: callable, event: str = "indices"): + def add_event_handler(self, handler: Callable, event: str = "indices"): """ Register an event handler. @@ -126,7 +140,7 @@ def add_event_handler(self, handler: callable, event: str = "indices"): Parameters ---------- - handler: callable + handler: Callable callback function, must take a tuple of int as the only argument. This tuple will be the `indices` event: str, "indices" @@ -153,7 +167,7 @@ def my_handler(indices): self._indices_changed_handlers.add(handler) - def remove_event_handler(self, handler: callable): + def remove_event_handler(self, handler: Callable): """Remove a registered event handler""" self._indices_changed_handlers.remove(handler) @@ -178,6 +192,7 @@ def __str__(self): return str(self._indices) +# TODO: Not sure if we'll actually do this here, just a placeholder for now class SelectionVector: @property def selection(self): diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py index e5800d538..638722716 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py @@ -28,7 +28,7 @@ block_reentrance, block_indices, ) -from .._index import GlobalIndex +from .._index import ReferenceIndex # types for the other features FeatureCallable = Callable[[np.ndarray, slice], np.ndarray] @@ -46,9 +46,10 @@ def default_cmap_transform_each(p: int, data_slice: np.ndarray, s: slice): start=s.start / p, stop=s.stop / p, num=n_displayed, - endpoint=False # since we use a slice object for the displayed data, the last point isn't included + endpoint=False, # since we use a slice object for the displayed data, the last point isn't included ) + class NDPositionsProcessor(NDProcessor): _other_features = ["colors", "markers", "cmap_transform_each", "sizes"] @@ -227,7 +228,6 @@ def cmap_transform_each(self, new: np.ndarray | FeatureCallable | None): return if new is None: - # default transform is just a transform based on the `p` dim size self._cmap_transform_each = None return @@ -485,7 +485,6 @@ def _get_other_features( val = getattr(self, attr) if val is None: - other[attr] = None continue if callable(val): @@ -558,7 +557,7 @@ def get(self, indices: dict[str, Any]) -> dict[str, np.ndarray]: class NDPositions(NDGraphic): def __init__( self, - global_index: GlobalIndex, + ref_index: ReferenceIndex, data: Any, dims: Sequence[str], spatial_dims: tuple[str, str, str], @@ -586,15 +585,18 @@ def __init__( cmap: str = None, # across the line/scatter collection cmap_each: Sequence[str] = None, # for each individual line/scatter cmap_transform_each: np.ndarray = None, # for each individual line/scatter - markers: Sequence[str] = None, - sizes: Sequence[float] = None, + markers: np.ndarray = None, # across the scatter collection, shape [l,] + markers_each: Sequence[str] = None, # for each individual scatter, shape [l, p] + sizes: np.ndarray = None, # across the scatter collection, shape [l,] + sizes_each: Sequence[float] = None, # for each individual scatter, shape [l, p] + thickness: np.ndarray = None, # for each line, shape [l,] name: str = None, graphic_kwargs: dict = None, processor_kwargs: dict = None, ): super().__init__(name) - self._global_index = global_index + self._ref_index = ref_index if processor_kwargs is None: processor_kwargs = dict() @@ -614,13 +616,17 @@ def __init__( window_funcs=window_funcs, index_mappings=index_mappings, colors=colors, - markers=markers, + markers=markers_each, cmap_transform_each=cmap_transform_each, - sizes=sizes, + sizes=sizes_each, **processor_kwargs, ) - self.cmap = cmap + self._cmap = cmap + self._sizes = sizes + self._markers = markers + self._thickness = thickness + self.cmap_each = cmap_each self.cmap_transform_each = cmap_transform_each @@ -641,7 +647,6 @@ def __init__( self._pause = False - @property def processor(self) -> NDPositionsProcessor: return self._processor @@ -672,60 +677,6 @@ def graphic(self, graphic_type): self._create_graphic(graphic_type) plot_area.add_graphic(self._graphic) - @property - def cmap(self) -> str | None: - return self._cmap - - @cmap.setter - def cmap(self, new: str | None): - self._cmap = new - - @property - def cmap_each(self) -> np.ndarray[str] | None: - # per-line/scatter - return self._cmap_each - - @cmap_each.setter - def cmap_each(self, new: Sequence[str] | None): - if new is None: - self._cmap_each = None - return - - if isinstance(new, str): - new = [new] - - new = np.asarray(new) - - if new.ndim != 1: - raise ValueError - - l_dim_size = self.processor.shape[self.processor.spatial_dims[0]] - # same cmap for all if size == 1, or specific cmap for each in `l` dim - if new.size != 1 and new.size != l_dim_size: - raise ValueError - - self._cmap_each = np.broadcast_to(new, shape=(l_dim_size,)) - - @property - def cmap_transform_each(self) -> np.ndarray | None: - # PER line/scatter, only allowed after `cmaps` is set. - return self.processor.cmap_transform_each - - @cmap_transform_each.setter - def cmap_transform_each(self, new: np.ndarray | FeatureCallable | None): - if self.cmap_each is None: - self.processor.cmap_transform_each = None - warn("must set `cmap_each` before `cmap_transform_each`") - return - - if new is None and self.cmap_each is not None: - # default transform is just a transform based on the `p` dim size - new = partial( - default_cmap_transform_each, self.shape[self.spatial_dims[1]] - ) - - self.processor.cmap_transform_each = new - @property def spatial_dims(self) -> tuple[str, str, str]: return self.processor.spatial_dims @@ -738,7 +689,7 @@ def spatial_dims(self, dims: tuple[str, str, str]): @property def indices(self) -> dict[Hashable, Any]: - return {d: self._global_index[d] for d in self.processor.slider_dims} + return {d: self._ref_index[d] for d in self.processor.slider_dims} @indices.setter @block_reentrance @@ -810,7 +761,7 @@ def indices(self, indices): def _linear_selector_handler(self, ev): with block_indices(self): # linear selector always acts on the `p` dim - self._global_index[self.processor.spatial_dims[1]] = ev.info["value"] + self._ref_index[self.processor.spatial_dims[1]] = ev.info["value"] def _tooltip_handler(self, graphic, pick_info): if isinstance(self.graphic, (LineCollection, ScatterCollection)): @@ -837,6 +788,18 @@ def _create_graphic( new_features = self.processor.get(self.indices) data_slice = new_features["data"] + # store any cmap, sizes, thickness, etc. to assign to new graphic + graphic_attrs = dict() + for attr in ["cmap", "markers", "sizes", "thickness"]: + if attr in new_features.keys(): + if new_features[attr] is not None: + # markers and sizes defined for each line via processor takes priority + continue + + val = getattr(self, attr) + if val is not None: + graphic_attrs[attr] = val + if issubclass(graphic_cls, ImageGraphic): # `d` dim must only have xy data to be interpreted as a heatmap, xyz can't become a timeseries heatmap if self.processor.shape[self.processor.spatial_dims[-1]] != 2: @@ -854,6 +817,10 @@ def _create_graphic( kwargs = self._graphic_kwargs self._graphic = graphic_cls(data_slice, **kwargs) + for attr in graphic_attrs.keys(): + if hasattr(self._graphic, attr): + setattr(self._graphic, attr, graphic_attrs[attr]) + if isinstance(self._graphic, (LineCollection, ScatterCollection)): for l, g in enumerate(self.graphic.graphics): for feature in ["colors", "sizes", "markers"]: @@ -972,11 +939,136 @@ def _update_from_view_range(self): new_width = abs(xr[1] - xr[0]) new_index = (xr[0] + xr[1]) / 2 - if (new_index == self._global_index[self.processor.spatial_dims[1]]) and ( + if (new_index == self._ref_index[self.processor.spatial_dims[1]]) and ( last_width == new_width ): return self.processor.display_window = new_width # set the `p` dim on the global index vector - self._global_index[self.processor.spatial_dims[1]] = new_index + self._ref_index[self.processor.spatial_dims[1]] = new_index + + @property + def cmap(self) -> str | None: + return self._cmap + + @cmap.setter + def cmap(self, new: str | None): + if new is None: + # just set a default + if isinstance(self.graphic, (LineCollection, ScatterCollection)): + self.graphic.colors = "w" + else: + self.graphic.cmap = "plasma" + + self._cmap = None + return + + self._graphic.cmap = new + self._cmap = new + # force a re-render + self.indices = self.indices + + @property + def cmap_each(self) -> np.ndarray[str] | None: + # per-line/scatter + return self._cmap_each + + @cmap_each.setter + def cmap_each(self, new: Sequence[str] | None): + if new is None: + self._cmap_each = None + return + + if isinstance(new, str): + new = [new] + + new = np.asarray(new) + + if new.ndim != 1: + raise ValueError + + l_dim_size = self.processor.shape[self.processor.spatial_dims[0]] + # same cmap for all if size == 1, or specific cmap for each in `l` dim + if new.size != 1 and new.size != l_dim_size: + raise ValueError + + self._cmap_each = np.broadcast_to(new, shape=(l_dim_size,)) + + @property + def cmap_transform_each(self) -> np.ndarray | None: + # PER line/scatter, only allowed after `cmaps` is set. + return self.processor.cmap_transform_each + + @cmap_transform_each.setter + def cmap_transform_each(self, new: np.ndarray | FeatureCallable | None): + if new is None: + self.processor.cmap_transform_each = None + + if self.cmap_each is None: + self.processor.cmap_transform_each = None + warn("must set `cmap_each` before `cmap_transform_each`") + return + + if new is None and self.cmap_each is not None: + # default transform is just a transform based on the `p` dim size + new = partial(default_cmap_transform_each, self.shape[self.spatial_dims[1]]) + + self.processor.cmap_transform_each = new + + @property + def markers(self) -> str | Sequence[str] | None: + return self._markers + + @markers.setter + def markers(self, new: str | None): + if not isinstance(self.graphic, ScatterCollection): + self._markers = None + return + + if new is None: + # just set a default + new = "circle" + + self.graphic.markers = new + self._markers = new + # force a re-render + self.indices = self.indices + + @property + def sizes(self) -> float | Sequence[float] | None: + return self._sizes + + @sizes.setter + def sizes(self, new: float | Sequence[float] | None): + if not isinstance(self.graphic, ScatterCollection): + self._sizes = None + return + + if new is None: + # just set a default + new = 5.0 + + self.graphic.sizes = new + self._sizes = new + # force a re-render + self.indices = self.indices + + @property + def thickness(self) -> float | Sequence[float] | None: + return self._thickness + + @thickness.setter + def thickness(self, new: float | Sequence[float] | None): + if not isinstance(self.graphic, LineCollection): + self._thickness = None + return + + if new is None: + # just set a default + new = 2.0 + + self.graphic.thickness = new + self._thickness = new + # force a re-render + self.indices = self.indices \ No newline at end of file diff --git a/fastplotlib/widgets/nd_widget/_ndwidget.py b/fastplotlib/widgets/nd_widget/_ndwidget.py index 8449a2c70..9ddfa8986 100644 --- a/fastplotlib/widgets/nd_widget/_ndwidget.py +++ b/fastplotlib/widgets/nd_widget/_ndwidget.py @@ -1,14 +1,22 @@ -from typing import Any +from __future__ import annotations -from ._index import RangeContinuous, RangeDiscrete, GlobalIndex +from typing import Any, Optional + +from ._index import RangeContinuous, RangeDiscrete, ReferenceIndex from ._ndw_subplot import NDWSubplot from ._ui import NDWidgetUI, RightClickMenu from ...layouts import ImguiFigure, Subplot class NDWidget: - def __init__(self, ref_ranges: dict[str, tuple], **kwargs): - self._indices = GlobalIndex(ref_ranges, self._get_ndgraphics) + def __init__(self, ref_ranges: dict[str, tuple], ref_index: Optional[ReferenceIndex] = None, **kwargs): + if ref_index is None: + self._indices = ReferenceIndex(ref_ranges) + else: + self._indices = ref_index + + self._indices._add_ndwidget_(self) + self._figure = ImguiFigure(std_right_click_menu=RightClickMenu, **kwargs) self._figure.std_right_click_menu.set_nd_widget(self) @@ -27,7 +35,7 @@ def figure(self) -> ImguiFigure: return self._figure @property - def indices(self) -> GlobalIndex: + def indices(self) -> ReferenceIndex: return self._indices @indices.setter @@ -35,21 +43,22 @@ def indices(self, new_indices: dict[str, int | float | Any]): self._indices.set(new_indices) @property - def ref_ranges(self) -> dict[str, RangeContinuous | RangeDiscrete]: + def ranges(self) -> dict[str, RangeContinuous | RangeDiscrete]: return self._indices.ref_ranges - def __getitem__(self, key: str | tuple[int, int] | Subplot): - if not isinstance(key, Subplot): - key = self.figure[key] - return self._subplots_nd[key] - - def _get_ndgraphics(self): + @property + def ndgraphics(self): gs = list() for subplot in self._subplots_nd.values(): gs.extend(subplot.nd_graphics) return tuple(gs) + def __getitem__(self, key: str | tuple[int, int] | Subplot): + if not isinstance(key, Subplot): + key = self.figure[key] + return self._subplots_nd[key] + def show(self, **kwargs): return self.figure.show(**kwargs) diff --git a/fastplotlib/widgets/nd_widget/_ui.py b/fastplotlib/widgets/nd_widget/_ui.py index a75d99e00..0e73f524f 100644 --- a/fastplotlib/widgets/nd_widget/_ui.py +++ b/fastplotlib/widgets/nd_widget/_ui.py @@ -6,6 +6,7 @@ from ...graphics import ( ScatterCollection, + ScatterStack, LineCollection, LineStack, ImageGraphic, @@ -19,7 +20,7 @@ from ._nd_positions import NDPositions from ._nd_image import NDImage -position_graphics = [ScatterCollection, LineCollection, LineStack, ImageGraphic] +position_graphics = [ScatterCollection, ScatterStack, LineCollection, LineStack, ImageGraphic] class NDWidgetUI(EdgeWindow): @@ -35,7 +36,7 @@ def __init__(self, figure, size, ndwidget): ) self._ndwidget = ndwidget - ref_ranges = self._ndwidget.ref_ranges + ref_ranges = self._ndwidget.ranges # whether or not a dimension is in play mode self._playing = {dim: False for dim in ref_ranges.keys()} @@ -61,11 +62,11 @@ def __init__(self, figure, size, ndwidget): self._max_display_windows: dict[NDGraphic, float | int] = dict() def _set_index(self, dim, index): - if index >= self._ndwidget.ref_ranges[dim].stop: + if index >= self._ndwidget.ranges[dim].stop: if self._loop[dim]: - index = self._ndwidget.ref_ranges[dim].start + index = self._ndwidget.ranges[dim].start else: - index = self._ndwidget.ref_ranges[dim].stop + index = self._ndwidget.ranges[dim].stop self._playing[dim] = False self._ndwidget.indices[dim] = index @@ -77,7 +78,7 @@ def update(self): # push id since we have the same buttons for each dim imgui.push_id(f"{self._id_counter}_{dim}") - rr = self._ndwidget.ref_ranges[dim] + rr = self._ndwidget.ranges[dim] if self._playing[dim]: # show pause button if playing @@ -252,7 +253,7 @@ def _draw_nd_pos_ui(self, subplot: Subplot, nd_graphic: NDPositions): nd_graphic.display_window = None else: # pick a value 10% of the reference range - nd_graphic.display_window = self._ndwidget.ref_ranges[p_dim].range * 0.1 + nd_graphic.display_window = self._ndwidget.ranges[p_dim].range * 0.1 if nd_graphic.display_window is not None: if isinstance(nd_graphic.display_window, (int, np.integer)): @@ -268,7 +269,7 @@ def _draw_nd_pos_ui(self, subplot: Subplot, nd_graphic: NDPositions): "display window", v=nd_graphic.display_window, v_min=type_(0), - v_max=type_(self._ndwidget.ref_ranges[p_dim].stop * 0.1), + v_max=type_(self._ndwidget.ranges[p_dim].stop * 0.1), ) if changed: From 0b603e91d3bfca7a14c01063765ef7fe2de5a385 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Tue, 17 Mar 2026 23:27:31 -0400 Subject: [PATCH 087/101] progress --- fastplotlib/widgets/nd_widget/_base.py | 64 +++++++++---- fastplotlib/widgets/nd_widget/_index.py | 1 - fastplotlib/widgets/nd_widget/_nd_image.py | 62 ++++++------- .../nd_widget/_nd_positions/_nd_positions.py | 90 +++++++++++-------- .../nd_widget/_nd_positions/_pandas.py | 2 +- fastplotlib/widgets/nd_widget/_ndw_subplot.py | 29 ++---- fastplotlib/widgets/nd_widget/_ui.py | 8 +- 7 files changed, 139 insertions(+), 117 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_base.py b/fastplotlib/widgets/nd_widget/_base.py index 707480e58..397f6bd37 100644 --- a/fastplotlib/widgets/nd_widget/_base.py +++ b/fastplotlib/widgets/nd_widget/_base.py @@ -4,13 +4,14 @@ from numbers import Real from pprint import pformat import textwrap -from typing import Literal, Any +from typing import Literal, Any, Type from warnings import warn import xarray as xr import numpy as np from numpy.typing import ArrayLike +from ...layouts import Subplot from ...utils import subsample_array, ArrayProtocol from ...graphics import Graphic @@ -35,7 +36,8 @@ def __init__( window_order: tuple[Hashable, ...] = None, spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, ): - self._data = self._validate_data(data, tuple(dims)) + self._dims = tuple(dims) + self._data = self._validate_data(data) self.spatial_dims = spatial_dims self.index_mappings = index_mappings @@ -50,16 +52,19 @@ def data(self) -> xr.DataArray: @data.setter def data(self, data: ArrayProtocol): - self._data = self._validate_data(data, self.dims) + self._data = self._validate_data(data) + + def _validate_data(self, data: ArrayProtocol): + if data is None: + return None - def _validate_data(self, data: ArrayProtocol, dims): if not isinstance(data, ArrayProtocol): raise TypeError("`data` must implement the ArrayProtocol") - if data.ndim != len(dims): + if data.ndim != len(self.dims): raise IndexError("must specify a dim for every dimension in the data array") - return xr.DataArray(data, dims=dims) + return xr.DataArray(data, dims=self.dims) @property def shape(self) -> dict[Hashable, int]: @@ -74,7 +79,7 @@ def ndim(self) -> int: @property def dims(self) -> tuple[Hashable, ...]: """dim names""" - return self.data.dims + return self._dims @property def spatial_dims(self) -> tuple[Hashable, ...]: @@ -316,26 +321,48 @@ def get(self, indices: dict[Hashable, Any]): raise NotImplementedError def __repr__(self): + if self.data is None: + return ( + f"{self.__class__.__name__}\n" + f"data is None, dims: {self.dims}" + ) tab = "\t" - return ( + + wf = {k: v for k, v in self.window_funcs.items() if v != (None, None)} + + r = ( f"{self.__class__.__name__}\n" f"shape:\n\t{self.shape}\n" f"dims:\n\t{self.dims}\n" f"spatial_dims:\n\t{self.spatial_dims}\n" f"slider_dims:\n\t{self.slider_dims}\n" f"index_mappings:\n{textwrap.indent(pformat(self.index_mappings, width=120), prefix=tab)}\n" - f"window_funcs:\n{textwrap.indent(pformat(self.window_funcs, width=120), prefix=tab)}\n" - f"window_order:\n\t{self.window_order}\n" - f"spatial_func:\n\t{self.spatial_func}\n" ) + if len(wf) > 0: + r += ( + f"window_funcs:\n{textwrap.indent(pformat(wf, width=120), prefix=tab)}\n" + f"window_order:\n\t{self.window_order}\n" + ) + + if self.spatial_func is not None: + r += f"spatial_func:\n\t{self.spatial_func}\n" + + return r + class NDGraphic: - def __init__(self, name: str | None): + def __init__( + self, + subplot: Subplot, + name: str | None, + ): + self._subplot = subplot self._name = name self._block_indices = False + self._graphic: Graphic | None = None - def _create_graphic(self, graphic_cls: type): + def _create_graphic(self): raise NotImplementedError @property @@ -367,11 +394,12 @@ def data(self) -> Any: def data(self, data: Any): self.processor.data = data # create a new graphic when data has changed - plot_area = self._graphic._plot_area - plot_area.delete_graphic(self._graphic) + if self.graphic is not None: + # it is already None is it was initialized with no data + self._subplot.delete_graphic(self.graphic) + self._graphic = None - self._create_graphic(self.graphic.__class__) - plot_area.add_graphic(self._graphic) + self._create_graphic() # force a re-render self.indices = self.indices @@ -456,7 +484,7 @@ def spatial_func( self.indices = self.indices def __repr__(self): - return f"graphic: {self.graphic}\n" f"processor:\n{self.processor}" + return f"graphic: {self.graphic.__class__.__name__}\n" f"processor:\n{self.processor}" @contextmanager diff --git a/fastplotlib/widgets/nd_widget/_index.py b/fastplotlib/widgets/nd_widget/_index.py index 3cb8a71f9..e91319893 100644 --- a/fastplotlib/widgets/nd_widget/_index.py +++ b/fastplotlib/widgets/nd_widget/_index.py @@ -110,7 +110,6 @@ def _render_indices(self): continue # only provide slider indices to the graphic g.indices = {d: self._indices[d] for d in g.processor.slider_dims} - print(g) @property def ref_ranges(self) -> dict[str, RangeContinuous | RangeDiscrete]: diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py index f78bf7ce9..0261a64e5 100644 --- a/fastplotlib/widgets/nd_widget/_nd_image.py +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -6,6 +6,7 @@ from numpy.typing import ArrayLike import xarray as xr +from ...layouts import Subplot from ...utils import subsample_array, ArrayProtocol, ARRAY_LIKE_ATTRS from ...graphics import ImageGraphic, ImageVolumeGraphic from ...tools import HistogramLUTTool @@ -96,10 +97,10 @@ def data(self) -> xr.DataArray | None: @data.setter def data(self, data: ArrayLike): # check that all array-like attributes are present - self._data = self._validate_data(data, self.dims) + self._data = self._validate_data(data) self._recompute_histogram() - def _validate_data(self, data: ArrayProtocol, dims): + def _validate_data(self, data: ArrayProtocol): if not isinstance(data, ArrayProtocol): raise TypeError( f"`data` arrays must have all of the following attributes to be sufficiently array-like:\n" @@ -111,7 +112,7 @@ def _validate_data(self, data: ArrayProtocol, dims): f"Image data must have a minimum of 2 dimensions, you have passed an array of shape: {data.shape}" ) - return xr.DataArray(data, dims=dims) + return xr.DataArray(data, dims=self.dims) @property def rgb_dim(self) -> str | None: @@ -221,7 +222,8 @@ def _recompute_histogram(self): class NDImage(NDGraphic): def __init__( self, - global_index, + ref_index, + subplot: Subplot, data: ArrayLike | None, dims: Sequence[Hashable], spatial_dims: ( @@ -236,9 +238,9 @@ def __init__( name: str = None, ): - super().__init__(name) + super().__init__(subplot, name) - self._global_index = global_index + self._ref_index = ref_index self._processor = NDImageProcessor( data, @@ -268,11 +270,6 @@ def graphic( """LineStack or ImageGraphic for heatmaps""" return self._graphic - @graphic.setter - def graphic(self, graphic_type): - # TODO implement if graphic type changes to custom user subclass - raise NotImplementedError - def _create_graphic(self): match len(self.processor.spatial_dims) - int(bool(self.processor.rgb_dim)): case 2: @@ -291,9 +288,8 @@ def _create_graphic(self): for k in attrs: attrs[k] = getattr(old_graphic, k) - plot_area = old_graphic._plot_area - plot_area.delete_graphic(old_graphic) - plot_area.add_graphic(new_graphic) + self._subplot.delete_graphic(old_graphic) + self._subplot.add_graphic(new_graphic) # set cmap and interpolation for attr, val in attrs.items(): @@ -301,19 +297,19 @@ def _create_graphic(self): self._graphic = new_graphic - if self._graphic._plot_area is not None: - self._reset_camera() + self._subplot.add_graphic(self._graphic) + self._reset_camera() self._reset_histogram() def _reset_histogram(self): # reset histogram - if self._graphic._plot_area is None: + if self.graphic is None: return if not self.processor.compute_histogram: # hide right dock if histogram not desired - self._graphic._plot_area.docks["right"].size = 0 + self._subplot.docks["right"].size = 0 return if self.processor.histogram: @@ -321,8 +317,8 @@ def _reset_histogram(self): # histogram widget exists, update it self._histogram_widget.histogram = self.processor.histogram self._histogram_widget.images = self.graphic - if self.graphic._plot_area.docks["right"].size < 1: - self.graphic._plot_area.docks["right"].size = 80 + if self._subplot.docks["right"].size < 1: + self._subplot.docks["right"].size = 80 else: # make hist tool self._histogram_widget = HistogramLUTTool( @@ -330,18 +326,16 @@ def _reset_histogram(self): images=self.graphic, name=f"hist-{hex(id(self.graphic))}", ) - self.graphic._plot_area.docks["right"].add_graphic(self._histogram_widget) - self.graphic._plot_area.docks["right"].size = 80 + self._subplot.docks["right"].add_graphic(self._histogram_widget) + self._subplot.docks["right"].size = 80 self.graphic.reset_vmin_vmax() def _reset_camera(self): - plot_area = self._graphic._plot_area - # set camera to a nice position for 2D or 3D if isinstance(self._graphic, ImageGraphic): # set camera orthogonal to the xy plane, flip y axis - plot_area.camera.set_state( + self._subplot.camera.set_state( { "position": [0, 0, -1], "rotation": [0, 0, 0, 1], @@ -352,21 +346,21 @@ def _reset_camera(self): } ) - plot_area.controller = "panzoom" - plot_area.axes.intersection = None - plot_area.auto_scale() + self._subplot.controller = "panzoom" + self._subplot.axes.intersection = None + self._subplot.auto_scale() else: - plot_area.camera.fov = 50 - plot_area.controller = "orbit" + self._subplot.camera.fov = 50 + self._subplot.controller = "orbit" # make sure all 3D dimension camera scales are positive # MIP rendering doesn't work with negative camera scales for dim in ["x", "y", "z"]: - if getattr(plot_area.camera.local, f"scale_{dim}") < 0: - setattr(plot_area.camera.local, f"scale_{dim}", 1) + if getattr(self._subplot.camera.local, f"scale_{dim}") < 0: + setattr(self._subplot.camera.local, f"scale_{dim}", 1) - plot_area.auto_scale() + self._subplot.auto_scale() @property def spatial_dims(self) -> tuple[str, str] | tuple[str, str, str]: @@ -381,7 +375,7 @@ def spatial_dims(self, dims: tuple[str, str] | tuple[str, str, str]): @property def indices(self) -> dict[Hashable, Any]: - return {d: self._global_index[d] for d in self.processor.slider_dims} + return {d: self._ref_index[d] for d in self.processor.slider_dims} @indices.setter def indices(self, indices): diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py index 638722716..ebfe3d476 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py @@ -8,6 +8,7 @@ from numpy.typing import ArrayLike import xarray as xr +from ....layouts import Subplot from ....graphics import ( Graphic, ImageGraphic, @@ -558,11 +559,12 @@ class NDPositions(NDGraphic): def __init__( self, ref_index: ReferenceIndex, + subplot: Subplot, data: Any, dims: Sequence[str], spatial_dims: tuple[str, str, str], *args, - graphic: Type[ + graphic_type: Type[ LineGraphic | LineCollection | LineStack @@ -577,6 +579,7 @@ def __init__( index_mappings: tuple[Callable[[Any], int] | None] | None = None, max_display_datapoints: int = 1_000, linear_selector: bool = False, + x_range_mode: Literal["fixed", "auto"] | None = None, colors: ( Sequence[str] | np.ndarray | Callable[[slice, np.ndarray], np.ndarray] ) = None, @@ -594,7 +597,7 @@ def __init__( graphic_kwargs: dict = None, processor_kwargs: dict = None, ): - super().__init__(name) + super().__init__(subplot, name) self._ref_index = ref_index @@ -630,9 +633,11 @@ def __init__( self.cmap_each = cmap_each self.cmap_transform_each = cmap_transform_each - self._create_graphic(graphic) + self._graphic_type = graphic_type + self._create_graphic() self._x_range_mode = None + self.x_range_mode = x_range_mode self._last_x_range = np.array([0.0, 0.0], dtype=np.float32) if linear_selector: @@ -662,20 +667,33 @@ def graphic( | ScatterCollection | ScatterStack | ImageGraphic + | None ): """LineStack or ImageGraphic for heatmaps""" return self._graphic - @graphic.setter - def graphic(self, graphic_type): + @property + def graphic_type( + self, + ) -> Type[ + LineGraphic + | LineCollection + | LineStack + | ScatterGraphic + | ScatterCollection + | ScatterStack + | ImageGraphic + ]: + return self._graphic_type + + @graphic_type.setter + def graphic_type(self, graphic_type): if type(self.graphic) is graphic_type: return - plot_area = self._graphic._plot_area - plot_area.delete_graphic(self._graphic) - - self._create_graphic(graphic_type) - plot_area.add_graphic(self._graphic) + self._subplot.delete_graphic(self._graphic) + self._graphic_type = graphic_type + self._create_graphic() @property def spatial_dims(self) -> tuple[str, str, str]: @@ -694,6 +712,9 @@ def indices(self) -> dict[Hashable, Any]: @indices.setter @block_reentrance def indices(self, indices): + if self.data is None: + return + new_features = self.processor.get(indices) data_slice = new_features["data"] @@ -743,7 +764,7 @@ def indices(self, indices): # x range of the data xr = data_slice[0, 0, 0], data_slice[0, -1, 0] - if self._x_range_mode is not None: + if self.x_range_mode is not None: self.graphic._plot_area.x_range = xr # if the update_from_view is polling, this prevents it from being called by setting the new last xrange @@ -770,20 +791,9 @@ def _tooltip_handler(self, graphic, pick_info): p_index = pick_info["vertex_index"] return self.processor.tooltip_format(n_index, p_index) - def _create_graphic( - self, - graphic_cls: Type[ - LineGraphic - | LineCollection - | LineStack - | ScatterGraphic - | ScatterCollection - | ScatterStack - | ImageGraphic - ], - ): - if not issubclass(graphic_cls, Graphic): - raise TypeError + def _create_graphic(self): + if self.data is None: + return new_features = self.processor.get(self.indices) data_slice = new_features["data"] @@ -800,22 +810,22 @@ def _create_graphic( if val is not None: graphic_attrs[attr] = val - if issubclass(graphic_cls, ImageGraphic): + if issubclass(self._graphic_type, ImageGraphic): # `d` dim must only have xy data to be interpreted as a heatmap, xyz can't become a timeseries heatmap if self.processor.shape[self.processor.spatial_dims[-1]] != 2: raise ValueError image_data, x0, x_scale = self._create_heatmap_data(data_slice) - self._graphic = graphic_cls( + self._graphic = self._graphic_type( image_data, offset=(x0, 0, -1), scale=(x_scale, 1, 1) ) else: - if issubclass(graphic_cls, (LineStack, ScatterStack)): + if issubclass(self._graphic_type, (LineStack, ScatterStack)): kwargs = {"separation": 0.0, **self._graphic_kwargs} else: kwargs = self._graphic_kwargs - self._graphic = graphic_cls(data_slice, **kwargs) + self._graphic = self._graphic_type(data_slice, **kwargs) for attr in graphic_attrs.keys(): if hasattr(self._graphic, attr): @@ -853,6 +863,8 @@ def _create_graphic( for g in self._graphic.graphics: g.tooltip_format = partial(self._tooltip_handler, g) + self._subplot.add_graphic(self._graphic) + def _create_heatmap_data(self, data_slice) -> tuple[np.ndarray, float, float]: """return [n_rows, n_cols] shape data""" # assumes x vals in every row is the same, otherwise a heatmap representation makes no sense @@ -908,18 +920,18 @@ def datapoints_window_func(self, funcs: tuple[Callable, str, int | float]): self.processor.datapoints_window_func = funcs @property - def x_range_mode(self) -> Literal[None, "fixed-window", "view-range"]: - """x-range using a fixed window from the display window, or by polling the camera (view-range)""" + def x_range_mode(self) -> Literal["fixed", "auto"] | None: + """x-range using a fixed window from the display window, or by polling the camera (auto)""" return self._x_range_mode @x_range_mode.setter - def x_range_mode(self, mode: Literal[None, "fixed-window", "view-range"]): - if self._x_range_mode == "view-range": - # old mode was view-range - self.graphic._plot_area.remove_animation(self._update_from_view_range) + def x_range_mode(self, mode: Literal[None, "fixed", "auto"]): + if self._x_range_mode == "auto": + # old mode was auto + self._subplot.remove_animation(self._update_from_view_range) - if mode == "view-range": - self.graphic._plot_area.add_animations(self._update_from_view_range) + if mode == "auto": + self._subplot.add_animations(self._update_from_view_range) self._x_range_mode = mode @@ -927,7 +939,7 @@ def _update_from_view_range(self): if self._graphic is None: return - xr = self.graphic._plot_area.x_range + xr = self._subplot.x_range # the floating point error near zero gets nasty here if np.allclose(xr, self._last_x_range, atol=1e-14): @@ -1071,4 +1083,4 @@ def thickness(self, new: float | Sequence[float] | None): self.graphic.thickness = new self._thickness = new # force a re-render - self.indices = self.indices \ No newline at end of file + self.indices = self.indices diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py index 740dfe21e..1b94e1cbc 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py @@ -41,7 +41,7 @@ def __init__( def data(self) -> pd.DataFrame: return self._data - def _validate_data(self, data: pd.DataFrame, dims): + def _validate_data(self, data: pd.DataFrame): if not isinstance(data, pd.DataFrame): raise TypeError diff --git a/fastplotlib/widgets/nd_widget/_ndw_subplot.py b/fastplotlib/widgets/nd_widget/_ndw_subplot.py index 0783379ec..0f53951bd 100644 --- a/fastplotlib/widgets/nd_widget/_ndw_subplot.py +++ b/fastplotlib/widgets/nd_widget/_ndw_subplot.py @@ -1,3 +1,4 @@ +from typing import Literal import numpy as np from ... import ScatterCollection, ScatterStack, LineCollection, LineStack, ImageGraphic @@ -29,40 +30,35 @@ def __getitem__(self, key): raise KeyError(f"NDGraphc with given key not found: {key}") def add_nd_image(self, *args, **kwargs): - nd = NDImage(self.ndw.indices, *args, **kwargs) + nd = NDImage(self.ndw.indices, self._subplot, *args, **kwargs) self._nd_graphics.append(nd) - self._subplot.add_graphic(nd.graphic) - nd._reset_camera() - - # graphic._plot_area must exist before this is called - nd._reset_histogram() return nd def add_nd_scatter(self, *args, **kwargs): # TODO: better func signature here, send all kwargs to processor_kwargs - nd = NDPositions(self.ndw.indices, *args, graphic=ScatterCollection, **kwargs) + nd = NDPositions(self.ndw.indices, self._subplot, *args, graphic_type=ScatterCollection, **kwargs) self._nd_graphics.append(nd) - self._subplot.add_graphic(nd.graphic) return nd def add_nd_timeseries( self, *args, - graphic: type[LineCollection | LineStack | ImageGraphic] = LineStack, - x_range_mode="fixed-window", + graphic_type: type[LineCollection | LineStack | ScatterStack | ImageGraphic] = LineStack, + x_range_mode: Literal["fixed", "auto"] | None = "auto", **kwargs, ): nd = NDPositions( self.ndw.indices, + self._subplot, *args, - graphic=graphic, + graphic_type=graphic_type, linear_selector=True, + x_range_mode=x_range_mode, **kwargs, ) self._nd_graphics.append(nd) - self._subplot.add_graphic(nd.graphic) self._subplot.add_graphic(nd._linear_selector) # need plot_area to exist before these this can be called @@ -74,13 +70,6 @@ def add_nd_timeseries( return nd def add_nd_lines(self, *args, **kwargs): - nd = NDPositions(self.ndw.indices, *args, graphic=LineCollection, **kwargs) + nd = NDPositions(self.ndw.indices, self._subplot, *args, graphic_type=LineCollection, **kwargs) self._nd_graphics.append(nd) - self._subplot.add_graphic(nd.graphic) return nd - - # def __repr__(self): - # return "NDWidget Subplot" - # - # def __str__(self): - # return "NDWidget Subplot" diff --git a/fastplotlib/widgets/nd_widget/_ui.py b/fastplotlib/widgets/nd_widget/_ui.py index 0e73f524f..e5ba7daf8 100644 --- a/fastplotlib/widgets/nd_widget/_ui.py +++ b/fastplotlib/widgets/nd_widget/_ui.py @@ -20,7 +20,7 @@ from ._nd_positions import NDPositions from ._nd_image import NDImage -position_graphics = [ScatterCollection, ScatterStack, LineCollection, LineStack, ImageGraphic] +position_graphic_types = [ScatterCollection, ScatterStack, LineCollection, LineStack, ImageGraphic] class NDWidgetUI(EdgeWindow): @@ -237,9 +237,9 @@ def _draw_nd_image_ui(self, subplot, nd_image: NDImage): nd_image.graphic._material.gamma = new_gamma def _draw_nd_pos_ui(self, subplot: Subplot, nd_graphic: NDPositions): - for i, cls in enumerate(position_graphics): + for i, cls in enumerate(position_graphic_types): if imgui.radio_button(cls.__name__, type(nd_graphic.graphic) is cls): - nd_graphic.graphic = cls + nd_graphic.graphic_type = cls subplot.auto_scale() changed, val = imgui.checkbox( @@ -275,7 +275,7 @@ def _draw_nd_pos_ui(self, subplot: Subplot, nd_graphic: NDPositions): if changed: nd_graphic.display_window = new - options = [None, "fixed-window", "view-range"] + options = [None, "fixed", "auto"] changed, option = imgui.combo( "x-range mode", options.index(nd_graphic.x_range_mode), From 7e862ee6fac597a9ec681fa6b121cc1b49cd81d0 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 18 Mar 2026 00:03:03 -0400 Subject: [PATCH 088/101] lighting objects only when a mesh is added --- fastplotlib/layouts/_plot_area.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/fastplotlib/layouts/_plot_area.py b/fastplotlib/layouts/_plot_area.py index 974b6f653..ac1d8dc3d 100644 --- a/fastplotlib/layouts/_plot_area.py +++ b/fastplotlib/layouts/_plot_area.py @@ -10,7 +10,7 @@ from ._utils import create_controller from ..graphics._base import Graphic, WORLD_OBJECT_TO_GRAPHIC -from ..graphics import ImageGraphic +from ..graphics import ImageGraphic, MeshGraphic from ..graphics.selectors._base_selector import BaseSelector from ._graphic_methods_mixin import GraphicMethodsMixin from ..legends import Legend @@ -120,11 +120,8 @@ def __init__( self._background = pygfx.Background(None, self._background_material) self.scene.add(self._background) - self._ambient_light = pygfx.AmbientLight() - self._directional_light = pygfx.DirectionalLight() - - self.scene.add(self._ambient_light) - self.scene.add(self._camera.add(self._directional_light)) + self._ambient_light = None + self._directional_light = None self._tooltip = Tooltip() self.get_figure()._fpl_overlay_scene.add(self._tooltip._fpl_world_object) @@ -293,12 +290,12 @@ def background_color(self, colors: str | tuple[float]): self._background_material.set_colors(*colors) @property - def ambient_light(self) -> pygfx.AmbientLight: + def ambient_light(self) -> pygfx.AmbientLight | None: """the ambient lighting in the scene""" return self._ambient_light @property - def directional_light(self) -> pygfx.DirectionalLight: + def directional_light(self) -> pygfx.DirectionalLight | None: """the directional lighting on the camera in the scene""" return self._directional_light @@ -631,6 +628,13 @@ def add_graphic(self, graphic: Graphic, center: bool = True): if isinstance(graphic, ImageGraphic): self._sort_images_by_depth() + if isinstance(graphic, MeshGraphic): + self._ambient_light = pygfx.AmbientLight() + self._directional_light = pygfx.DirectionalLight() + + self.scene.add(self._ambient_light) + self.scene.add(self._camera.add(self._directional_light)) + def insert_graphic( self, graphic: Graphic, From 4a6643867e9eed4f80a664c7d510b1bbe0830dbe Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 18 Mar 2026 01:19:29 -0400 Subject: [PATCH 089/101] fix --- fastplotlib/layouts/_plot_area.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fastplotlib/layouts/_plot_area.py b/fastplotlib/layouts/_plot_area.py index ac1d8dc3d..f90cdcf87 100644 --- a/fastplotlib/layouts/_plot_area.py +++ b/fastplotlib/layouts/_plot_area.py @@ -176,8 +176,9 @@ def camera(self, new_camera: str | pygfx.PerspectiveCamera): # user wants to set completely new camera, remove current camera from controller if isinstance(new_camera, pygfx.PerspectiveCamera): self.controller.remove_camera(self._camera) - # add directional light to new camera - new_camera.add(self._directional_light) + if self._directional_light is not None: + # add directional light to new camera + new_camera.add(self._directional_light) # add new camera to controller self.controller.add_camera(new_camera) From 06d526620ab0cde515195369b969fe023e4d805f Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 18 Mar 2026 01:20:15 -0400 Subject: [PATCH 090/101] update axes only when camera or view changes --- fastplotlib/graphics/_axes.py | 37 +++++++++++++++-------------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/fastplotlib/graphics/_axes.py b/fastplotlib/graphics/_axes.py index 5b4c21682..56ca792a4 100644 --- a/fastplotlib/graphics/_axes.py +++ b/fastplotlib/graphics/_axes.py @@ -301,6 +301,8 @@ def __init__( self._basis = None self.basis = basis + self._last_state = self._get_view_state() + @property def world_object(self) -> pygfx.WorldObject: return self._world_object @@ -402,6 +404,14 @@ def intersection(self, intersection: tuple[float, float, float] | None): self._intersection = tuple(float(v) for v in intersection) + def _get_view_state(self) -> tuple[bytes, tuple[int, int], tuple[int, int], bytes]: + viewport = self._plot_area.viewport + cam_matrix = self._plot_area.camera.camera_matrix.tobytes() + scale = self._plot_area.camera.local.scale.tobytes() + + return (cam_matrix, viewport.rect, viewport.logical_size, scale) + + def update_using_bbox(self, bbox): """ Update the w.r.t. the given bbox @@ -444,6 +454,10 @@ def update_using_camera(self): if not self.visible: return + state = self._get_view_state() + if state == self._last_state: + # no changes in the camera or viewport rect + return if self._plot_area.camera.fov == 0: xpos, ypos, width, height = self._plot_area.viewport.rect @@ -453,27 +467,6 @@ def update_using_camera(self): xmin, xmax = xpos, xpos + width ymin, ymax = ypos + height, ypos - # apply quaternion to account for rotation of axes - # xmin, _, _ = vec_transform_quat( - # [xmin, ypos + height / 2, 0], - # self.x.local.rotation - # ) - # - # xmax, _, _ = vec_transform_quat( - # [xmax, ypos + height / 2, 0], - # self.x.local.rotation, - # ) - # - # _, ymin, _ = vec_transform_quat( - # [xpos + width / 2, ymin, 0], - # self.y.local.rotation - # ) - # - # _, ymax, _ = vec_transform_quat( - # [xpos + width / 2, ymax, 0], - # self.y.local.rotation - # ) - min_vals = self._plot_area.map_screen_to_world((xmin, ymin)) max_vals = self._plot_area.map_screen_to_world((xmax, ymax)) @@ -515,6 +508,8 @@ def update_using_camera(self): self.update(bbox, intersection) + self._last_state = state + def update(self, bbox, intersection): """ Update the axes using the given bbox and ruler intersection point From 066094ad22bc04bbe28d4ec08b784413836e6376 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 18 Mar 2026 01:57:14 -0400 Subject: [PATCH 091/101] clean heatmap func --- .../widgets/nd_widget/_nd_positions/_nd_positions.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py index ebfe3d476..81aa535c4 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py @@ -866,17 +866,19 @@ def _create_graphic(self): self._subplot.add_graphic(self._graphic) def _create_heatmap_data(self, data_slice) -> tuple[np.ndarray, float, float]: - """return [n_rows, n_cols] shape data""" + """return [n_rows, n_cols] shape data from [n_timeseries, n_timepoints, xy] data""" # assumes x vals in every row is the same, otherwise a heatmap representation makes no sense + # data slice is of shape [n_timeseries, n_timepoints, xy], where xy is x-y coordinates of each timeseries x = data_slice[0, :, 0] # get x from just the first row # check if we need to interpolate norm = np.linalg.norm(np.diff(np.diff(x))) / x.size if norm > 1e-6: + print(norm) # x is not uniform upto float32 precision, must interpolate x_uniform = np.linspace(x[0], x[-1], num=x.size) - y_interp = np.zeros(shape=data_slice[..., 1].shape, dtype=np.float32) + y_interp = np.empty(shape=data_slice[..., 1].shape, dtype=np.float32) # this for loop is actually slightly faster than numpy.apply_along_axis() for i in range(data_slice.shape[0]): @@ -890,7 +892,7 @@ def _create_heatmap_data(self, data_slice) -> tuple[np.ndarray, float, float]: # assume all x values are the same across all lines # otherwise a heatmap representation makes no sense anyways - x_stop = data_slice[:, -1, 0][0] + x_stop = x[-1] x_scale = (x_stop - x0) / data_slice.shape[1] return y_interp, x0, x_scale From d1d1f6c6e52108d012bd62215524627387983738 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Wed, 18 Mar 2026 02:05:57 -0400 Subject: [PATCH 092/101] stupid print --- fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py index 81aa535c4..b0eca548d 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py @@ -875,7 +875,6 @@ def _create_heatmap_data(self, data_slice) -> tuple[np.ndarray, float, float]: norm = np.linalg.norm(np.diff(np.diff(x))) / x.size if norm > 1e-6: - print(norm) # x is not uniform upto float32 precision, must interpolate x_uniform = np.linspace(x[0], x[-1], num=x.size) y_interp = np.empty(shape=data_slice[..., 1].shape, dtype=np.float32) From 9372ec6fdd1bb00abcaa6946f6535c97c50bddcc Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Thu, 19 Mar 2026 04:41:25 -0400 Subject: [PATCH 093/101] docstrings, comments --- examples/ndwidget/ndimage.py | 21 +- examples/ndwidget/timeseries.py | 7 +- fastplotlib/utils/_protocols.py | 22 +- fastplotlib/widgets/nd_widget/_base.py | 271 ++++++-- fastplotlib/widgets/nd_widget/_index.py | 163 ++++- fastplotlib/widgets/nd_widget/_nd_image.py | 237 +++++-- .../nd_widget/_nd_positions/_nd_positions.py | 86 ++- fastplotlib/widgets/nd_widget/_ndw_subplot.py | 81 ++- .../widgets/nd_widget/_repr_formatter.py | 599 ++++++++++++++++++ 9 files changed, 1326 insertions(+), 161 deletions(-) create mode 100644 fastplotlib/widgets/nd_widget/_repr_formatter.py diff --git a/examples/ndwidget/ndimage.py b/examples/ndwidget/ndimage.py index 80c010ea1..eafd3c3c3 100644 --- a/examples/ndwidget/ndimage.py +++ b/examples/ndwidget/ndimage.py @@ -13,6 +13,7 @@ data = np.random.rand(1000, 30, 64, 64) +data2 = np.random.rand(1000, 30, 128, 128) # must define a reference range for each dim ref = { @@ -21,8 +22,15 @@ } -ndw = fpl.NDWidget(ref_ranges=ref, size=(700, 560)) -ndw.show() +ndw = fpl.NDWidget( + ref_ranges=ref, + size=(700, 560) +) +ndw2 = fpl.NDWidget( + ref_ranges=ref, + ref_index=ndw.indices, # can create another NDWidget that shared the reference index! So multiple windows are possible + size=(700, 560) +) ndi = ndw[0, 0].add_nd_image( data, @@ -31,7 +39,16 @@ name="4d-image", ) +ndi2 = ndw2[0, 0].add_nd_image( + data2, + ("time", "depth", "m", "n"), # specify all dim names + ("m", "n"), # specify spatial dims IN ORDER, rest are auto slider dims + name="4d-image", +) + # change spatial dims on the fly # ndi.spatial_dims = ("depth", "m", "n") +ndw.show() +ndw2.show() fpl.loop.run() diff --git a/examples/ndwidget/timeseries.py b/examples/ndwidget/timeseries.py index e506182e3..9d7ba851f 100644 --- a/examples/ndwidget/timeseries.py +++ b/examples/ndwidget/timeseries.py @@ -43,16 +43,17 @@ data, ("freq", "ampl", "n_lines", "angle", "d"), ("n_lines", "angle", "d"), - index_mappings={ + slider_dim_transforms={ "angle": xs, "ampl": lambda x: int(x + 1), "freq": lambda x: int(x + 1), }, - x_range_mode="view-range", + cmap="jet", + x_range_mode="auto", name="nd-sine" ) -nd_lines.graphic.cmap = "tab10" +nd_lines.cmap = "tab10" subplot = ndw.figure[0, 0] subplot.controller.add_camera(subplot.camera, include_state={"x", "width"}) diff --git a/fastplotlib/utils/_protocols.py b/fastplotlib/utils/_protocols.py index 7ae63ed67..95d7d2763 100644 --- a/fastplotlib/utils/_protocols.py +++ b/fastplotlib/utils/_protocols.py @@ -1,11 +1,29 @@ -from typing import Protocol, runtime_checkable +from __future__ import annotations +from typing import Any, Protocol, runtime_checkable -ARRAY_LIKE_ATTRS = ["shape", "ndim", "__getitem__"] + +ARRAY_LIKE_ATTRS = [ + "__array__", + "__array_ufunc__", + "dtype", + "shape", + "ndim", + "__getitem__", +] @runtime_checkable class ArrayProtocol(Protocol): + def __array__(self) -> ArrayProtocol: ... + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): ... + + def __array_function__(self, func, types, *args, **kwargs): ... + + @property + def dtype(self) -> Any: ... + @property def ndim(self) -> int: ... diff --git a/fastplotlib/widgets/nd_widget/_base.py b/fastplotlib/widgets/nd_widget/_base.py index 397f6bd37..2fa60a5ed 100644 --- a/fastplotlib/widgets/nd_widget/_base.py +++ b/fastplotlib/widgets/nd_widget/_base.py @@ -14,6 +14,8 @@ from ...layouts import Subplot from ...utils import subsample_array, ArrayProtocol from ...graphics import Graphic +from ._repr_formatter import ndp_fmt_text, ndg_fmt_text, ndp_fmt_html, ndg_fmt_html +from ._index import ReferenceIndex # must take arguments: array-like, `axis`: int, `keepdims`: bool WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] @@ -26,21 +28,94 @@ def identity(index: int) -> int: class NDProcessor: def __init__( self, - data, + data: Any, dims: Sequence[Hashable], spatial_dims: Sequence[Hashable] | None, - index_mappings: dict[Hashable, Callable[[Any], int] | ArrayLike] = None, + slider_dim_transforms: dict[Hashable, Callable[[Any], int] | ArrayLike] = None, window_funcs: dict[ Hashable, tuple[WindowFuncCallable | None, int | float | None] ] = None, window_order: tuple[Hashable, ...] = None, spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, ): + """ + Base class for managing n-dimensional data and producing array slices. + + By default, wraps input data into an ``xarray.DataArray`` and provides an interface + for indexing slider dimensions, applying window functions, spatial functions, and mapping + reference-space values to local array indices. Subclasses must implement + :meth:`get`, which is called whenever the :class:`ReferenceIndex` updates. + + Subclasses can implement any type of data representation, they do not necessarily need to be compatible with + (they dot not have to be xarray compatible). However their ``get()`` method must still return a data slice that + corresponds to the graphical representation they map to. + + Every dimension that is *not* listed in ``spatial_dims`` becomes a slider + dimension. Each slider dim must have a ``ReferenceRange`` defined in the + ``ReferenceIndex`` of the parent ``NDWidget``. The widget uses this to direct + a change in the ``ReferenceIndex`` and update the graphics. + + Parameters + ---------- + data: Any + data object that is managed, usually uses the ArrayProtocol. Custom subclasses can manage any kind of data + object but the corresponding :meth:`get` must return an array-like that maps to a graphical representation. + + dims: Sequence[str] + names for each dimension in ``data``. Dimensions not listed in + ``spatial_dims`` are treated as slider dimensions and **must** appear as + keys in the parent ``NDWidget``'s ``ref_ranges`` + Examples:: + ``("time", "depth", "row", "col")`` + ``("channels", "time", "xy")`` + ``("keypoints", "time", "xyz")`` + + A custom subclass's ``data`` object doesn't necessarily need to have these dims, but the ``get()`` method + must operate as if these dimensions exist and return an array that matches the spatial dimensions. + + spatial_dims: Sequence[str] + Subset of ``dims`` that are spatial (rendered) dimensions **in order**. All remaining dims are treated as + slider dims. See subclass for specific info. + + slider_dim_transforms: dict mapping dim_name -> Callable, an ArrayLike, or None + Per-slider-dim mapping from reference-space values to local array indices. + + You may also provide an array of reference values for the slider dims, ``searchsorted`` is then used + as the transform (ex: a timestamps array). + + If ``None`` and identity mapping is used, i.e. rounds the current reference index value to the nearest + integer for array indexing. + + If a transform is not provided for a dim then the identity mapping is used. + + window_funcs: dict[ + Hashable, tuple[WindowFuncCallable | None, int | float | None] + ] + Per-slider-dim window functions applied around the current slider position. Ex: {"time": (np.mean, 2.5)}. + Each value is a ``(func, window_size)`` pair where: + + * *func* must accept ``axis: int`` and ``keepdims: bool`` kwargs + (ex: ``np.mean``, ``np.max``). The window function **must** return an array that has the same dimensions + as specified in the NDProcessor, therefore the size of any dim along which a window_func was applied + should reduce to ``1``. These dims must not be removed by the window_func. + + * *window_size* is in reference-space units (ex: 2.5 seconds). + + + window_order: tuple[Hashable, ...] + Order in which window functions are applied across dims. Only dims listed + here have their window function applied. window_funcs are ignored for any + dims not specified in ``window_order`` + + spatial_func: + A function applied to the spatial slice *after* window_funcs right before rendering. + + """ self._dims = tuple(dims) self._data = self._validate_data(data) self.spatial_dims = spatial_dims - self.index_mappings = index_mappings + self.slider_dim_transforms = slider_dim_transforms self.window_funcs = window_funcs self.window_order = window_order @@ -48,6 +123,10 @@ def __init__( @property def data(self) -> xr.DataArray: + """ + get or set managed data. If setting with new data, the new data is interpreted + to have the same dims (i.e. same dim names and ordering of dims). + """ return self._data @data.setter @@ -55,15 +134,21 @@ def data(self, data: ArrayProtocol): self._data = self._validate_data(data) def _validate_data(self, data: ArrayProtocol): + # does some basic validation if data is None: + # we allow data to be None, in this case no ndgraphic is rendered + # useful when we want to initialize an NDWidget with no traces for example + # and populate it as components/channels are selected return None if not isinstance(data, ArrayProtocol): + # This is required for xarray compatibility and general array-like requirements raise TypeError("`data` must implement the ArrayProtocol") if data.ndim != len(self.dims): raise IndexError("must specify a dim for every dimension in the data array") + # data can be set, but the dims must still match/have the same meaning return xr.DataArray(data, dims=self.dims) @property @@ -79,11 +164,15 @@ def ndim(self) -> int: @property def dims(self) -> tuple[Hashable, ...]: """dim names""" + # these are read-only and cannot be set after it's created + # the user should create a new NDGraphic if they need different dims + # I can't think of a usecase where we'd want to change the dims, and + # I think that would be complicated and probably and anti-pattern return self._dims @property def spatial_dims(self) -> tuple[Hashable, ...]: - """Spatial dims, **in order**)""" + """Spatial dims, **in order**""" return self._spatial_dims @spatial_dims.setter @@ -109,10 +198,12 @@ def tooltip_format(self, *args) -> str | None: @property def slider_dims(self) -> set[Hashable]: + """Slider dim names, ``set(dims) - set(spatial_dims)""" return set(self.dims) - set(self.spatial_dims) @property def n_slider_dims(self): + """number of slider dims, i.e. len(slider_dims)""" return len(self.slider_dims) @property @@ -195,6 +286,7 @@ def window_order(self, order: tuple[Hashable] | None): @property def spatial_func(self) -> Callable[[xr.DataArray], xr.DataArray] | None: + """get or set the spatial function which is applied on the data slice after the window functions""" return self._spatial_func @spatial_func.setter @@ -207,11 +299,12 @@ def spatial_func( self._spatial_func = func @property - def index_mappings(self) -> dict[Hashable, Callable[[Any], int]]: + def slider_dim_transforms(self) -> dict[Hashable, Callable[[Any], int]]: + """get or set the slider_dim_transforms, see docstring for details""" return self._index_mappings - @index_mappings.setter - def index_mappings( + @slider_dim_transforms.setter + def slider_dim_transforms( self, maps: dict[Hashable, Callable[[Any], int] | ArrayLike | None] | None ): if maps is None: @@ -240,20 +333,64 @@ def index_mappings( self._index_mappings = maps def _ref_index_to_array_index(self, dim: str, ref_index: Any) -> int: - # wraps index mappings, clamps between 0 and max array index for this dimension - index = self.index_mappings[dim](ref_index) + # wraps slider_dim_transforms, clamps between 0 and the array size in this dim + # ref-space -> local-array-index transform + index = self.slider_dim_transforms[dim](ref_index) + + # clamp between 0 and array size in this dim return max(min(index, self.shape[dim] - 1), 0) - def _get_slider_dims_indexer(self, indices) -> dict: + def _get_slider_dims_indexer(self, indices: dict[Hashable, Any]) -> dict[Hashable, slice]: + """ + Creates an xarray-compatible indexer dict mapping each slider_dim -> slice object. + + - If a window_func is defined for a dim and the dim appears in ``window_order``, + the slice is defined as: + start: index - half_window + stop: index + half_window + step: 1 + + It then applies the slider_dim_transform to the start and stop to map these values from reference-space to + the local array index, and then finally produces the slice object in local array indices. + + ex: if we have indices = {"time": 50.0}, a window size of 5.0s and the ``slider_dim_transform`` + for time is based on a sampling rate of 10Hz, the window in ref units is [45.0, 55.0], and the final + slice object would be ``slice(450, 550, 1)``. + + - If no window func is specified, the final slice just corresponds to that index as an int array-index. + + This exists separate from ``_apply_window_functions()`` because it is useful for debugging purposes. + + Parameters + ---------- + indices : dict[Hashable, Any], {dim: ref_value} + Reference-space values for each slider dim. Must contain an entry + for every slider dim; raises ``IndexError`` otherwise. + ex: {"time": 46.397, "depth": 23.24} + + Returns + ------- + dict[Hashable, slice] + Indexer compatible for ``xr.DataArray.isel()``, with one ``slice`` per + slider dim. These are array indices mapped from the reference space using + the given ``slider_dim_transform``. + + Raises + ------ + IndexError + If ``indices`` are not provided for every ``slider_dim`` + """ + if set(indices.keys()) != set(self.slider_dims): raise IndexError( f"Must provide an index for all slider dims: {self.slider_dims}, you have provided: {indices.keys()}" ) indexer = dict() + # get only slider dims which are not also spatial dims (example: p dim for positional data) - # since that is dealt with separately + # since `p` dim windowing is dealt with separately for positional data slider_dims = set(self.slider_dims) - set(self.spatial_dims) # go through each slider dim and accumulate slice objects for dim in slider_dims: @@ -277,8 +414,8 @@ def _get_slider_dims_indexer(self, indices) -> dict: stop_ref = index_ref + hw # map start and stop ref to array indices - start = self.index_mappings[dim](start_ref) - stop = self.index_mappings[dim](stop_ref) + start = self.slider_dim_transforms[dim](start_ref) + stop = self.slider_dim_transforms[dim](stop_ref) # clamp within array bounds start = max(min(self.shape[dim] - 1, start), 0) @@ -287,7 +424,7 @@ def _get_slider_dims_indexer(self, indices) -> dict: else: # no window func for this dim, direct indexing # index mapped to array index - index = self.index_mappings[dim](index_ref) + index = self.slider_dim_transforms[dim](index_ref) # clamp within the bounds start = max(min(self.shape[dim] - 1, index), 0) @@ -297,10 +434,29 @@ def _get_slider_dims_indexer(self, indices) -> dict: return indexer - def _apply_window_functions(self, indices) -> xr.DataArray: - """slice with windows at given indices and apply window functions""" + def _apply_window_functions(self, indices: dict[Hashable, Any]) -> xr.DataArray: + """ + Slice the data at the given indices and apply window functions in the order specified by + ``window_order``. + + Parameters + ---------- + indices : dict[Hashable, Any], {dim: ref_value} + Reference-space values for each slider dim. + ex: {"time": 46.397, "depth": 23.24} + + Returns + ------- + xr.DataArray + Data slice after windowed indexing and window function application, + with the same dims as the original data. Dims of size ``1`` are not + squeezed. + + """ indexer = self._get_slider_dims_indexer(indices) + # get the data slice w.r.t. the desired windows, and get the underlying numpy array + # ``.values`` gives the numpy array # there is significant overhead with passing xarray objects to numpy for things like np.mean() # so convert to numpy, apply window functions, then convert back to xarray # creating an xarray object from a numpy array has very little overhead, ~10 microseconds @@ -312,7 +468,11 @@ def _apply_window_functions(self, indices) -> xr.DataArray: continue func, _ = self.window_funcs[dim] - + # ``keepdims=True`` is critical, any "collapsed" dims will be of size ``1``. + # Ex: if `array` is of shape [10, 512, 512] and we applied the np.mean() window func on the first dim + # ``keepdims`` means the resultant shape is [1, 512, 512] and NOT [512, 512] + # this is necessary for applying window functions on multiple dims separately and so that the + # dims names correspond after all the window funcs are applied. array = func(array, axis=self.dims.index(dim), keepdims=True) return xr.DataArray(array, dims=self.dims) @@ -320,7 +480,17 @@ def _apply_window_functions(self, indices) -> xr.DataArray: def get(self, indices: dict[Hashable, Any]): raise NotImplementedError - def __repr__(self): + # TODO: html and pretty text repr # + # def _repr_html_(self) -> str: + # return ndp_fmt_html(self) + # + # def _repr_mimebundle_(self, **kwargs) -> dict: + # return { + # "text/plain": self._repr_text_(), + # "text/html": self._repr_html_(), + # } + + def _repr_text_(self): if self.data is None: return ( f"{self.__class__.__name__}\n" @@ -336,7 +506,7 @@ def __repr__(self): f"dims:\n\t{self.dims}\n" f"spatial_dims:\n\t{self.spatial_dims}\n" f"slider_dims:\n\t{self.slider_dims}\n" - f"index_mappings:\n{textwrap.indent(pformat(self.index_mappings, width=120), prefix=tab)}\n" + f"slider_dim_transforms:\n{textwrap.indent(pformat(self.slider_dim_transforms, width=120), prefix=tab)}\n" ) if len(wf) > 0: @@ -367,6 +537,7 @@ def _create_graphic(self): @property def name(self) -> str | None: + """name given to the NDGraphic""" return self._name @property @@ -388,6 +559,10 @@ def indices(self, new: dict[Hashable, Any]): # aliases for easier access to processor properties @property def data(self) -> Any: + """ + get or set managed data. If setting with new data, the new data is interpreted + to have the same dims (i.e. same dim names and ordering of dims). + """ return self.processor.data @data.setter @@ -395,13 +570,13 @@ def data(self, data: Any): self.processor.data = data # create a new graphic when data has changed if self.graphic is not None: - # it is already None is it was initialized with no data + # it is already None if NDGraphic was initialized with no data self._subplot.delete_graphic(self.graphic) self._graphic = None self._create_graphic() - # force a re-render + # force a render self.indices = self.indices @property @@ -427,18 +602,20 @@ def spatial_dims(self) -> tuple[str, ...]: @property def slider_dims(self) -> set[Hashable]: + """the slider dims""" return self.processor.slider_dims @property - def index_mappings(self) -> dict[Hashable, Callable[[Any], int]]: - return self.processor.index_mappings + def slider_dim_transforms(self) -> dict[Hashable, Callable[[Any], int]]: + return self.processor.slider_dim_transforms - @index_mappings.setter - def index_mappings( + @slider_dim_transforms.setter + def slider_dim_transforms( self, maps: dict[Hashable, Callable[[Any], int] | ArrayLike | None] | None ): - self.processor.index_mappings = maps - # force a re-render + """get or set the slider_dim_transforms, see docstring for details""" + self.processor.slider_dim_transforms = maps + # force a render self.indices = self.indices @property @@ -457,7 +634,7 @@ def window_funcs( ), ): self.processor.window_funcs = window_funcs - # force a re-render + # force a render self.indices = self.indices @property @@ -468,7 +645,7 @@ def window_order(self) -> tuple[Hashable, ...]: @window_order.setter def window_order(self, order: tuple[Hashable] | None): self.processor.window_order = order - # force a re-render + # force a render self.indices = self.indices @property @@ -479,33 +656,31 @@ def spatial_func(self) -> Callable[[xr.DataArray], xr.DataArray] | None: def spatial_func( self, func: Callable[[xr.DataArray], xr.DataArray] ) -> Callable | None: + """get or set the spatial_func, see docstring for details""" self.processor.spatial_func = func - # force a re-render + # force a render self.indices = self.indices - def __repr__(self): + # def _repr_text_(self) -> str: + # return ndg_fmt_text(self) + # + # def _repr_html_(self) -> str: + # return ndg_fmt_html(self) + # + # def _repr_mimebundle_(self, **kwargs) -> dict: + # return { + # "text/plain": self._repr_text_(), + # "text/html": self._repr_html_(), + # } + + def _repr_text_(self): return f"graphic: {self.graphic.__class__.__name__}\n" f"processor:\n{self.processor}" @contextmanager def block_indices(ndgraphic: NDGraphic): """ - Context manager for pausing Graphic events. - - Optionally pass in only specific event handlers which are blocked. Other events for the graphic will not be blocked. - - Examples - -------- - - .. code-block:: - - # pass in any number of graphics - with fpl.pause_events(graphic1, graphic2, graphic3): - # enter context manager - # all events are blocked from graphic1, graphic2, graphic3 - - # context manager exited, event states restored. - + Context manager for pausing an NDGraphic from updating indices """ ndgraphic._block_indices = True @@ -518,7 +693,7 @@ def block_indices(ndgraphic: NDGraphic): def block_reentrance(setter): - # decorator to block re-entrant indices setter + # decorator to block re-entrance of indices setter def set_indices_wrapper(self: NDGraphic, new_indices): """ wraps NDGraphic.indices diff --git a/fastplotlib/widgets/nd_widget/_index.py b/fastplotlib/widgets/nd_widget/_index.py index e91319893..6d7b17445 100644 --- a/fastplotlib/widgets/nd_widget/_index.py +++ b/fastplotlib/widgets/nd_widget/_index.py @@ -1,15 +1,51 @@ from __future__ import annotations from dataclasses import dataclass +from numbers import Number from typing import Sequence, Any, Callable from typing import TYPE_CHECKING + if TYPE_CHECKING: from ._ndwidget import NDWidget @dataclass class RangeContinuous: + """ + A continuous reference range for a single slider dimension. + + Stores the (start, stop, step) in scientific units (ex: seconds, micrometers, + Hz). The imgui slider for this dimension uses these values to determine its + minimum and maximum bounds. The step size is used for the "next" and "previous" buttons. + + Parameters + ---------- + start : int or float + Minimum value of the range, inclusive. + + stop : int or float + Maximum value of the range, exclusive upper bound. + + step : int or float + Step size used for imgui step next/previous buttons + + Raises + ------ + IndexError + If ``start >= stop``. + + Examples + -------- + A time axis sampled at 1 ms resolution over 10 seconds: + + RangeContinuous(start=0, stop=10_000, step=1) + + A depth axis in micrometers with 0.5 µm steps: + + RangeContinuous(start=0.0, stop=500.0, step=0.5) + """ + start: int | float stop: int | float step: int | float @@ -41,6 +77,7 @@ def range(self) -> int | float: @dataclass class RangeDiscrete: + # TODO: not implemented yet, placeholder until we have a clear usecase options: Sequence[Any] def __getitem__(self, index: int): @@ -56,21 +93,70 @@ def __len__(self): class ReferenceIndex: def __init__( self, - ref_ranges: dict[str, tuple], + ref_ranges: dict[ + str, + tuple[Number, Number, Number] | tuple[Any] | RangeContinuous, + ], ): - self._ref_ranges = dict() + """ + Manages the shared reference index for one or more ``NDWidget`` instances. - for name, r in ref_ranges.items(): - if len(r) == 3: - # assume start, stop, step - self._ref_ranges[name] = RangeContinuous(*r) + Stores the current index for each named slider dimension in reference-space + units (ex: seconds, depth in µm, Hz). Whenever an index is updated, every + ``NDGraphic`` in the manged ``NDWidgets`` are requested to render data at + the new indices. - elif len(r) == 1: - # assume just options - self._ref_ranges[name] = RangeDiscrete(*r) + Each key in ``ref_ranges`` defines a slider dimension. When adding an + ``NDGraphic``, every dimension listed in ``dims`` must be either a spatial + dimension (listed in ``spatial_dims``) or a key in ``ref_ranges``. + If a dim is not spatial, it must have a corresponding reference range, + otherwise an error will be raised. - else: - raise ValueError + You can also define conceptually identical but *independent* reference spaces + by using distinct names, ex: ``"time-1"`` and ``"time-2"`` for two recordings + that should be sycned independently. Each ``NDGraphic`` then declares the + specific "time-n" space that corresponds to its data, so the widget keeps the + two timelines decoupled. + + Parameters + ---------- + ref_ranges : dict[str, tuple], or a RangeContinuous + Mapping of dimension names to range specifications. A 3-tuple + ``(start, stop, step)`` creates a :class:`RangeContinuous`. A 1-tuple + ``(options,)`` creates a :class:`RangeDiscrete`. + + Attributes + ---------- + ref_ranges : dict[str, RangeContinuous | RangeDiscrete] + The reference range for each registered slider dimension. + + dims: set[str] + the set of "slider dims" + + Examples + -------- + Single shared time axis: + + ri = ReferenceIndex(ref_ranges={"time": (0, 1000, 1), "depth": (15, 35, 0.5)}) + ri["time"] = 500 # update one dim and re-render + ri.set({"time": 500, "depth": 10}) # update several dims atomically + + Two independent time axes for data from two different recording sessions: + + ri = ReferenceIndex({ + "time-1": (0, 3600, 1), # session 1 — 1 h at 1 s resolution + "time-s": (0, 1800, 1), # session 2 — 30 min at 1 s resolution + }) + + Each ``NDGraphic`` declares matching names for slider dims to indicate that these should be + synced across graphics. + + ndw[0, 0].add_nd_image(data_s1, ("time-s1", "row", "col"), ("row", "col")) + ndw[0, 1].add_nd_image(data_s2, ("time-s2", "row", "col"), ("row", "col")) + + """ + self._ref_ranges = dict() + self.push_dims(ref_ranges) # starting index for all dims self._indices: dict[str, int | float | Any] = { @@ -81,8 +167,17 @@ def __init__( self._ndwidgets: list[NDWidget] = list() + @property + def ref_ranges(self) -> dict[str, RangeContinuous | RangeDiscrete]: + return self._ref_ranges + + @property + def dims(self) -> set[str]: + return set(self.ref_ranges.keys()) + def _add_ndwidget_(self, ndw: NDWidget): from ._ndwidget import NDWidget + if not isinstance(ndw, NDWidget): raise TypeError @@ -111,31 +206,51 @@ def _render_indices(self): # only provide slider indices to the graphic g.indices = {d: self._indices[d] for d in g.processor.slider_dims} - @property - def ref_ranges(self) -> dict[str, RangeContinuous | RangeDiscrete]: - return self._ref_ranges - def __getitem__(self, dim): + self._check_has_dim(dim) return self._indices[dim] def __setitem__(self, dim, value): + self._check_has_dim(dim) # set index for given dim and render self._indices[dim] = self._clamp(dim, value) self._render_indices() + def _check_has_dim(self, dim): + if dim not in self.dims: + raise KeyError( + f"provided dimension: {dim} has no associated ReferenceRange in this ReferenceIndex, valid dims in this ReferenceIndex are: {self.dims}" + ) + def pop_dim(self): pass - def push_dim(self, ref_range: RangeContinuous): - # TODO: implement pushing and popping dims - pass + def push_dims(self, ref_ranges: dict[ + str, + tuple[Number, Number, Number] | tuple[Any] | RangeContinuous, + ],): + + for name, r in ref_ranges.items(): + if isinstance(r, (RangeContinuous, RangeDiscrete)): + self._ref_ranges[name] = r + + elif len(r) == 3: + # assume start, stop, step + self._ref_ranges[name] = RangeContinuous(*r) + + elif len(r) == 1: + # assume just options + self._ref_ranges[name] = RangeDiscrete(*r) + + else: + raise ValueError( + f"ref_ranges must be a mapping of dimension names to range specifications, " + f"see the docstring, you have passed: {ref_ranges}" + ) def add_event_handler(self, handler: Callable, event: str = "indices"): """ - Register an event handler. - - Currently the only event that ImageWidget supports is "indices". This event is - emitted whenever the indices of the ImageWidget changes. + Register an event handler that is called whenever the indices change. Parameters ---------- @@ -143,7 +258,7 @@ def add_event_handler(self, handler: Callable, event: str = "indices"): callback function, must take a tuple of int as the only argument. This tuple will be the `indices` event: str, "indices" - the only supported event is "indices" + the only supported valid is "indices" Example ------- @@ -152,7 +267,7 @@ def add_event_handler(self, handler: Callable, event: str = "indices"): def my_handler(indices): print(indices) - # example prints: {"t": 100, "z": 15} if the index has 2 slider dimensions "t" and "z" + # example prints: {"t": 100, "z": 15} if the index has 2 reference spaces "t" and "z" # create an NDWidget ndw = NDWidget(...) diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py index 0261a64e5..c6292b68c 100644 --- a/fastplotlib/widgets/nd_widget/_nd_image.py +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -11,12 +11,13 @@ from ...graphics import ImageGraphic, ImageVolumeGraphic from ...tools import HistogramLUTTool from ._base import NDProcessor, NDGraphic, WindowFuncCallable +from ._index import ReferenceIndex class NDImageProcessor(NDProcessor): def __init__( self, - data: ArrayLike | None, + data: ArrayProtocol | None, dims: Sequence[Hashable], spatial_dims: ( tuple[str, str] | tuple[str, str, str] @@ -26,60 +27,93 @@ def __init__( window_order: tuple[int, ...] = None, spatial_func: Callable[[ArrayLike], ArrayLike] = None, compute_histogram: bool = True, - index_mappings=None, + slider_dim_transforms=None, ): """ - An ND image that supports computing window functions, and functions over spatial dimensions. + ``NDProcessor`` subclass for n-dimensional image data. + + Produces 2-D or 3-D spatial slices for an ``ImageGraphic`` or ``ImageVolumeGraphic``. Parameters ---------- - data: ArrayLike + data: ArrayProtocol array-like data, must have 2 or more dimensions - n_display_dims: int, 2 or 3, default 2 - number of display dimensions - - rgb: bool, default False - whether the image data is RGB(A) or not - - window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable, optional - A function or a ``tuple`` of functions that are applied to a rolling window of the data. + dims: Sequence[str] + names for each dimension in ``data``. Dimensions not listed in + ``spatial_dims`` are treated as slider dimensions and **must** appear as + keys in the parent ``NDWidget``'s ``ref_ranges`` + Examples:: + ``("time", "depth", "row", "col")`` + ``("channels", "time", "xy")`` + ``("keypoints", "time", "xyz")`` + + A custom subclass's ``data`` object doesn't necessarily need to have these dims, but the ``get()`` method + must operate as if these dimensions exist and return an array that matches the spatial dimensions. + + dims: Sequence[str] + names for each dimension in ``data``. Dimensions not listed in + ``spatial_dims`` are treated as slider dimensions and **must** appear as + keys in the parent ``NDWidget``'s ``ref_ranges`` + Examples:: + ``("time", "depth", "row", "col")`` + ``("row", "col")`` + ``("other_dim", "depth", "time", "row", "col")`` + + dims in the array do not need to be in order, for example you can have a weird array where the dims are + interpreted as: ``("col", "depth", "row", "time")``, and then specify spatial_dims as ``("row", "col")`` + thanks to xarray magic =D. + + spatial_dims : tuple[str, str] | tuple[str, str, str] + The 2 or 3 spatial dimensions **in order**: ``(rows, cols)`` or ``(z, rows, cols)``. + This also determines whether an ``ImageGraphic`` or ``ImageVolumeGraphic`` is used for rendering. + The ordering determines how the Image/Volume is rendered. For example, if + you specify ``spatial_dims = ("rows", "cols")`` and then change it to ``("cols", "rows")``, it will display + the transpose. + + rgb_dim : str, optional + Name of an RGB(A) dimension, if present. - You can provide unique window functions for each dimension. If you want to apply a window function - only to a subset of the dimensions, put ``None`` to indicate no window function for a given dimension. - - A "window function" must take ``axis`` argument, which is an ``int`` that specifies the axis along which - the window function is applied. It must also take a ``keepdims`` argument which is a ``bool``. The window - function **must** return an array that has the same number of dimensions as the original ``data`` array, - therefore the size of the dimension along which the window was applied will reduce to ``1``. - - The output array-like type from a window function **must** support a ``.squeeze()`` method, but the - function itself should NOT squeeze the output array. + compute_histogram: bool, default True + Compute a histogram of the data, disable if random-access of data is not blazing-fast (ex: data that uses + video codecs), or if histograms are not useful for this data. - window_sizes: tuple[int | None, ...], optional - ``tuple`` of ``int`` that specifies the window size for each dimension. + slider_dim_transforms : dict, optional + See :class:`NDProcessor`. - window_order: tuple[int, ...] | None, optional - order in which to apply the window functions, by default just applies it from the left-most dim to the - right-most slider dim. + window_funcs : dict, optional + See :class:`NDProcessor`. - spatial_func: Callable[[ArrayLike], ArrayLike] | None, optional - A function that is applied on the _spatial_ dimensions of the data array, i.e. the last 2 or 3 dimensions. - This function is applied after the window functions (if present). + window_order : tuple, optional + See :class:`NDProcessor`. - compute_histogram: bool, default True - Compute a histogram of the data, auto re-computes if window function propties or spatial_func changes. - Disable if slow. + spatial_func : callable, optional + See :class:`NDProcessor`. + See Also + -------- + NDProcessor : Base class with full parameter documentation. + NDImage : The ``NDGraphic`` that wraps this processor. """ + # set as False until data, window funcs stuff and spatial func is all set self._compute_histogram = False + # make sure rgb dim is size 3 or 4 + if rgb_dim is not None: + dim_index = dims.index(rgb_dim) + if data.shape[dim_index] not in (3, 4): + raise IndexError( + f"The size of the RGB(A) dim must be 3 | 4. You have specified an array of shape: {data.shape}, " + f"with dims: {dims}, and specified the ``rgb_dim`` name as: {rgb_dim} which has size " + f"{data.shape[dim_index]} != 3 | 4" + ) + super().__init__( data=data, dims=dims, spatial_dims=spatial_dims, - index_mappings=index_mappings, + slider_dim_transforms=slider_dim_transforms, window_funcs=window_funcs, window_order=window_order, spatial_func=spatial_func, @@ -91,23 +125,27 @@ def __init__( @property def data(self) -> xr.DataArray | None: - """get or set the data array""" + """ + get or set managed data. If setting with new data, the new data is interpreted + to have the same dims (i.e. same dim names and ordering of dims). + """ return self._data @data.setter - def data(self, data: ArrayLike): - # check that all array-like attributes are present + def data(self, data: ArrayProtocol): self._data = self._validate_data(data) self._recompute_histogram() def _validate_data(self, data: ArrayProtocol): if not isinstance(data, ArrayProtocol): + # check that it's compatible with array and generally array-like raise TypeError( f"`data` arrays must have all of the following attributes to be sufficiently array-like:\n" f"{ARRAY_LIKE_ATTRS}, or they must be `None`" ) if data.ndim < 2: + # ndim < 2 makes no sense for image data raise IndexError( f"Image data must have a minimum of 2 dimensions, you have passed an array of shape: {data.shape}" ) @@ -116,7 +154,9 @@ def _validate_data(self, data: ArrayProtocol): @property def rgb_dim(self) -> str | None: - """indicates the rgb dim if one exists""" + """ + get or set the RGB(A) dim name, ``None`` if no RGB(A) dim exists + """ return self._rgb @rgb_dim.setter @@ -129,6 +169,7 @@ def rgb_dim(self, rgb: str | None): @property def compute_histogram(self) -> bool: + """get or set whether or not to compute the histogram""" return self._compute_histogram @compute_histogram.setter @@ -213,7 +254,7 @@ def _recompute_histogram(self): if isinstance(sub, xr.DataArray): # can't do the isnan and isinf boolean indexing below on xarray sub = sub.values - + sub_real = sub[~(np.isnan(sub) | np.isinf(sub))] self._histogram = np.histogram(sub_real, bins=100) @@ -222,10 +263,10 @@ def _recompute_histogram(self): class NDImage(NDGraphic): def __init__( self, - ref_index, + ref_index: ReferenceIndex, subplot: Subplot, - data: ArrayLike | None, - dims: Sequence[Hashable], + data: ArrayProtocol | None, + dims: Sequence[str], spatial_dims: ( tuple[str, str] | tuple[str, str, str] ), # must be in order! [rows, cols] | [z, rows, cols] @@ -234,9 +275,77 @@ def __init__( window_order: tuple[int, ...] = None, spatial_func: Callable[[ArrayLike], ArrayLike] = None, compute_histogram: bool = True, - index_mappings=None, + slider_dim_transforms=None, name: str = None, ): + """ + ``NDGraphic`` subclass for n-dimensional image rendering. + + Wraps an :class:`NDImageProcessor` and manages either an ``ImageGraphic`` or``ImageVolumeGraphic``. + swaps automatically when :attr:`spatial_dims` is reassigned at runtime. Also + owns a ``HistogramLUTTool`` for interactive vmin, vmax adjustment. + + Every dimension that is *not* listed in ``spatial_dims`` becomes a slider + dimension. Each slider dim must have a ``ReferenceRange`` defined in the + ``ReferenceIndex`` of the parent ``NDWidget``. The widget uses this to direct + a change in the ``ReferenceIndex`` and update the graphics. + + Parameters + ---------- + ref_index : ReferenceIndex + The shared reference index that delivers slider updates to this graphic. + + subplot : Subplot + parent subplot the NDGraphic is in + + data : array-like or None + n-dimension image data array + + dims : sequence of hashable + Name for every dimension of ``data``, in order. Non-spatial dims must + match keys in ``ref_index``. + + ex: ``("time", "depth", "row", "col")`` — ``"time"`` and ``"depth"`` must + be present in ``ref_index``. + + spatial_dims : tuple[str, str] | tuple[str, str, str] + Spatial dimensions **in order**: ``(rows, cols)`` for 2-D images or + ``(z, rows, cols)`` for volumes. Controls whether an ``ImageGraphic`` or + ``ImageVolumeGraphic`` is used. + + rgb_dim : str, optional + Name of the RGB or channel dimension, if present. + + window_funcs : dict, optional + See :class:`NDProcessor`. + + window_order : tuple, optional + See :class:`NDProcessor`. + + spatial_func : callable, optional + See :class:`NDProcessor`. + + compute_histogram : bool, default ``True`` + Whether to initialize the ``HistogramLUTTool``. + + slider_dim_transforms : dict, optional + See :class:`NDProcessor`. + + name : str, optional + Name for the underlying graphic. + + See Also + -------- + NDImageProcessor : The processor that backs this graphic. + + """ + + if not (set(dims) - set(spatial_dims)).issubset(ref_index.dims): + raise IndexError( + f"all specified `dims` must either be a spatial dim or a slider dim " + f"specified in the NDWidget ref_ranges, provided dims: {dims}, " + f"spatial_dims: {spatial_dims}. Specified NDWidget ref_ranges: {ref_index.dims}" + ) super().__init__(subplot, name) @@ -251,47 +360,65 @@ def __init__( window_order=window_order, spatial_func=spatial_func, compute_histogram=compute_histogram, - index_mappings=index_mappings, + slider_dim_transforms=slider_dim_transforms, ) self._graphic: ImageGraphic | None = None self._histogram_widget: HistogramLUTTool | None = None + # create a graphic self._create_graphic() @property def processor(self) -> NDImageProcessor: + """NDProcessor that manages the data and produces data slices to display""" return self._processor @property def graphic( self, ) -> ImageGraphic | ImageVolumeGraphic: - """LineStack or ImageGraphic for heatmaps""" + """Underlying Graphic object used to display the current data slice""" return self._graphic def _create_graphic(self): + # Creates an ``ImageGraphic`` or ``ImageVolumeGraphic`` based on the number of spatial dims, + # adds it to the subplot, and resets the camera and histogram. + + if self.processor.data is None: + # no graphic if data is None, useful for initializing in null states when we want to set data later + return + + # determine if we need a 2d image or 3d volume + # remove RGB spatial dim, ex: if we have an RGBA image of shape [512, 512, 4] we want to interpet this as + # 2D for images + # [30, 512, 512, 4] with an rgb dim is an RGBA volume which is also supported match len(self.processor.spatial_dims) - int(bool(self.processor.rgb_dim)): case 2: cls = ImageGraphic case 3: cls = ImageVolumeGraphic + # get the data slice for this index + # this will only have the dims specified by ``spatial_dims`` data_slice = self.processor.get(self.indices) - old_graphic = self._graphic + # create the new graphic new_graphic = cls(data_slice) + old_graphic = self._graphic + # check if we are replacing a graphic + # ex: swapping from 2D <-> 3D representation after ``spatial_dims`` was changed if old_graphic is not None: # carry over some attributes from old graphic attrs = dict.fromkeys(["cmap", "interpolation", "cmap_interpolation"]) for k in attrs: attrs[k] = getattr(old_graphic, k) + # delete the old graphic self._subplot.delete_graphic(old_graphic) - self._subplot.add_graphic(new_graphic) - # set cmap and interpolation + # set any attributes that we're carrying over like cmap for attr, val in attrs.items(): setattr(new_graphic, attr, val) @@ -332,7 +459,7 @@ def _reset_histogram(self): self.graphic.reset_vmin_vmax() def _reset_camera(self): - # set camera to a nice position for 2D or 3D + # set camera to a nice position based on whether it's a 2D ImageGraphic or 3D ImageVolumeGraphic if isinstance(self._graphic, ImageGraphic): # set camera orthogonal to the xy plane, flip y axis self._subplot.camera.set_state( @@ -341,7 +468,7 @@ def _reset_camera(self): "rotation": [0, 0, 0, 1], "scale": [1, -1, 1], "reference_up": [0, 1, 0], - "fov": 0, + "fov": 0, # orthographic projection "depth_range": None, } ) @@ -351,11 +478,12 @@ def _reset_camera(self): self._subplot.auto_scale() else: + # It's not an ImageGraphic, set perspective projection self._subplot.camera.fov = 50 self._subplot.controller = "orbit" - # make sure all 3D dimension camera scales are positive - # MIP rendering doesn't work with negative camera scales + # set all 3D dimension camera scales to positive since positive scales + # are typically used for looking at volumes for dim in ["x", "y", "z"]: if getattr(self._subplot.camera.local, f"scale_{dim}") < 0: setattr(self._subplot.camera.local, f"scale_{dim}", 1) @@ -364,6 +492,7 @@ def _reset_camera(self): @property def spatial_dims(self) -> tuple[str, str] | tuple[str, str, str]: + """get or set the spatial dims, see docstring for details""" return self.processor.spatial_dims @spatial_dims.setter @@ -375,6 +504,7 @@ def spatial_dims(self, dims: tuple[str, str] | tuple[str, str, str]): @property def indices(self) -> dict[Hashable, Any]: + """get or set the indices, managed by the ReferenceIndex, users usually don't want to set this manually""" return {d: self._ref_index[d] for d in self.processor.slider_dims} @indices.setter @@ -385,6 +515,7 @@ def indices(self, indices): @property def compute_histogram(self) -> bool: + """whether or not to compute the histogram and display the HistogramLUTTool""" return self.processor.compute_histogram @compute_histogram.setter @@ -394,6 +525,7 @@ def compute_histogram(self, v: bool): @property def spatial_func(self) -> Callable[[xr.DataArray], xr.DataArray] | None: + """get or set the spatial_func, see docstring for details""" return self.processor.spatial_func @spatial_func.setter @@ -405,6 +537,7 @@ def spatial_func( self._reset_histogram() def _tooltip_handler(self, graphic, pick_info): + # TODO: need to do this better # get graphic within the collection n_index = np.argwhere(self.graphic.graphics == graphic).item() p_index = pick_info["vertex_index"] diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py index b0eca548d..6cb69a83d 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py @@ -62,7 +62,7 @@ def __init__( spatial_dims: tuple[ Hashable | None, Hashable, Hashable ], # [stack_dim, n_datapoints, spatial_dim], IN ORDER!! - index_mappings: dict[str, Callable[[Any], int] | ArrayLike] = None, + slider_dim_transforms: dict[str, Callable[[Any], int] | ArrayLike] = None, display_window: int | float | None = 100, # window for n_datapoints dim only max_display_datapoints: int = 1_000, datapoints_window_func: tuple[Callable, str, int | float] | None = None, @@ -73,13 +73,21 @@ def __init__( **kwargs, ): """ + ``NDProcessor`` subclass for n-dimensional positional and timeseries data. + + + The *datapoints* dimension is + simultaneously a slider dim and a spatial dim and is handled by a dedicated + :attr:`datapoints_window_func` rather than the general ``window_funcs`` + mechanism. + Parameters ---------- data dims spatial_dims - index_mappings + slider_dim_transforms display_window max_display_datapoints: int, default 1_000 this is approximate since floor division is used to determine the step size of the current display window slice @@ -94,7 +102,7 @@ def __init__( data=data, dims=dims, spatial_dims=spatial_dims, - index_mappings=index_mappings, + slider_dim_transforms=slider_dim_transforms, **kwargs, ) @@ -407,7 +415,7 @@ def _apply_dw_window_func( # display window in array index space if self.display_window is not None: - dw = self.index_mappings[p_dim](self.display_window) + dw = self.slider_dim_transforms[p_dim](self.display_window) # step size based on max number of datapoints to render step = max(1, dw // self.max_display_datapoints) @@ -576,7 +584,7 @@ def __init__( processor: type[NDPositionsProcessor] = NDPositionsProcessor, display_window: int = 10, window_funcs: tuple[WindowFuncCallable | None] | None = None, - index_mappings: tuple[Callable[[Any], int] | None] | None = None, + slider_dim_transforms: tuple[Callable[[Any], int] | None] | None = None, max_display_datapoints: int = 1_000, linear_selector: bool = False, x_range_mode: Literal["fixed", "auto"] | None = None, @@ -594,9 +602,47 @@ def __init__( sizes_each: Sequence[float] = None, # for each individual scatter, shape [l, p] thickness: np.ndarray = None, # for each line, shape [l,] name: str = None, + timeseries: bool = False, graphic_kwargs: dict = None, processor_kwargs: dict = None, ): + """ + Wraps an :class:`NDPositionsProcessor` and supports four interchangeable + graphical representations: ``LineStack``, ``LineCollection``, ``ScatterStack``, + and ``ScatterCollection``, as well as a heatmap view. For timeseries use-cases + it also manages a linear selector and automatically adjusts the view according + to the current x-range of the displayed data. + + Parameters + ---------- + ref_index + subplot + data + dims + spatial_dims + args + graphic_type + processor + display_window + window_funcs + slider_dim_transforms + max_display_datapoints + linear_selector + x_range_mode + colors + cmap + cmap_each + cmap_transform_each + markers + markers_each + sizes + sizes_each + thickness + name + graphic_kwargs + processor_kwargs + """ + super().__init__(subplot, name) self._ref_index = ref_index @@ -617,7 +663,7 @@ def __init__( display_window=display_window, max_display_datapoints=max_display_datapoints, window_funcs=window_funcs, - index_mappings=index_mappings, + slider_dim_transforms=slider_dim_transforms, colors=colors, markers=markers_each, cmap_transform_each=cmap_transform_each, @@ -640,13 +686,26 @@ def __init__( self.x_range_mode = x_range_mode self._last_x_range = np.array([0.0, 0.0], dtype=np.float32) - if linear_selector: - self._linear_selector = LinearSelector( - 0, limits=(-np.inf, np.inf), edge_color="cyan" - ) - self._linear_selector.add_event_handler( - self._linear_selector_handler, "selection" - ) + self._timeseries = timeseries + # TODO: I think this is messy af, NDTimeseriesSubclass??? + if self._timeseries: + # makes some assumptions about positional data that apply only to timeseries representations + # probably don't want to maintain aspect + self._subplot.camera.maintain_aspect = False + + # auto x range modes make no sense for non-timeseries data + self.x_range_mode = x_range_mode + + if linear_selector: + self._linear_selector = LinearSelector( + 0, limits=(-np.inf, np.inf), edge_color="cyan" + ) + self._linear_selector.add_event_handler( + self._linear_selector_handler, "selection" + ) + self._subplot.add_graphic(self._linear_selector) + else: + self._linear_selector = None else: self._linear_selector = None @@ -762,6 +821,7 @@ def indices(self, indices): self.graphic.offset = (x0, *self.graphic.offset[1:]) self.graphic.scale = (x_scale, *self.graphic.scale[1:]) + # TODO: I think this is messy af, NDTimeseriesSubclass??? # x range of the data xr = data_slice[0, 0, 0], data_slice[0, -1, 0] if self.x_range_mode is not None: diff --git a/fastplotlib/widgets/nd_widget/_ndw_subplot.py b/fastplotlib/widgets/nd_widget/_ndw_subplot.py index 0f53951bd..6666b3fc1 100644 --- a/fastplotlib/widgets/nd_widget/_ndw_subplot.py +++ b/fastplotlib/widgets/nd_widget/_ndw_subplot.py @@ -1,13 +1,25 @@ -from typing import Literal +from collections.abc import Callable +from typing import Literal, Sequence, Hashable + import numpy as np from ... import ScatterCollection, ScatterStack, LineCollection, LineStack, ImageGraphic from ...layouts import Subplot +from ...utils import ArrayProtocol from . import NDImage, NDPositions -from ._base import NDGraphic +from ._base import NDGraphic, WindowFuncCallable class NDWSubplot: + """ + Entry point for adding ``NDGraphic`` objects to a subplot of an ``NDWidget``. + + Accessed via ``ndw[row, col]`` or ``ndw["subplot_name"]``. + Each ``add_nd_<...>`` method constructs the appropriate ``NDGraphic``, registers it with the parent + ``ReferenceIndex``, appends it to this subplot and returns the ``NDGraphic`` instance to the user. + + Note: ``NDWSubplot`` is not meant to be constructed directly, it only exists as part of an ``NDWidget`` + """ def __init__(self, ndw, subplot: Subplot): self.ndw = ndw self._subplot = subplot @@ -16,9 +28,11 @@ def __init__(self, ndw, subplot: Subplot): @property def nd_graphics(self) -> tuple[NDGraphic]: + """all the NDGraphic instance in this subplot""" return tuple(self._nd_graphics) def __getitem__(self, key): + # get a specific NDGraphic by index or name if isinstance(key, (int, np.integer)): return self.nd_graphics[key] @@ -29,23 +43,55 @@ def __getitem__(self, key): else: raise KeyError(f"NDGraphc with given key not found: {key}") - def add_nd_image(self, *args, **kwargs): - nd = NDImage(self.ndw.indices, self._subplot, *args, **kwargs) - self._nd_graphics.append(nd) + def add_nd_image( + self, + data: ArrayProtocol | None, + dims: Sequence[Hashable], + spatial_dims: ( + tuple[str, str] | tuple[str, str, str] + ), # must be in order! [rows, cols] | [z, rows, cols] + rgb_dim: str | None = None, + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, + window_order: tuple[int, ...] = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] = None, + compute_histogram: bool = True, + slider_dim_transforms=None, + name: str = None, + ): + nd = NDImage(self.ndw.indices, self._subplot, data=data, + dims=dims, + spatial_dims=spatial_dims, + rgb_dim=rgb_dim, + window_funcs=window_funcs, + window_order=window_order, + spatial_func=spatial_func, + compute_histogram=compute_histogram, + slider_dim_transforms=slider_dim_transforms, + name=name, + ) + self._nd_graphics.append(nd) return nd def add_nd_scatter(self, *args, **kwargs): # TODO: better func signature here, send all kwargs to processor_kwargs - nd = NDPositions(self.ndw.indices, self._subplot, *args, graphic_type=ScatterCollection, **kwargs) - self._nd_graphics.append(nd) + nd = NDPositions( + self.ndw.indices, + self._subplot, + *args, + graphic_type=ScatterCollection, + **kwargs, + ) + self._nd_graphics.append(nd) return nd def add_nd_timeseries( self, *args, - graphic_type: type[LineCollection | LineStack | ScatterStack | ImageGraphic] = LineStack, + graphic_type: type[ + LineCollection | LineStack | ScatterStack | ImageGraphic + ] = LineStack, x_range_mode: Literal["fixed", "auto"] | None = "auto", **kwargs, ): @@ -56,20 +102,21 @@ def add_nd_timeseries( graphic_type=graphic_type, linear_selector=True, x_range_mode=x_range_mode, + timeseries=True, **kwargs, ) - self._nd_graphics.append(nd) - self._subplot.add_graphic(nd._linear_selector) - - # need plot_area to exist before these this can be called - nd.x_range_mode = x_range_mode - - # probably don't want to maintain aspect - self._subplot.camera.maintain_aspect = False + self._nd_graphics.append(nd) return nd def add_nd_lines(self, *args, **kwargs): - nd = NDPositions(self.ndw.indices, self._subplot, *args, graphic_type=LineCollection, **kwargs) + nd = NDPositions( + self.ndw.indices, + self._subplot, + *args, + graphic_type=LineCollection, + **kwargs, + ) + self._nd_graphics.append(nd) return nd diff --git a/fastplotlib/widgets/nd_widget/_repr_formatter.py b/fastplotlib/widgets/nd_widget/_repr_formatter.py new file mode 100644 index 000000000..0569f1004 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_repr_formatter.py @@ -0,0 +1,599 @@ +from __future__ import annotations + +import html +from collections.abc import Callable +from typing import Any + + +_RESET = "\033[0m" +_BOLD = "\033[1m" +_DIM = "\033[2m" + +_C = { + "title": "\033[38;5;75m", # sky-blue + "spatial": "\033[38;5;114m", # sage-green + "slider": "\033[38;5;215m", # soft-orange + "label": "\033[38;5;246m", # mid-grey + "value": "\033[38;5;252m", # near-white + "section": "\033[38;5;68m", # steel-blue + "muted": "\033[38;5;240m", # dark-grey + "warn": "\033[38;5;222m", # amber +} + + +def _c(key: str, text: str) -> str: + return f"{_C[key]}{text}{_RESET}" + + +def _callable_name(f: Callable | None) -> str: + if f is None: + return "—" + module = getattr(f, "__module__", "") or "" + qname = getattr(f, "__qualname__", None) or getattr(f, "__name__", repr(f)) + if module and not module.startswith("__"): + short = module.split(".")[-1] + return f"{short}.{qname}" + return qname + + +def ndprocessor_fmt_txt(processor) -> str: + """ + Returns a colored, ascii box + """ + lines: list[str] = [] + + cls = type(processor).__name__ + lines.append(_c("title", _BOLD + cls) + _RESET) + lines.append(_c("muted", "─" * 72)) + + lines.append(_c("section", " Dimensions")) + + header = ( + f" {'dim':<14}{'size':>6} {'role':<10} {'window_func size':<26} index_mapping" + ) + lines.append(_c("label", header)) + lines.append(_c("muted", " " + "─" * 70)) + + for dim in processor.dims: + size = processor.shape[dim] + is_sp = dim in processor.spatial_dims + role_s = (_c("spatial", f"{'spatial':<10}") if is_sp + else _c("slider", f"{'slider':<10}")) + + # window_func - size column + if not is_sp: + wf, ws = processor.window_funcs.get(dim, (None, None)) + if wf is not None and ws is not None: + win_s = _c("value", f"{_callable_name(wf)}") + _c("muted", f" - {ws}") + else: + win_s = _c("muted", "—") + else: + win_s = "" + + # index_mapping column (slider dims only; skip identity) + if not is_sp: + imap = processor.index_mappings.get(dim) + iname = getattr(imap, "__name__", "") if imap is not None else "" + if iname != "identity" and imap is not None: + idx_s = _c("value", _callable_name(imap)) + else: + idx_s = _c("muted", "—") + else: + idx_s = "" + + # pad win_s to fixed visible width (strip ANSI for measuring) + import re + _ansi_re = re.compile(r"\033\[[^m]*m") + win_visible = len(_ansi_re.sub("", win_s)) + win_pad = win_s + " " * max(0, 26 - win_visible) + + line = ( + f" {_c('value', f'{str(dim):<14}')}" + f"{_c('label', f'{size:>6}')} " + f"{role_s} {win_pad} {idx_s}" + ) + lines.append(line) + + # window order + if processor.window_order: + lines.append("") + order_s = " → ".join(str(d) for d in processor.window_order) + lines.append(f" {_c('section', 'Window order')} {_c('value', order_s)}") + + # spatial func + if processor.spatial_func is not None: + lines.append("") + lines.append( + f" {_c('section', 'Spatial func')} " + f"{_c('value', _callable_name(processor.spatial_func))}" + ) + + lines.append(_c("muted", "─" * 72)) + return "\n".join(lines) + + +def ndgraphic_fmt_txt(ndg) -> str: + """Text repr for NDGraphic.""" + cls = type(ndg).__name__ + gcls = type(ndg.graphic).__name__ if ndg.graphic is not None else "—" + name = ndg.name or "—" + + header = ( + f"{_c('title', _BOLD + cls)}{_RESET} " + f"{_c('muted', '·')} " + f"{_c('section', 'graphic')} {_c('value', gcls)} " + f"{_c('muted', '·')} " + f"{_c('section', 'name')} {_c('value', name)}\n" + ) + + proc_block = ndprocessor_fmt_txt(ndg.processor) + # indent processor block + indented = "\n".join(" " + l for l in proc_block.splitlines()) + return header + indented + +_CSS = """ + +""" + + +def _h(s: Any) -> str: + """html-escape a stringified value""" + return html.escape(str(s)) + + +def _badge(role: str) -> str: + cls = "fpl-badge-spatial" if role == "spatial" else "fpl-badge-slider" + return f'{role}' + + +def _code(s: str) -> str: + return f"{_h(s)}" + + +def _section(title: str, content_html: str, count: str = "", open_: bool = True) -> str: + open_attr = " open" if open_ else "" + count_badge = ( + f'{_h(count)}' if count else "" + ) + return ( + f'
' + f'' + f'{_h(title)}' + f'{count_badge}' + f'' + f'{content_html}' + f'
' + ) + + +def _dim_rows_html(proc) -> str: + rows = [] + for dim in proc.dims: + size = proc.shape[dim] + is_sp = dim in proc.spatial_dims + badge = _badge("spatial" if is_sp else "slider") + + # window_func - size column + if not is_sp: + wf, ws = proc.window_funcs.get(dim, (None, None)) + if wf is not None and ws is not None: + win_td = ( + f'' + f'{_code(_callable_name(wf))}' + f'-' + f'{_code(str(ws))}' + f'' + ) + else: + win_td = '—' + else: + win_td = '' + + # index_mapping column (slider dims only; hide identity) + if not is_sp: + imap = proc.index_mappings.get(dim) + if imap is not None: + idx_td = f'{_code(_callable_name(imap))}' + else: + idx_td = '—' + else: + idx_td = '' + + rows.append( + f'' + f'{_h(str(dim))}' + f'{size:,}' + f'{badge}' + f'{win_td}' + f'{idx_td}' + f'' + ) + + # column header row + header = ( + f'' + f'dim' + f'size' + f'role' + f'window_func - size' + f'index_mapping' + f'' + ) + + table = ( + '' + '' + '' + '' + '' + + header + + "".join(rows) + + "
" + ) + return table + + +def _footer_kv(pairs: list[tuple[str, str]]) -> str: + """Always-visible key/value rows rendered below the dim table.""" + inner = "" + for k, v in pairs: + inner += ( + f'' + f'' + ) + return f'' + + +def _kv_list_html(pairs: list[tuple[str, str]]) -> str: + inner = "" + for k, v in pairs: + inner += ( + f'
{_h(k)}
' + f'
{v}
' + ) + return f'
{inner}
' + + +def _html_processor(proc) -> str: + cls = _h(type(proc).__name__) + + # header + ndim_pill = ( + f'' + f'{proc.ndim}D' + ) + header = ( + f'
' + f'{cls}' + f'{ndim_pill}' + f'
' + ) + + # dims section (always open) + dim_content = _dim_rows_html(proc) + sections = _section("Dimensions", dim_content, + count=str(proc.ndim), open_=True) + + # always-visible footer rows + footer_pairs: list[tuple[str, str]] = [] + + if proc.window_order: + chain = " → ".join( + f'{_h(str(d))}' + if i > 0 else _h(str(d)) + for i, d in enumerate(proc.window_order) + ) + footer_pairs.append(("window order", f'{chain}')) + + if proc.spatial_func is not None: + footer_pairs.append(("spatial func", _code(_callable_name(proc.spatial_func)))) + + if footer_pairs: + sections += _footer_kv(footer_pairs) + + body = f'
{sections}
' + return f'{_CSS}
{header}{body}
' + + +def ndgraphic_fmt_html(ndg) -> str: + cls = _h(type(ndg).__name__) + gcls = _h(type(ndg.graphic).__name__) if ndg.graphic is not None else "—" + name = _h(ndg.name or "—") + + graphic_pill = f'graphic: {gcls}' + name_pill = f'name: {name}' + + header = ( + f'
' + f'{cls}' + f'·' + f'{graphic_pill}{name_pill}' + f'
' + ) + + # embed processor repr (without its own outer box) inside a section + proc_inner = _dim_rows_html(ndg.processor) + sections = _section("Processor · Dimensions", proc_inner, open_=True) + + footer_pairs: list[tuple[str, str]] = [] + + if ndg.processor.window_order: + chain = " → ".join( + f'{_h(str(d))}' + if i > 0 else _h(str(d)) + for i, d in enumerate(ndg.processor.window_order) + ) + footer_pairs.append(("window order", f'{chain}')) + + if ndg.processor.spatial_func is not None: + footer_pairs.append(("spatial func", _code(_callable_name(ndg.processor.spatial_func)))) + + if footer_pairs: + sections += _footer_kv(footer_pairs) + + body = f'
{sections}
' + return f'{_CSS}
{header}{body}
' + +class ReprMixin: + """ + Mixin that provides: + • __repr__ → coloured ANSI text (terminal / plain REPL) + • _repr_html_ → rich HTML (Jupyter) + • _repr_mimebundle_ → both, so Jupyter picks the richest format + + Subclasses must implement _repr_text_() and _repr_html_() themselves OR + rely on the dispatch below which checks the concrete type. + """ + + def _repr_text_(self) -> str: + # lazy import avoids circular; swap for a direct call in your module + if _is_ndgraphic(self): + return ndgraphic_fmt_txt(self) + return ndprocessor_fmt_txt(self) + + def _repr_html_(self) -> str: + return ndgraphic_fmt_html(self) + return _html_processor(self) + + def __repr__(self) -> str: + return self._repr_text_() + + def _repr_mimebundle_(self, **kwargs) -> dict: + return { + "text/plain": self._repr_text_(), + "text/html": self._repr_html_(), + } + + +def _is_ndgraphic(obj) -> bool: + """duck-type check: does this object have a .graphic and .processor?""" + return hasattr(obj, "graphic") and hasattr(obj, "processor") \ No newline at end of file From c354c674e210e3b340df6a5c4f78f0dfd06c08e7 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Thu, 19 Mar 2026 04:52:09 -0400 Subject: [PATCH 094/101] remove unused attr, comments --- fastplotlib/widgets/nd_widget/_base.py | 12 ++++++++++-- .../widgets/nd_widget/_nd_positions/_nd_positions.py | 8 +++----- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_base.py b/fastplotlib/widgets/nd_widget/_base.py index 2fa60a5ed..bfbabb2f5 100644 --- a/fastplotlib/widgets/nd_widget/_base.py +++ b/fastplotlib/widgets/nd_widget/_base.py @@ -529,9 +529,17 @@ def __init__( ): self._subplot = subplot self._name = name - self._block_indices = False self._graphic: Graphic | None = None + # used to indicate that the NDGraphic should ignore any requests to update the indices + # used by block_indices_ctx context manager, usecase is when the LinearSelector on timeseries + # NDGraphic changes the selection, it shouldn't change the graphic that it is on top of! Would + # also cause recursion + # It is also used by the @block_reentrance decorator which is on the ``NDGraphic.indices`` property setter + # this is also to block recursion + self._block_indices = False + + def _create_graphic(self): raise NotImplementedError @@ -678,7 +686,7 @@ def _repr_text_(self): @contextmanager -def block_indices(ndgraphic: NDGraphic): +def block_indices_ctx(ndgraphic: NDGraphic): """ Context manager for pausing an NDGraphic from updating indices """ diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py index 6cb69a83d..f1699d1a4 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py @@ -27,7 +27,7 @@ NDGraphic, WindowFuncCallable, block_reentrance, - block_indices, + block_indices_ctx, ) from .._index import ReferenceIndex @@ -709,8 +709,6 @@ def __init__( else: self._linear_selector = None - self._pause = False - @property def processor(self) -> NDPositionsProcessor: return self._processor @@ -832,7 +830,7 @@ def indices(self, indices): self._last_x_range[:] = self.graphic._plot_area.x_range if self._linear_selector is not None: - with pause_events(self._linear_selector): + with pause_events(self._linear_selector): # we don't want the linear selector change to update the indices self._linear_selector.limits = xr # linear selector acts on `p` dim self._linear_selector.selection = indices[ @@ -840,7 +838,7 @@ def indices(self, indices): ] def _linear_selector_handler(self, ev): - with block_indices(self): + with block_indices_ctx(self): # linear selector always acts on the `p` dim self._ref_index[self.processor.spatial_dims[1]] = ev.info["value"] From 4e8b8f5d9750779c21fd5b9df5fccd3dc11230b3 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Thu, 19 Mar 2026 18:04:27 -0400 Subject: [PATCH 095/101] remove print --- fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py index f1699d1a4..2d3ff2b9a 100644 --- a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py @@ -513,7 +513,6 @@ def _get_other_features( if val_sliced.shape[0] == 1: # broadcast across all graphical elements n_graphics = self.shape[self.spatial_dims[0]] - print(val_sliced.shape, n_graphics) val_sliced = np.broadcast_to( val_sliced, shape=(n_graphics, *val_sliced.shape[1:]) ) From 0b1bc0bc31e7174f4c084ff0d4989c5968e6c58b Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Fri, 20 Mar 2026 03:26:17 -0400 Subject: [PATCH 096/101] add NDGraphic.pause, expose histogram widget --- fastplotlib/widgets/nd_widget/_base.py | 12 ++++++++++++ fastplotlib/widgets/nd_widget/_index.py | 2 +- fastplotlib/widgets/nd_widget/_nd_image.py | 5 +++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/fastplotlib/widgets/nd_widget/_base.py b/fastplotlib/widgets/nd_widget/_base.py index bfbabb2f5..04d8fc745 100644 --- a/fastplotlib/widgets/nd_widget/_base.py +++ b/fastplotlib/widgets/nd_widget/_base.py @@ -539,10 +539,22 @@ def __init__( # this is also to block recursion self._block_indices = False + # user settable bool to make the graphic unresponsive to change in the ReferenceIndex + self._pause = False + def _create_graphic(self): raise NotImplementedError + @property + def pause(self) -> bool: + """if True, changes in the reference until it is set back to False""" + return self._pause + + @pause.setter + def pause(self, val: bool): + self._pause = bool(val) + @property def name(self) -> str | None: """name given to the NDGraphic""" diff --git a/fastplotlib/widgets/nd_widget/_index.py b/fastplotlib/widgets/nd_widget/_index.py index 6d7b17445..9d45c844c 100644 --- a/fastplotlib/widgets/nd_widget/_index.py +++ b/fastplotlib/widgets/nd_widget/_index.py @@ -201,7 +201,7 @@ def _clamp(self, dim, value): def _render_indices(self): for ndw in self._ndwidgets: for g in ndw.ndgraphics: - if g.data is None: + if g.data is None or g.pause: continue # only provide slider indices to the graphic g.indices = {d: self._indices[d] for d in g.processor.slider_dims} diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py index c6292b68c..be319942d 100644 --- a/fastplotlib/widgets/nd_widget/_nd_image.py +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -523,6 +523,11 @@ def compute_histogram(self, v: bool): self.processor.compute_histogram = v self._reset_histogram() + @property + def histogram_widget(self) -> HistogramLUTTool: + """The histogram lut tool associated with this NDGraphic""" + return self._histogram_widget + @property def spatial_func(self) -> Callable[[xr.DataArray], xr.DataArray] | None: """get or set the spatial_func, see docstring for details""" From 3e755f67a8b54e029c2e620f8cd138ecd33b8b1b Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Fri, 20 Mar 2026 03:26:35 -0400 Subject: [PATCH 097/101] ndg pause in imgui --- fastplotlib/widgets/nd_widget/_ui.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fastplotlib/widgets/nd_widget/_ui.py b/fastplotlib/widgets/nd_widget/_ui.py index e5ba7daf8..2855c7063 100644 --- a/fastplotlib/widgets/nd_widget/_ui.py +++ b/fastplotlib/widgets/nd_widget/_ui.py @@ -211,6 +211,8 @@ def update(self): elif isinstance(ndg, NDImage): self._draw_nd_image_ui(subplot, ndg) + _, ndg.pause = imgui.checkbox("pause", ndg.pause) + if not open: self._ndgraphic_windows.remove(ndg) From 894de5524f1c8844f21058ed3b5422d164fc3173 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Fri, 20 Mar 2026 03:27:03 -0400 Subject: [PATCH 098/101] add helper function to convert heatmap timeseries to postional data shape --- fastplotlib/utils/functions.py | 35 ++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/fastplotlib/utils/functions.py b/fastplotlib/utils/functions.py index a839ed9d0..34062824a 100644 --- a/fastplotlib/utils/functions.py +++ b/fastplotlib/utils/functions.py @@ -477,3 +477,38 @@ def subsample_array( slices = tuple(slices) return np.asarray(arr[slices]) + + +def heatmap_to_positions(heatmap: np.ndarray, xvals: np.ndarray) -> np.ndarray: + """ + + Convert a heatmap of shape [n_rows, n_datapoints] to timeseries x-y data of shape [n_rows, n_datapoints, xy] + + Parameters + ---------- + heatmap: np.ndarray, shape [n_rows, n_datapoints] + timeseries data with a heatmap representation, where each column represents a timepoint. + + xvals: np.ndarray, shape: [n_datapoints,] + x-values for the columns in the heatmap + + Returns + ------- + np.ndarray, shape [n_rows, n_datapoints, 2] + timeseries data where the xy data are explicitly stored for every row + + """ + if heatmap.ndim != 2: + raise ValueError + + if xvals.ndim != 1: + raise ValueError + + if xvals.size != heatmap.shape[1]: + raise ValueError + + ts = np.empty((*heatmap.shape, 2), dtype=np.float32) + ts[..., 0] = xvals + ts[..., 1] = heatmap + + return ts From ef28878f103a3e910415337c1e85bf4e3cbdf1f6 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Fri, 20 Mar 2026 23:19:37 -0400 Subject: [PATCH 099/101] index wans't calling handlers --- fastplotlib/widgets/nd_widget/_index.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fastplotlib/widgets/nd_widget/_index.py b/fastplotlib/widgets/nd_widget/_index.py index 9d45c844c..fc51c345c 100644 --- a/fastplotlib/widgets/nd_widget/_index.py +++ b/fastplotlib/widgets/nd_widget/_index.py @@ -188,6 +188,7 @@ def set(self, indices: dict[str, Any]): self._indices[dim] = self._clamp(dim, value) self._render_indices() + self._indices_changed() def _clamp(self, dim, value): if isinstance(self.ref_ranges[dim], RangeContinuous): @@ -215,6 +216,7 @@ def __setitem__(self, dim, value): # set index for given dim and render self._indices[dim] = self._clamp(dim, value) self._render_indices() + self._indices_changed() def _check_has_dim(self, dim): if dim not in self.dims: @@ -289,6 +291,10 @@ def clear_event_handlers(self): """Clear all registered event handlers""" self._indices_changed_handlers.clear() + def _indices_changed(self): + for f in self._indices_changed_handlers: + f(self._indices) + def __iter__(self): for index in self._indices.items(): yield index From 7778403a634d1b1cfa78b4d9eaa0b7beb56e64ab Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sun, 22 Mar 2026 22:46:43 -0400 Subject: [PATCH 100/101] remove --- fastplotlib/widgets/nd_widget/_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fastplotlib/widgets/nd_widget/_base.py b/fastplotlib/widgets/nd_widget/_base.py index 04d8fc745..932018f6e 100644 --- a/fastplotlib/widgets/nd_widget/_base.py +++ b/fastplotlib/widgets/nd_widget/_base.py @@ -14,7 +14,6 @@ from ...layouts import Subplot from ...utils import subsample_array, ArrayProtocol from ...graphics import Graphic -from ._repr_formatter import ndp_fmt_text, ndg_fmt_text, ndp_fmt_html, ndg_fmt_html from ._index import ReferenceIndex # must take arguments: array-like, `axis`: int, `keepdims`: bool From 3de9b23fa38342169d9623247924aa869fc10c5a Mon Sep 17 00:00:00 2001 From: Kushal Kolar Date: Wed, 8 Apr 2026 04:22:02 -0400 Subject: [PATCH 101/101] allow image types other than float32 (#1027) --- fastplotlib/graphics/features/_image.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/fastplotlib/graphics/features/_image.py b/fastplotlib/graphics/features/_image.py index af0783c71..681075ef2 100644 --- a/fastplotlib/graphics/features/_image.py +++ b/fastplotlib/graphics/features/_image.py @@ -1,5 +1,6 @@ from itertools import product from math import ceil +from warnings import warn import cmap as cmap_lib import numpy as np @@ -104,8 +105,11 @@ def _fix_data(self, data): "it must be of shape [rows, cols], [rows, cols, 3] or [rows, cols, 4]" ) - # let's just cast to float32 always - return data.astype(np.float32) + if data.itemsize == 8: + warn(f"casting {array.dtype} array to float32") + return data.astype(np.float32) + + return data def __iter__(self): self._iter = product(enumerate(self.row_indices), enumerate(self.col_indices))