diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 470e2e5a5..f17941405 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -49,7 +49,7 @@ jobs: - name: build docs run: | cd docs - RTD_BUILD=1 make html SPHINXOPTS="-W --keep-going" + DOCS_BUILD=1 make html SPHINXOPTS="-W --keep-going" # set environment variable `DOCS_VERSION_DIR` to either the pr-branch name, "dev", or the release version tag - name: set output pr diff --git a/README.md b/README.md index da5ed64f8..c8e64e65e 100644 --- a/README.md +++ b/README.md @@ -63,31 +63,28 @@ Questions, issues, ideas? You are welcome to post an [issue](https://github.com/ To install use pip: -```bash -# with imgui and jupyterlab -pip install -U "fastplotlib[notebook,imgui]" +### With imgui support (recommended) -# minimal install, install glfw, pyqt6 or pyside6 separately -pip install -U fastplotlib +Without jupyterlab support, install desired GUI framework such as glfw, PyQt6, or PySide6 separately. -# with imgui -pip install -U "fastplotlib[imgui]" + pip install -U "fastplotlib[imgui]" -# to use in jupyterlab without imgui -pip install -U "fastplotlib[notebook]" -``` +With jupyterlab support. -We strongly recommend installing ``simplejpeg`` for use in notebooks, you must first install [libjpeg-turbo](https://libjpeg-turbo.org/) + pip install -U "fastplotlib[notebook,imgui]" -- If you use ``conda``, you can get ``libjpeg-turbo`` through conda. -- If you are on linux, you can get it through your distro's package manager. -- For Windows and Mac compiled binaries are available on their release page: https://github.com/libjpeg-turbo/libjpeg-turbo/releases +### Without imgui -Once you have ``libjpeg-turbo``: +Minimal, install desired GUI library such as PyQt6, PySide6, or glfw separately. + + pip install fastplotlib + +With jupyterlab support only. + + pip install -U "fastplotlib[notebook]" + +Fastplotlib is also available on conda-forge. For imgui support you will need to separately install `imgui-bundle`, and for jupyterlab you will need to install `jupyter-rfb` and `simplejpeg` which are all available on conda-forge. -```bash -pip install simplejpeg -``` > **Note:** > `fastplotlib` and `pygfx` are fast evolving projects, the version available through pip might be outdated, you will need to follow the "For developers" instructions below if you want the latest features. You can find the release history here: https://github.com/fastplotlib/fastplotlib/releases diff --git a/docs/source/api/graphics/LineGraphic.rst b/docs/source/api/graphics/LineGraphic.rst index 428e8ef56..867f1bfbb 100644 --- a/docs/source/api/graphics/LineGraphic.rst +++ b/docs/source/api/graphics/LineGraphic.rst @@ -25,6 +25,7 @@ Properties LineGraphic.axes LineGraphic.block_events LineGraphic.cmap + LineGraphic.color_mode LineGraphic.colors LineGraphic.data LineGraphic.deleted diff --git a/docs/source/api/graphics/ScatterGraphic.rst b/docs/source/api/graphics/ScatterGraphic.rst index cf8e1224d..f9dcd2487 100644 --- a/docs/source/api/graphics/ScatterGraphic.rst +++ b/docs/source/api/graphics/ScatterGraphic.rst @@ -25,6 +25,7 @@ Properties ScatterGraphic.axes ScatterGraphic.block_events ScatterGraphic.cmap + ScatterGraphic.color_mode ScatterGraphic.colors ScatterGraphic.data ScatterGraphic.deleted diff --git a/docs/source/user_guide/guide.rst b/docs/source/user_guide/guide.rst index bd0352aa7..c3487de2e 100644 --- a/docs/source/user_guide/guide.rst +++ b/docs/source/user_guide/guide.rst @@ -6,31 +6,38 @@ Installation To install use pip: -.. code-block:: +With imgui support (recommended) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - # with imgui and jupyterlab - pip install -U "fastplotlib[notebook,imgui]" +Without jupyterlab support, install desired GUI framework such as glfw, PyQt6, or PySide6 separately. - # minimal install, install glfw, pyqt6 or pyside6 separately - pip install -U fastplotlib +.. code-block:: - # with imgui pip install -U "fastplotlib[imgui]" - # to use in jupyterlab, no imgui - pip install -U "fastplotlib[notebook]" +With jupyterlab support. -We strongly recommend installing ``simplejpeg`` for use in notebooks, you must first install `libjpeg-turbo `_. +.. code-block:: + + pip install -U "fastplotlib[notebook,imgui]" -- If you use ``conda``, you can get ``libjpeg-turbo`` through conda. -- If you are on linux you can get it through your distro's package manager. -- For Windows and Mac compiled binaries are available on their release page: https://github.com/libjpeg-turbo/libjpeg-turbo/releases +Without imgui +^^^^^^^^^^^^^ -Once you have ``libjpeg-turbo``: +Minimal, install desired GUI library such as PyQt6, PySide6, or glfw separately. .. code-block:: - pip install simplejpeg + pip install fastplotlib + +With jupyterlab support only. + +.. code-block:: + + pip install -U "fastplotlib[notebook]" + +Fastplotlib is also available on conda-forge. For imgui support you will need to separately install ``imgui-bundle``, and for jupyterlab you will need to install ``jupyter-rfb`` and ``simplejpeg`` which are all available on conda-forge. + What is ``fastplotlib``? ------------------------ diff --git a/examples/events/cmap_event.py b/examples/events/cmap_event.py index 62913cb29..f01f06d6a 100644 --- a/examples/events/cmap_event.py +++ b/examples/events/cmap_event.py @@ -34,7 +34,7 @@ xs = np.linspace(0, 4 * np.pi, 100) ys = np.sin(xs) -figure["sine"].add_line(np.column_stack([xs, ys])) +figure["sine"].add_line(np.column_stack([xs, ys]), color_mode="vertex") # make a 2D gaussian cloud cloud_data = np.random.normal(0, scale=3, size=1000).reshape(500, 2) diff --git a/examples/gridplot/multigraphic_gridplot.py b/examples/gridplot/multigraphic_gridplot.py index cbf546e2a..0e89efcdc 100644 --- a/examples/gridplot/multigraphic_gridplot.py +++ b/examples/gridplot/multigraphic_gridplot.py @@ -106,7 +106,7 @@ def make_circle(center, radius: float, n_points: int = 75) -> np.ndarray: gaussian_cloud2 = np.random.multivariate_normal(mean, covariance, n_points) # add the scatter graphics to the figure -figure["scatter"].add_scatter(data=gaussian_cloud, sizes=2, cmap="jet") +figure["scatter"].add_scatter(data=gaussian_cloud, sizes=2, cmap="jet", color_mode="vertex") figure["scatter"].add_scatter(data=gaussian_cloud2, colors="r", sizes=2) figure.show() diff --git a/examples/guis/imgui_basic.py b/examples/guis/imgui_basic.py index 74d3c3629..7f42eadd6 100644 --- a/examples/guis/imgui_basic.py +++ b/examples/guis/imgui_basic.py @@ -29,10 +29,10 @@ figure = fpl.Figure(size=(700, 560)) # make some scatter points at every 10th point -figure[0, 0].add_scatter(data[::10], colors="cyan", sizes=15, name="sine-scatter", uniform_color=True) +figure[0, 0].add_scatter(data[::10], colors="cyan", sizes=15, name="sine-scatter") # place a line above the scatter -figure[0, 0].add_line(data, thickness=3, colors="r", name="sine-wave", uniform_color=True) +figure[0, 0].add_line(data, thickness=3, colors="r", name="sine-wave") class ImguiExample(EdgeWindow): diff --git a/examples/image/image_reshaping.py b/examples/image/image_reshaping.py new file mode 100644 index 000000000..23264bda1 --- /dev/null +++ b/examples/image/image_reshaping.py @@ -0,0 +1,50 @@ +""" +Image reshaping +=============== + +An example that shows replacement of the image data with new data of a different shape. Under the hood, this creates a +new buffer and a new array of Textures on the GPU that replace the older Textures. Creating a new buffer and textures +has a performance cost, so you should do this only if you need to or if the performance drawback is not a concern for +your use case. + +Note that the vmin-vmax is reset when you replace the buffers. +""" + +# test_example = false +# sphinx_gallery_pygfx_docs = 'animate' + + +import numpy as np +import fastplotlib as fpl + +# create some data, diagonal sinusoidal bands +xs = np.linspace(0, 2300, 2300, dtype=np.float16) +full_data = np.vstack([np.cos(np.sqrt(xs + (np.pi / 2) * i)) * i for i in range(2_300)]) + +figure = fpl.Figure() + +image = figure[0, 0].add_image(full_data) + +figure.show() + +i, j = 1, 1 + + +def update(): + global i, j + # set the new image data as a subset of the full data + row = np.abs(np.sin(i)) * 2300 + col = np.abs(np.cos(i)) * 2300 + image.data = full_data[: int(row), : int(col)] + + i += 0.01 + j += 0.01 + + +figure.add_animations(update) + +# NOTE: fpl.loop.run() should not be used for interactive sessions +# See the "JupyterLab and IPython" section in the user guide +if __name__ == "__main__": + print(__doc__) + fpl.loop.run() diff --git a/examples/line/line_cmap.py b/examples/line/line_cmap.py index 3d2b5e8c9..6dfc1fe23 100644 --- a/examples/line/line_cmap.py +++ b/examples/line/line_cmap.py @@ -27,7 +27,7 @@ data=sine_data, thickness=10, cmap="plasma", - cmap_transform=sine_data[:, 1] + cmap_transform=sine_data[:, 1], ) # qualitative colormaps, useful for cluster labels or other types of categorical labels @@ -36,7 +36,7 @@ data=cosine_data, thickness=10, cmap="tab10", - cmap_transform=labels + cmap_transform=labels, ) figure.show() diff --git a/examples/line/line_cmap_more.py b/examples/line/line_cmap_more.py index c7c0d80f4..c6e811fb2 100644 --- a/examples/line/line_cmap_more.py +++ b/examples/line/line_cmap_more.py @@ -31,16 +31,35 @@ # set colormap by mapping data using a transform # here we map the color using the y-values of the sine data # i.e., the color is a function of sine(x) -line2 = figure[0, 0].add_line(sine, thickness=10, cmap="jet", cmap_transform=sine[:, 1], offset=(0, 4, 0)) +line2 = figure[0, 0].add_line( + sine, + thickness=10, + cmap="jet", + cmap_transform=sine[:, 1], + offset=(0, 4, 0), +) # make a line and change the cmap afterward, here we are using the cosine instead fot the transform -line3 = figure[0, 0].add_line(sine, thickness=10, cmap="jet", cmap_transform=cosine[:, 1], offset=(0, 6, 0)) +line3 = figure[0, 0].add_line( + sine, + thickness=10, + cmap="jet", + cmap_transform=cosine[:, 1], + offset=(0, 6, 0) +) + # change the cmap line3.cmap = "bwr" # use quantitative colormaps with categorical cmap_transforms labels = [0] * 25 + [1] * 5 + [2] * 50 + [3] * 20 -line4 = figure[0, 0].add_line(sine, thickness=10, cmap="tab10", cmap_transform=labels, offset=(0, 8, 0)) +line4 = figure[0, 0].add_line( + sine, + thickness=10, + cmap="tab10", + cmap_transform=labels, + offset=(0, 8, 0), +) # some text labels for i in range(5): diff --git a/examples/line/line_colorslice.py b/examples/line/line_colorslice.py index b6865eadb..264f944f3 100644 --- a/examples/line/line_colorslice.py +++ b/examples/line/line_colorslice.py @@ -30,7 +30,8 @@ sine = figure[0, 0].add_line( data=sine_data, thickness=5, - colors="magenta" + colors="magenta", + color_mode="vertex", # initialize with same color across vertices, but we will change the per-vertex colors later ) # you can also use colormaps for lines! @@ -56,6 +57,7 @@ data=zeros_data, thickness=8, colors="w", + color_mode="vertex", # initialize with same color across vertices, but we will change the per-vertex colors later offset=(0, 10, 0) ) diff --git a/examples/line_collection/line_collection_slicing.py b/examples/line_collection/line_collection_slicing.py index f829a53c6..98ad97056 100644 --- a/examples/line_collection/line_collection_slicing.py +++ b/examples/line_collection/line_collection_slicing.py @@ -26,6 +26,7 @@ multi_data, thickness=[2, 10, 2, 5, 5, 5, 8, 8, 8, 9, 3, 3, 3, 4, 4], separation=4, + color_mode="vertex", # this will allow us to set per-vertex colors on each line metadatas=list(range(15)), # some metadata names=list("abcdefghijklmno"), # unique name for each line ) diff --git a/examples/machine_learning/kmeans.py b/examples/machine_learning/kmeans.py index f571882ce..4c49844f0 100644 --- a/examples/machine_learning/kmeans.py +++ b/examples/machine_learning/kmeans.py @@ -80,6 +80,7 @@ sizes=5, cmap="tab10", # use a qualitative cmap cmap_transform=kmeans.labels_, # color by the predicted cluster + uniform_size=False, ) # initial index diff --git a/examples/misc/buffer_replace_gc.py b/examples/misc/buffer_replace_gc.py new file mode 100644 index 000000000..e3b0ac104 --- /dev/null +++ b/examples/misc/buffer_replace_gc.py @@ -0,0 +1,91 @@ +""" +Buffer replacement garbage collection test +========================================== + +This is an example that used for a manual test to ensure that GPU VRAM is free when buffers are replaced. + +Use while monitoring VRAM usage with nvidia-smi +""" + +# test_example = false +# sphinx_gallery_pygfx_docs = 'code' + + +from typing import Literal +import numpy as np +import fastplotlib as fpl +from fastplotlib.ui import EdgeWindow +from imgui_bundle import imgui + + +def generate_dataset(size: int) -> dict[str, np.ndarray]: + return { + "data": np.random.rand(size, 3), + "colors": np.random.rand(size, 4), + # TODO: there's a wgpu bind group issue with edge_colors, will figure out later + # "edge_colors": np.random.rand(size, 4), + "markers": np.random.choice(list("osD+x^v<>*"), size=size), + "sizes": np.random.rand(size) * 5, + "point_rotations": np.random.rand(size) * 180, + } + + +datasets = { + "init": generate_dataset(50_000), + "small": generate_dataset(100), + "large": generate_dataset(5_000_000), +} + + +class UI(EdgeWindow): + def __init__(self, figure): + super().__init__(figure=figure, size=200, location="right", title="UI") + init_data = datasets["init"] + self._figure["line"].add_line( + data=init_data["data"], colors=init_data["colors"], name="line" + ) + self._figure["scatter"].add_scatter( + **init_data, + uniform_size=False, + uniform_marker=False, + uniform_edge_color=False, + point_rotation_mode="vertex", + name="scatter", + ) + + def update(self): + for graphic in ["line", "scatter"]: + if graphic == "line": + features = ["data", "colors"] + + elif graphic == "scatter": + features = list(datasets["init"].keys()) + + for size in ["small", "large"]: + for fea in features: + if imgui.button(f"{size} - {graphic} - {fea}"): + self._replace(graphic, fea, size) + + def _replace( + self, + graphic: Literal["line", "scatter", "image"], + feature: Literal["data", "colors", "markers", "sizes", "point_rotations"], + size: Literal["small", "large"], + ): + new_value = datasets[size][feature] + + setattr(self._figure[graphic][graphic], feature, new_value) + + +figure = fpl.Figure(shape=(3, 1), size=(700, 1600), names=["line", "scatter", "image"]) +ui = UI(figure) +figure.add_gui(ui) + +figure.show() + + +# NOTE: fpl.loop.run() should not be used for interactive sessions +# See the "JupyterLab and IPython" section in the user guide +if __name__ == "__main__": + print(__doc__) + fpl.loop.run() diff --git a/examples/misc/lorenz_animation.py b/examples/misc/lorenz_animation.py index 20aee5d83..52a77a243 100644 --- a/examples/misc/lorenz_animation.py +++ b/examples/misc/lorenz_animation.py @@ -60,7 +60,12 @@ def lorenz(xyz, *, s=10, r=28, b=2.667): scatter_markers = list() for graphic in lorenz_line: - marker = figure[0, 0].add_scatter(graphic.data.value[0], sizes=16, colors=graphic.colors[0]) + marker = figure[0, 0].add_scatter( + graphic.data.value[0], + sizes=16, + colors=graphic.colors, + edge_colors="w", + ) scatter_markers.append(marker) # initialize time diff --git a/examples/misc/reshape_lines_scatters.py b/examples/misc/reshape_lines_scatters.py new file mode 100644 index 000000000..db8adb29e --- /dev/null +++ b/examples/misc/reshape_lines_scatters.py @@ -0,0 +1,92 @@ +""" +Change number of points in lines and scatters +============================================= + +This example sets lines and scatters with new data of a different shape, i.e. new data with more or fewer datapoints. +Internally, this creates new buffers for the feature that is being set (data, colors, markers, etc.). Note that there +are performance drawbacks to doing this, so it is recommended to maintain the same number of datapoints in a graphic +when possible. You only want to change the number of datapoints when it's really necessary, and you don't want to do +it constantly (such as tens or hundreds of times per second). + +This example is also useful for manually checking that GPU buffers are freed when they're no longer in use. Run this +example while monitoring VRAM usage with `nvidia-smi` +""" + +# test_example = false +# sphinx_gallery_pygfx_docs = 'animate' + + +import numpy as np +import fastplotlib as fpl + +# create some data to start with +xs = np.linspace(0, 10 * np.pi, 100) +ys = np.sin(xs) + +data = np.column_stack([xs, ys]) + +# create a figure, add a line, scatter and line_stack +figure = fpl.Figure(shape=(3, 1), size=(700, 700)) + +line = figure[0, 0].add_line(data) + +scatter = figure[1, 0].add_scatter( + np.random.rand(100, 3), + colors=np.random.rand(100, 4), + markers=np.random.choice(list("osD+x^v<>*"), size=100), + sizes=(np.random.rand(100) + 1) * 3, + edge_colors=np.random.rand(100, 4), + point_rotations=np.random.rand(100) * 180, + uniform_marker=False, + uniform_size=False, + uniform_edge_color=False, + point_rotation_mode="vertex", +) + +line_stack = figure[2, 0].add_line_stack(np.stack([data] * 10), cmap="viridis") + +text = figure[0, 0].add_text(f"n_points: {100}", offset=(0, 1.5, 0), anchor="middle-left") + +figure.show(maintain_aspect=False) + +i = 0 + + +def update(): + # set a new larger or smaller data array on every render + global i + + # create new data + freq = np.abs(np.sin(i)) * 10 + n_points = int((freq * 20_000) + 10) + + xs = np.linspace(0, 10 * np.pi, n_points) + ys = np.sin(xs * freq) + + new_data = np.column_stack([xs, ys]) + + # update line data + line.data = new_data + + # update scatter data, colors, markers, etc. + scatter.data = np.random.rand(n_points, 3) + scatter.colors = np.random.rand(n_points, 4) + scatter.markers = np.random.choice(list("osD+x^v<>*"), size=n_points) + scatter.edge_colors = np.random.rand(n_points, 4) + scatter.point_rotations = np.random.rand(n_points) * 180 + + # update line stack data + line_stack.data = np.stack([new_data] * 10) + + text.text = f"n_points: {n_points}" + + i += 0.01 + + +figure.add_animations(update) + +# NOTE: fpl.loop.run() should not be used for interactive sessions +# See the "JupyterLab and IPython" section in the user guide +if __name__ == "__main__": + print(__doc__) + fpl.loop.run() diff --git a/examples/misc/scatter_animation.py b/examples/misc/scatter_animation.py index d37aea976..549059b65 100644 --- a/examples/misc/scatter_animation.py +++ b/examples/misc/scatter_animation.py @@ -37,7 +37,7 @@ figure = fpl.Figure(size=(700, 560)) subplot_scatter = figure[0, 0] # use an alpha value since this will be a lot of points -scatter = subplot_scatter.add_scatter(data=cloud, sizes=3, colors=colors, alpha=0.6) +scatter = subplot_scatter.add_scatter(data=cloud, sizes=3, uniform_size=False, colors=colors, alpha=0.6) def update_points(subplot): diff --git a/examples/misc/scatter_sizes_animation.py b/examples/misc/scatter_sizes_animation.py index 53a616a68..2092787f3 100644 --- a/examples/misc/scatter_sizes_animation.py +++ b/examples/misc/scatter_sizes_animation.py @@ -20,7 +20,7 @@ figure = fpl.Figure(size=(700, 560)) -figure[0, 0].add_scatter(data, sizes=sizes, name="sine") +figure[0, 0].add_scatter(data, sizes=sizes, uniform_size=False, name="sine") i = 0 diff --git a/examples/ndwidget/README.rst b/examples/ndwidget/README.rst new file mode 100644 index 000000000..28ed4d752 --- /dev/null +++ b/examples/ndwidget/README.rst @@ -0,0 +1,2 @@ +NDWidget Examples +================= diff --git a/examples/ndwidget/ndimage.py b/examples/ndwidget/ndimage.py new file mode 100644 index 000000000..eafd3c3c3 --- /dev/null +++ b/examples/ndwidget/ndimage.py @@ -0,0 +1,54 @@ +""" +NDWidget image +============== + +NDWidget image example +""" + +# test_example = true +# sphinx_gallery_pygfx_docs = 'screenshot' + +import numpy as np +import fastplotlib as fpl + + +data = np.random.rand(1000, 30, 64, 64) +data2 = np.random.rand(1000, 30, 128, 128) + +# must define a reference range for each dim +ref = { + "time": (0, 1000, 1), + "depth": (0, 30, 1), +} + + +ndw = fpl.NDWidget( + ref_ranges=ref, + size=(700, 560) +) +ndw2 = fpl.NDWidget( + ref_ranges=ref, + ref_index=ndw.indices, # can create another NDWidget that shared the reference index! So multiple windows are possible + size=(700, 560) +) + +ndi = ndw[0, 0].add_nd_image( + data, + ("time", "depth", "m", "n"), # specify all dim names + ("m", "n"), # specify spatial dims IN ORDER, rest are auto slider dims + name="4d-image", +) + +ndi2 = ndw2[0, 0].add_nd_image( + data2, + ("time", "depth", "m", "n"), # specify all dim names + ("m", "n"), # specify spatial dims IN ORDER, rest are auto slider dims + name="4d-image", +) + +# change spatial dims on the fly +# ndi.spatial_dims = ("depth", "m", "n") + +ndw.show() +ndw2.show() +fpl.loop.run() diff --git a/examples/ndwidget/timeseries.py b/examples/ndwidget/timeseries.py new file mode 100644 index 000000000..9d7ba851f --- /dev/null +++ b/examples/ndwidget/timeseries.py @@ -0,0 +1,62 @@ +""" +NDWidget Timeseries +=================== + +NDWidget timeseries example +""" + +# test_example = true +# sphinx_gallery_pygfx_docs = 'screenshot' + +import numpy as np +import fastplotlib as fpl + +# generate some toy timeseries data +n_datapoints = 100_000 # number of datapoints per line +n_freqs = 20 # number of frequencies +n_ampls = 15 # number of amplitudes +n_lines = 8 + +xs = np.linspace(0, 1000 * np.pi, n_datapoints) + +data = np.zeros(shape=(n_freqs, n_ampls, n_lines, n_datapoints, 2), dtype=np.float32) + +for freq in range(data.shape[0]): + for ampl in range(data.shape[1]): + ys = np.sin(xs * (freq + 1)) * (ampl + 1) + np.random.normal( + 0, 0.1, size=n_datapoints + ) + line = np.column_stack([xs, ys]) + data[freq, ampl] = np.stack([line] * n_lines) + + +# must define a reference range, this would often be your time dimension and corresponds to your x-dimension +ref = { + "freq": (1, n_freqs + 1, 1), + "ampl": (1, n_ampls + 1, 1), + "angle": (0, xs[-1], 0.1), +} + +ndw = fpl.NDWidget(ref_ranges=ref, size=(700, 560)) + +nd_lines = ndw[0, 0].add_nd_timeseries( + data, + ("freq", "ampl", "n_lines", "angle", "d"), + ("n_lines", "angle", "d"), + slider_dim_transforms={ + "angle": xs, + "ampl": lambda x: int(x + 1), + "freq": lambda x: int(x + 1), + }, + cmap="jet", + x_range_mode="auto", + name="nd-sine" +) + +nd_lines.cmap = "tab10" + +subplot = ndw.figure[0, 0] +subplot.controller.add_camera(subplot.camera, include_state={"x", "width"}) + +ndw.show(maintain_aspect=False) +fpl.loop.run() diff --git a/examples/notebooks/quickstart.ipynb b/examples/notebooks/quickstart.ipynb index 7b7551588..61bcb6b06 100644 --- a/examples/notebooks/quickstart.ipynb +++ b/examples/notebooks/quickstart.ipynb @@ -719,8 +719,8 @@ "# we will add all the lines to the same subplot\n", "subplot = fig_lines[0, 0]\n", "\n", - "# plot sine wave, use a single color\n", - "sine = subplot.add_line(data=sine_data, thickness=5, colors=\"magenta\")\n", + "# plot sine wave, use a single color for now, but we will set per-vertex colors later\n", + "sine = subplot.add_line(data=sine_data, thickness=5, colors=\"magenta\", color_mode=\"vertex\")\n", "\n", "# you can also use colormaps for lines!\n", "cosine = subplot.add_line(data=cosine_data, thickness=12, cmap=\"autumn\")\n", diff --git a/examples/scatter/scatter_iris.py b/examples/scatter/scatter_iris.py index b9df16026..fc228e5bf 100644 --- a/examples/scatter/scatter_iris.py +++ b/examples/scatter/scatter_iris.py @@ -35,6 +35,7 @@ cmap="tab10", cmap_transform=clusters_labels, markers=markers, + uniform_marker=False, ) figure.show() diff --git a/examples/scatter/scatter_size.py b/examples/scatter/scatter_size.py index 30d3e6ea3..2b3899dbe 100644 --- a/examples/scatter/scatter_size.py +++ b/examples/scatter/scatter_size.py @@ -35,7 +35,7 @@ ) # add a set of scalar sizes non_scalar_sizes = np.abs((y_values / np.pi)) # ensure minimum size of 5 -figure["array_size"].add_scatter(data=data, sizes=non_scalar_sizes, colors="red") +figure["array_size"].add_scatter(data=data, sizes=non_scalar_sizes, uniform_size=False, colors="red") for graph in figure: graph.auto_scale(maintain_aspect=True) diff --git a/examples/scatter/scatter_validate.py b/examples/scatter/scatter_validate.py index abddffee0..45f0a177c 100644 --- a/examples/scatter/scatter_validate.py +++ b/examples/scatter/scatter_validate.py @@ -41,6 +41,7 @@ uniform_edge_color=False, edge_colors=["w"] * 3 + ["orange"] * 3 + ["blue"] * 3 + ["green"], markers=list("osD+x^v<>*"), + uniform_marker=False, edge_width=2.0, sizes=20, uniform_size=True, @@ -64,6 +65,7 @@ sine, markers="s", sizes=xs * 5, + uniform_size=False, offset=(0, 2, 0) ) diff --git a/examples/scatter/spinning_spiral.py b/examples/scatter/spinning_spiral.py index 89e74eaec..4f947970a 100644 --- a/examples/scatter/spinning_spiral.py +++ b/examples/scatter/spinning_spiral.py @@ -34,7 +34,14 @@ canvas_kwargs={"max_fps": 500, "vsync": False} ) -spiral = figure[0, 0].add_scatter(data, cmap="viridis_r", edge_colors=None, alpha=0.5, sizes=sizes) +spiral = figure[0, 0].add_scatter( + data, + cmap="viridis_r", + edge_colors=None, + alpha=0.5, + sizes=sizes, + uniform_size=False, +) # pre-generate normally distributed data to jitter the points before each render jitter = np.random.normal(scale=0.001, size=n * 3).reshape((n, 3)) diff --git a/fastplotlib/__init__.py b/fastplotlib/__init__.py index 6dab91605..bde2c89e3 100644 --- a/fastplotlib/__init__.py +++ b/fastplotlib/__init__.py @@ -19,7 +19,7 @@ else: from .layouts import Figure -from .widgets import ImageWidget +from .widgets import NDWidget, ImageWidget from .utils import config, enumerate_adapters, select_adapter, print_wgpu_report diff --git a/fastplotlib/graphics/__init__.py b/fastplotlib/graphics/__init__.py index 3d01e4a35..cca2afc21 100644 --- a/fastplotlib/graphics/__init__.py +++ b/fastplotlib/graphics/__init__.py @@ -7,7 +7,7 @@ from .mesh import MeshGraphic, SurfaceGraphic, PolygonGraphic from .text import TextGraphic from .line_collection import LineCollection, LineStack - +from .scatter_collection import ScatterCollection, ScatterStack __all__ = [ "Graphic", @@ -22,4 +22,6 @@ "TextGraphic", "LineCollection", "LineStack", + "ScatterCollection", + "ScatterStack", ] diff --git a/fastplotlib/graphics/_axes.py b/fastplotlib/graphics/_axes.py index 5b4c21682..56ca792a4 100644 --- a/fastplotlib/graphics/_axes.py +++ b/fastplotlib/graphics/_axes.py @@ -301,6 +301,8 @@ def __init__( self._basis = None self.basis = basis + self._last_state = self._get_view_state() + @property def world_object(self) -> pygfx.WorldObject: return self._world_object @@ -402,6 +404,14 @@ def intersection(self, intersection: tuple[float, float, float] | None): self._intersection = tuple(float(v) for v in intersection) + def _get_view_state(self) -> tuple[bytes, tuple[int, int], tuple[int, int], bytes]: + viewport = self._plot_area.viewport + cam_matrix = self._plot_area.camera.camera_matrix.tobytes() + scale = self._plot_area.camera.local.scale.tobytes() + + return (cam_matrix, viewport.rect, viewport.logical_size, scale) + + def update_using_bbox(self, bbox): """ Update the w.r.t. the given bbox @@ -444,6 +454,10 @@ def update_using_camera(self): if not self.visible: return + state = self._get_view_state() + if state == self._last_state: + # no changes in the camera or viewport rect + return if self._plot_area.camera.fov == 0: xpos, ypos, width, height = self._plot_area.viewport.rect @@ -453,27 +467,6 @@ def update_using_camera(self): xmin, xmax = xpos, xpos + width ymin, ymax = ypos + height, ypos - # apply quaternion to account for rotation of axes - # xmin, _, _ = vec_transform_quat( - # [xmin, ypos + height / 2, 0], - # self.x.local.rotation - # ) - # - # xmax, _, _ = vec_transform_quat( - # [xmax, ypos + height / 2, 0], - # self.x.local.rotation, - # ) - # - # _, ymin, _ = vec_transform_quat( - # [xpos + width / 2, ymin, 0], - # self.y.local.rotation - # ) - # - # _, ymax, _ = vec_transform_quat( - # [xpos + width / 2, ymax, 0], - # self.y.local.rotation - # ) - min_vals = self._plot_area.map_screen_to_world((xmin, ymin)) max_vals = self._plot_area.map_screen_to_world((xmax, ymax)) @@ -515,6 +508,8 @@ def update_using_camera(self): self.update(bbox, intersection) + self._last_state = state + def update(self, bbox, intersection): """ Update the axes using the given bbox and ruler intersection point diff --git a/fastplotlib/graphics/_base.py b/fastplotlib/graphics/_base.py index 5279cf306..2a02adef4 100644 --- a/fastplotlib/graphics/_base.py +++ b/fastplotlib/graphics/_base.py @@ -67,7 +67,6 @@ class Graphic: _fpl_support_tooltip: bool = True def __init_subclass__(cls, **kwargs): - # set of all features cls._features = { **cls._features, @@ -178,6 +177,7 @@ def __init__( self._alpha_mode = AlphaMode(alpha_mode) self._visible = Visible(visible) self._block_events = False + self._block_handlers = list() self._axes: Axes = None @@ -274,6 +274,11 @@ def block_events(self) -> bool: def block_events(self, value: bool): self._block_events = value + @property + def block_handlers(self) -> list: + """Used to block event handlers for a graphic and prevent recursion.""" + return self._block_handlers + @property def world_object(self) -> pygfx.WorldObject: """Associated pygfx WorldObject. Always returns a proxy, real object cannot be accessed directly.""" @@ -285,15 +290,8 @@ def _set_world_object(self, wo: pygfx.WorldObject): # add to world object -> graphic mapping if isinstance(wo, pygfx.Group): - for child in wo.children: - if isinstance( - child, (pygfx.Image, pygfx.Volume, pygfx.Points, pygfx.Line) - ): - # unique 32 bit integer id for each world object - global_id = child.id - WORLD_OBJECT_TO_GRAPHIC[global_id] = self - # store id to pop from dict when graphic is deleted - self._world_object_ids.append(global_id) + # for Graphics which use a pygfx.Group, ImageGraphic and graphic collections + self._add_group_graphic_map(wo) else: global_id = wo.id WORLD_OBJECT_TO_GRAPHIC[global_id] = self @@ -322,6 +320,27 @@ def _set_world_object(self, wo: pygfx.WorldObject): if not all(wo.world.scale == self.scale): self.scale = self.scale + def _add_group_graphic_map(self, wo: pygfx.Group): + # add the children of the group to the WorldObject -> Graphic map + # used by images since they create new WorldObject ImageTiles when a different buffer size is required + # also used by GraphicCollections inititally, but not used for reseting like images + for child in wo.children: + if isinstance(child, (pygfx.Image, pygfx.Volume, pygfx.Points, pygfx.Line)): + # unique 32 bit integer id for each world object + global_id = child.id + WORLD_OBJECT_TO_GRAPHIC[global_id] = self + # store id to pop from dict when graphic is deleted + self._world_object_ids.append(global_id) + + def _remove_group_graphic_map(self, wo: pygfx.Group): + # remove the children of the group to the WorldObject -> Graphic map + for child in wo.children: + if isinstance(child, (pygfx.Image, pygfx.Volume, pygfx.Points, pygfx.Line)): + # unique 32 bit integer id for each world object + global_id = child.id + WORLD_OBJECT_TO_GRAPHIC.pop(global_id) + self._world_object_ids.remove(global_id) + @property def tooltip_format(self) -> Callable[[dict], str] | None: """ @@ -444,6 +463,9 @@ def _handle_event(self, callback, event: pygfx.Event): if self.block_events: return + if callback in self._block_handlers: + return + if event.type in self._features: # for feature events event._target = self.world_object @@ -501,6 +523,23 @@ def my_handler(event): feature = getattr(self, f"_{t}") feature.remove_event_handler(wrapper) + def _parse_positions(self, position: tuple | np.ndarray) -> np.ndarray: + """ + Converts position data (in the form of tuple or np.ndarray) into a (num_points, 3)-shaped np.ndarray for processing + """ + position = np.asarray(position) + + if not 0 < position.ndim < 3: + raise ValueError(f"position must be of shape (num_points, 3) or (3,)") + + elif position.ndim == 1: + position = position[None, :] + + if position.shape[-1] != 3: + raise ValueError(f"position must be of shape (num_points, 3) or (3,)") + + return position + def map_model_to_world( self, position: tuple[float, float, float] | tuple[float, float] | np.ndarray ) -> np.ndarray: @@ -509,27 +548,18 @@ def map_model_to_world( Parameters ---------- - position: (float, float, float) or (float, float) - (x, y, z) or (x, y) position. If z is not provided then the graphic's offset z is used. + position: tuple of (x, y, z) or np.ndarray of shape (num_points, 3) + The xyz positions we wish to map to model space Returns ------- np.ndarray - (x, y, z) position in world space - + either shape (3,) or (num_points, 3), specifying position in world space """ - - if len(position) == 2: - # use z of the graphic - position = [*position, self.offset[-1]] - - if len(position) != 3: - raise ValueError( - f"position must be tuple or array indicating (x, y, z) position in *model space*" - ) + position = self._parse_positions(position) # apply world transform to project from model space to world space - return la.vec_transform(position, self.world_object.world.matrix) + return la.vec_transform(position, self.world_object.world.matrix).squeeze() def map_world_to_model( self, position: tuple[float, float, float] | tuple[float, float] | np.ndarray @@ -539,26 +569,20 @@ def map_world_to_model( Parameters ---------- - position: (float, float, float) or (float, float) - (x, y, z) or (x, y) position. If z is not provided then 0 is used. + position: tuple of (x, y, z) or np.ndarray of shape (num_points, 3) + The xyz positions we wish to map to model space Returns ------- np.ndarray - (x, y, z) position in world space + either shape (3,) or (num_points, 3), specifying position in model space """ + position = self._parse_positions(position) - if len(position) == 2: - # use z of the graphic - position = [*position, self.offset[-1]] - - if len(position) != 3: - raise ValueError( - f"position must be tuple or array indicating (x, y, z) position in *model space*" - ) - - return la.vec_transform(position, self.world_object.world.inverse_matrix) + return la.vec_transform( + position, self.world_object.world.inverse_matrix + ).squeeze() def format_pick_info(self, ev: pygfx.PointerEvent) -> str: """ diff --git a/fastplotlib/graphics/_positions_base.py b/fastplotlib/graphics/_positions_base.py index af7d7badb..763f5e775 100644 --- a/fastplotlib/graphics/_positions_base.py +++ b/fastplotlib/graphics/_positions_base.py @@ -1,4 +1,6 @@ -from typing import Any, Sequence +from numbers import Real +from typing import Any, Sequence, Literal +from warnings import warn import numpy as np @@ -18,12 +20,20 @@ class PositionsGraphic(Graphic): @property def data(self) -> VertexPositions: - """Get or set the graphic's data""" + """ + Get or set the graphic's data. + + Note that if the number of datapoints does not match the number of + current datapoints a new buffer is automatically allocated. This can + have performance drawbacks when you have a very large number of datapoints. + This is usually fine as long as you don't need to do it hundreds of times + per second. + """ return self._data @data.setter def data(self, value): - self._data[:] = value + self._data.set_value(self, value) @property def colors(self) -> VertexColors | pygfx.Color: @@ -36,11 +46,59 @@ def colors(self) -> VertexColors | pygfx.Color: @colors.setter def colors(self, value: str | np.ndarray | Sequence[float] | Sequence[str]): + self._colors.set_value(self, value) + + @property + def color_mode(self) -> Literal["uniform", "vertex"]: + """ + Get or set the color mode. Note that after setting the color_mode, you will have to set the `colors` + as well for switching between 'uniform' and 'vertex' modes. + """ + return self.world_object.material.color_mode + + @color_mode.setter + def color_mode(self, mode: Literal["uniform", "vertex"]): + valid = ("uniform", "vertex") + if mode not in valid: + raise ValueError(f"`color_mode` must be one of : {valid}") + if mode == "vertex" and isinstance(self._colors, UniformColor): + # uniform -> vertex + # need to make a new vertex buffer and get rid of uniform buffer + new_colors = self._create_colors_buffer(self._colors.value, "vertex") + # we can't clear world_object.material.color so just set the colors buffer on the geometry + # this doesn't really matter anyways since the lingering uniform color takes up just a few bytes + self.world_object.geometry.colors = new_colors._fpl_buffer + + elif mode == "uniform" and isinstance(self._colors, VertexColors): + # vertex -> uniform + # use first vertex color and spit out a warning + warn( + "changing `color_mode` from vertex -> uniform, will use first vertex color " + "for the uniform and discard the remaining color values" + ) + new_colors = self._create_colors_buffer(self._colors.value[0], "uniform") + self.world_object.geometry.colors = None + self.world_object.material.color = new_colors.value + + # clear out cmap + self._cmap.clear_event_handlers() + self._cmap = None + + else: + # no change, return + return + + # restore event handlers onto the new colors feature + new_colors._event_handlers[:] = self._colors._event_handlers + self._colors.clear_event_handlers() + # this should trigger gc + self._colors = new_colors + + # this is created so that cmap can be set later if isinstance(self._colors, VertexColors): - self._colors[:] = value + self._cmap = VertexCmap(self._colors, cmap_name=None, transform=None) - elif isinstance(self._colors, UniformColor): - self._colors.set_value(self, value) + self.world_object.material.color_mode = mode @property def cmap(self) -> VertexCmap: @@ -53,8 +111,8 @@ def cmap(self) -> VertexCmap: @cmap.setter def cmap(self, name: str): - if self._cmap is None: - raise BufferError("Cannot use cmap with uniform_colors=True") + if self.color_mode == "uniform": + raise ValueError("cannot use `cmap` with `color_mode` = 'uniform'") self._cmap[:] = name @@ -71,14 +129,68 @@ def size_space(self): def size_space(self, value: str): self._size_space.set_value(self, value) + def _create_colors_buffer(self, colors, color_mode) -> UniformColor | VertexColors: + # creates either a UniformColor or VertexColors based on the given `colors` and `color_mode` + # if `color_mode` = "auto", returns {UniformColor | VertexColor} based on what the `colors` arg represents + # if `color_mode` = "uniform", it verifies that the user `colors` input represents just 1 color + # if `color_mode` = "vertex", always returns VertexColors regardless of whether `colors` represents >= 1 colors + + if isinstance(colors, VertexColors): + if color_mode == "uniform": + raise ValueError( + "if a `VertexColors` instance is provided for `colors`, " + "`color_mode` must be 'vertex' or 'auto', not 'uniform'" + ) + # share buffer with existing colors instance + new_colors = colors + # blank colormap instance + self._cmap = VertexCmap(new_colors, cmap_name=None, transform=None) + + else: + # determine if a single or multiple colors were passed and decide color mode + if isinstance(colors, (pygfx.Color, str)) or ( + len(colors) in [3, 4] and all(isinstance(v, Real) for v in colors) + ): + # one color specified as a str or pygfx.Color, or one color specified with RGB(A) values + if color_mode in ("auto", "uniform"): + new_colors = UniformColor(colors) + else: + new_colors = VertexColors( + colors, n_colors=self._data.value.shape[0] + ) + + elif all(isinstance(c, (str, pygfx.Color)) for c in colors): + # sequence of colors + if color_mode == "uniform": + raise ValueError( + "You passed `color_mode` = 'uniform', but specified a sequence of multiple colors. Use " + "`color_mode` = 'auto' or 'vertex' for multiple colors." + ) + new_colors = VertexColors(colors, n_colors=self._data.value.shape[0]) + + elif len(colors) > 4: + # sequence of multiple colors, must again ensure color_mode is not uniform + if color_mode == "uniform": + raise ValueError( + "You passed `color_mode` = 'uniform', but specified a sequence of multiple colors. Use " + "`color_mode` = 'auto' or 'vertex' for multiple colors." + ) + new_colors = VertexColors(colors, n_colors=self._data.value.shape[0]) + else: + raise ValueError( + "`colors` must be a str, pygfx.Color, array, list or tuple indicating an RGB(A) color, or a " + "sequence of str, pygfx.Color, or array of shape [n_datapoints, 3 | 4]" + ) + + return new_colors + def __init__( self, data: Any, colors: str | np.ndarray | tuple[float] | list[float] | list[str] = "w", - uniform_color: bool = False, cmap: str | VertexCmap = None, cmap_transform: np.ndarray = None, - isolated_buffer: bool = True, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", size_space: str = "screen", *args, **kwargs, @@ -86,22 +198,31 @@ def __init__( if isinstance(data, VertexPositions): self._data = data else: - self._data = VertexPositions(data, isolated_buffer=isolated_buffer) + self._data = VertexPositions(data) if cmap_transform is not None and cmap is None: raise ValueError("must pass `cmap` if passing `cmap_transform`") + valid = ("auto", "uniform", "vertex") + + # default _cmap is None + self._cmap = None + + if color_mode not in valid: + raise ValueError(f"`color_mode` must be one of {valid}") + if cmap is not None: # if a cmap is specified it overrides colors argument - if uniform_color: - raise TypeError("Cannot use cmap if uniform_color=True") + if color_mode == "uniform": + raise ValueError( + "if a `cmap` is provided, `color_mode` must be 'vertex' or 'auto', not 'uniform'" + ) if isinstance(cmap, str): # make colors from cmap if isinstance(colors, VertexColors): # share buffer with existing colors instance for the cmap self._colors = colors - self._colors._shared += 1 else: # create vertex colors buffer self._colors = VertexColors("w", n_colors=self._data.value.shape[0]) @@ -115,34 +236,18 @@ def __init__( # use existing cmap instance self._cmap = cmap self._colors = cmap._vertex_colors + else: raise TypeError( "`cmap` argument must be a cmap name or an existing `VertexCmap` instance" ) else: # no cmap given - if isinstance(colors, VertexColors): - # share buffer with existing colors instance - self._colors = colors - self._colors._shared += 1 - # blank colormap instance + self._colors = self._create_colors_buffer(colors, color_mode) + + # this is created so that cmap can be set later + if isinstance(self._colors, VertexColors): self._cmap = VertexCmap(self._colors, cmap_name=None, transform=None) - else: - if uniform_color: - if not isinstance(colors, str): # not a single color - if not len(colors) in [3, 4]: # not an RGB(A) array - raise TypeError( - "must pass a single color if using `uniform_colors=True`" - ) - self._colors = UniformColor(colors) - self._cmap = None - else: - self._colors = VertexColors( - colors, n_colors=self._data.value.shape[0] - ) - self._cmap = VertexCmap( - self._colors, cmap_name=None, transform=None - ) self._size_space = SizeSpace(size_space) super().__init__(*args, **kwargs) diff --git a/fastplotlib/graphics/features/_base.py b/fastplotlib/graphics/features/_base.py index 779310476..68fe54c33 100644 --- a/fastplotlib/graphics/features/_base.py +++ b/fastplotlib/graphics/features/_base.py @@ -1,5 +1,6 @@ +import weakref from warnings import warn -from typing import Literal +from typing import Callable import numpy as np from numpy.typing import NDArray @@ -78,7 +79,7 @@ def block_events(self, val: bool): """ self._block_events = val - def add_event_handler(self, handler: callable): + def add_event_handler(self, handler: Callable): """ Add an event handler. All added event handlers are called when this feature changes. @@ -89,7 +90,7 @@ def add_event_handler(self, handler: callable): Parameters ---------- - handler: callable + handler: Callable a function to call when this feature changes """ @@ -102,7 +103,7 @@ def add_event_handler(self, handler: callable): self._event_handlers.append(handler) - def remove_event_handler(self, handler: callable): + def remove_event_handler(self, handler: Callable): """ Remove a registered event ``handler``. @@ -137,32 +138,28 @@ class BufferManager(GraphicFeature): def __init__( self, - data: NDArray | pygfx.Buffer, - buffer_type: Literal["buffer", "texture", "texture-array"] = "buffer", - isolated_buffer: bool = True, + data: NDArray | pygfx.Buffer | None, **kwargs, ): super().__init__(**kwargs) - if isolated_buffer and not isinstance(data, pygfx.Resource): - # useful if data is read-only, example: memmaps - bdata = np.zeros(data.shape, dtype=data.dtype) - bdata[:] = data[:] - else: - # user's input array is used as the buffer - bdata = data - - if isinstance(data, pygfx.Resource): - # already a buffer, probably used for - # managing another BufferManager, example: VertexCmap manages VertexColors - self._buffer = data - elif buffer_type == "buffer": - self._buffer = pygfx.Buffer(bdata) + + # if data is None, then the BufferManager just provides a view into an existing buffer + # example: VertexCmap is basically a view into VertexColors + if data is not None: + if isinstance(data, pygfx.Resource): + # already a buffer, probably used for + # managing another BufferManager, example: VertexCmap manages VertexColors + self._fpl_buffer = data + else: + # create a buffer + bdata = np.empty(data.shape, dtype=data.dtype) + bdata[:] = data[:] + + self._fpl_buffer = pygfx.Buffer(bdata) else: - raise ValueError( - "`data` must be a pygfx.Buffer instance or `buffer_type` must be one of: 'buffer' or 'texture'" - ) + self._fpl_buffer = None - self._event_handlers: list[callable] = list() + self._event_handlers: list[Callable] = list() @property def value(self) -> np.ndarray: @@ -174,9 +171,10 @@ def set_value(self, graphic, value): self[:] = value @property - def buffer(self) -> pygfx.Buffer | pygfx.Texture: - """managed buffer""" - return self._buffer + def buffer(self) -> pygfx.Buffer: + """managed buffer, returns a weakref proxy""" + # the user should never create their own references to the buffer + return weakref.proxy(self._fpl_buffer) @property def __array_interface__(self): @@ -320,7 +318,7 @@ def __repr__(self): def block_reentrance(set_value): # decorator to block re-entrant set_value methods # useful when creating complex, circular, bidirectional event graphs - def set_value_wrapper(self: GraphicFeature, graphic_or_key, value): + def set_value_wrapper(self: GraphicFeature, graphic_or_key, value, **kwargs): """ wraps GraphicFeature.set_value @@ -336,7 +334,7 @@ def set_value_wrapper(self: GraphicFeature, graphic_or_key, value): try: # block re-execution of set_value until it has *fully* finished executing self._reentrant_block = True - set_value(self, graphic_or_key, value) + set_value(self, graphic_or_key, value, **kwargs) except Exception as exc: # raise original exception raise exc # set_value has raised. The line above and the lines 2+ steps below are probably more relevant! diff --git a/fastplotlib/graphics/features/_image.py b/fastplotlib/graphics/features/_image.py index 648f79bc8..af0783c71 100644 --- a/fastplotlib/graphics/features/_image.py +++ b/fastplotlib/graphics/features/_image.py @@ -1,14 +1,13 @@ from itertools import product - from math import ceil +import cmap as cmap_lib import numpy as np import pygfx from ._base import GraphicFeature, GraphicFeatureEvent, block_reentrance from ...utils import ( - make_colors, get_cmap_texture, ) @@ -33,7 +32,7 @@ class TextureArray(GraphicFeature): }, ] - def __init__(self, data, isolated_buffer: bool = True, property_name: str = "data"): + def __init__(self, data, property_name: str = "data"): super().__init__(property_name=property_name) data = self._fix_data(data) @@ -41,13 +40,9 @@ def __init__(self, data, isolated_buffer: bool = True, property_name: str = "dat shared = pygfx.renderers.wgpu.get_shared() self._texture_limit_2d = shared.device.limits["max-texture-dimension-2d"] - if isolated_buffer: - # useful if data is read-only, example: memmaps - self._value = np.zeros(data.shape, dtype=data.dtype) - self.value[:] = data[:] - else: - # user's input array is used as the buffer - self._value = data + # create a new buffer + self._value = np.zeros(data.shape, dtype=data.dtype) + self.value[:] = data[:] # data start indices for each Texture self._row_indices = np.arange( @@ -243,8 +238,8 @@ def value(self) -> str: @block_reentrance def set_value(self, graphic, value: str): - new_colors = make_colors(256, value) - graphic._material.map.texture.data[:] = new_colors + colormap = pygfx.cm.create_colormap(cmap_lib.Colormap(value).lut()) + graphic._material.map = colormap graphic._material.map.texture.update_range((0, 0, 0), size=(256, 1, 1)) self._value = value diff --git a/fastplotlib/graphics/features/_mesh.py b/fastplotlib/graphics/features/_mesh.py index 7355acb4e..776d77ce4 100644 --- a/fastplotlib/graphics/features/_mesh.py +++ b/fastplotlib/graphics/features/_mesh.py @@ -51,18 +51,14 @@ class MeshIndices(VertexPositions): }, ] - def __init__( - self, data: Any, isolated_buffer: bool = True, property_name: str = "indices" - ): + def __init__(self, data: Any, property_name: str = "indices"): """ Manages the vertex indices buffer shown in the graphic. Supports fancy indexing if the data array also supports it. """ data = self._fix_data(data) - super().__init__( - data, isolated_buffer=isolated_buffer, property_name=property_name - ) + super().__init__(data, property_name=property_name) def _fix_data(self, data): if data.ndim != 2 or data.shape[1] not in (3, 4): diff --git a/fastplotlib/graphics/features/_positions.py b/fastplotlib/graphics/features/_positions.py index 295d22417..507fc1ee0 100644 --- a/fastplotlib/graphics/features/_positions.py +++ b/fastplotlib/graphics/features/_positions.py @@ -39,7 +39,6 @@ def __init__( self, colors: str | pygfx.Color | np.ndarray | Sequence[float] | Sequence[str], n_colors: int, - isolated_buffer: bool = True, property_name: str = "colors", ): """ @@ -57,9 +56,59 @@ def __init__( """ data = parse_colors(colors, n_colors) - super().__init__( - data=data, isolated_buffer=isolated_buffer, property_name=property_name - ) + super().__init__(data=data, property_name=property_name) + + def set_value( + self, + graphic, + value: str | pygfx.Color | np.ndarray | Sequence[float] | Sequence[str], + ): + """set the entire array, create new buffer if necessary""" + if isinstance(value, (np.ndarray, list, tuple)): + # TODO: Refactor this triage so it's more elegant + + # first make sure it's not representing one color + skip = False + if isinstance(value, np.ndarray): + if (value.shape in ((3,), (4,))) and ( + np.issubdtype(value.dtype, np.floating) + or np.issubdtype(value.dtype, np.integer) + ): + # represents one color + skip = True + elif isinstance(value, (list, tuple)): + if len(value) in (3, 4) and all( + [isinstance(v, (float, int)) for v in value] + ): + # represents one color + skip = True + + # check if the number of elements matches current buffer size + if not skip and self.buffer.data.shape[0] != len(value): + # parse the new colors + new_colors = parse_colors(value, len(value)) + + # create the new buffer, old buffer should get dereferenced + # make sure new buffer is isolated (i.e. allocate a buffer, then set the values) + buff = np.empty(new_colors.shape, dtype=np.float32) + buff[:] = new_colors + self._fpl_buffer = pygfx.Buffer(buff) + graphic.world_object.geometry.colors = self._fpl_buffer + + if len(self._event_handlers) < 1: + return + + event_info = { + "key": slice(None), + "value": new_colors, + "user_value": value, + } + + event = GraphicFeatureEvent(self._property_name, info=event_info) + self._call_event_handlers(event) + return + + self[:] = value @block_reentrance def __setitem__( @@ -231,18 +280,14 @@ class VertexPositions(BufferManager): }, ] - def __init__( - self, data: Any, isolated_buffer: bool = True, property_name: str = "data" - ): + def __init__(self, data: Any, property_name: str = "data"): """ Manages the vertex positions buffer shown in the graphic. Supports fancy indexing if the data array also supports it. """ data = self._fix_data(data) - super().__init__( - data, isolated_buffer=isolated_buffer, property_name=property_name - ) + super().__init__(data, property_name=property_name) def _fix_data(self, data): if data.ndim == 1: @@ -261,13 +306,42 @@ def _fix_data(self, data): return to_gpu_supported_dtype(data) + def set_value(self, graphic, value): + """Sets the entire array, creates new buffer if necessary""" + if isinstance(value, np.ndarray): + if self.buffer.data.shape[0] != value.shape[0]: + # number of items doesn't match, create a new buffer + + # if data is not 3D + if value.ndim == 1: + # _fix_data creates a new array so we don't need to re-allocate with np.zeros + bdata = self._fix_data(value) + + elif value.shape[1] == 2: + # _fix_data creates a new array so we don't need to re-allocate with np.zeros + bdata = self._fix_data(value) + + elif value.shape[1] == 3: + # need to allocate a buffer to use here + bdata = np.empty(value.shape, dtype=np.float32) + bdata[:] = value[:] + + # create the new buffer, old buffer should get dereferenced + self._fpl_buffer = pygfx.Buffer(bdata) + graphic.world_object.geometry.positions = self._fpl_buffer + + self._emit_event(self._property_name, key=slice(None), value=value) + return + + self[:] = value + @block_reentrance def __setitem__( self, key: int | slice | np.ndarray[int | bool] | tuple[slice, ...], value: np.ndarray | float | list[float], ): - # directly use the key to slice the buffer + # directly use the key to slice the buffer and set the values self.buffer.data[key] = value # _update_range handles parsing the key to @@ -306,7 +380,7 @@ def __init__( provides a way to set colormaps with arbitrary transforms """ - super().__init__(data=vertex_colors.buffer, property_name=property_name) + super().__init__(data=None, property_name=property_name) self._vertex_colors = vertex_colors self._cmap_name = cmap_name @@ -331,6 +405,10 @@ def __init__( # set vertex colors from cmap self._vertex_colors[:] = colors + @property + def buffer(self) -> pygfx.Buffer: + return self._vertex_colors.buffer + @block_reentrance def __setitem__(self, key: slice, cmap_name): if not isinstance(key, slice): diff --git a/fastplotlib/graphics/features/_scatter.py b/fastplotlib/graphics/features/_scatter.py index 16671ef89..e41115ae3 100644 --- a/fastplotlib/graphics/features/_scatter.py +++ b/fastplotlib/graphics/features/_scatter.py @@ -100,6 +100,37 @@ def searchsorted_markers_to_int_array(markers_str_array: np.ndarray[str]): return marker_int_searchsorted_vals[indices] +def parse_markers(markers: str | Sequence[str] | np.ndarray, n_datapoints: int): + # first validate then allocate buffers + + if isinstance(markers, str): + markers = user_input_to_marker(markers) + + elif isinstance(markers, (tuple, list, np.ndarray)): + validate_user_markers_array(markers) + + # allocate buffers + markers_int_array = np.zeros(n_datapoints, dtype=np.int32) + + marker_str_length = max(map(len, list(pygfx.MarkerShape))) + + markers_readable_array = np.empty(n_datapoints, dtype=f" np.ndarray[str]: @@ -200,6 +200,25 @@ def _set_markers_arrays(self, key, value, n_markers): "new markers value must be a str, Sequence or np.ndarray of new marker values" ) + def set_value(self, graphic, value): + """set all the markers, create new buffer if necessary""" + if isinstance(value, (np.ndarray, list, tuple)): + if self.buffer.data.shape[0] != len(value): + # need to create a new buffer + markers_int_array, self._markers_readable_array = parse_markers( + value, len(value) + ) + + # create the new buffer, old buffer should get dereferenced + self._fpl_buffer = pygfx.Buffer(markers_int_array) + graphic.world_object.geometry.markers = self._fpl_buffer + + self._emit_event(self._property_name, key=slice(None), value=value) + + return + + self[:] = value + @block_reentrance def __setitem__( self, @@ -414,18 +433,15 @@ def __init__( self, rotations: int | float | np.ndarray | Sequence[int | float], n_datapoints: int, - isolated_buffer: bool = True, property_name: str = "point_rotations", ): """ Manages rotations buffer of scatter points. """ - sizes = self._fix_sizes(rotations, n_datapoints) - super().__init__( - data=sizes, isolated_buffer=isolated_buffer, property_name=property_name - ) + sizes = self._fix_rotations(rotations, n_datapoints) + super().__init__(data=sizes, property_name=property_name) - def _fix_sizes( + def _fix_rotations( self, sizes: int | float | np.ndarray | Sequence[int | float], n_datapoints: int, @@ -454,6 +470,22 @@ def _fix_sizes( return sizes + def set_value(self, graphic, value): + """set all rotations, create new buffer if necessary""" + if isinstance(value, (np.ndarray, list, tuple)): + if self.buffer.data.shape[0] != value.shape[0]: + # need to create a new buffer + value = self._fix_rotations(value, len(value)) + data = np.empty(shape=(len(value),), dtype=np.float32) + + # create the new buffer, old buffer should get dereferenced + self._fpl_buffer = pygfx.Buffer(data) + graphic.world_object.geometry.rotations = self._fpl_buffer + self._emit_event(self._property_name, key=slice(None), value=value) + return + + self[:] = value + @block_reentrance def __setitem__( self, @@ -488,16 +520,13 @@ def __init__( self, sizes: int | float | np.ndarray | Sequence[int | float], n_datapoints: int, - isolated_buffer: bool = True, property_name: str = "sizes", ): """ Manages sizes buffer of scatter points. """ sizes = self._fix_sizes(sizes, n_datapoints) - super().__init__( - data=sizes, isolated_buffer=isolated_buffer, property_name=property_name - ) + super().__init__(data=sizes, property_name=property_name) def _fix_sizes( self, @@ -533,6 +562,24 @@ def _fix_sizes( return sizes + def set_value(self, graphic, value): + """set all sizes, create new buffer if necessary""" + if isinstance(value, (np.ndarray, list, tuple)): + if self.buffer.data.shape[0] != len(value): + # create new buffer + value = self._fix_sizes(value, len(value)) + data = np.empty(shape=(len(value),), dtype=np.float32) + data[:] = value + + # create the new buffer, old buffer should get dereferenced + self._fpl_buffer = pygfx.Buffer(data) + graphic.world_object.geometry.sizes = self._fpl_buffer + + self._emit_event(self._property_name, key=slice(None), value=value) + return + + self[:] = value + @block_reentrance def __setitem__( self, diff --git a/fastplotlib/graphics/features/_selection_features.py b/fastplotlib/graphics/features/_selection_features.py index 9b30dd70c..1f049f0cb 100644 --- a/fastplotlib/graphics/features/_selection_features.py +++ b/fastplotlib/graphics/features/_selection_features.py @@ -118,7 +118,7 @@ def axis(self) -> str: return self._axis @block_reentrance - def set_value(self, selector, value: Sequence[float]): + def set_value(self, selector, value: Sequence[float], *, change: str = "full"): """ Set start, stop range of selector @@ -182,7 +182,9 @@ def set_value(self, selector, value: Sequence[float]): if len(self._event_handlers) < 1: return - event = GraphicFeatureEvent(self._property_name, {"value": self.value}) + event = GraphicFeatureEvent( + self._property_name, {"value": self.value, "change": change} + ) event.get_selected_indices = selector.get_selected_indices event.get_selected_data = selector.get_selected_data diff --git a/fastplotlib/graphics/features/_vectors.py b/fastplotlib/graphics/features/_vectors.py index 9c86d25fc..729562b06 100644 --- a/fastplotlib/graphics/features/_vectors.py +++ b/fastplotlib/graphics/features/_vectors.py @@ -22,7 +22,6 @@ class VectorPositions(GraphicFeature): def __init__( self, positions: np.ndarray, - isolated_buffer: bool = True, property_name: str = "positions", ): """ @@ -111,7 +110,6 @@ class VectorDirections(GraphicFeature): def __init__( self, directions: np.ndarray, - isolated_buffer: bool = True, property_name: str = "directions", ): """Manages vector field positions by managing the mesh instance buffer's full transform matrix""" diff --git a/fastplotlib/graphics/features/_volume.py b/fastplotlib/graphics/features/_volume.py index ec4c4052a..532065fb7 100644 --- a/fastplotlib/graphics/features/_volume.py +++ b/fastplotlib/graphics/features/_volume.py @@ -34,7 +34,7 @@ class TextureArrayVolume(GraphicFeature): }, ] - def __init__(self, data, isolated_buffer: bool = True): + def __init__(self, data): super().__init__(property_name="data") data = self._fix_data(data) @@ -43,13 +43,9 @@ def __init__(self, data, isolated_buffer: bool = True): self._texture_size_limit = shared.device.limits["max-texture-dimension-3d"] - if isolated_buffer: - # useful if data is read-only, example: memmaps - self._value = np.zeros(data.shape, dtype=data.dtype) - self.value[:] = data[:] - else: - # user's input array is used as the buffer - self._value = data + # create a new buffer that will be used for the texture data + self._value = np.zeros(data.shape, dtype=data.dtype) + self.value[:] = data[:] # data start indices for each Texture self._row_indices = np.arange( diff --git a/fastplotlib/graphics/image.py b/fastplotlib/graphics/image.py index 44bffcedc..8e11f4751 100644 --- a/fastplotlib/graphics/image.py +++ b/fastplotlib/graphics/image.py @@ -1,6 +1,7 @@ import math from typing import * +import numpy as np import pygfx from ..utils import quick_min_max @@ -102,7 +103,6 @@ def __init__( cmap: str = "plasma", interpolation: str = "nearest", cmap_interpolation: str = "linear", - isolated_buffer: bool = True, **kwargs, ): """ @@ -130,12 +130,6 @@ def __init__( cmap_interpolation: str, optional, default "linear" colormap interpolation method, one of "nearest" or "linear" - isolated_buffer: bool, default True - If True, initialize a buffer with the same shape as the input data and then - set the data, useful if the data arrays are ready-only such as memmaps. - If False, the input array is itself used as the buffer - useful if the - array is large. - kwargs: additional keyword arguments passed to :class:`.Graphic` @@ -143,7 +137,7 @@ def __init__( super().__init__(**kwargs) - world_object = pygfx.Group() + group = pygfx.Group() if isinstance(data, TextureArray): # share buffer @@ -151,7 +145,7 @@ def __init__( else: # create new texture array to manage buffer # texture array that manages the multiple textures on the GPU that represent this image - self._data = TextureArray(data, isolated_buffer=isolated_buffer) + self._data = TextureArray(data) if (vmin is None) or (vmax is None): _vmin, _vmax = quick_min_max(self.data.value) @@ -165,21 +159,28 @@ def __init__( self._vmax = ImageVmax(vmax) self._interpolation = ImageInterpolation(interpolation) + self._cmap_interpolation = ImageCmapInterpolation(cmap_interpolation) # set map to None for RGB images - if self._data.value.ndim > 2: + if self._data.value.ndim == 3: self._cmap = None _map = None - else: + + elif self._data.value.ndim == 2: # use TextureMap for grayscale images self._cmap = ImageCmap(cmap) - self._cmap_interpolation = ImageCmapInterpolation(cmap_interpolation) _map = pygfx.TextureMap( self._cmap.texture, filter=self._cmap_interpolation.value, wrap="clamp-to-edge", ) + else: + raise ValueError( + f"ImageGraphic `data` must have 2 dimensions for grayscale images, or 3 dimensions for RGB(A) images.\n" + f"You have passed a a data array with: {self._data.value.ndim} dimensions, " + f"and of shape: {self._data.value.shape}" + ) # one common material is used for every Texture chunk self._material = pygfx.ImageBasicMaterial( @@ -189,6 +190,14 @@ def __init__( pick_write=True, ) + # create the _ImageTile world objects, add to group + for tile in self._create_tiles(): + group.add(tile) + + self._set_world_object(group) + + def _create_tiles(self) -> list[_ImageTile]: + tiles = list() # iterate through each texture chunk and create # an _ImageTile, offset the tile using the data indices for texture, chunk_index, data_slice in self._data: @@ -209,17 +218,62 @@ def __init__( img.world.x = data_col_start img.world.y = data_row_start - world_object.add(img) + tiles.append(img) - self._set_world_object(world_object) + return tiles @property def data(self) -> TextureArray: - """Get or set the image data""" + """ + Get or set the image data. + + Note that if the shape of the new data array does not equal the shape of + current data array, a new set of GPU Textures are automatically created. + This can have performance drawbacks when you have a ver large images. + This is usually fine as long as you don't need to do it hundreds of times + per second. + """ return self._data @data.setter def data(self, data): + if isinstance(data, np.ndarray): + # check if a new buffer is required + if self._data.value.shape != data.shape: + # create new TextureArray + self._data = TextureArray(data) + + # cmap based on if rgb or grayscale + if self._data.value.ndim > 2: + self._cmap = None + + # must be None if RGB(A) + self._material.map = None + else: + if self.cmap is None: # have switched from RGBA -> grayscale image + # create default cmap + self._cmap = ImageCmap("plasma") + self._material.map = pygfx.TextureMap( + self._cmap.texture, + filter=self._cmap_interpolation.value, + wrap="clamp-to-edge", + ) + + # remove tiles from the WorldObject -> Graphic map + self._remove_group_graphic_map(self.world_object) + + # clear image tiles + self.world_object.clear() + + # create new tiles + for tile in self._create_tiles(): + self.world_object.add(tile) + + # add new tiles to WorldObject -> Graphic map + self._add_group_graphic_map(self.world_object) + + return + self._data[:] = data @property @@ -232,8 +286,6 @@ def cmap(self) -> str | None: if self._cmap is not None: return self._cmap.value - return None - @cmap.setter def cmap(self, name: str): if self.data.value.ndim > 2: @@ -269,7 +321,7 @@ def interpolation(self, value: str): @property def cmap_interpolation(self) -> str: - """cmap interpolation method""" + """cmap interpolation method, 'linear' or 'nearest'. Used only for grayscale images""" return self._cmap_interpolation.value @cmap_interpolation.setter diff --git a/fastplotlib/graphics/image_volume.py b/fastplotlib/graphics/image_volume.py index db8f29eaa..3d2d064e8 100644 --- a/fastplotlib/graphics/image_volume.py +++ b/fastplotlib/graphics/image_volume.py @@ -113,7 +113,6 @@ def __init__( substep_size: float = 0.1, emissive: str | tuple | np.ndarray = (0, 0, 0), shininess: int = 30, - isolated_buffer: bool = True, **kwargs, ): """ @@ -170,11 +169,6 @@ def __init__( How shiny the specular highlight is; a higher value gives a sharper highlight. Used only if `mode` = "iso" - isolated_buffer: bool, default True - If True, initialize a buffer with the same shape as the input data and then set the data, useful if the - data arrays are ready-only such as memmaps. If False, the input array is itself used as the - buffer - useful if the array is large. - kwargs additional keyword arguments passed to :class:`.Graphic` @@ -188,7 +182,7 @@ def __init__( super().__init__(**kwargs) - world_object = pygfx.Group() + group = pygfx.Group() if isinstance(data, TextureArrayVolume): # share existing buffer @@ -196,7 +190,7 @@ def __init__( else: # create new texture array to manage buffer # texture array that manages the textures on the GPU that represent this image volume - self._data = TextureArrayVolume(data, isolated_buffer=isolated_buffer) + self._data = TextureArrayVolume(data) if (vmin is None) or (vmax is None): _vmin, _vmax = quick_min_max(self.data.value) @@ -210,18 +204,24 @@ def __init__( self._vmax = ImageVmax(vmax) self._interpolation = ImageInterpolation(interpolation) + self._cmap_interpolation = ImageCmapInterpolation(cmap_interpolation) - # TODO: I'm assuming RGB volume images aren't supported??? # use TextureMap for grayscale images self._cmap = ImageCmap(cmap) - self._cmap_interpolation = ImageCmapInterpolation(cmap_interpolation) - self._texture_map = pygfx.TextureMap( self._cmap.texture, filter=self._cmap_interpolation.value, wrap="clamp-to-edge", ) + if self._data.value.ndim not in (3, 4): + raise ValueError( + f"ImageVolumeGraphic `data` must have 3 dimensions for grayscale images, " + f"or 4 dimensions for RGB(A) images.\n" + f"You have passed a a data array with: {self._data.value.ndim} dimensions, " + f"and of shape: {self._data.value.shape}" + ) + self._plane = VolumeSlicePlane(plane) self._threshold = VolumeIsoThreshold(threshold) self._step_size = VolumeIsoStepSize(step_size) @@ -237,6 +237,15 @@ def __init__( self._mode = VolumeRenderMode(mode) + # create tiles + for tile in self._create_tiles(): + group.add(tile) + + self._set_world_object(group) + + def _create_tiles(self) -> list[_VolumeTile]: + tiles = list() + # iterate through each texture chunk and create # a _VolumeTile, offset the tile using the data indices for texture, chunk_index, data_slice in self._data: @@ -259,9 +268,9 @@ def __init__( vol.world.x = data_col_start vol.world.y = data_row_start - world_object.add(vol) + tiles.append(vol) - self._set_world_object(world_object) + return tiles @property def data(self) -> TextureArrayVolume: @@ -270,6 +279,21 @@ def data(self) -> TextureArrayVolume: @data.setter def data(self, data): + if isinstance(data, np.ndarray): + # check if a new buffer is required + if self._data.value.shape != data.shape: + # create new TextureArray + self._data = TextureArrayVolume(data) + + # clear image tiles + self.world_object.clear() + + # create new tiles + for tile in self._create_tiles(): + self.world_object.add(tile) + + return + self._data[:] = data @property @@ -283,7 +307,7 @@ def mode(self, mode: str): @property def cmap(self) -> str: - """Get or set colormap name""" + """Get or set colormap name, only used for grayscale images""" return self._cmap.value @cmap.setter diff --git a/fastplotlib/graphics/line.py b/fastplotlib/graphics/line.py index a4f42704f..bba10b10f 100644 --- a/fastplotlib/graphics/line.py +++ b/fastplotlib/graphics/line.py @@ -18,6 +18,7 @@ UniformColor, VertexCmap, SizeSpace, + UniformRotations, ) from ..utils import quick_min_max @@ -36,10 +37,9 @@ def __init__( data: Any, thickness: float = 2.0, colors: str | np.ndarray | Sequence = "w", - uniform_color: bool = False, cmap: str = None, cmap_transform: np.ndarray | Sequence = None, - isolated_buffer: bool = True, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", size_space: str = "screen", **kwargs, ): @@ -61,15 +61,19 @@ def __init__( specify colors as a single human-readable string, a single RGBA array, or a Sequence (array, tuple, or list) of strings or RGBA arrays - uniform_color: bool, default ``False`` - if True, uses a uniform buffer for the line color, - basically saves GPU VRAM when the entire line has a single color - cmap: str, optional Apply a colormap to the line instead of assigning colors manually, this overrides any argument passed to "colors". For supported colormaps see the ``cmap`` library catalogue: https://cmap-docs.readthedocs.io/en/stable/catalog/ + color_mode: one of "auto", "uniform", "vertex", default "auto" + "uniform" restricts to a single color for all line datapoints. + "vertex" allows independent colors per vertex. + For most cases you can keep it as "auto" and the `color_mode` is determineed automatically based on the + argument passed to `colors`. if `colors` represents a single color, then the mode is set to "uniform". + If `colors` represents a unique color per-datapoint, or if a cmap is provided, then `color_mode` is set to + "vertex". You can switch between "uniform" and "vertex" `color_mode` after creating the graphic. + cmap_transform: 1D array-like of numerical values, optional if provided, these values are used to map the colors from the cmap @@ -84,10 +88,9 @@ def __init__( super().__init__( data=data, colors=colors, - uniform_color=uniform_color, cmap=cmap, cmap_transform=cmap_transform, - isolated_buffer=isolated_buffer, + color_mode=color_mode, size_space=size_space, **kwargs, ) @@ -102,8 +105,8 @@ def __init__( aa = kwargs.get("alpha_mode", "auto") in ("blend", "weighted_blend") - if uniform_color: - geometry = pygfx.Geometry(positions=self._data.buffer) + if isinstance(self._colors, UniformColor): + geometry = pygfx.Geometry(positions=self._data._fpl_buffer) material = MaterialCls( aa=aa, thickness=self.thickness, @@ -123,7 +126,7 @@ def __init__( depth_compare="<=", ) geometry = pygfx.Geometry( - positions=self._data.buffer, colors=self._colors.buffer + positions=self._data._fpl_buffer, colors=self._colors._fpl_buffer ) world_object: pygfx.Line = pygfx.Line(geometry=geometry, material=material) diff --git a/fastplotlib/graphics/line_collection.py b/fastplotlib/graphics/line_collection.py index d08231f7d..351f3368e 100644 --- a/fastplotlib/graphics/line_collection.py +++ b/fastplotlib/graphics/line_collection.py @@ -1,3 +1,5 @@ +from itertools import repeat +from numbers import Number from typing import * import numpy as np @@ -105,8 +107,11 @@ def thickness(self) -> np.ndarray: return np.asarray([g.thickness for g in self]) @thickness.setter - def thickness(self, values: np.ndarray | list[float]): - if not len(values) == len(self): + def thickness(self, values: float | Sequence[float]): + if isinstance(values, Number): + values = repeat(values, len(self)) + + elif not len(values) == len(self): raise IndexError for g, v in zip(self, values): @@ -128,14 +133,13 @@ def __init__( data: np.ndarray | List[np.ndarray], thickness: float | Sequence[float] = 2.0, colors: str | Sequence[str] | np.ndarray | Sequence[np.ndarray] = "w", - uniform_colors: bool = False, cmap: Sequence[str] | str = None, cmap_transform: np.ndarray | List = None, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", name: str = None, names: list[str] = None, metadata: Any = None, metadatas: Sequence[Any] | np.ndarray = None, - isolated_buffer: bool = True, kwargs_lines: list[dict] = None, **kwargs, ): @@ -170,6 +174,9 @@ def __init__( cmap_transform: 1D array-like of numerical values, optional if provided, these values are used to map the colors from the cmap + color_mode: one of "auto", "uniform", "vertex", default "auto" + The color mode for each line in the collection. See `color_mode` in :class:`.LineGraphic` for details. + name: str, optional name of the line collection as a whole @@ -320,11 +327,10 @@ def __init__( data=d, thickness=_s, colors=_c, - uniform_color=uniform_colors, cmap=_cmap, + color_mode=color_mode, name=_name, metadata=_m, - isolated_buffer=isolated_buffer, **kwargs_lines, ) @@ -560,7 +566,6 @@ def __init__( names: list[str] = None, metadata: Any = None, metadatas: Sequence[Any] | np.ndarray = None, - isolated_buffer: bool = True, separation: float = 10.0, separation_axis: str = "y", kwargs_lines: list[dict] = None, @@ -634,7 +639,6 @@ def __init__( names=names, metadata=metadata, metadatas=metadatas, - isolated_buffer=isolated_buffer, kwargs_lines=kwargs_lines, **kwargs, ) diff --git a/fastplotlib/graphics/mesh.py b/fastplotlib/graphics/mesh.py index 0e1ac42a3..efe03c57b 100644 --- a/fastplotlib/graphics/mesh.py +++ b/fastplotlib/graphics/mesh.py @@ -38,7 +38,6 @@ def __init__( mapcoords: Any = None, cmap: str | dict | pygfx.Texture | pygfx.TextureMap | np.ndarray = None, clim: tuple[float, float] = None, - isolated_buffer: bool = True, **kwargs, ): """ @@ -77,12 +76,6 @@ def __init__( Both 1D and 2D colormaps are supported, though the mapcoords has to match the dimensionality. An image can also be used, this is basically a 2D colormap. - isolated_buffer: bool, default True - If True, initialize a buffer with the same shape as the input data and then - set the data, useful if the data arrays are ready-only such as memmaps. - If False, the input array is itself used as the buffer - useful if the - array is large. In almost all cases this should be ``True``. - **kwargs passed to :class:`.Graphic` @@ -93,16 +86,12 @@ def __init__( if isinstance(positions, VertexPositions): self._positions = positions else: - self._positions = VertexPositions( - positions, isolated_buffer=isolated_buffer, property_name="positions" - ) + self._positions = VertexPositions(positions, property_name="positions") if isinstance(positions, MeshIndices): self._indices = indices else: - self._indices = MeshIndices( - indices, isolated_buffer=isolated_buffer, property_name="indices" - ) + self._indices = MeshIndices(indices, property_name="indices") self._cmap = MeshCmap(cmap) @@ -139,7 +128,7 @@ def __init__( ) geometry = pygfx.Geometry( - positions=self._positions.buffer, indices=self._indices._buffer + positions=self._positions.buffer, indices=self._indices._fpl_buffer ) valid_modes = ["basic", "phong", "slice"] diff --git a/fastplotlib/graphics/scatter.py b/fastplotlib/graphics/scatter.py index a2e696a82..b9cacf908 100644 --- a/fastplotlib/graphics/scatter.py +++ b/fastplotlib/graphics/scatter.py @@ -40,12 +40,12 @@ def __init__( self, data: Any, colors: str | np.ndarray | Sequence[float] | Sequence[str] = "w", - uniform_color: bool = False, cmap: str = None, cmap_transform: np.ndarray = None, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", mode: Literal["markers", "simple", "gaussian", "image"] = "markers", markers: str | np.ndarray | Sequence[str] = "o", - uniform_marker: bool = False, + uniform_marker: bool = True, custom_sdf: str = None, edge_colors: str | np.ndarray | pygfx.Color | Sequence[float] = "black", uniform_edge_color: bool = True, @@ -53,10 +53,9 @@ def __init__( image: np.ndarray = None, point_rotations: float | np.ndarray = 0, point_rotation_mode: Literal["uniform", "vertex", "curve"] = "uniform", - sizes: float | np.ndarray | Sequence[float] = 1, - uniform_size: bool = False, + sizes: float | np.ndarray | Sequence[float] = 5, + uniform_size: bool = True, size_space: str = "screen", - isolated_buffer: bool = True, **kwargs, ): """ @@ -72,18 +71,23 @@ def __init__( specify colors as a single human-readable string, a single RGBA array, or a Sequence (array, tuple, or list) of strings or RGBA arrays - uniform_color: bool, default False - if True, uses a uniform buffer for the scatter point colors. Useful if you need to - save GPU VRAM when all points have the same color. - cmap: str, optional apply a colormap to the scatter instead of assigning colors manually, this - overrides any argument passed to "colors". For supported colormaps see the - ``cmap`` library catalogue: https://cmap-docs.readthedocs.io/en/stable/catalog/ + overrides any argument passed to "colors". + For supported colormaps see the ``cmap`` library catalogue: + https://cmap-docs.readthedocs.io/en/stable/catalog/ cmap_transform: 1D array-like or list of numerical values, optional if provided, these values are used to map the colors from the cmap + color_mode: one of "auto", "uniform", "vertex", default "auto" + "uniform" restricts to a single color for all line datapoints. + "vertex" allows independent colors per vertex. + For most cases you can keep it as "auto" and the `color_mode` is determineed automatically based on the + argument passed to `colors`. if `colors` represents a single color, then the mode is set to "uniform". + If `colors` represents a unique color per-datapoint, or if a cmap is provided, then `color_mode` is set to + "vertex". You can switch between "uniform" and "vertex" `color_mode` after creating the graphic. + mode: one of: "markers", "simple", "gaussian", "image", default "markers" The scatter points mode, cannot be changed after the graphic has been created. @@ -103,9 +107,10 @@ def __init__( * Emojis: "❤️♠️♣️♦️💎💍✳️📍". * A string containing the value "custom". In this case, WGSL code defined by ``custom_sdf`` will be used. - uniform_marker: bool, default False - Use the same marker for all points. Only valid when `mode` is "markers". Useful if you need to use - the same marker for all points and want to save GPU RAM. + uniform_marker: bool, default ``True`` + If ``True``, use the same marker for all points. Only valid when `mode` is "markers". + Useful if you need to use the same marker for all points and want to save GPU RAM. If ``False``, you can + set per-vertex markers. custom_sdf: str = None, The SDF code for the marker shape when the marker is set to custom. @@ -125,8 +130,9 @@ def __init__( edge_colors: str | np.ndarray | pygfx.Color | Sequence[float], default "black" edge color of the markers, used when `mode` is "markers" - uniform_edge_color: bool, default True - Set the same edge color for all markers. Useful for saving GPU RAM. + uniform_edge_color: bool, default ``True`` + Set the same edge color for all markers. Useful for saving GPU RAM. Set to ``False`` for per-vertex edge + colors edge_width: float = 1.0, Width of the marker edges. used when `mode` is "markers". @@ -147,17 +153,13 @@ def __init__( sizes: float or iterable of float, optional, default 1.0 sizes of the scatter points - uniform_size: bool, default False - if True, uses a uniform buffer for the scatter point sizes. Useful if you need to - save GPU VRAM when all points have the same size. + uniform_size: bool, default ``False`` + if ``True``, uses a uniform buffer for the scatter point sizes. Useful if you need to + save GPU VRAM when all points have the same size. Set to ``False`` if you need per-vertex sizes. size_space: str, default "screen" coordinate space in which the size is expressed, one of ("screen", "world", "model") - isolated_buffer: bool, default True - whether the buffers should be isolated from the user input array. - Generally always ``True``, ``False`` is for rare advanced use if you have large arrays. - kwargs passed to :class:`.Graphic` @@ -166,17 +168,16 @@ def __init__( super().__init__( data=data, colors=colors, - uniform_color=uniform_color, cmap=cmap, cmap_transform=cmap_transform, - isolated_buffer=isolated_buffer, + color_mode=color_mode, size_space=size_space, **kwargs, ) n_datapoints = self.data.value.shape[0] - geo_kwargs = {"positions": self._data.buffer} + geo_kwargs = {"positions": self._data._fpl_buffer} aa = kwargs.get("alpha_mode", "auto") in ("blend", "weighted_blend") @@ -214,7 +215,7 @@ def __init__( self._markers = VertexMarkers(markers, n_datapoints) - geo_kwargs["markers"] = self._markers.buffer + geo_kwargs["markers"] = self._markers._fpl_buffer if edge_colors is None: # interpret as no edge color @@ -237,7 +238,7 @@ def __init__( edge_colors, n_datapoints, property_name="edge_colors" ) material_kwargs["edge_color_mode"] = pygfx.ColorMode.vertex - geo_kwargs["edge_colors"] = self._edge_colors.buffer + geo_kwargs["edge_colors"] = self._edge_colors._fpl_buffer self._edge_width = EdgeWidth(edge_width) material_kwargs["edge_width"] = self._edge_width.value @@ -274,12 +275,12 @@ def __init__( self._size_space = SizeSpace(size_space) - if uniform_color: + if isinstance(self._colors, UniformColor): material_kwargs["color_mode"] = pygfx.ColorMode.uniform material_kwargs["color"] = self.colors else: material_kwargs["color_mode"] = pygfx.ColorMode.vertex - geo_kwargs["colors"] = self.colors.buffer + geo_kwargs["colors"] = self.colors._fpl_buffer if uniform_size: material_kwargs["size_mode"] = pygfx.SizeMode.uniform @@ -288,14 +289,14 @@ def __init__( else: material_kwargs["size_mode"] = pygfx.SizeMode.vertex self._sizes = VertexPointSizes(sizes, n_datapoints=n_datapoints) - geo_kwargs["sizes"] = self.sizes.buffer + geo_kwargs["sizes"] = self.sizes._fpl_buffer match point_rotation_mode: case pygfx.enums.RotationMode.vertex: self._point_rotations = VertexRotations( point_rotations, n_datapoints=n_datapoints ) - geo_kwargs["rotations"] = self._point_rotations.buffer + geo_kwargs["rotations"] = self._point_rotations._fpl_buffer case pygfx.enums.RotationMode.uniform: self._point_rotations = UniformRotations(point_rotations) @@ -338,10 +339,8 @@ def markers(self, value: str | np.ndarray[str] | Sequence[str]): raise AttributeError( f"scatter plot is: {self.mode}. The mode must be 'markers' to set the markers" ) - if isinstance(self._markers, VertexMarkers): - self._markers[:] = value - elif isinstance(self._markers, UniformMarker): - self._markers.set_value(self, value) + + self._markers.set_value(self, value) @property def edge_colors(self) -> str | pygfx.Color | VertexColors | None: @@ -359,12 +358,7 @@ def edge_colors(self, value: str | np.ndarray | Sequence[str] | Sequence[float]) raise AttributeError( f"scatter plot is: {self.mode}. The mode must be 'markers' to set the edge_colors" ) - - if isinstance(self._edge_colors, VertexColors): - self._edge_colors[:] = value - - elif isinstance(self._edge_colors, UniformEdgeColor): - self._edge_colors.set_value(self, value) + self._edge_colors.set_value(self, value) @property def edge_width(self) -> float | None: @@ -406,11 +400,7 @@ def point_rotations(self, value: float | np.ndarray[float]): f"it be 'uniform' or 'vertex' to set the `point_rotations`" ) - if isinstance(self._point_rotations, VertexRotations): - self._point_rotations[:] = value - - elif isinstance(self._point_rotations, UniformRotations): - self._point_rotations.set_value(self, value) + self._point_rotations.set_value(self, value) @property def image(self) -> TextureArray | None: @@ -437,8 +427,4 @@ def sizes(self) -> VertexPointSizes | float: @sizes.setter def sizes(self, value): - if isinstance(self._sizes, VertexPointSizes): - self._sizes[:] = value - - elif isinstance(self._sizes, UniformSize): - self._sizes.set_value(self, value) + self._sizes.set_value(self, value) diff --git a/fastplotlib/graphics/scatter_collection.py b/fastplotlib/graphics/scatter_collection.py new file mode 100644 index 000000000..f0993dd46 --- /dev/null +++ b/fastplotlib/graphics/scatter_collection.py @@ -0,0 +1,672 @@ +from itertools import repeat +from numbers import Number +from typing import * + +import numpy as np + +import pygfx + +from ..utils import parse_cmap_values +from ._collection_base import CollectionIndexer, GraphicCollection, CollectionFeature +from .scatter import ScatterGraphic +from .selectors import ( + LinearRegionSelector, + LinearSelector, + RectangleSelector, + PolygonSelector, +) + + +class _ScatterCollectionProperties: + """Mix-in class for ScatterCollection properties""" + + @property + def colors(self) -> CollectionFeature: + """get or set colors of scatters in the collection""" + return CollectionFeature(self.graphics, "colors") + + @colors.setter + def colors(self, values: str | np.ndarray | tuple[float] | list[float] | list[str]): + if isinstance(values, str): + # set colors of all scatter to one str color + for g in self: + g.colors = values + return + + elif all(isinstance(v, str) for v in values): + # individual str colors for each scatter + if not len(values) == len(self): + raise IndexError + + for g, v in zip(self.graphics, values): + g.colors = v + + return + + if isinstance(values, np.ndarray): + if values.ndim == 2: + # assume individual colors for each + for g, v in zip(self, values): + g.colors = v + return + + elif len(values) == 4: + # assume RGBA + self.colors[:] = values + + else: + # assume individual colors for each + for g, v in zip(self, values): + g.colors = v + + @property + def data(self) -> CollectionFeature: + """get or set data of scatters in the collection""" + return CollectionFeature(self.graphics, "data") + + @data.setter + def data(self, values): + for g, v in zip(self, values): + g.data = v + + @property + def cmap(self) -> CollectionFeature: + """ + Get or set a cmap along the scatter collection. + + Optionally set using a tuple ("cmap", ) to set the transform. + Example: + + scatter_collection.cmap = ("jet", sine_transform_vals, 0.7) + + """ + return CollectionFeature(self.graphics, "cmap") + + @cmap.setter + def cmap(self, args): + if isinstance(args, str): + name = args + transform = None + elif len(args) == 1: + name = args[0] + transform = None + elif len(args) == 2: + name, transform = args + else: + raise ValueError( + "Too many values for cmap (note that alpha is deprecated, set alpha on the graphic instead)" + ) + + self.colors = parse_cmap_values( + n_colors=len(self), cmap_name=name, transform=transform + ) + + @property + def markers(self) -> CollectionFeature: + """get or set markers of scatters in the collection""" + return CollectionFeature(self.graphics, "markers") + + @markers.setter + def markers(self, values: str | Sequence[str]): + if isinstance(values, str): + values = repeat(values, len(self)) + + elif len(values) != len(self): + raise IndexError("len(markers) must be the same as the number of ScatterGraphics in the collection") + + for g, v in zip(self, values): + g.markers = v + + @property + def sizes(self) -> CollectionFeature: + """get or set sizes of scatter points in the collection""" + return CollectionFeature(self.graphics, "sizes") + + @sizes.setter + def sizes(self, values): + if isinstance(values, Number): + values = repeat(values, len(self)) + + elif len(values) != len(self): + raise IndexError("len(sizes) must be the same as the number of ScatterGraphics in the collection") + + for g, v in zip(self, values): + g.sizes = v + + +class ScatterCollectionIndexer(CollectionIndexer, _ScatterCollectionProperties): + """Indexer for scatter collections""" + pass + + +class ScatterCollection(GraphicCollection, _ScatterCollectionProperties): + _child_type = ScatterGraphic + _indexer = ScatterCollectionIndexer + + def __init__( + self, + data: np.ndarray | List[np.ndarray], + colors: str | Sequence[str] | np.ndarray | Sequence[np.ndarray] = "w", + cmap: Sequence[str] | str = None, + cmap_transform: np.ndarray | List = None, + sizes: float | Sequence[float] = 5.0, + uniform_size: bool = True, + markers: np.ndarray | Sequence[str] = None, + uniform_marker: bool = True, + edge_width: float = 1.0, + name: str = None, + names: list[str] = None, + metadata: Any = None, + metadatas: Sequence[Any] | np.ndarray = None, + **kwargs, + ): + """ + Create a collection of :class:`.ScatterGraphic` + + Parameters + ---------- + data: list of array-like + List or array-like of multiple line data to plot + + | if ``list`` each item in the list must be a 1D, 2D, or 3D numpy array + | if array-like, must be of shape [n_lines, n_points_line, y | xy | xyz] + + colors: str, RGBA array, Iterable of RGBA array, or Iterable of str, default "w" + | if single ``str`` such as "w", "r", "b", etc, represents a single color for all lines + | if single ``RGBA array`` (tuple or list of size 4), represents a single color for all lines + | if ``list`` of ``str``, represents color for each individual line, example ["w", "b", "r",...] + | if ``RGBA array`` of shape [data_size, 4], represents a single RGBA array for each line + + cmap: Iterable of str or str, optional + | if ``str``, single cmap will be used for all lines + | if ``list`` of ``str``, each cmap will apply to the individual lines + + .. note:: + ``cmap`` overrides any arguments passed to ``colors`` + + cmap_transform: 1D array-like of numerical values, optional + if provided, these values are used to map the colors from the cmap + + name: str, optional + name of the line collection as a whole + + names: list[str], optional + names of the individual lines in the collection, ``len(names)`` must equal ``len(data)`` + + metadata: Any + meatadata associated with the collection as a whole + + metadatas: Iterable or array + metadata for each individual line associated with this collection, this is for the user to manage. + ``len(metadata)`` must be same as ``len(data)`` + + kwargs_lines: list[dict], optional + list of kwargs passed to the individual lines, ``len(kwargs_lines)`` must equal ``len(data)`` + + kwargs_collection + kwargs for the collection, passed to GraphicCollection + + """ + + super().__init__(name=name, metadata=metadata, **kwargs) + + if names is not None: + if len(names) != len(data): + raise ValueError( + f"len(names) != len(data)\n{len(names)} != {len(data)}" + ) + + if metadatas is not None: + if len(metadatas) != len(data): + raise ValueError( + f"len(metadata) != len(data)\n{len(metadatas)} != {len(data)}" + ) + + self._cmap_transform = cmap_transform + self._cmap_str = cmap + + # cmap takes priority over colors + if cmap is not None: + # cmap across lines + if isinstance(cmap, str): + colors = parse_cmap_values( + n_colors=len(data), cmap_name=cmap, transform=cmap_transform + ) + single_color = False + cmap = None + + elif isinstance(cmap, (tuple, list)): + if len(cmap) != len(data): + raise ValueError( + "cmap argument must be a single cmap or a list of cmaps " + "with the same length as the data" + ) + single_color = False + else: + raise ValueError( + "cmap argument must be a single cmap or a list of cmaps " + "with the same length as the data" + ) + else: + if isinstance(colors, np.ndarray): + # single color for all lines in the collection as RGBA + if colors.shape in [(3,), (4,)]: + single_color = True + + # colors specified for each line as array of shape [n_lines, RGBA] + elif colors.shape == (len(data), 4): + single_color = False + + else: + raise ValueError( + f"numpy array colors argument must be of shape (4,) or (n_lines, 4)." + f"You have pass the following shape: {colors.shape}" + ) + + elif isinstance(colors, str): + if colors == "random": + colors = np.random.rand(len(data), 3) + single_color = False + else: + # parse string color + single_color = True + colors = pygfx.Color(colors) + + elif isinstance(colors, (tuple, list)): + if len(colors) == 4: + # single color specified as (R, G, B, A) tuple or list + if all([isinstance(c, (float, int)) for c in colors]): + single_color = True + + elif len(colors) == len(data): + # colors passed as list/tuple of colors, such as list of string + single_color = False + + else: + raise ValueError( + "tuple or list colors argument must be a single color represented as [R, G, B, A], " + "or must be a tuple/list of colors represented by a string with the same length as the data" + ) + + self._set_world_object(pygfx.Group()) + + for i, d in enumerate(data): + if cmap is None: + _cmap = None + + if single_color: + _c = colors + else: + _c = colors[i] + else: + _cmap = cmap[i] + _c = None + + if metadatas is not None: + _m = metadatas[i] + else: + _m = None + + if names is not None: + _name = names[i] + else: + _name = None + + if markers is not None: + if isinstance(markers, (tuple, list, np.ndarray)): + markers_ = markers[i] + else: + markers_ = markers + else: + markers_ = "o" + + if sizes is not None: + if isinstance(sizes, (tuple, list, np.ndarray)): + sizes_ = sizes[i] + else: + sizes_ = sizes + else: + sizes_ = 5 + + lg = ScatterGraphic( + data=d, + colors=_c, + sizes=sizes_, + markers=markers_, + cmap=_cmap, + name=_name, + metadata=_m, + uniform_marker=uniform_marker, + uniform_size=uniform_size, + edge_width=edge_width, + **kwargs, + ) + + self.add_graphic(lg) + + def __getitem__(self, item) -> ScatterCollectionIndexer: + return super().__getitem__(item) + + def add_linear_selector( + self, selection: float = None, padding: float = 0.0, axis: str = "x", **kwargs + ) -> LinearSelector: + """ + Adds a linear selector. + + Parameters + ---------- + Parameters + ---------- + selection: float, optional + selected point on the linear selector, computed from data if not provided + + axis: str, default "x" + axis that the selector resides on + + padding: float, default 0.0 + Extra padding to extend the linear selector along the orthogonal axis to make it easier to interact with. + + kwargs + passed to :class:`.LinearSelector` + + Returns + ------- + LinearSelector + + """ + + bounds_init, limits, size, center = self._get_linear_selector_init_args( + axis, padding + ) + + if selection is None: + selection = bounds_init[0] + + selector = LinearSelector( + selection=selection, + limits=limits, + axis=axis, + parent=self, + **kwargs, + ) + + self._plot_area.add_graphic(selector, center=False) + + return selector + + def add_linear_region_selector( + self, + selection: tuple[float, float] = None, + padding: float = 0.0, + axis: str = "x", + **kwargs, + ) -> LinearRegionSelector: + """ + Add a :class:`.LinearRegionSelector`. Selectors are just ``Graphic`` objects, so you can manage, + remove, or delete them from a plot area just like any other ``Graphic``. + + Parameters + ---------- + selection: (float, float), optional + the starting bounds of the linear region selector, computed from data if not provided + + axis: str, default "x" + axis that the selector resides on + + padding: float, default 0.0 + Extra padding to extend the linear region selector along the orthogonal axis to make it easier to interact with. + + kwargs + passed to ``LinearRegionSelector`` + + Returns + ------- + LinearRegionSelector + linear selection graphic + + """ + + bounds_init, limits, size, center = self._get_linear_selector_init_args( + axis, padding + ) + + if selection is None: + selection = bounds_init + + # create selector + selector = LinearRegionSelector( + selection=selection, + limits=limits, + size=size, + center=center, + axis=axis, + parent=self, + **kwargs, + ) + + self._plot_area.add_graphic(selector, center=False) + + # PlotArea manages this for garbage collection etc. just like all other Graphics + # so we should only work with a proxy on the user-end + return selector + + def add_rectangle_selector( + self, + selection: tuple[float, float, float] = None, + **kwargs, + ) -> RectangleSelector: + """ + Add a :class:`.RectangleSelector`. Selectors are just ``Graphic`` objects, so you can manage, + remove, or delete them from a plot area just like any other ``Graphic``. + + Parameters + ---------- + selection: (float, float, float, float), optional + initial (xmin, xmax, ymin, ymax) of the selection + """ + bbox = self.world_object.get_world_bounding_box() + + xdata = np.array(self.data[:, 0]) + xmin, xmax = (np.nanmin(xdata), np.nanmax(xdata)) + value_25px = (xmax - xmin) / 4 + + ydata = np.array(self.data[:, 1]) + ymin = np.floor(ydata.min()).astype(int) + + ymax = np.ptp(bbox[:, 1]) + + if selection is None: + selection = (xmin, value_25px, ymin, ymax) + + limits = (xmin, xmax, ymin - (ymax * 1.5 - ymax), ymax * 1.5) + + selector = RectangleSelector( + selection=selection, + limits=limits, + parent=self, + **kwargs, + ) + + self._plot_area.add_graphic(selector, center=False) + + return selector + + def add_polygon_selector( + self, + selection: List[tuple[float, float]] = None, + **kwargs, + ) -> PolygonSelector: + """ + Add a :class:`.PolygonSelector`. Selectors are just ``Graphic`` objects, so you can manage, + remove, or delete them from a plot area just like any other ``Graphic``. + + Parameters + ---------- + selection: List of positions, optional + Initial points for the polygon. If not given or None, you'll start drawing the selection (clicking adds points to the polygon). + """ + bbox = self.world_object.get_world_bounding_box() + + xdata = np.array(self.data[:, 0]) + xmin, xmax = (np.nanmin(xdata), np.nanmax(xdata)) + + ydata = np.array(self.data[:, 1]) + ymin = np.floor(ydata.min()).astype(int) + + ymax = np.ptp(bbox[:, 1]) + + limits = (xmin, xmax, ymin - (ymax * 1.5 - ymax), ymax * 1.5) + + selector = PolygonSelector( + selection, + limits, + parent=self, + **kwargs, + ) + + self._plot_area.add_graphic(selector, center=False) + + return selector + + def _get_linear_selector_init_args(self, axis, padding): + # use bbox to get size and center + bbox = self.world_object.get_world_bounding_box() + + if axis == "x": + xdata = np.array(self.data[:, 0]) + xmin, xmax = (np.nanmin(xdata), np.nanmax(xdata)) + value_25p = (xmax - xmin) / 4 + + bounds = (xmin, value_25p) + limits = (xmin, xmax) + # size from orthogonal axis + size = np.ptp(bbox[:, 1]) * 1.5 + # center on orthogonal axis + center = bbox[:, 1].mean() + + elif axis == "y": + ydata = np.array(self.data[:, 1]) + xmin, xmax = (np.nanmin(ydata), np.nanmax(ydata)) + value_25p = (xmax - xmin) / 4 + + bounds = (xmin, value_25p) + limits = (xmin, xmax) + + size = np.ptp(bbox[:, 0]) * 1.5 + # center on orthogonal axis + center = bbox[:, 0].mean() + + return bounds, limits, size, center + + +axes = {"x": 0, "y": 1, "z": 2} + + +class ScatterStack(ScatterCollection): + def __init__( + self, + data: np.ndarray | List[np.ndarray], + colors: str | Sequence[str] | np.ndarray | Sequence[np.ndarray] = "w", + cmap: Sequence[str] | str = None, + cmap_transform: np.ndarray | List = None, + name: str = None, + names: list[str] = None, + metadata: Any = None, + metadatas: Sequence[Any] | np.ndarray = None, + separation: float = 0.0, + separation_axis: str = "y", + **kwargs, + ): + """ + Create a stack of :class:`.LineGraphic` that are separated along the "x" or "y" axis. + + Parameters + ---------- + data: list of array-like + List or array-like of multiple line data to plot + + | if ``list`` each item in the list must be a 1D, 2D, or 3D numpy array + | if array-like, must be of shape [n_lines, n_points_line, y | xy | xyz] + + thickness: float or Iterable of float, default 2.0 + | if ``float``, single thickness will be used for all lines + | if ``list`` of ``float``, each value will apply to the individual lines + + colors: str, RGBA array, Iterable of RGBA array, or Iterable of str, default "w" + | if single ``str`` such as "w", "r", "b", etc, represents a single color for all lines + | if single ``RGBA array`` (tuple or list of size 4), represents a single color for all lines + | if ``list`` of ``str``, represents color for each individual line, example ["w", "b", "r",...] + | if ``RGBA array`` of shape [data_size, 4], represents a single RGBA array for each line + + cmap: Iterable of str or str, optional + | if ``str``, single cmap will be used for all lines + | if ``list`` of ``str``, each cmap will apply to the individual lines + + .. note:: + ``cmap`` overrides any arguments passed to ``colors`` + + cmap_transform: 1D array-like of numerical values, optional + if provided, these values are used to map the colors from the cmap + + name: str, optional + name of the line collection as a whole + + names: list[str], optional + names of the individual lines in the collection, ``len(names)`` must equal ``len(data)`` + + metadata: Any + metadata associated with the collection as a whole + + metadatas: Iterable or array + metadata for each individual line associated with this collection, this is for the user to manage. + ``len(metadata)`` must be same as ``len(data)`` + + separation: float, default 0.0 + space in between each line graphic in the stack + + separation_axis: str, default "y" + axis in which the line graphics in the stack should be separated + + kwargs_collection + kwargs for the collection, passed to GraphicCollection + + """ + super().__init__( + data=data, + colors=colors, + cmap=cmap, + cmap_transform=cmap_transform, + name=name, + names=names, + metadata=metadata, + metadatas=metadatas, + **kwargs, + ) + + self._sepration_axis = separation_axis + self._separation = separation + + self.separation = separation + + @property + def separation(self) -> float: + """distance between each line in the stack, in world space""" + return self._separation + + @separation.setter + def separation(self, value: float): + separation = float(value) + + axis_zero = 0 + for i, line in enumerate(self.graphics): + if self._sepration_axis == "x": + line.offset = (axis_zero, *line.offset[1:]) + + elif self._sepration_axis == "y": + line.offset = (line.offset[0], axis_zero, line.offset[2]) + + axis_zero = ( + axis_zero + line.data.value[:, axes[self._sepration_axis]].max() + separation + ) + + self._separation = value diff --git a/fastplotlib/graphics/selectors/_linear.py b/fastplotlib/graphics/selectors/_linear.py index 0c956d57b..4ea454ee8 100644 --- a/fastplotlib/graphics/selectors/_linear.py +++ b/fastplotlib/graphics/selectors/_linear.py @@ -45,10 +45,8 @@ def limits(self, values: tuple[float, float]): # using `Real` here allows it to work with builtin `int` and `float` types, and numpy scaler types if len(values) != 2 or not all(map(lambda v: isinstance(v, Real), values)): raise TypeError("limits must be an iterable of two numeric values") - self._limits = tuple( - map(round, values) - ) # if values are close to zero things get weird so round them - self.selection._limits = self._limits + self._limits = np.asarray(values) # if values are close to zero things get weird so round them + self._selection._limits = self._limits @property def edge_color(self) -> pygfx.Color: diff --git a/fastplotlib/graphics/selectors/_linear_region.py b/fastplotlib/graphics/selectors/_linear_region.py index 70a8dffa8..8a8583ae9 100644 --- a/fastplotlib/graphics/selectors/_linear_region.py +++ b/fastplotlib/graphics/selectors/_linear_region.py @@ -472,9 +472,9 @@ def _move_graphic(self, move_info: MoveInfo): if move_info.source == self._edges[0]: # change only left or bottom bound new_min = min(cur_min + delta, cur_max) - self._selection.set_value(self, (new_min, cur_max)) + self._selection.set_value(self, (new_min, cur_max), change="min") elif move_info.source == self._edges[1]: # change only right or top bound new_max = max(cur_max + delta, cur_min) - self._selection.set_value(self, (cur_min, new_max)) + self._selection.set_value(self, (cur_min, new_max), change="max") diff --git a/fastplotlib/graphics/utils.py b/fastplotlib/graphics/utils.py index 6be5aefc4..f32d80809 100644 --- a/fastplotlib/graphics/utils.py +++ b/fastplotlib/graphics/utils.py @@ -1,13 +1,16 @@ from contextlib import contextmanager +from typing import Callable, Iterable from ._base import Graphic @contextmanager -def pause_events(*graphics: Graphic): +def pause_events(*graphics: Graphic, event_handlers: Iterable[Callable] = None): """ Context manager for pausing Graphic events. + Optionally pass in only specific event handlers which are blocked. Other events for the graphic will not be blocked. + Examples -------- @@ -30,8 +33,14 @@ def pause_events(*graphics: Graphic): original_vals = [g.block_events for g in graphics] for g in graphics: - g.block_events = True + if event_handlers is not None: + g.block_handlers.extend([e for e in event_handlers]) + else: + g.block_events = True yield for g, value in zip(graphics, original_vals): - g.block_events = value + if event_handlers is not None: + g.block_handlers.clear() + else: + g.block_events = value diff --git a/fastplotlib/layouts/_figure.py b/fastplotlib/layouts/_figure.py index 28b7c4a49..013ce847c 100644 --- a/fastplotlib/layouts/_figure.py +++ b/fastplotlib/layouts/_figure.py @@ -548,7 +548,7 @@ def _render(self, draw=True): # call the animation functions before render self._call_animate_functions(self._animate_funcs_pre) - for subplot in self: + for subplot in self._subplots.ravel(): subplot._render() # overlay render pass @@ -615,14 +615,16 @@ def show( sidecar_kwargs = dict() # flip y-axis if ImageGraphics are present - for subplot in self: + for subplot in self._subplots.ravel(): for g in subplot.graphics: if isinstance(g, ImageGraphic): - subplot.camera.local.scale_y *= -1 + if subplot.camera.local.scale_y == 1: + # if it's 1 it's likely not been touched manually before show was called + subplot.camera.local.scale_y = -1 break if autoscale: - for subplot in self: + for subplot in self._subplots.ravel(): if maintain_aspect is None: _maintain_aspect = subplot.camera.maintain_aspect else: @@ -631,7 +633,7 @@ def show( # set axes visibility if False if not axes_visible: - for subplot in self: + for subplot in self._subplots.ravel(): subplot.axes.visible = False # parse based on canvas type @@ -655,15 +657,15 @@ def show( elif self.canvas.__class__.__name__ == "OffscreenRenderCanvas": # for test and docs gallery screenshots self._fpl_reset_layout() - for subplot in self: + for subplot in self._subplots.ravel(): subplot.axes.update_using_camera() # render call is blocking only on github actions for some reason, # but not for rtd build, this is a workaround # for CI tests, the render call works if it's in test_examples # but it is necessary for the gallery images too so that's why this check is here - if "RTD_BUILD" in os.environ.keys(): - if os.environ["RTD_BUILD"] == "1": + if "DOCS_BUILD" in os.environ.keys(): + if os.environ["DOCS_BUILD"] == "1": self._render() else: # assume GLFW @@ -779,7 +781,7 @@ def clear_animations(self, removal: str = None): def clear(self): """Clear all Subplots""" - for subplot in self: + for subplot in self._subplots.ravel(): subplot.clear() def export_numpy(self, rgb: bool = False) -> np.ndarray: @@ -938,18 +940,20 @@ def __getitem__(self, index: str | int | tuple[int, int]) -> Subplot: return subplot raise IndexError(f"no subplot with given name: {index}") + if isinstance(index, (int, np.integer)): + return self._subplots.ravel()[index] + if isinstance(self.layout, GridLayout): return self._subplots[index[0], index[1]] - return self._subplots[index] + raise TypeError( + f"Can index figure using subplot name, numerical subplot index, or a " + f"tuple[int, int] if the layout is a grid" + ) def __iter__(self): - self._current_iter = iter(range(len(self))) - return self - - def __next__(self) -> Subplot: - pos = self._current_iter.__next__() - return self._subplots.ravel()[pos] + for subplot in self._subplots.ravel(): + yield subplot def __len__(self): """number of subplots""" @@ -964,6 +968,6 @@ def __repr__(self): return ( f"fastplotlib.{self.__class__.__name__}" f" Subplots:\n" - f"\t{newline.join(subplot.__str__() for subplot in self)}" + f"\t{newline.join(subplot.__str__() for subplot in self._subplots.ravel())}" f"\n" ) diff --git a/fastplotlib/layouts/_graphic_methods_mixin.py b/fastplotlib/layouts/_graphic_methods_mixin.py index 06a4c7517..1fbf337e2 100644 --- a/fastplotlib/layouts/_graphic_methods_mixin.py +++ b/fastplotlib/layouts/_graphic_methods_mixin.py @@ -33,8 +33,7 @@ def add_image( cmap: str = "plasma", interpolation: str = "nearest", cmap_interpolation: str = "linear", - isolated_buffer: bool = True, - **kwargs, + **kwargs ) -> ImageGraphic: """ @@ -62,12 +61,6 @@ def add_image( cmap_interpolation: str, optional, default "linear" colormap interpolation method, one of "nearest" or "linear" - isolated_buffer: bool, default True - If True, initialize a buffer with the same shape as the input data and then - set the data, useful if the data arrays are ready-only such as memmaps. - If False, the input array is itself used as the buffer - useful if the - array is large. - kwargs: additional keyword arguments passed to :class:`.Graphic` @@ -81,8 +74,7 @@ def add_image( cmap, interpolation, cmap_interpolation, - isolated_buffer, - **kwargs, + **kwargs ) def add_image_volume( @@ -100,8 +92,7 @@ def add_image_volume( substep_size: float = 0.1, emissive: str | tuple | numpy.ndarray = (0, 0, 0), shininess: int = 30, - isolated_buffer: bool = True, - **kwargs, + **kwargs ) -> ImageVolumeGraphic: """ @@ -158,11 +149,6 @@ def add_image_volume( How shiny the specular highlight is; a higher value gives a sharper highlight. Used only if `mode` = "iso" - isolated_buffer: bool, default True - If True, initialize a buffer with the same shape as the input data and then set the data, useful if the - data arrays are ready-only such as memmaps. If False, the input array is itself used as the - buffer - useful if the array is large. - kwargs additional keyword arguments passed to :class:`.Graphic` @@ -183,8 +169,7 @@ def add_image_volume( substep_size, emissive, shininess, - isolated_buffer, - **kwargs, + **kwargs ) def add_line_collection( @@ -192,16 +177,15 @@ def add_line_collection( data: Union[numpy.ndarray, List[numpy.ndarray]], thickness: Union[float, Sequence[float]] = 2.0, colors: Union[str, Sequence[str], numpy.ndarray, Sequence[numpy.ndarray]] = "w", - uniform_colors: bool = False, cmap: Union[Sequence[str], str] = None, cmap_transform: Union[numpy.ndarray, List] = None, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", name: str = None, names: list[str] = None, metadata: Any = None, metadatas: Union[Sequence[Any], numpy.ndarray] = None, - isolated_buffer: bool = True, kwargs_lines: list[dict] = None, - **kwargs, + **kwargs ) -> LineCollection: """ @@ -235,6 +219,9 @@ def add_line_collection( cmap_transform: 1D array-like of numerical values, optional if provided, these values are used to map the colors from the cmap + color_mode: one of "auto", "uniform", "vertex", default "auto" + The color mode for each line in the collection. See `color_mode` in :class:`.LineGraphic` for details. + name: str, optional name of the line collection as a whole @@ -261,16 +248,15 @@ def add_line_collection( data, thickness, colors, - uniform_colors, cmap, cmap_transform, + color_mode, name, names, metadata, metadatas, - isolated_buffer, kwargs_lines, - **kwargs, + **kwargs ) def add_line( @@ -278,12 +264,11 @@ def add_line( data: Any, thickness: float = 2.0, colors: Union[str, numpy.ndarray, Sequence] = "w", - uniform_color: bool = False, cmap: str = None, cmap_transform: Union[numpy.ndarray, Sequence] = None, - isolated_buffer: bool = True, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", size_space: str = "screen", - **kwargs, + **kwargs ) -> LineGraphic: """ @@ -304,15 +289,19 @@ def add_line( specify colors as a single human-readable string, a single RGBA array, or a Sequence (array, tuple, or list) of strings or RGBA arrays - uniform_color: bool, default ``False`` - if True, uses a uniform buffer for the line color, - basically saves GPU VRAM when the entire line has a single color - cmap: str, optional Apply a colormap to the line instead of assigning colors manually, this overrides any argument passed to "colors". For supported colormaps see the ``cmap`` library catalogue: https://cmap-docs.readthedocs.io/en/stable/catalog/ + color_mode: one of "auto", "uniform", "vertex", default "auto" + "uniform" restricts to a single color for all line datapoints. + "vertex" allows independent colors per vertex. + For most cases you can keep it as "auto" and the `color_mode` is determineed automatically based on the + argument passed to `colors`. if `colors` represents a single color, then the mode is set to "uniform". + If `colors` represents a unique color per-datapoint, or if a cmap is provided, then `color_mode` is set to + "vertex". You can switch between "uniform" and "vertex" `color_mode` after creating the graphic. + cmap_transform: 1D array-like of numerical values, optional if provided, these values are used to map the colors from the cmap @@ -329,12 +318,11 @@ def add_line( data, thickness, colors, - uniform_color, cmap, cmap_transform, - isolated_buffer, + color_mode, size_space, - **kwargs, + **kwargs ) def add_line_stack( @@ -348,11 +336,10 @@ def add_line_stack( names: list[str] = None, metadata: Any = None, metadatas: Union[Sequence[Any], numpy.ndarray] = None, - isolated_buffer: bool = True, separation: float = 10.0, separation_axis: str = "y", kwargs_lines: list[dict] = None, - **kwargs, + **kwargs ) -> LineStack: """ @@ -425,11 +412,10 @@ def add_line_stack( names, metadata, metadatas, - isolated_buffer, separation, separation_axis, kwargs_lines, - **kwargs, + **kwargs ) def add_mesh( @@ -448,8 +434,7 @@ def add_mesh( | numpy.ndarray ) = None, clim: tuple[float, float] = None, - isolated_buffer: bool = True, - **kwargs, + **kwargs ) -> MeshGraphic: """ @@ -488,12 +473,6 @@ def add_mesh( Both 1D and 2D colormaps are supported, though the mapcoords has to match the dimensionality. An image can also be used, this is basically a 2D colormap. - isolated_buffer: bool, default True - If True, initialize a buffer with the same shape as the input data and then - set the data, useful if the data arrays are ready-only such as memmaps. - If False, the input array is itself used as the buffer - useful if the - array is large. In almost all cases this should be ``True``. - **kwargs passed to :class:`.Graphic` @@ -509,8 +488,7 @@ def add_mesh( mapcoords, cmap, clim, - isolated_buffer, - **kwargs, + **kwargs ) def add_polygon( @@ -527,7 +505,7 @@ def add_polygon( | numpy.ndarray ) = None, clim: tuple[float, float] | None = None, - **kwargs, + **kwargs ) -> PolygonGraphic: """ @@ -570,16 +548,100 @@ def add_polygon( PolygonGraphic, data, mode, colors, mapcoords, cmap, clim, **kwargs ) + def add_scatter_collection( + self, + data: Union[numpy.ndarray, List[numpy.ndarray]], + colors: Union[str, Sequence[str], numpy.ndarray, Sequence[numpy.ndarray]] = "w", + cmap: Union[Sequence[str], str] = None, + cmap_transform: Union[numpy.ndarray, List] = None, + sizes: Union[float, Sequence[float]] = 5.0, + uniform_size: bool = True, + markers: Union[numpy.ndarray, Sequence[str]] = None, + uniform_marker: bool = True, + edge_width: float = 1.0, + name: str = None, + names: list[str] = None, + metadata: Any = None, + metadatas: Union[Sequence[Any], numpy.ndarray] = None, + **kwargs + ) -> ScatterCollection: + """ + + Create a collection of :class:`.ScatterGraphic` + + Parameters + ---------- + data: list of array-like + List or array-like of multiple line data to plot + + | if ``list`` each item in the list must be a 1D, 2D, or 3D numpy array + | if array-like, must be of shape [n_lines, n_points_line, y | xy | xyz] + + colors: str, RGBA array, Iterable of RGBA array, or Iterable of str, default "w" + | if single ``str`` such as "w", "r", "b", etc, represents a single color for all lines + | if single ``RGBA array`` (tuple or list of size 4), represents a single color for all lines + | if ``list`` of ``str``, represents color for each individual line, example ["w", "b", "r",...] + | if ``RGBA array`` of shape [data_size, 4], represents a single RGBA array for each line + + cmap: Iterable of str or str, optional + | if ``str``, single cmap will be used for all lines + | if ``list`` of ``str``, each cmap will apply to the individual lines + + .. note:: + ``cmap`` overrides any arguments passed to ``colors`` + + cmap_transform: 1D array-like of numerical values, optional + if provided, these values are used to map the colors from the cmap + + name: str, optional + name of the line collection as a whole + + names: list[str], optional + names of the individual lines in the collection, ``len(names)`` must equal ``len(data)`` + + metadata: Any + meatadata associated with the collection as a whole + + metadatas: Iterable or array + metadata for each individual line associated with this collection, this is for the user to manage. + ``len(metadata)`` must be same as ``len(data)`` + + kwargs_lines: list[dict], optional + list of kwargs passed to the individual lines, ``len(kwargs_lines)`` must equal ``len(data)`` + + kwargs_collection + kwargs for the collection, passed to GraphicCollection + + + """ + return self._create_graphic( + ScatterCollection, + data, + colors, + cmap, + cmap_transform, + sizes, + uniform_size, + markers, + uniform_marker, + edge_width, + name, + names, + metadata, + metadatas, + **kwargs + ) + def add_scatter( self, data: Any, colors: Union[str, numpy.ndarray, Sequence[float], Sequence[str]] = "w", - uniform_color: bool = False, cmap: str = None, cmap_transform: numpy.ndarray = None, + color_mode: Literal["auto", "uniform", "vertex"] = "auto", mode: Literal["markers", "simple", "gaussian", "image"] = "markers", markers: Union[str, numpy.ndarray, Sequence[str]] = "o", - uniform_marker: bool = False, + uniform_marker: bool = True, custom_sdf: str = None, edge_colors: Union[ str, pygfx.utils.color.Color, numpy.ndarray, Sequence[float] @@ -589,11 +651,10 @@ def add_scatter( image: numpy.ndarray = None, point_rotations: float | numpy.ndarray = 0, point_rotation_mode: Literal["uniform", "vertex", "curve"] = "uniform", - sizes: Union[float, numpy.ndarray, Sequence[float]] = 1, - uniform_size: bool = False, + sizes: Union[float, numpy.ndarray, Sequence[float]] = 5, + uniform_size: bool = True, size_space: str = "screen", - isolated_buffer: bool = True, - **kwargs, + **kwargs ) -> ScatterGraphic: """ @@ -609,18 +670,23 @@ def add_scatter( specify colors as a single human-readable string, a single RGBA array, or a Sequence (array, tuple, or list) of strings or RGBA arrays - uniform_color: bool, default False - if True, uses a uniform buffer for the scatter point colors. Useful if you need to - save GPU VRAM when all points have the same color. - cmap: str, optional apply a colormap to the scatter instead of assigning colors manually, this - overrides any argument passed to "colors". For supported colormaps see the - ``cmap`` library catalogue: https://cmap-docs.readthedocs.io/en/stable/catalog/ + overrides any argument passed to "colors". + For supported colormaps see the ``cmap`` library catalogue: + https://cmap-docs.readthedocs.io/en/stable/catalog/ cmap_transform: 1D array-like or list of numerical values, optional if provided, these values are used to map the colors from the cmap + color_mode: one of "auto", "uniform", "vertex", default "auto" + "uniform" restricts to a single color for all line datapoints. + "vertex" allows independent colors per vertex. + For most cases you can keep it as "auto" and the `color_mode` is determineed automatically based on the + argument passed to `colors`. if `colors` represents a single color, then the mode is set to "uniform". + If `colors` represents a unique color per-datapoint, or if a cmap is provided, then `color_mode` is set to + "vertex". You can switch between "uniform" and "vertex" `color_mode` after creating the graphic. + mode: one of: "markers", "simple", "gaussian", "image", default "markers" The scatter points mode, cannot be changed after the graphic has been created. @@ -640,9 +706,10 @@ def add_scatter( * Emojis: "❤️♠️♣️♦️💎💍✳️📍". * A string containing the value "custom". In this case, WGSL code defined by ``custom_sdf`` will be used. - uniform_marker: bool, default False - Use the same marker for all points. Only valid when `mode` is "markers". Useful if you need to use - the same marker for all points and want to save GPU RAM. + uniform_marker: bool, default ``True`` + If ``True``, use the same marker for all points. Only valid when `mode` is "markers". + Useful if you need to use the same marker for all points and want to save GPU RAM. If ``False``, you can + set per-vertex markers. custom_sdf: str = None, The SDF code for the marker shape when the marker is set to custom. @@ -662,8 +729,9 @@ def add_scatter( edge_colors: str | np.ndarray | pygfx.Color | Sequence[float], default "black" edge color of the markers, used when `mode` is "markers" - uniform_edge_color: bool, default True - Set the same edge color for all markers. Useful for saving GPU RAM. + uniform_edge_color: bool, default ``True`` + Set the same edge color for all markers. Useful for saving GPU RAM. Set to ``False`` for per-vertex edge + colors edge_width: float = 1.0, Width of the marker edges. used when `mode` is "markers". @@ -684,17 +752,13 @@ def add_scatter( sizes: float or iterable of float, optional, default 1.0 sizes of the scatter points - uniform_size: bool, default False - if True, uses a uniform buffer for the scatter point sizes. Useful if you need to - save GPU VRAM when all points have the same size. + uniform_size: bool, default ``False`` + if ``True``, uses a uniform buffer for the scatter point sizes. Useful if you need to + save GPU VRAM when all points have the same size. Set to ``False`` if you need per-vertex sizes. size_space: str, default "screen" coordinate space in which the size is expressed, one of ("screen", "world", "model") - isolated_buffer: bool, default True - whether the buffers should be isolated from the user input array. - Generally always ``True``, ``False`` is for rare advanced use if you have large arrays. - kwargs passed to :class:`.Graphic` @@ -704,9 +768,9 @@ def add_scatter( ScatterGraphic, data, colors, - uniform_color, cmap, cmap_transform, + color_mode, mode, markers, uniform_marker, @@ -720,8 +784,92 @@ def add_scatter( sizes, uniform_size, size_space, - isolated_buffer, - **kwargs, + **kwargs + ) + + def add_scatter_stack( + self, + data: Union[numpy.ndarray, List[numpy.ndarray]], + colors: Union[str, Sequence[str], numpy.ndarray, Sequence[numpy.ndarray]] = "w", + cmap: Union[Sequence[str], str] = None, + cmap_transform: Union[numpy.ndarray, List] = None, + name: str = None, + names: list[str] = None, + metadata: Any = None, + metadatas: Union[Sequence[Any], numpy.ndarray] = None, + separation: float = 0.0, + separation_axis: str = "y", + **kwargs + ) -> ScatterStack: + """ + + Create a stack of :class:`.LineGraphic` that are separated along the "x" or "y" axis. + + Parameters + ---------- + data: list of array-like + List or array-like of multiple line data to plot + + | if ``list`` each item in the list must be a 1D, 2D, or 3D numpy array + | if array-like, must be of shape [n_lines, n_points_line, y | xy | xyz] + + thickness: float or Iterable of float, default 2.0 + | if ``float``, single thickness will be used for all lines + | if ``list`` of ``float``, each value will apply to the individual lines + + colors: str, RGBA array, Iterable of RGBA array, or Iterable of str, default "w" + | if single ``str`` such as "w", "r", "b", etc, represents a single color for all lines + | if single ``RGBA array`` (tuple or list of size 4), represents a single color for all lines + | if ``list`` of ``str``, represents color for each individual line, example ["w", "b", "r",...] + | if ``RGBA array`` of shape [data_size, 4], represents a single RGBA array for each line + + cmap: Iterable of str or str, optional + | if ``str``, single cmap will be used for all lines + | if ``list`` of ``str``, each cmap will apply to the individual lines + + .. note:: + ``cmap`` overrides any arguments passed to ``colors`` + + cmap_transform: 1D array-like of numerical values, optional + if provided, these values are used to map the colors from the cmap + + name: str, optional + name of the line collection as a whole + + names: list[str], optional + names of the individual lines in the collection, ``len(names)`` must equal ``len(data)`` + + metadata: Any + metadata associated with the collection as a whole + + metadatas: Iterable or array + metadata for each individual line associated with this collection, this is for the user to manage. + ``len(metadata)`` must be same as ``len(data)`` + + separation: float, default 0.0 + space in between each line graphic in the stack + + separation_axis: str, default "y" + axis in which the line graphics in the stack should be separated + + kwargs_collection + kwargs for the collection, passed to GraphicCollection + + + """ + return self._create_graphic( + ScatterStack, + data, + colors, + cmap, + cmap_transform, + name, + names, + metadata, + metadatas, + separation, + separation_axis, + **kwargs ) def add_surface( @@ -738,7 +886,7 @@ def add_surface( | numpy.ndarray ) = None, clim: tuple[float, float] | None = None, - **kwargs, + **kwargs ) -> SurfaceGraphic: """ @@ -792,7 +940,7 @@ def add_text( screen_space: bool = True, offset: tuple[float] = (0, 0, 0), anchor: str = "middle-center", - **kwargs, + **kwargs ) -> TextGraphic: """ @@ -843,7 +991,7 @@ def add_text( screen_space, offset, anchor, - **kwargs, + **kwargs ) def add_vectors( @@ -853,7 +1001,7 @@ def add_vectors( color: Union[str, Sequence[float], numpy.ndarray] = "w", size: float = None, vector_shape_options: dict = None, - **kwargs, + **kwargs ) -> VectorsGraphic: """ @@ -898,5 +1046,5 @@ def add_vectors( color, size, vector_shape_options, - **kwargs, + **kwargs ) diff --git a/fastplotlib/layouts/_imgui_figure.py b/fastplotlib/layouts/_imgui_figure.py index 33cc6d925..15b3d7c45 100644 --- a/fastplotlib/layouts/_imgui_figure.py +++ b/fastplotlib/layouts/_imgui_figure.py @@ -44,6 +44,7 @@ def __init__( canvas_kwargs: dict = None, size: tuple[int, int] = (500, 300), names: list | np.ndarray = None, + std_right_click_menu: type[Popup] = StandardRightClickMenu, ): self._guis: dict[str, EdgeWindow] = {k: None for k in GUI_EDGES} @@ -105,7 +106,7 @@ def __init__( toolbar = SubplotToolbar(subplot=subplot) self._subplot_toolbars[i] = toolbar - self._right_click_menu = StandardRightClickMenu(figure=self) + self._std_right_click_menu = std_right_click_menu(figure=self) self._popups: dict[str, Popup] = {} @@ -118,6 +119,10 @@ def __init__( def default_imgui_font(self) -> imgui.ImFont: return self._default_imgui_font + @property + def std_right_click_menu(self) -> Popup: + return self._std_right_click_menu + @property def guis(self) -> dict[str, EdgeWindow]: """GUI windows added to the Figure""" @@ -158,7 +163,7 @@ def _draw_imgui(self) -> imgui.ImDrawData: for popup in self._popups.values(): popup.update() - self._right_click_menu.update() + self._std_right_click_menu.update() # imgui.end_frame() diff --git a/fastplotlib/layouts/_plot_area.py b/fastplotlib/layouts/_plot_area.py index 5d38ce37d..f90cdcf87 100644 --- a/fastplotlib/layouts/_plot_area.py +++ b/fastplotlib/layouts/_plot_area.py @@ -10,7 +10,7 @@ from ._utils import create_controller from ..graphics._base import Graphic, WORLD_OBJECT_TO_GRAPHIC -from ..graphics import ImageGraphic +from ..graphics import ImageGraphic, MeshGraphic from ..graphics.selectors._base_selector import BaseSelector from ._graphic_methods_mixin import GraphicMethodsMixin from ..legends import Legend @@ -120,11 +120,8 @@ def __init__( self._background = pygfx.Background(None, self._background_material) self.scene.add(self._background) - self._ambient_light = pygfx.AmbientLight() - self._directional_light = pygfx.DirectionalLight() - - self.scene.add(self._ambient_light) - self.scene.add(self._camera.add(self._directional_light)) + self._ambient_light = None + self._directional_light = None self._tooltip = Tooltip() self.get_figure()._fpl_overlay_scene.add(self._tooltip._fpl_world_object) @@ -179,8 +176,9 @@ def camera(self, new_camera: str | pygfx.PerspectiveCamera): # user wants to set completely new camera, remove current camera from controller if isinstance(new_camera, pygfx.PerspectiveCamera): self.controller.remove_camera(self._camera) - # add directional light to new camera - new_camera.add(self._directional_light) + if self._directional_light is not None: + # add directional light to new camera + new_camera.add(self._directional_light) # add new camera to controller self.controller.add_camera(new_camera) @@ -233,7 +231,10 @@ def controller(self, new_controller: str | pygfx.Controller): # pygfx plans on refactoring viewports anyways if self.parent is not None: if self.parent.__class__.__name__.endswith("Figure"): - for subplot in self.parent: + # always use figure._subplots.ravel() in internal fastplotlib code + # otherwise if we use `for subplot in figure`, this could conflict + # with a user's iterator where they are doing `for subplot in figure` !!! + for subplot in self.parent._subplots.ravel(): if subplot.camera in cameras_list: new_controller.register_events(subplot.viewport) subplot._controller = new_controller @@ -290,12 +291,12 @@ def background_color(self, colors: str | tuple[float]): self._background_material.set_colors(*colors) @property - def ambient_light(self) -> pygfx.AmbientLight: + def ambient_light(self) -> pygfx.AmbientLight | None: """the ambient lighting in the scene""" return self._ambient_light @property - def directional_light(self) -> pygfx.DirectionalLight: + def directional_light(self) -> pygfx.DirectionalLight | None: """the directional lighting on the camera in the scene""" return self._directional_light @@ -628,6 +629,13 @@ def add_graphic(self, graphic: Graphic, center: bool = True): if isinstance(graphic, ImageGraphic): self._sort_images_by_depth() + if isinstance(graphic, MeshGraphic): + self._ambient_light = pygfx.AmbientLight() + self._directional_light = pygfx.DirectionalLight() + + self.scene.add(self._ambient_light) + self.scene.add(self._camera.add(self._directional_light)) + def insert_graphic( self, graphic: Graphic, @@ -857,6 +865,42 @@ def _auto_scale_scene( camera.zoom = zoom + @property + def x_range(self) -> tuple[float, float]: + """ + Get or set the x-range currently in view. + Only valid for orthographic projections of the xy plane. + Use camera.set_state() to set the camera position for arbitrary projections. + """ + hw = self.camera.width / 2 + x = self.camera.local.x + return x - hw, x + hw + + @x_range.setter + def x_range(self, xr: tuple[float, float]): + width = xr[1] - xr[0] + x_mid = (xr[0] + xr[1]) / 2 + self.camera.width = width + self.camera.local.x = x_mid + + @property + def y_range(self) -> tuple[float, float]: + """ + Get or set the y-range currently in view. + Only valid for orthographic projections of the xy plane. + Use camera.set_state() to set the camera position for arbitrary projections. + """ + hh = self.camera.height / 2 + y = self.camera.local.y + return y - hh, y + hh + + @y_range.setter + def y_range(self, yr: tuple[float, float]): + height = yr[1] - yr[0] + y_mid = yr[0] + (height / 2) + self.camera.height = height + self.camera.local.y = y_mid + def remove_graphic(self, graphic: Graphic): """ Remove a ``Graphic`` from the scene. Note: This does not garbage collect the graphic, diff --git a/fastplotlib/tools/_histogram_lut.py b/fastplotlib/tools/_histogram_lut.py index d651137da..8edfb046b 100644 --- a/fastplotlib/tools/_histogram_lut.py +++ b/fastplotlib/tools/_histogram_lut.py @@ -6,424 +6,412 @@ import pygfx -from ..utils import subsample_array +from ..utils import subsample_array, RenderQueue from ..graphics import LineGraphic, ImageGraphic, ImageVolumeGraphic, TextGraphic from ..graphics.utils import pause_events from ..graphics._base import Graphic +from ..graphics.features import GraphicFeatureEvent from ..graphics.selectors import LinearRegionSelector -def _get_image_graphic_events(image_graphic: ImageGraphic) -> list[str]: - """Small helper function to return the relevant events for an ImageGraphic""" - events = ["vmin", "vmax"] +def _format_value(value: float): + abs_val = abs(value) + if abs_val < 0.01 or abs_val > 9_999: + return f"{value:.2e}" + else: + return f"{value:.2f}" - if not image_graphic.data.value.ndim > 2: - events.append("cmap") - # if RGB(A), do not add cmap - - return events - - -# TODO: This is a widget, we can think about a BaseWidget class later if necessary class HistogramLUTTool(Graphic): _fpl_support_tooltip = False def __init__( self, - data: np.ndarray, - images: ( - ImageGraphic - | ImageVolumeGraphic - | Sequence[ImageGraphic | ImageVolumeGraphic] - ), - nbins: int = 100, - flank_divisor: float = 5.0, + histogram: tuple[np.ndarray, np.ndarray], + images: ImageGraphic | ImageVolumeGraphic | Sequence[ImageGraphic | ImageVolumeGraphic] | None = None, **kwargs, ): """ - HistogramLUT tool that can be used to control the vmin, vmax of ImageGraphics or ImageVolumeGraphics. - If used to control multiple images or image volumes it is assumed that they share a representation of - the same data, and that their histogram, vmin, and vmax are identical. For example, displaying a - ImageVolumeGraphic and several images that represent slices of the same volume data. + A histogram tool that allows adjusting the vmin, vmax of images. + Also allows changing the cmap LUT for grayscale images and displays a colorbar. Parameters ---------- - data: np.ndarray - - images: ImageGraphic | ImageVolumeGraphic | tuple[ImageGraphic | ImageVolumeGraphic] - - nbins: int, defaut 100. - Total number of bins used in the histogram + histogram: tuple[np.ndarray, np.ndarray] + [frequency, bin_edges], must be 100 bins - flank_divisor: float, default 5.0. - Fraction of empty histogram bins on the tails of the distribution set `np.inf` for no flanks + images: ImageGraphic | ImageVolumeGraphic | Sequence[ImageGraphic | ImageVolumeGraphic] + the images that are managed by the histogram tool - kwargs: passed to ``Graphic`` + kwargs: + passed to ``Graphic`` """ - super().__init__(**kwargs) - - self._nbins = nbins - self._flank_divisor = flank_divisor - - if isinstance(images, (ImageGraphic, ImageVolumeGraphic)): - images = (images,) - elif isinstance(images, Sequence): - if not all( - [isinstance(ig, (ImageGraphic, ImageVolumeGraphic)) for ig in images] - ): - raise TypeError( - f"`images` argument must be an ImageGraphic, ImageVolumeGraphic, or a " - f"tuple or list or ImageGraphic | ImageVolumeGraphic" - ) - else: - raise TypeError( - f"`images` argument must be an ImageGraphic, ImageVolumeGraphic, or a " - f"tuple or list or ImageGraphic | ImageVolumeGraphic" - ) - self._images = images + super().__init__(**kwargs) - self._data = weakref.proxy(data) + if len(histogram) != 2: + raise TypeError - self._scale_factor: float = 1.0 + self._block_reentrance = False + self._images = list() - hist, edges, hist_scaled, edges_flanked = self._calculate_histogram(data) + self._bin_centers_flanked = np.zeros(120, dtype=np.float64) + self._freq_flanked = np.zeros(120, dtype=np.float32) - line_data = np.column_stack([hist_scaled, edges_flanked]) + # 100 points for the histogram, 10 points on each side for the flank + line_data = np.column_stack( + [np.zeros(120, dtype=np.float32), np.arange(0, 120)] + ) - self._histogram_line = LineGraphic( - line_data, colors=(0.8, 0.8, 0.8), alpha_mode="solid", offset=(0, 0, -1) + # line that displays the histogram + self._line = LineGraphic( + line_data, colors=(0.8, 0.8, 0.8), alpha_mode="solid", offset=(1, 0, 0) + ) + self._line.world_object.local.scale_x = -1 + + # vmin, vmax selector + self._selector = LinearRegionSelector( + selection=(10, 110), + limits=(0, 119), + size=1.5, + center=0.5, # frequency data are normalized between 0-1 + axis="y", + parent=self._line, ) - bounds = (edges[0] * self._scale_factor, edges[-1] * self._scale_factor) - limits = (edges_flanked[0], edges_flanked[-1]) - size = 120 # since it's scaled to 100 - origin = (hist_scaled.max() / 2, 0) + self._selector.add_event_handler(self._selector_event_handler, "selection") - self._linear_region_selector = LinearRegionSelector( - selection=bounds, - limits=limits, - size=size, - center=origin[0], - axis="y", - parent=self._histogram_line, + self._colorbar = ImageGraphic( + data=np.zeros([120, 2]), interpolation="linear", offset=(1.5, 0, 0) ) - self._vmin = self.images[0].vmin - self._vmax = self.images[0].vmax + # make the colorbar thin + self._colorbar.world_object.local.scale_x = 0.15 + self._colorbar.add_event_handler(self._open_cmap_picker, "click") - # there will be a small difference with the histogram edges so this makes them both line up exactly - self._linear_region_selector.selection = ( - self._vmin * self._scale_factor, - self._vmax * self._scale_factor, + # colorbar ruler + self._ruler = pygfx.Ruler( + end_pos=(0, 119, 0), + alpha_mode="solid", + render_queue=RenderQueue.axes, + tick_side="right", + tick_marker="tick_right", + tick_format=self._ruler_tick_map, + min_tick_distance=10, ) + self._ruler.local.x = 1.75 - vmin_str, vmax_str = self._get_vmin_vmax_str() + # TODO: need to auto-scale using the text so it appears nicely, will do later + self._ruler.visible = False self._text_vmin = TextGraphic( - text=vmin_str, + text="", font_size=16, - offset=(0, 0, 0), anchor="top-left", outline_color="black", outline_thickness=0.5, alpha_mode="solid", ) - + # this is to make sure clicking text doesn't conflict with the selector tool + # since the text appears near the selector tool self._text_vmin.world_object.material.pick_write = False self._text_vmax = TextGraphic( - text=vmax_str, + text="", font_size=16, - offset=(0, 0, 0), anchor="bottom-left", outline_color="black", outline_thickness=0.5, alpha_mode="solid", ) - self._text_vmax.world_object.material.pick_write = False - widget_wo = pygfx.Group() - widget_wo.add( - self._histogram_line.world_object, - self._linear_region_selector.world_object, + # add all the world objects to a pygfx.Group + wo = pygfx.Group() + wo.add( + self._line.world_object, + self._selector.world_object, + self._colorbar.world_object, + self._ruler, self._text_vmin.world_object, self._text_vmax.world_object, ) + self._set_world_object(wo) - self._set_world_object(widget_wo) + # for convenience, a list that stores all the graphics managed by the histogram LUT tool + self._children = [ + self._line, + self._selector, + self._colorbar, + self._text_vmin, + self._text_vmax, + ] - self.world_object.local.scale_x *= -1 + # set histogram + self.histogram = histogram - self._text_vmin.offset = (-120, self._linear_region_selector.selection[0], 0) + # set the images + self.images = images - self._text_vmax.offset = (-120, self._linear_region_selector.selection[1], 0) + def _fpl_add_plot_area_hook(self, plot_area): + self._plot_area = plot_area - self._linear_region_selector.add_event_handler( - self._linear_region_handler, "selection" - ) + for child in self._children: + # need all of them to call the add_plot_area_hook so that events are connected correctly + # example, the linear region selector needs all the canvas events to be connected + child._fpl_add_plot_area_hook(plot_area) - ig_events = _get_image_graphic_events(self.images[0]) + if hasattr(self._plot_area, "size"): + # if it's in a dock area + self._plot_area.size = 80 - for ig in self.images: - ig.add_event_handler(self._image_cmap_handler, *ig_events) + # disable the controller in this plot area + self._plot_area.controller.enabled = False + self._plot_area.auto_scale(maintain_aspect=False) - # colorbar for grayscale images - if self.images[0].cmap is not None: - self._colorbar: ImageGraphic = self._make_colorbar(edges_flanked) - self._colorbar.add_event_handler(self._open_cmap_picker, "click") + # tick text for colorbar ruler doesn't show without this call + self._ruler.update(plot_area.camera, plot_area.canvas.get_logical_size()) - self.world_object.add(self._colorbar.world_object) - else: - self._colorbar = None - self._cmap = None + def _ruler_tick_map(self, bin_index, *args): + return f"{self._bin_centers_flanked[int(bin_index)]:.2f}" - def _make_colorbar(self, edges_flanked) -> ImageGraphic: - # use the histogram edge values as data for an - # image with 2 columns, this will be our colorbar! - colorbar_data = np.column_stack( - [ - np.linspace( - edges_flanked[0], edges_flanked[-1], ceil(np.ptp(edges_flanked)) - ) - ] - * 2 - ).astype(np.float32) - - colorbar_data /= self._scale_factor - - cbar = ImageGraphic( - data=colorbar_data, - vmin=self.vmin, - vmax=self.vmax, - cmap=self.images[0].cmap, - interpolation="linear", - offset=(-55, edges_flanked[0], -1), - ) + @property + def histogram(self) -> tuple[np.ndarray, np.ndarray]: + """histogram [frequency, bin_centers]. Frequency is flanked by 10 zeros on both sides""" + return self._freq_flanked, self._bin_centers_flanked - cbar.world_object.world.scale_x = 20 - self._cmap = self.images[0].cmap + @histogram.setter + def histogram( + self, histogram: tuple[np.ndarray, np.ndarray], limits: tuple[int, int] = None + ): + """set histogram with pre-compuated [frequency, edges], must have exactly 100 bins""" - return cbar + freq, edges = histogram - def _get_vmin_vmax_str(self) -> tuple[str, str]: - if self.vmin < 0.001 or self.vmin > 99_999: - vmin_str = f"{self.vmin:.2e}" - else: - vmin_str = f"{self.vmin:.2f}" + if freq.max() > 0: + # if the histogram is made from an empty array, then the max freq will be 0 + # we don't want to divide by 0 because then we just get nans + freq = freq / freq.max() - if self.vmax < 0.001 or self.vmax > 99_999: - vmax_str = f"{self.vmax:.2e}" - else: - vmax_str = f"{self.vmax:.2f}" + bin_centers = 0.5 * (edges[1:] + edges[:-1]) - return vmin_str, vmax_str + step = bin_centers[1] - bin_centers[0] - def _fpl_add_plot_area_hook(self, plot_area): - self._plot_area = plot_area - self._linear_region_selector._fpl_add_plot_area_hook(plot_area) - self._histogram_line._fpl_add_plot_area_hook(plot_area) + under_flank = np.linspace(bin_centers[0] - step * 10, bin_centers[0] - step, 10) + over_flank = np.linspace( + bin_centers[-1] + step, bin_centers[-1] + step * 10, 10 + ) + self._bin_centers_flanked[:] = np.concatenate( + [under_flank, bin_centers, over_flank] + ) + + self._freq_flanked[10:110] = freq - self._plot_area.auto_scale() - self._plot_area.controller.enabled = True + self._line.data[:, 0] = self._freq_flanked + self._colorbar.data = np.column_stack( + [self._bin_centers_flanked, self._bin_centers_flanked] + ) - def _calculate_histogram(self, data): + # self.vmin, self.vmax = bin_centers[0], bin_centers[-1] - # get a subsampled view of this array - data_ss = subsample_array(data, max_size=int(1e6)) # 1e6 is default - hist, edges = np.histogram(data_ss, bins=self._nbins) + if hasattr(self, "plot_area"): + self._ruler.update( + self._plot_area.camera, self._plot_area.canvas.get_logical_size() + ) - # used if data ptp <= 10 because event things get weird - # with tiny world objects due to floating point error - # so if ptp <= 10, scale up by a factor - data_interval = edges[-1] - edges[0] - self._scale_factor: int = max(1, 100 * int(10 / data_interval)) + @property + def images(self) -> tuple[ImageGraphic | ImageVolumeGraphic, ...] | None: + """get or set the managed images""" + return tuple(self._images) - edges = edges * self._scale_factor + @images.setter + def images(self, new_images: ImageGraphic | ImageVolumeGraphic | Sequence[ImageGraphic | ImageVolumeGraphic] | None): + self._disconnect_images() + self._images.clear() - bin_width = edges[1] - edges[0] + if new_images is None: + return - flank_nbins = int(self._nbins / self._flank_divisor) - flank_size = flank_nbins * bin_width + if isinstance(new_images, (ImageGraphic, ImageVolumeGraphic)): + new_images = [new_images] - flank_left = np.arange(edges[0] - flank_size, edges[0], bin_width) - flank_right = np.arange( - edges[-1] + bin_width, edges[-1] + flank_size, bin_width - ) + if not all( + [ + isinstance(image, (ImageGraphic, ImageVolumeGraphic)) + for image in new_images + ] + ): + raise TypeError - edges_flanked = np.concatenate((flank_left, edges, flank_right)) + for image in new_images: + if image.cmap is not None: + self._colorbar.visible = True + break + else: + self._colorbar.visible = False - hist_flanked = np.concatenate( - (np.zeros(flank_nbins), hist, np.zeros(flank_nbins)) - ) + self._images = list(new_images) - # scale 0-100 to make it easier to see - # float32 data can produce unnecessarily high values - hist_scale_value = hist_flanked.max() - if np.allclose(hist_scale_value, 0): - hist_scale_value = 1 - hist_scaled = hist_flanked / (hist_scale_value / 100) + # reset vmin, vmax using first image + self.vmin = self._images[0].vmin + self.vmax = self._images[0].vmax - if edges_flanked.size > hist_scaled.size: - # we don't care about accuracy here so if it's off by 1-2 bins that's fine - edges_flanked = edges_flanked[: hist_scaled.size] + if self._images[0].cmap is not None: + self._colorbar.cmap = self._images[0].cmap - return hist, edges, hist_scaled, edges_flanked + # connect event handlers + for image in self._images: + image.add_event_handler(self._image_event_handler, "vmin", "vmax") + image.add_event_handler(self._disconnect_images, "deleted") + if image.cmap is not None: + image.add_event_handler( + self._image_event_handler, "vmin", "vmax", "cmap" + ) - def _linear_region_handler(self, ev): - # must use world coordinate values directly from selection() - # otherwise the linear region bounds jump to the closest bin edges - selected_ixs = self._linear_region_selector.selection - vmin, vmax = selected_ixs[0], selected_ixs[1] - vmin, vmax = vmin / self._scale_factor, vmax / self._scale_factor - self.vmin, self.vmax = vmin, vmax + def _disconnect_images(self, *args): + """disconnect event handlers of the managed images""" + for image in self._images: + for ev, handlers in image.event_handlers: + if self._image_event_handler in handlers: + image.remove_event_handler(self._image_event_handler, ev) - def _image_cmap_handler(self, ev): - setattr(self, ev.type, ev.info["value"]) + def _image_event_handler(self, ev): + """when the image vmin, vmax, or cmap changes it will update the HistogramLUTTool""" + new_value = ev.info["value"] + setattr(self, ev.type, new_value) @property def cmap(self) -> str: - return self._cmap + """get or set the colormap, only for grayscale images""" + return self._colorbar.cmap @cmap.setter def cmap(self, name: str): - if self._colorbar is None: + if self._block_reentrance: return - with pause_events(*self.images): - for ig in self.images: - ig.cmap = name + if name is None: + return - self._cmap = name + self._block_reentrance = True + try: self._colorbar.cmap = name + with pause_events( + *self._images, event_handlers=[self._image_event_handler] + ): + for image in self._images: + if image.cmap is None: + # rgb(a) images have no cmap + continue + + image.cmap = name + except Exception as exc: + # raise original exception + raise exc # vmax setter has raised. The lines above below are probably more relevant! + finally: + # set_value has finished executing, now allow future executions + self._block_reentrance = False + @property def vmin(self) -> float: - return self._vmin + """get or set the vmin, the lower contrast limit""" + # no offset or rotation so we can directly use the world space selection value + index = int(self._selector.selection[0]) + return self._bin_centers_flanked[index] @vmin.setter def vmin(self, value: float): - with pause_events(self._linear_region_selector, *self.images): - # must use world coordinate values directly from selection() - # otherwise the linear region bounds jump to the closest bin edges - self._linear_region_selector.selection = ( - value * self._scale_factor, - self._linear_region_selector.selection[1], - ) - for ig in self.images: - ig.vmin = value + if self._block_reentrance: + return + self._block_reentrance = True + try: + index_min = np.searchsorted(self._bin_centers_flanked, value) + with pause_events( + self._selector, + *self._images, + event_handlers=[ + self._selector_event_handler, + self._image_event_handler, + ], + ): + self._selector.selection = (index_min, self._selector.selection[1]) - self._vmin = value - if self._colorbar is not None: - self._colorbar.vmin = value + self._colorbar.vmin = value - vmin_str, vmax_str = self._get_vmin_vmax_str() - self._text_vmin.offset = (-120, self._linear_region_selector.selection[0], 0) - self._text_vmin.text = vmin_str + self._text_vmin.text = _format_value(value) + self._text_vmin.offset = (-0.45, self._selector.selection[0], 0) + + for image in self._images: + image.vmin = value + + except Exception as exc: + # raise original exception + raise exc # vmax setter has raised. The lines above below are probably more relevant! + finally: + # set_value has finished executing, now allow future executions + self._block_reentrance = False @property def vmax(self) -> float: - return self._vmax + """get or set the vmax, the upper contrast limit""" + # no offset or rotation so we can directly use the world space selection value + index = int(self._selector.selection[1]) + return self._bin_centers_flanked[index] @vmax.setter def vmax(self, value: float): - with pause_events(self._linear_region_selector, *self.images): - # must use world coordinate values directly from selection() - # otherwise the linear region bounds jump to the closest bin edges - self._linear_region_selector.selection = ( - self._linear_region_selector.selection[0], - value * self._scale_factor, - ) - - for ig in self.images: - ig.vmax = value - - self._vmax = value - if self._colorbar is not None: - self._colorbar.vmax = value - - vmin_str, vmax_str = self._get_vmin_vmax_str() - self._text_vmax.offset = (-120, self._linear_region_selector.selection[1], 0) - self._text_vmax.text = vmax_str - - def set_data(self, data, reset_vmin_vmax: bool = True): - hist, edges, hist_scaled, edges_flanked = self._calculate_histogram(data) - - line_data = np.column_stack([hist_scaled, edges_flanked]) - - # set x and y vals - self._histogram_line.data[:, :2] = line_data - - bounds = (edges[0], edges[-1]) - limits = (edges_flanked[0], edges_flanked[-11]) - origin = (hist_scaled.max() / 2, 0) - - if reset_vmin_vmax: - # reset according to the new data - self._linear_region_selector.limits = limits - self._linear_region_selector.selection = bounds - else: - with pause_events(self._linear_region_selector, *self.images): - # don't change the current selection - self._linear_region_selector.limits = limits - - self._data = weakref.proxy(data) - - if self._colorbar is not None: - self._colorbar.clear_event_handlers() - self.world_object.remove(self._colorbar.world_object) - - if self.images[0].cmap is not None: - self._colorbar: ImageGraphic = self._make_colorbar(edges_flanked) - self._colorbar.add_event_handler(self._open_cmap_picker, "click") + if self._block_reentrance: + return - self.world_object.add(self._colorbar.world_object) - else: - self._colorbar = None - self._cmap = None + self._block_reentrance = True + try: + index_max = np.searchsorted(self._bin_centers_flanked, value) + with pause_events( + self._selector, + *self._images, + event_handlers=[ + self._selector_event_handler, + self._image_event_handler, + ], + ): + self._selector.selection = (self._selector.selection[0], index_max) - # reset plotarea dims - self._plot_area.auto_scale() + self._colorbar.vmax = value - @property - def images(self) -> tuple[ImageGraphic | ImageVolumeGraphic]: - return self._images + self._text_vmax.text = _format_value(value) + self._text_vmax.offset = (-0.45, self._selector.selection[1], 0) - @images.setter - def images(self, images): - if isinstance(images, (ImageGraphic, ImageVolumeGraphic)): - images = (images,) - elif isinstance(images, Sequence): - if not all( - [isinstance(ig, (ImageGraphic, ImageVolumeGraphic)) for ig in images] - ): - raise TypeError( - f"`images` argument must be an ImageGraphic, ImageVolumeGraphic, or a " - f"tuple or list or ImageGraphic | ImageVolumeGraphic" - ) - else: - raise TypeError( - f"`images` argument must be an ImageGraphic, ImageVolumeGraphic, or a " - f"tuple or list or ImageGraphic | ImageVolumeGraphic" - ) + for image in self._images: + image.vmax = value - if self._images is not None: - for ig in self._images: - # cleanup events from current image graphics - ig_events = _get_image_graphic_events(ig) - ig.remove_event_handler(self._image_cmap_handler, *ig_events) + except Exception as exc: + # raise original exception + raise exc # vmax setter has raised. The lines above below are probably more relevant! + finally: + # set_value has finished executing, now allow future executions + self._block_reentrance = False - self._images = images + def _selector_event_handler(self, ev: GraphicFeatureEvent): + """when the selector's selctor has changed, it will update the vmin, vmax, or both""" + selection = ev.info["value"] + index_min = int(selection[0]) + vmin = self._bin_centers_flanked[index_min] - ig_events = _get_image_graphic_events(self._images[0]) + index_max = int(selection[1]) + vmax = self._bin_centers_flanked[index_max] - for ig in self.images: - ig.add_event_handler(self._image_cmap_handler, *ig_events) + match ev.info["change"]: + case "min": + self.vmin = vmin + case "max": + self.vmax = vmax + case _: + self.vmin, self.vmax = vmin, vmax def _open_cmap_picker(self, ev): + """open imgui cmap picker""" # check if right click if ev.button != 2: return @@ -433,7 +421,11 @@ def _open_cmap_picker(self, ev): self._plot_area.get_figure().open_popup("colormap-picker", pos, lut_tool=self) def _fpl_prepare_del(self): - self._linear_region_selector._fpl_prepare_del() - self._histogram_line._fpl_prepare_del() - del self._histogram_line - del self._linear_region_selector + """cleanup, need to disconnect events and remove image references for proper garbage collection""" + self._disconnect_images() + self._images.clear() + + for i in range(len(self._children)): + g = self._children.pop(0) + g._fpl_prepare_del() + del g diff --git a/fastplotlib/ui/_base.py b/fastplotlib/ui/_base.py index 355edc46d..eefac02f7 100644 --- a/fastplotlib/ui/_base.py +++ b/fastplotlib/ui/_base.py @@ -44,7 +44,7 @@ def __init__( location: Literal["bottom", "right", "top"], title: str, window_flags: enum.IntFlag = imgui.WindowFlags_.no_collapse - | imgui.WindowFlags_.no_resize, + | imgui.WindowFlags_.no_resize | imgui.WindowFlags_.no_title_bar, *args, **kwargs, ): @@ -111,6 +111,15 @@ def __init__( self._title = title self._window_flags = window_flags + self._resize_cursor_set = False + self._resize_blocked = False + self._right_gui_resizing = False + + self._separator_thickness = 14.0 + + self._collapsed = False + self._old_size = self.size + self._x, self._y, self._width, self._height = self.get_rect() self._figure.canvas.add_event_handler(self._set_rect, "resize") @@ -123,8 +132,9 @@ def size(self) -> int | None: @size.setter def size(self, value): if not isinstance(value, int): - raise TypeError + raise TypeError(f"{self.__class__.__name__}.size must be an ") self._size = value + self._set_rect() @property def location(self) -> str: @@ -153,6 +163,7 @@ def height(self) -> int: def _set_rect(self, *args): self._x, self._y, self._width, self._height = self.get_rect() + self._figure._fpl_reset_layout() def get_rect(self) -> tuple[int, int, int, int]: """ @@ -192,25 +203,203 @@ def get_rect(self) -> tuple[int, int, int, int]: return x_pos, y_pos, width, height + def _draw_resize_handle(self): + if self._location == "bottom": + imgui.set_cursor_pos((0, 0)) + imgui.invisible_button("##resize_handle", imgui.ImVec2(imgui.get_window_width(), self._separator_thickness)) + + hovered = imgui.is_item_hovered() + active = imgui.is_item_active() + + # Get the actual screen rect of the button after it's been laid out + rect_min = imgui.get_item_rect_min() + rect_max = imgui.get_item_rect_max() + + elif self._location == "right": + imgui.set_cursor_pos((0, 0)) + screen_pos = imgui.get_cursor_screen_pos() + win_height = imgui.get_window_height() + mouse_pos = imgui.get_mouse_pos() + + rect_min = imgui.ImVec2(screen_pos.x, screen_pos.y) + rect_max = imgui.ImVec2(screen_pos.x + self._separator_thickness, screen_pos.y + win_height) + + hovered = ( + rect_min.x <= mouse_pos.x <= rect_max.x + and rect_min.y <= mouse_pos.y <= rect_max.y + ) + + if hovered and imgui.is_mouse_clicked(0): + self._right_gui_resizing = True + + if not imgui.is_mouse_down(0): + self._right_gui_resizing = False + + active = self._right_gui_resizing + + imgui.set_cursor_pos((self._separator_thickness, 0)) + + if hovered and imgui.is_mouse_double_clicked(0): + if not self._collapsed: + self._old_size = self.size + if self._location == "bottom": + self.size = int(self._separator_thickness) + elif self._location == "right": + self.size = int(self._separator_thickness) + self._collapsed = True + else: + self.size = self._old_size + self._collapsed = False + + if hovered or active: + if not self._resize_cursor_set: + if self._location == "bottom": + self._figure.canvas.set_cursor("ns_resize") + + elif self._location == "right": + self._figure.canvas.set_cursor("ew_resize") + + self._resize_cursor_set = True + imgui.set_tooltip("Drag to resize, double click to expand/collapse") + + elif self._resize_cursor_set: + self._figure.canvas.set_cursor("default") + self._resize_cursor_set = False + + if active and imgui.is_mouse_dragging(0): + if self._location == "bottom": + delta = imgui.get_mouse_drag_delta(0).y + + elif self._location == "right": + delta = imgui.get_mouse_drag_delta(0).x + + imgui.reset_mouse_drag_delta(0) + px, py, pw, ph = self._figure.get_pygfx_render_area() + + if self._location == "bottom": + new_render_size = ph + delta + elif self._location == "right": + new_render_size = pw + delta + + # check if the new size would make the pygfx render area too small + if (delta < 0) and (new_render_size < 150): + print("not enough render area") + self._resize_blocked = True + + if self._resize_blocked: + # check if cursor has returned + if self._location == "bottom": + _min, pos, _max = rect_min.y, imgui.get_mouse_pos().y, rect_max.y + + elif self._location == "right": + _min, pos, _max = rect_min.x, imgui.get_mouse_pos().x, rect_max.x + + if ((_min - 5) <= pos <= (_max + 5)) and delta > 0: + # if the mouse cursor is back on the bar and the delta > 0, i.e. render area increasing + self._resize_blocked = False + + if not self._resize_blocked: + self.size = max(30, round(self.size - delta)) + self._collapsed = False + + draw_list = imgui.get_window_draw_list() + + line_color = ( + imgui.get_color_u32(imgui.ImVec4(0.9, 0.9, 0.9, 1.0)) + if (hovered or active) + else imgui.get_color_u32(imgui.ImVec4(0.5, 0.5, 0.5, 0.8)) + ) + bg_color = ( + imgui.get_color_u32(imgui.ImVec4(0.2, 0.2, 0.2, 0.8)) + if (hovered or active) + else imgui.get_color_u32(imgui.ImVec4(0.15, 0.15, 0.15, 0.6)) + ) + + # Background bar + draw_list.add_rect_filled( + imgui.ImVec2(rect_min.x, rect_min.y), + imgui.ImVec2(rect_max.x, rect_max.y), + bg_color, + ) + + # Three grip dots centered on the line + dot_spacing = 7.0 + dot_radius = 2 + if self._location == "bottom": + mid_y = (rect_min.y + rect_max.y) * 0.5 + center_x = (rect_min.x + rect_max.x) * 0.5 + for i in (-1, 0, 1): + cx = center_x + i * dot_spacing + draw_list.add_circle_filled(imgui.ImVec2(cx, mid_y), dot_radius, line_color) + + imgui.set_cursor_pos((0, imgui.get_cursor_pos_y() - imgui.get_style().item_spacing.y)) + + elif self._location == "right": + mid_x = (rect_min.x + rect_max.x) * 0.5 + center_y = (rect_min.y + rect_max.y) * 0.5 + for i in (-1, 0, 1): + cy = center_y + i * dot_spacing + draw_list.add_circle_filled( + imgui.ImVec2(mid_x, cy), dot_radius, line_color + ) + + def _draw_title(self, title: str): + padding = imgui.ImVec2(10, 4) + text_size = imgui.calc_text_size(title) + win_width = imgui.get_window_width() + box_size = imgui.ImVec2(win_width, text_size.y + padding.y * 2) + + box_screen_pos = imgui.get_cursor_screen_pos() + + draw_list = imgui.get_window_draw_list() + + # Background — use imgui's default title bar color + draw_list.add_rect_filled( + imgui.ImVec2(box_screen_pos.x, box_screen_pos.y), + imgui.ImVec2(box_screen_pos.x + box_size.x, box_screen_pos.y + box_size.y), + imgui.get_color_u32(imgui.Col_.title_bg_active), + ) + + # Centered text + text_pos = imgui.ImVec2( + box_screen_pos.x + (win_width - text_size.x) * 0.5, + box_screen_pos.y + padding.y, + ) + draw_list.add_text( + text_pos, imgui.get_color_u32(imgui.ImVec4(1, 1, 1, 1)), title + ) + + imgui.dummy(imgui.ImVec2(win_width, box_size.y)) + def draw_window(self): """helps simplify using imgui by managing window creation & position, and pushing/popping the ID""" # window position & size x, y, w, h = self.get_rect() imgui.set_next_window_size((self.width, self.height)) imgui.set_next_window_pos((self.x, self.y)) - # imgui.set_next_window_pos((x, y)) - # imgui.set_next_window_size((w, h)) flags = self._window_flags # begin window imgui.begin(self._title, p_open=None, flags=flags) + self._draw_resize_handle() + # push ID to prevent conflict between multiple figs with same UI imgui.push_id(self._id_counter) + # collapse the UI if the separator state is collapsed + # otherwise the UI renders partially on the separator for "right" guis and it looks weird + main_height = 1.0 if self._collapsed else 0.0 + imgui.begin_child("##main_ui", imgui.ImVec2(0, main_height)) + + self._draw_title(self._title) + + imgui.indent(6.0) # draw stuff from subclass into window self.update() + imgui.end_child() + # pop ID imgui.pop_id() diff --git a/fastplotlib/ui/right_click_menus/_colormap_picker.py b/fastplotlib/ui/right_click_menus/_colormap_picker.py index a80e5b2aa..9df26dcdc 100644 --- a/fastplotlib/ui/right_click_menus/_colormap_picker.py +++ b/fastplotlib/ui/right_click_menus/_colormap_picker.py @@ -154,7 +154,8 @@ def update(self): self._texture_height = (imgui.get_font_size()) - 2 if imgui.menu_item("Reset vmin-vmax", "", False)[0]: - self._lut_tool.images[0].reset_vmin_vmax() + for image in self._lut_tool.images: + image.reset_vmin_vmax() # add all the cmap options for cmap_type in COLORMAP_NAMES.keys(): diff --git a/fastplotlib/ui/right_click_menus/_standard_menu.py b/fastplotlib/ui/right_click_menus/_standard_menu.py index bb9e5bdef..9c659f4a7 100644 --- a/fastplotlib/ui/right_click_menus/_standard_menu.py +++ b/fastplotlib/ui/right_click_menus/_standard_menu.py @@ -31,6 +31,8 @@ def __init__(self, figure): # whether the right click menu is currently open or not self.is_open: bool = False + self._controller_window_open: bool | PlotArea = False + def get_subplot(self) -> PlotArea | bool | None: """get the subplot that a click occurred in""" if self._last_right_click_pos is None: @@ -47,6 +49,10 @@ def cleanup(self): """called when the popup disappears""" self.is_open = False + def _extra_menu(self): + # extra menu items, optional, implement in subclass + pass + def update(self): if imgui.is_mouse_down(1) and not self._mouse_down: # mouse button was pressed down, store this position @@ -147,39 +153,55 @@ def update(self): imgui.separator() # controller options - if imgui.begin_menu("Controller"): - _, enabled = imgui.menu_item( - "Enabled", "", self.get_subplot().controller.enabled - ) + if imgui.menu_item("Controller Options", "", False)[0]: + self._controller_window_open = self.get_subplot() - self.get_subplot().controller.enabled = enabled + self._extra_menu() - changed, damping = imgui.slider_float( - "Damping", - v=self.get_subplot().controller.damping, - v_min=0.0, - v_max=10.0, - ) - - if changed: - self.get_subplot().controller.damping = damping - - imgui.separator() - imgui.text("Controller type:") - # switching between different controllers - for name, controller_type_iter in controller_types.items(): - current_type = type(self.get_subplot().controller) + imgui.end_popup() - clicked, _ = imgui.menu_item( - label=name, - shortcut="", - p_selected=current_type is controller_type_iter, - ) + if self._controller_window_open: + self._draw_controller_window() + + def _draw_controller_window(self): + subplot = self._controller_window_open + + imgui.set_next_window_size((0, 0)) + _, keep_open = imgui.begin(f"Controller", True) + imgui.text(f"subplot: {subplot.name}") + _, enabled = imgui.menu_item( + "Enabled", "", subplot.controller.enabled + ) + + subplot.controller.enabled = enabled + + changed, damping = imgui.slider_float( + "Damping", + v=subplot.controller.damping, + v_min=0.0, + v_max=10.0, + ) + + if changed: + subplot.controller.damping = damping + + imgui.separator() + imgui.text("Controller type:") + # switching between different controllers + for name, controller_type_iter in controller_types.items(): + current_type = type(subplot.controller) + + clicked, _ = imgui.menu_item( + label=name, + shortcut="", + p_selected=current_type is controller_type_iter, + ) - if clicked and (current_type is not controller_type_iter): - # menu item was clicked and the desired controller isn't the current one - self.get_subplot().controller = name + if clicked and (current_type is not controller_type_iter): + # menu item was clicked and the desired controller isn't the current one + subplot.controller = name - imgui.end_menu() + if not keep_open: + self._controller_window_open = False - imgui.end_popup() + imgui.end() diff --git a/fastplotlib/utils/__init__.py b/fastplotlib/utils/__init__.py index dd527ca67..6f0059f6a 100644 --- a/fastplotlib/utils/__init__.py +++ b/fastplotlib/utils/__init__.py @@ -6,6 +6,7 @@ from .gpu import enumerate_adapters, select_adapter, print_wgpu_report from ._plot_helpers import * from .enums import * +from ._protocols import ArrayProtocol, ARRAY_LIKE_ATTRS @dataclass diff --git a/fastplotlib/utils/_protocols.py b/fastplotlib/utils/_protocols.py new file mode 100644 index 000000000..95d7d2763 --- /dev/null +++ b/fastplotlib/utils/_protocols.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + + +ARRAY_LIKE_ATTRS = [ + "__array__", + "__array_ufunc__", + "dtype", + "shape", + "ndim", + "__getitem__", +] + + +@runtime_checkable +class ArrayProtocol(Protocol): + def __array__(self) -> ArrayProtocol: ... + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): ... + + def __array_function__(self, func, types, *args, **kwargs): ... + + @property + def dtype(self) -> Any: ... + + @property + def ndim(self) -> int: ... + + @property + def shape(self) -> tuple[int, ...]: ... + + def __getitem__(self, key): ... diff --git a/fastplotlib/utils/functions.py b/fastplotlib/utils/functions.py index a839ed9d0..34062824a 100644 --- a/fastplotlib/utils/functions.py +++ b/fastplotlib/utils/functions.py @@ -477,3 +477,38 @@ def subsample_array( slices = tuple(slices) return np.asarray(arr[slices]) + + +def heatmap_to_positions(heatmap: np.ndarray, xvals: np.ndarray) -> np.ndarray: + """ + + Convert a heatmap of shape [n_rows, n_datapoints] to timeseries x-y data of shape [n_rows, n_datapoints, xy] + + Parameters + ---------- + heatmap: np.ndarray, shape [n_rows, n_datapoints] + timeseries data with a heatmap representation, where each column represents a timepoint. + + xvals: np.ndarray, shape: [n_datapoints,] + x-values for the columns in the heatmap + + Returns + ------- + np.ndarray, shape [n_rows, n_datapoints, 2] + timeseries data where the xy data are explicitly stored for every row + + """ + if heatmap.ndim != 2: + raise ValueError + + if xvals.ndim != 1: + raise ValueError + + if xvals.size != heatmap.shape[1]: + raise ValueError + + ts = np.empty((*heatmap.shape, 2), dtype=np.float32) + ts[..., 0] = xvals + ts[..., 1] = heatmap + + return ts diff --git a/fastplotlib/widgets/__init__.py b/fastplotlib/widgets/__init__.py index 766620ea6..4347f6c80 100644 --- a/fastplotlib/widgets/__init__.py +++ b/fastplotlib/widgets/__init__.py @@ -1,3 +1,12 @@ +from .nd_widget import ( + NDWidget, + NDProcessor, + NDGraphic, + NDPositionsProcessor, + NDPositions, + NDImageProcessor, + NDImage, +) from .image_widget import ImageWidget -__all__ = ["ImageWidget"] +__all__ = ["NDWidget", "ImageWidget"] diff --git a/fastplotlib/widgets/image_widget/_widget.py b/fastplotlib/widgets/image_widget/_widget.py index 86a01b083..6d262678d 100644 --- a/fastplotlib/widgets/image_widget/_widget.py +++ b/fastplotlib/widgets/image_widget/_widget.py @@ -358,6 +358,11 @@ def __init__( passed to each ImageGraphic in the ImageWidget figure subplots """ + warn( + "`ImageWidget` is deprecated and will be removed in a" + " future release, please migrate to NDWidget", + DeprecationWarning + ) self._initialized = False if figure_kwargs is None: diff --git a/fastplotlib/widgets/nd_widget/__init__.py b/fastplotlib/widgets/nd_widget/__init__.py new file mode 100644 index 000000000..378f7dfcd --- /dev/null +++ b/fastplotlib/widgets/nd_widget/__init__.py @@ -0,0 +1,24 @@ +from ...layouts import IMGUI + +try: + import imgui_bundle +except ImportError: + HAS_XARRAY = False +else: + HAS_XARRAY = True + + +if IMGUI and HAS_XARRAY: + from ._base import NDProcessor, NDGraphic + from ._nd_positions import NDPositions, NDPositionsProcessor, ndp_extras + from ._nd_image import NDImageProcessor, NDImage + from ._ndwidget import NDWidget + +else: + + class NDWidget: + def __init__(self, *args, **kwargs): + raise ModuleNotFoundError( + "NDWidget requires `imgui-bundle` and `xarray` to be installed.\n" + "pip install imgui-bundle" + ) diff --git a/fastplotlib/widgets/nd_widget/_base.py b/fastplotlib/widgets/nd_widget/_base.py new file mode 100644 index 000000000..932018f6e --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_base.py @@ -0,0 +1,738 @@ +from collections.abc import Callable, Hashable, Sequence +from contextlib import contextmanager +import inspect +from numbers import Real +from pprint import pformat +import textwrap +from typing import Literal, Any, Type +from warnings import warn + +import xarray as xr +import numpy as np +from numpy.typing import ArrayLike + +from ...layouts import Subplot +from ...utils import subsample_array, ArrayProtocol +from ...graphics import Graphic +from ._index import ReferenceIndex + +# must take arguments: array-like, `axis`: int, `keepdims`: bool +WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] + + +def identity(index: int) -> int: + return round(index) + + +class NDProcessor: + def __init__( + self, + data: Any, + dims: Sequence[Hashable], + spatial_dims: Sequence[Hashable] | None, + slider_dim_transforms: dict[Hashable, Callable[[Any], int] | ArrayLike] = None, + window_funcs: dict[ + Hashable, tuple[WindowFuncCallable | None, int | float | None] + ] = None, + window_order: tuple[Hashable, ...] = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, + ): + """ + Base class for managing n-dimensional data and producing array slices. + + By default, wraps input data into an ``xarray.DataArray`` and provides an interface + for indexing slider dimensions, applying window functions, spatial functions, and mapping + reference-space values to local array indices. Subclasses must implement + :meth:`get`, which is called whenever the :class:`ReferenceIndex` updates. + + Subclasses can implement any type of data representation, they do not necessarily need to be compatible with + (they dot not have to be xarray compatible). However their ``get()`` method must still return a data slice that + corresponds to the graphical representation they map to. + + Every dimension that is *not* listed in ``spatial_dims`` becomes a slider + dimension. Each slider dim must have a ``ReferenceRange`` defined in the + ``ReferenceIndex`` of the parent ``NDWidget``. The widget uses this to direct + a change in the ``ReferenceIndex`` and update the graphics. + + Parameters + ---------- + data: Any + data object that is managed, usually uses the ArrayProtocol. Custom subclasses can manage any kind of data + object but the corresponding :meth:`get` must return an array-like that maps to a graphical representation. + + dims: Sequence[str] + names for each dimension in ``data``. Dimensions not listed in + ``spatial_dims`` are treated as slider dimensions and **must** appear as + keys in the parent ``NDWidget``'s ``ref_ranges`` + Examples:: + ``("time", "depth", "row", "col")`` + ``("channels", "time", "xy")`` + ``("keypoints", "time", "xyz")`` + + A custom subclass's ``data`` object doesn't necessarily need to have these dims, but the ``get()`` method + must operate as if these dimensions exist and return an array that matches the spatial dimensions. + + spatial_dims: Sequence[str] + Subset of ``dims`` that are spatial (rendered) dimensions **in order**. All remaining dims are treated as + slider dims. See subclass for specific info. + + slider_dim_transforms: dict mapping dim_name -> Callable, an ArrayLike, or None + Per-slider-dim mapping from reference-space values to local array indices. + + You may also provide an array of reference values for the slider dims, ``searchsorted`` is then used + as the transform (ex: a timestamps array). + + If ``None`` and identity mapping is used, i.e. rounds the current reference index value to the nearest + integer for array indexing. + + If a transform is not provided for a dim then the identity mapping is used. + + window_funcs: dict[ + Hashable, tuple[WindowFuncCallable | None, int | float | None] + ] + Per-slider-dim window functions applied around the current slider position. Ex: {"time": (np.mean, 2.5)}. + Each value is a ``(func, window_size)`` pair where: + + * *func* must accept ``axis: int`` and ``keepdims: bool`` kwargs + (ex: ``np.mean``, ``np.max``). The window function **must** return an array that has the same dimensions + as specified in the NDProcessor, therefore the size of any dim along which a window_func was applied + should reduce to ``1``. These dims must not be removed by the window_func. + + * *window_size* is in reference-space units (ex: 2.5 seconds). + + + window_order: tuple[Hashable, ...] + Order in which window functions are applied across dims. Only dims listed + here have their window function applied. window_funcs are ignored for any + dims not specified in ``window_order`` + + spatial_func: + A function applied to the spatial slice *after* window_funcs right before rendering. + + """ + self._dims = tuple(dims) + self._data = self._validate_data(data) + self.spatial_dims = spatial_dims + + self.slider_dim_transforms = slider_dim_transforms + + self.window_funcs = window_funcs + self.window_order = window_order + self.spatial_func = spatial_func + + @property + def data(self) -> xr.DataArray: + """ + get or set managed data. If setting with new data, the new data is interpreted + to have the same dims (i.e. same dim names and ordering of dims). + """ + return self._data + + @data.setter + def data(self, data: ArrayProtocol): + self._data = self._validate_data(data) + + def _validate_data(self, data: ArrayProtocol): + # does some basic validation + if data is None: + # we allow data to be None, in this case no ndgraphic is rendered + # useful when we want to initialize an NDWidget with no traces for example + # and populate it as components/channels are selected + return None + + if not isinstance(data, ArrayProtocol): + # This is required for xarray compatibility and general array-like requirements + raise TypeError("`data` must implement the ArrayProtocol") + + if data.ndim != len(self.dims): + raise IndexError("must specify a dim for every dimension in the data array") + + # data can be set, but the dims must still match/have the same meaning + return xr.DataArray(data, dims=self.dims) + + @property + def shape(self) -> dict[Hashable, int]: + """interpreted shape of the data""" + return {d: n for d, n in zip(self.dims, self.data.shape)} + + @property + def ndim(self) -> int: + """number of dims""" + return self.data.ndim + + @property + def dims(self) -> tuple[Hashable, ...]: + """dim names""" + # these are read-only and cannot be set after it's created + # the user should create a new NDGraphic if they need different dims + # I can't think of a usecase where we'd want to change the dims, and + # I think that would be complicated and probably and anti-pattern + return self._dims + + @property + def spatial_dims(self) -> tuple[Hashable, ...]: + """Spatial dims, **in order**""" + return self._spatial_dims + + @spatial_dims.setter + def spatial_dims(self, sdims: Sequence[Hashable]): + for dim in sdims: + if dim not in self.dims: + raise KeyError + + self._spatial_dims = tuple(sdims) + + @property + def tooltip(self) -> bool: + """ + whether or not a custom tooltip formatter method exists + """ + return False + + def tooltip_format(self, *args) -> str | None: + """ + Override in subclass to format custom tooltips + """ + return None + + @property + def slider_dims(self) -> set[Hashable]: + """Slider dim names, ``set(dims) - set(spatial_dims)""" + return set(self.dims) - set(self.spatial_dims) + + @property + def n_slider_dims(self): + """number of slider dims, i.e. len(slider_dims)""" + return len(self.slider_dims) + + @property + def window_funcs( + self, + ) -> dict[Hashable, tuple[WindowFuncCallable | None, int | float | None]]: + """get or set window functions, see docstring for details""" + return self._window_funcs + + @window_funcs.setter + def window_funcs( + self, + window_funcs: ( + dict[Hashable, tuple[WindowFuncCallable | None, int | float | None] | None] + | None + ), + ): + if window_funcs is None: + # tuple of (None, None) makes the checks easier in _apply_window_funcs + self._window_funcs = {d: (None, None) for d in self.slider_dims} + return + + for k in window_funcs.keys(): + if k not in self.slider_dims: + raise KeyError + + func = window_funcs[k][0] + size = window_funcs[k][1] + + if func is None: + pass + elif callable(func): + sig = inspect.signature(func) + + if "axis" not in sig.parameters or "keepdims" not in sig.parameters: + raise TypeError( + f"Each window function must take an `axis` and `keepdims` argument, " + f"you passed: {func} with the following function signature: {sig}" + ) + else: + raise TypeError( + f"`window_funcs` must be a dict mapping dim names to a tuple of the window function callable and " + f"window size, {'name': (func, size), ...}.\nYou have passed: {window_funcs}" + ) + + if size is None: + pass + + elif not isinstance(size, Real): + raise TypeError + + elif size < 0: + raise ValueError + + # fill in rest with None + for d in self.slider_dims: + if d not in window_funcs.keys(): + window_funcs[d] = (None, None) + + self._window_funcs = window_funcs + + @property + def window_order(self) -> tuple[Hashable, ...]: + """get or set dimension order in which window functions are applied""" + return self._window_order + + @window_order.setter + def window_order(self, order: tuple[Hashable] | None): + if order is None: + self._window_order = tuple() + return + + if not set(order).issubset(self.slider_dims): + raise ValueError( + f"each dimension in `window_order` must be a slider dim. You passed order: {order} " + f"and the slider dims are: {self.slider_dims}" + ) + + self._window_order = tuple(order) + + @property + def spatial_func(self) -> Callable[[xr.DataArray], xr.DataArray] | None: + """get or set the spatial function which is applied on the data slice after the window functions""" + return self._spatial_func + + @spatial_func.setter + def spatial_func( + self, func: Callable[[xr.DataArray], xr.DataArray] + ) -> Callable | None: + if not callable(func) and func is not None: + raise TypeError + + self._spatial_func = func + + @property + def slider_dim_transforms(self) -> dict[Hashable, Callable[[Any], int]]: + """get or set the slider_dim_transforms, see docstring for details""" + return self._index_mappings + + @slider_dim_transforms.setter + def slider_dim_transforms( + self, maps: dict[Hashable, Callable[[Any], int] | ArrayLike | None] | None + ): + if maps is None: + self._index_mappings = {d: identity for d in self.dims} + return + + for d in maps.keys(): + if d not in self.dims: + raise KeyError( + f"`index_mapping` provided for non-existent dimension: {d}, existing dims are: {self.dims}" + ) + + if isinstance(maps[d], ArrayProtocol): + # create a searchsorted mapping function automatically + maps[d] = maps[d].searchsorted + + elif maps[d] is None: + # assign identity mapping + maps[d] = identity + + for d in self.dims: + # fill in any unspecified maps with identity + if d not in maps.keys(): + maps[d] = identity + + self._index_mappings = maps + + def _ref_index_to_array_index(self, dim: str, ref_index: Any) -> int: + # wraps slider_dim_transforms, clamps between 0 and the array size in this dim + + # ref-space -> local-array-index transform + index = self.slider_dim_transforms[dim](ref_index) + + # clamp between 0 and array size in this dim + return max(min(index, self.shape[dim] - 1), 0) + + def _get_slider_dims_indexer(self, indices: dict[Hashable, Any]) -> dict[Hashable, slice]: + """ + Creates an xarray-compatible indexer dict mapping each slider_dim -> slice object. + + - If a window_func is defined for a dim and the dim appears in ``window_order``, + the slice is defined as: + start: index - half_window + stop: index + half_window + step: 1 + + It then applies the slider_dim_transform to the start and stop to map these values from reference-space to + the local array index, and then finally produces the slice object in local array indices. + + ex: if we have indices = {"time": 50.0}, a window size of 5.0s and the ``slider_dim_transform`` + for time is based on a sampling rate of 10Hz, the window in ref units is [45.0, 55.0], and the final + slice object would be ``slice(450, 550, 1)``. + + - If no window func is specified, the final slice just corresponds to that index as an int array-index. + + This exists separate from ``_apply_window_functions()`` because it is useful for debugging purposes. + + Parameters + ---------- + indices : dict[Hashable, Any], {dim: ref_value} + Reference-space values for each slider dim. Must contain an entry + for every slider dim; raises ``IndexError`` otherwise. + ex: {"time": 46.397, "depth": 23.24} + + Returns + ------- + dict[Hashable, slice] + Indexer compatible for ``xr.DataArray.isel()``, with one ``slice`` per + slider dim. These are array indices mapped from the reference space using + the given ``slider_dim_transform``. + + Raises + ------ + IndexError + If ``indices`` are not provided for every ``slider_dim`` + """ + + if set(indices.keys()) != set(self.slider_dims): + raise IndexError( + f"Must provide an index for all slider dims: {self.slider_dims}, you have provided: {indices.keys()}" + ) + + indexer = dict() + + # get only slider dims which are not also spatial dims (example: p dim for positional data) + # since `p` dim windowing is dealt with separately for positional data + slider_dims = set(self.slider_dims) - set(self.spatial_dims) + # go through each slider dim and accumulate slice objects + for dim in slider_dims: + # index for this dim in reference space + index_ref = indices[dim] + + if dim not in self.window_funcs.keys(): + wf, ws = None, None + else: + # get window func and size in reference units + wf, ws = self.window_funcs[dim] + + # if a window function exists for this dim, and it's specified in the window order + if (wf is not None) and (ws is not None) and (dim in self.window_order): + # half window in reference units + hw = ws / 2 + + # start in reference units + start_ref = index_ref - hw + # stop in ref units + stop_ref = index_ref + hw + + # map start and stop ref to array indices + start = self.slider_dim_transforms[dim](start_ref) + stop = self.slider_dim_transforms[dim](stop_ref) + + # clamp within array bounds + start = max(min(self.shape[dim] - 1, start), 0) + stop = max(min(self.shape[dim] - 1, stop), 0) + indexer[dim] = slice(start, stop, 1) + else: + # no window func for this dim, direct indexing + # index mapped to array index + index = self.slider_dim_transforms[dim](index_ref) + + # clamp within the bounds + start = max(min(self.shape[dim] - 1, index), 0) + + # stop index is just the start index + 1 + indexer[dim] = slice(start, start + 1, 1) + + return indexer + + def _apply_window_functions(self, indices: dict[Hashable, Any]) -> xr.DataArray: + """ + Slice the data at the given indices and apply window functions in the order specified by + ``window_order``. + + Parameters + ---------- + indices : dict[Hashable, Any], {dim: ref_value} + Reference-space values for each slider dim. + ex: {"time": 46.397, "depth": 23.24} + + Returns + ------- + xr.DataArray + Data slice after windowed indexing and window function application, + with the same dims as the original data. Dims of size ``1`` are not + squeezed. + + """ + indexer = self._get_slider_dims_indexer(indices) + + # get the data slice w.r.t. the desired windows, and get the underlying numpy array + # ``.values`` gives the numpy array + # there is significant overhead with passing xarray objects to numpy for things like np.mean() + # so convert to numpy, apply window functions, then convert back to xarray + # creating an xarray object from a numpy array has very little overhead, ~10 microseconds + array = self.data.isel(indexer).values + + # apply window funcs in the specified order + for dim in self.window_order: + if self.window_funcs[dim] is None: + continue + + func, _ = self.window_funcs[dim] + # ``keepdims=True`` is critical, any "collapsed" dims will be of size ``1``. + # Ex: if `array` is of shape [10, 512, 512] and we applied the np.mean() window func on the first dim + # ``keepdims`` means the resultant shape is [1, 512, 512] and NOT [512, 512] + # this is necessary for applying window functions on multiple dims separately and so that the + # dims names correspond after all the window funcs are applied. + array = func(array, axis=self.dims.index(dim), keepdims=True) + + return xr.DataArray(array, dims=self.dims) + + def get(self, indices: dict[Hashable, Any]): + raise NotImplementedError + + # TODO: html and pretty text repr # + # def _repr_html_(self) -> str: + # return ndp_fmt_html(self) + # + # def _repr_mimebundle_(self, **kwargs) -> dict: + # return { + # "text/plain": self._repr_text_(), + # "text/html": self._repr_html_(), + # } + + def _repr_text_(self): + if self.data is None: + return ( + f"{self.__class__.__name__}\n" + f"data is None, dims: {self.dims}" + ) + tab = "\t" + + wf = {k: v for k, v in self.window_funcs.items() if v != (None, None)} + + r = ( + f"{self.__class__.__name__}\n" + f"shape:\n\t{self.shape}\n" + f"dims:\n\t{self.dims}\n" + f"spatial_dims:\n\t{self.spatial_dims}\n" + f"slider_dims:\n\t{self.slider_dims}\n" + f"slider_dim_transforms:\n{textwrap.indent(pformat(self.slider_dim_transforms, width=120), prefix=tab)}\n" + ) + + if len(wf) > 0: + r += ( + f"window_funcs:\n{textwrap.indent(pformat(wf, width=120), prefix=tab)}\n" + f"window_order:\n\t{self.window_order}\n" + ) + + if self.spatial_func is not None: + r += f"spatial_func:\n\t{self.spatial_func}\n" + + return r + + +class NDGraphic: + def __init__( + self, + subplot: Subplot, + name: str | None, + ): + self._subplot = subplot + self._name = name + self._graphic: Graphic | None = None + + # used to indicate that the NDGraphic should ignore any requests to update the indices + # used by block_indices_ctx context manager, usecase is when the LinearSelector on timeseries + # NDGraphic changes the selection, it shouldn't change the graphic that it is on top of! Would + # also cause recursion + # It is also used by the @block_reentrance decorator which is on the ``NDGraphic.indices`` property setter + # this is also to block recursion + self._block_indices = False + + # user settable bool to make the graphic unresponsive to change in the ReferenceIndex + self._pause = False + + + def _create_graphic(self): + raise NotImplementedError + + @property + def pause(self) -> bool: + """if True, changes in the reference until it is set back to False""" + return self._pause + + @pause.setter + def pause(self, val: bool): + self._pause = bool(val) + + @property + def name(self) -> str | None: + """name given to the NDGraphic""" + return self._name + + @property + def processor(self) -> NDProcessor: + raise NotImplementedError + + @property + def graphic(self) -> Graphic: + raise NotImplementedError + + @property + def indices(self) -> dict[Hashable, Any]: + raise NotImplementedError + + @indices.setter + def indices(self, new: dict[Hashable, Any]): + raise NotImplementedError + + # aliases for easier access to processor properties + @property + def data(self) -> Any: + """ + get or set managed data. If setting with new data, the new data is interpreted + to have the same dims (i.e. same dim names and ordering of dims). + """ + return self.processor.data + + @data.setter + def data(self, data: Any): + self.processor.data = data + # create a new graphic when data has changed + if self.graphic is not None: + # it is already None if NDGraphic was initialized with no data + self._subplot.delete_graphic(self.graphic) + self._graphic = None + + self._create_graphic() + + # force a render + self.indices = self.indices + + @property + def shape(self) -> dict[Hashable, int]: + """interpreted shape of the data""" + return self.processor.shape + + @property + def ndim(self) -> int: + """number of dims""" + return self.processor.ndim + + @property + def dims(self) -> tuple[Hashable, ...]: + """dim names""" + return self.processor.dims + + @property + def spatial_dims(self) -> tuple[str, ...]: + # number of spatial dims for positional data is always 3 + # for image is 2 or 3, so it must be implemented in subclass + raise NotImplementedError + + @property + def slider_dims(self) -> set[Hashable]: + """the slider dims""" + return self.processor.slider_dims + + @property + def slider_dim_transforms(self) -> dict[Hashable, Callable[[Any], int]]: + return self.processor.slider_dim_transforms + + @slider_dim_transforms.setter + def slider_dim_transforms( + self, maps: dict[Hashable, Callable[[Any], int] | ArrayLike | None] | None + ): + """get or set the slider_dim_transforms, see docstring for details""" + self.processor.slider_dim_transforms = maps + # force a render + self.indices = self.indices + + @property + def window_funcs( + self, + ) -> dict[Hashable, tuple[WindowFuncCallable | None, int | float | None]]: + """get or set window functions, see docstring for details""" + return self.processor.window_funcs + + @window_funcs.setter + def window_funcs( + self, + window_funcs: ( + dict[Hashable, tuple[WindowFuncCallable | None, int | float | None] | None] + | None + ), + ): + self.processor.window_funcs = window_funcs + # force a render + self.indices = self.indices + + @property + def window_order(self) -> tuple[Hashable, ...]: + """get or set dimension order in which window functions are applied""" + return self.processor.window_order + + @window_order.setter + def window_order(self, order: tuple[Hashable] | None): + self.processor.window_order = order + # force a render + self.indices = self.indices + + @property + def spatial_func(self) -> Callable[[xr.DataArray], xr.DataArray] | None: + return self.processor.spatial_func + + @spatial_func.setter + def spatial_func( + self, func: Callable[[xr.DataArray], xr.DataArray] + ) -> Callable | None: + """get or set the spatial_func, see docstring for details""" + self.processor.spatial_func = func + # force a render + self.indices = self.indices + + # def _repr_text_(self) -> str: + # return ndg_fmt_text(self) + # + # def _repr_html_(self) -> str: + # return ndg_fmt_html(self) + # + # def _repr_mimebundle_(self, **kwargs) -> dict: + # return { + # "text/plain": self._repr_text_(), + # "text/html": self._repr_html_(), + # } + + def _repr_text_(self): + return f"graphic: {self.graphic.__class__.__name__}\n" f"processor:\n{self.processor}" + + +@contextmanager +def block_indices_ctx(ndgraphic: NDGraphic): + """ + Context manager for pausing an NDGraphic from updating indices + """ + ndgraphic._block_indices = True + + try: + yield + except Exception as e: + raise e from None # indices setter has raised, the line above and the lines below are probably more relevant! + finally: + ndgraphic._block_indices = False + + +def block_reentrance(setter): + # decorator to block re-entrance of indices setter + def set_indices_wrapper(self: NDGraphic, new_indices): + """ + wraps NDGraphic.indices + + self: NDGraphic instance + + new_indices: new indices to set + """ + # set_value is already in the middle of an execution, block re-entrance + if self._block_indices: + return + try: + # block re-execution of set_value until it has *fully* finished executing + self._block_indices = True + setter(self, new_indices) + except Exception as exc: + # raise original exception + raise exc # set_value has raised. The line above and the lines 2+ steps below are probably more relevant! + finally: + # set_value has finished executing, now allow future executions + self._block_indices = False + + return set_indices_wrapper diff --git a/fastplotlib/widgets/nd_widget/_index.py b/fastplotlib/widgets/nd_widget/_index.py new file mode 100644 index 000000000..fc51c345c --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_index.py @@ -0,0 +1,329 @@ +from __future__ import annotations + +from dataclasses import dataclass +from numbers import Number +from typing import Sequence, Any, Callable + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ._ndwidget import NDWidget + + +@dataclass +class RangeContinuous: + """ + A continuous reference range for a single slider dimension. + + Stores the (start, stop, step) in scientific units (ex: seconds, micrometers, + Hz). The imgui slider for this dimension uses these values to determine its + minimum and maximum bounds. The step size is used for the "next" and "previous" buttons. + + Parameters + ---------- + start : int or float + Minimum value of the range, inclusive. + + stop : int or float + Maximum value of the range, exclusive upper bound. + + step : int or float + Step size used for imgui step next/previous buttons + + Raises + ------ + IndexError + If ``start >= stop``. + + Examples + -------- + A time axis sampled at 1 ms resolution over 10 seconds: + + RangeContinuous(start=0, stop=10_000, step=1) + + A depth axis in micrometers with 0.5 µm steps: + + RangeContinuous(start=0.0, stop=500.0, step=0.5) + """ + + start: int | float + stop: int | float + step: int | float + + def __post_init__(self): + if self.start >= self.stop: + raise IndexError( + f"start must be less than stop, {self.start} !< {self.stop}" + ) + + def __getitem__(self, index: int): + """return the value at the index w.r.t. the step size""" + # if index is negative, turn to positive index + if index < 0: + raise ValueError("negative indexing not supported") + + val = self.start + (self.step * index) + if not self.start <= val <= self.stop: + raise IndexError( + f"index: {index} value: {val} out of bounds: [{self.start}, {self.stop}]" + ) + + return val + + @property + def range(self) -> int | float: + return self.stop - self.start + + +@dataclass +class RangeDiscrete: + # TODO: not implemented yet, placeholder until we have a clear usecase + options: Sequence[Any] + + def __getitem__(self, index: int): + if index > len(self.options): + raise IndexError + + return self.options[index] + + def __len__(self): + return len(self.options) + + +class ReferenceIndex: + def __init__( + self, + ref_ranges: dict[ + str, + tuple[Number, Number, Number] | tuple[Any] | RangeContinuous, + ], + ): + """ + Manages the shared reference index for one or more ``NDWidget`` instances. + + Stores the current index for each named slider dimension in reference-space + units (ex: seconds, depth in µm, Hz). Whenever an index is updated, every + ``NDGraphic`` in the manged ``NDWidgets`` are requested to render data at + the new indices. + + Each key in ``ref_ranges`` defines a slider dimension. When adding an + ``NDGraphic``, every dimension listed in ``dims`` must be either a spatial + dimension (listed in ``spatial_dims``) or a key in ``ref_ranges``. + If a dim is not spatial, it must have a corresponding reference range, + otherwise an error will be raised. + + You can also define conceptually identical but *independent* reference spaces + by using distinct names, ex: ``"time-1"`` and ``"time-2"`` for two recordings + that should be sycned independently. Each ``NDGraphic`` then declares the + specific "time-n" space that corresponds to its data, so the widget keeps the + two timelines decoupled. + + Parameters + ---------- + ref_ranges : dict[str, tuple], or a RangeContinuous + Mapping of dimension names to range specifications. A 3-tuple + ``(start, stop, step)`` creates a :class:`RangeContinuous`. A 1-tuple + ``(options,)`` creates a :class:`RangeDiscrete`. + + Attributes + ---------- + ref_ranges : dict[str, RangeContinuous | RangeDiscrete] + The reference range for each registered slider dimension. + + dims: set[str] + the set of "slider dims" + + Examples + -------- + Single shared time axis: + + ri = ReferenceIndex(ref_ranges={"time": (0, 1000, 1), "depth": (15, 35, 0.5)}) + ri["time"] = 500 # update one dim and re-render + ri.set({"time": 500, "depth": 10}) # update several dims atomically + + Two independent time axes for data from two different recording sessions: + + ri = ReferenceIndex({ + "time-1": (0, 3600, 1), # session 1 — 1 h at 1 s resolution + "time-s": (0, 1800, 1), # session 2 — 30 min at 1 s resolution + }) + + Each ``NDGraphic`` declares matching names for slider dims to indicate that these should be + synced across graphics. + + ndw[0, 0].add_nd_image(data_s1, ("time-s1", "row", "col"), ("row", "col")) + ndw[0, 1].add_nd_image(data_s2, ("time-s2", "row", "col"), ("row", "col")) + + """ + self._ref_ranges = dict() + self.push_dims(ref_ranges) + + # starting index for all dims + self._indices: dict[str, int | float | Any] = { + name: rr.start for name, rr in self._ref_ranges.items() + } + + self._indices_changed_handlers = set() + + self._ndwidgets: list[NDWidget] = list() + + @property + def ref_ranges(self) -> dict[str, RangeContinuous | RangeDiscrete]: + return self._ref_ranges + + @property + def dims(self) -> set[str]: + return set(self.ref_ranges.keys()) + + def _add_ndwidget_(self, ndw: NDWidget): + from ._ndwidget import NDWidget + + if not isinstance(ndw, NDWidget): + raise TypeError + + self._ndwidgets.append(ndw) + + def set(self, indices: dict[str, Any]): + for dim, value in indices.items(): + self._indices[dim] = self._clamp(dim, value) + + self._render_indices() + self._indices_changed() + + def _clamp(self, dim, value): + if isinstance(self.ref_ranges[dim], RangeContinuous): + return max( + min(value, self.ref_ranges[dim].stop - self.ref_ranges[dim].step), + self.ref_ranges[dim].start, + ) + + return value + + def _render_indices(self): + for ndw in self._ndwidgets: + for g in ndw.ndgraphics: + if g.data is None or g.pause: + continue + # only provide slider indices to the graphic + g.indices = {d: self._indices[d] for d in g.processor.slider_dims} + + def __getitem__(self, dim): + self._check_has_dim(dim) + return self._indices[dim] + + def __setitem__(self, dim, value): + self._check_has_dim(dim) + # set index for given dim and render + self._indices[dim] = self._clamp(dim, value) + self._render_indices() + self._indices_changed() + + def _check_has_dim(self, dim): + if dim not in self.dims: + raise KeyError( + f"provided dimension: {dim} has no associated ReferenceRange in this ReferenceIndex, valid dims in this ReferenceIndex are: {self.dims}" + ) + + def pop_dim(self): + pass + + def push_dims(self, ref_ranges: dict[ + str, + tuple[Number, Number, Number] | tuple[Any] | RangeContinuous, + ],): + + for name, r in ref_ranges.items(): + if isinstance(r, (RangeContinuous, RangeDiscrete)): + self._ref_ranges[name] = r + + elif len(r) == 3: + # assume start, stop, step + self._ref_ranges[name] = RangeContinuous(*r) + + elif len(r) == 1: + # assume just options + self._ref_ranges[name] = RangeDiscrete(*r) + + else: + raise ValueError( + f"ref_ranges must be a mapping of dimension names to range specifications, " + f"see the docstring, you have passed: {ref_ranges}" + ) + + def add_event_handler(self, handler: Callable, event: str = "indices"): + """ + Register an event handler that is called whenever the indices change. + + Parameters + ---------- + handler: Callable + callback function, must take a tuple of int as the only argument. This tuple will be the `indices` + + event: str, "indices" + the only supported valid is "indices" + + Example + ------- + + .. code-block:: py + + def my_handler(indices): + print(indices) + # example prints: {"t": 100, "z": 15} if the index has 2 reference spaces "t" and "z" + + # create an NDWidget + ndw = NDWidget(...) + + # add event handler + ndw.indices.add_event_handler(my_handler) + + """ + if event != "indices": + raise ValueError("`indices` is the only event supported by `GlobalIndex`") + + self._indices_changed_handlers.add(handler) + + def remove_event_handler(self, handler: Callable): + """Remove a registered event handler""" + self._indices_changed_handlers.remove(handler) + + def clear_event_handlers(self): + """Clear all registered event handlers""" + self._indices_changed_handlers.clear() + + def _indices_changed(self): + for f in self._indices_changed_handlers: + f(self._indices) + + def __iter__(self): + for index in self._indices.items(): + yield index + + def __len__(self): + return len(self._indices) + + def __eq__(self, other): + return self._indices == other + + def __repr__(self): + return f"Global Index: {self._indices}" + + def __str__(self): + return str(self._indices) + + +# TODO: Not sure if we'll actually do this here, just a placeholder for now +class SelectionVector: + @property + def selection(self): + pass + + @property + def graphics(self): + pass + + def add_graphic(self): + pass + + def remove_graphic(self): + pass diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py new file mode 100644 index 000000000..be319942d --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -0,0 +1,549 @@ +from collections.abc import Hashable, Sequence +import inspect +from typing import Callable, Any + +import numpy as np +from numpy.typing import ArrayLike +import xarray as xr + +from ...layouts import Subplot +from ...utils import subsample_array, ArrayProtocol, ARRAY_LIKE_ATTRS +from ...graphics import ImageGraphic, ImageVolumeGraphic +from ...tools import HistogramLUTTool +from ._base import NDProcessor, NDGraphic, WindowFuncCallable +from ._index import ReferenceIndex + + +class NDImageProcessor(NDProcessor): + def __init__( + self, + data: ArrayProtocol | None, + dims: Sequence[Hashable], + spatial_dims: ( + tuple[str, str] | tuple[str, str, str] + ), # must be in order! [rows, cols] | [z, rows, cols] + rgb_dim: str | None = None, + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, + window_order: tuple[int, ...] = None, + spatial_func: Callable[[ArrayLike], ArrayLike] = None, + compute_histogram: bool = True, + slider_dim_transforms=None, + ): + """ + ``NDProcessor`` subclass for n-dimensional image data. + + Produces 2-D or 3-D spatial slices for an ``ImageGraphic`` or ``ImageVolumeGraphic``. + + Parameters + ---------- + data: ArrayProtocol + array-like data, must have 2 or more dimensions + + dims: Sequence[str] + names for each dimension in ``data``. Dimensions not listed in + ``spatial_dims`` are treated as slider dimensions and **must** appear as + keys in the parent ``NDWidget``'s ``ref_ranges`` + Examples:: + ``("time", "depth", "row", "col")`` + ``("channels", "time", "xy")`` + ``("keypoints", "time", "xyz")`` + + A custom subclass's ``data`` object doesn't necessarily need to have these dims, but the ``get()`` method + must operate as if these dimensions exist and return an array that matches the spatial dimensions. + + dims: Sequence[str] + names for each dimension in ``data``. Dimensions not listed in + ``spatial_dims`` are treated as slider dimensions and **must** appear as + keys in the parent ``NDWidget``'s ``ref_ranges`` + Examples:: + ``("time", "depth", "row", "col")`` + ``("row", "col")`` + ``("other_dim", "depth", "time", "row", "col")`` + + dims in the array do not need to be in order, for example you can have a weird array where the dims are + interpreted as: ``("col", "depth", "row", "time")``, and then specify spatial_dims as ``("row", "col")`` + thanks to xarray magic =D. + + spatial_dims : tuple[str, str] | tuple[str, str, str] + The 2 or 3 spatial dimensions **in order**: ``(rows, cols)`` or ``(z, rows, cols)``. + This also determines whether an ``ImageGraphic`` or ``ImageVolumeGraphic`` is used for rendering. + The ordering determines how the Image/Volume is rendered. For example, if + you specify ``spatial_dims = ("rows", "cols")`` and then change it to ``("cols", "rows")``, it will display + the transpose. + + rgb_dim : str, optional + Name of an RGB(A) dimension, if present. + + compute_histogram: bool, default True + Compute a histogram of the data, disable if random-access of data is not blazing-fast (ex: data that uses + video codecs), or if histograms are not useful for this data. + + slider_dim_transforms : dict, optional + See :class:`NDProcessor`. + + window_funcs : dict, optional + See :class:`NDProcessor`. + + window_order : tuple, optional + See :class:`NDProcessor`. + + spatial_func : callable, optional + See :class:`NDProcessor`. + + See Also + -------- + NDProcessor : Base class with full parameter documentation. + NDImage : The ``NDGraphic`` that wraps this processor. + """ + + # set as False until data, window funcs stuff and spatial func is all set + self._compute_histogram = False + + # make sure rgb dim is size 3 or 4 + if rgb_dim is not None: + dim_index = dims.index(rgb_dim) + if data.shape[dim_index] not in (3, 4): + raise IndexError( + f"The size of the RGB(A) dim must be 3 | 4. You have specified an array of shape: {data.shape}, " + f"with dims: {dims}, and specified the ``rgb_dim`` name as: {rgb_dim} which has size " + f"{data.shape[dim_index]} != 3 | 4" + ) + + super().__init__( + data=data, + dims=dims, + spatial_dims=spatial_dims, + slider_dim_transforms=slider_dim_transforms, + window_funcs=window_funcs, + window_order=window_order, + spatial_func=spatial_func, + ) + + self.rgb_dim = rgb_dim + self._compute_histogram = compute_histogram + self._recompute_histogram() + + @property + def data(self) -> xr.DataArray | None: + """ + get or set managed data. If setting with new data, the new data is interpreted + to have the same dims (i.e. same dim names and ordering of dims). + """ + return self._data + + @data.setter + def data(self, data: ArrayProtocol): + self._data = self._validate_data(data) + self._recompute_histogram() + + def _validate_data(self, data: ArrayProtocol): + if not isinstance(data, ArrayProtocol): + # check that it's compatible with array and generally array-like + raise TypeError( + f"`data` arrays must have all of the following attributes to be sufficiently array-like:\n" + f"{ARRAY_LIKE_ATTRS}, or they must be `None`" + ) + + if data.ndim < 2: + # ndim < 2 makes no sense for image data + raise IndexError( + f"Image data must have a minimum of 2 dimensions, you have passed an array of shape: {data.shape}" + ) + + return xr.DataArray(data, dims=self.dims) + + @property + def rgb_dim(self) -> str | None: + """ + get or set the RGB(A) dim name, ``None`` if no RGB(A) dim exists + """ + return self._rgb + + @rgb_dim.setter + def rgb_dim(self, rgb: str | None): + if rgb is not None: + if rgb not in self.dims: + raise KeyError + + self._rgb = rgb + + @property + def compute_histogram(self) -> bool: + """get or set whether or not to compute the histogram""" + return self._compute_histogram + + @compute_histogram.setter + def compute_histogram(self, compute: bool): + if compute: + if not self._compute_histogram: + # compute a histogram + self._recompute_histogram() + self._compute_histogram = True + else: + self._compute_histogram = False + self._histogram = None + + @property + def histogram(self) -> tuple[np.ndarray, np.ndarray] | None: + """ + an estimate of the histogram of the data, (histogram_values, bin_edges). + + returns `None` if `compute_histogram` is `False` + """ + return self._histogram + + def get(self, indices: dict[str, Any]) -> ArrayLike | None: + """ + Get the data at the given index, process data through the window functions. + + Note that we do not use __getitem__ here since the index is a tuple specifying a single integer + index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. + + Parameters + ---------- + indices: tuple[int, ...] + Get the processed data at this index. Must provide a value for each dimension. + Example: get((100, 5)) + + """ + if len(self.slider_dims) > 0: + # there are dims in addition to the spatial dims + window_output = self._apply_window_functions(indices).squeeze() + else: + # no slider dims, use all the data + window_output = self.data + + if window_output.ndim != len(self.spatial_dims): + raise ValueError + + # apply spatial_func + if self.spatial_func is not None: + spatial_out = self._spatial_func(window_output) + if spatial_out.ndim != len(self.spatial_dims): + raise ValueError + + return spatial_out.transpose(*self.spatial_dims).values + + return window_output.transpose(*self.spatial_dims).values + + def _recompute_histogram(self): + """ + + Returns + ------- + (histogram_values, bin_edges) + + """ + if not self._compute_histogram or self.data is None: + self._histogram = None + return + + if self.spatial_func is not None: + # don't subsample spatial dims if a spatial function is used + # spatial functions often operate on the spatial dims, ex: a gaussian kernel + # so their results require the full spatial resolution, the histogram of a + # spatially subsampled image will be very different + ignore_dims = [self.dims.index(dim) for dim in self.spatial_dims] + else: + ignore_dims = None + + # TODO: account for window funcs + + sub = subsample_array(self.data, ignore_dims=ignore_dims) + + if isinstance(sub, xr.DataArray): + # can't do the isnan and isinf boolean indexing below on xarray + sub = sub.values + + sub_real = sub[~(np.isnan(sub) | np.isinf(sub))] + + self._histogram = np.histogram(sub_real, bins=100) + + +class NDImage(NDGraphic): + def __init__( + self, + ref_index: ReferenceIndex, + subplot: Subplot, + data: ArrayProtocol | None, + dims: Sequence[str], + spatial_dims: ( + tuple[str, str] | tuple[str, str, str] + ), # must be in order! [rows, cols] | [z, rows, cols] + rgb_dim: str | None = None, + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, + window_order: tuple[int, ...] = None, + spatial_func: Callable[[ArrayLike], ArrayLike] = None, + compute_histogram: bool = True, + slider_dim_transforms=None, + name: str = None, + ): + """ + ``NDGraphic`` subclass for n-dimensional image rendering. + + Wraps an :class:`NDImageProcessor` and manages either an ``ImageGraphic`` or``ImageVolumeGraphic``. + swaps automatically when :attr:`spatial_dims` is reassigned at runtime. Also + owns a ``HistogramLUTTool`` for interactive vmin, vmax adjustment. + + Every dimension that is *not* listed in ``spatial_dims`` becomes a slider + dimension. Each slider dim must have a ``ReferenceRange`` defined in the + ``ReferenceIndex`` of the parent ``NDWidget``. The widget uses this to direct + a change in the ``ReferenceIndex`` and update the graphics. + + Parameters + ---------- + ref_index : ReferenceIndex + The shared reference index that delivers slider updates to this graphic. + + subplot : Subplot + parent subplot the NDGraphic is in + + data : array-like or None + n-dimension image data array + + dims : sequence of hashable + Name for every dimension of ``data``, in order. Non-spatial dims must + match keys in ``ref_index``. + + ex: ``("time", "depth", "row", "col")`` — ``"time"`` and ``"depth"`` must + be present in ``ref_index``. + + spatial_dims : tuple[str, str] | tuple[str, str, str] + Spatial dimensions **in order**: ``(rows, cols)`` for 2-D images or + ``(z, rows, cols)`` for volumes. Controls whether an ``ImageGraphic`` or + ``ImageVolumeGraphic`` is used. + + rgb_dim : str, optional + Name of the RGB or channel dimension, if present. + + window_funcs : dict, optional + See :class:`NDProcessor`. + + window_order : tuple, optional + See :class:`NDProcessor`. + + spatial_func : callable, optional + See :class:`NDProcessor`. + + compute_histogram : bool, default ``True`` + Whether to initialize the ``HistogramLUTTool``. + + slider_dim_transforms : dict, optional + See :class:`NDProcessor`. + + name : str, optional + Name for the underlying graphic. + + See Also + -------- + NDImageProcessor : The processor that backs this graphic. + + """ + + if not (set(dims) - set(spatial_dims)).issubset(ref_index.dims): + raise IndexError( + f"all specified `dims` must either be a spatial dim or a slider dim " + f"specified in the NDWidget ref_ranges, provided dims: {dims}, " + f"spatial_dims: {spatial_dims}. Specified NDWidget ref_ranges: {ref_index.dims}" + ) + + super().__init__(subplot, name) + + self._ref_index = ref_index + + self._processor = NDImageProcessor( + data, + dims=dims, + spatial_dims=spatial_dims, + rgb_dim=rgb_dim, + window_funcs=window_funcs, + window_order=window_order, + spatial_func=spatial_func, + compute_histogram=compute_histogram, + slider_dim_transforms=slider_dim_transforms, + ) + + self._graphic: ImageGraphic | None = None + self._histogram_widget: HistogramLUTTool | None = None + + # create a graphic + self._create_graphic() + + @property + def processor(self) -> NDImageProcessor: + """NDProcessor that manages the data and produces data slices to display""" + return self._processor + + @property + def graphic( + self, + ) -> ImageGraphic | ImageVolumeGraphic: + """Underlying Graphic object used to display the current data slice""" + return self._graphic + + def _create_graphic(self): + # Creates an ``ImageGraphic`` or ``ImageVolumeGraphic`` based on the number of spatial dims, + # adds it to the subplot, and resets the camera and histogram. + + if self.processor.data is None: + # no graphic if data is None, useful for initializing in null states when we want to set data later + return + + # determine if we need a 2d image or 3d volume + # remove RGB spatial dim, ex: if we have an RGBA image of shape [512, 512, 4] we want to interpet this as + # 2D for images + # [30, 512, 512, 4] with an rgb dim is an RGBA volume which is also supported + match len(self.processor.spatial_dims) - int(bool(self.processor.rgb_dim)): + case 2: + cls = ImageGraphic + case 3: + cls = ImageVolumeGraphic + + # get the data slice for this index + # this will only have the dims specified by ``spatial_dims`` + data_slice = self.processor.get(self.indices) + + # create the new graphic + new_graphic = cls(data_slice) + + old_graphic = self._graphic + # check if we are replacing a graphic + # ex: swapping from 2D <-> 3D representation after ``spatial_dims`` was changed + if old_graphic is not None: + # carry over some attributes from old graphic + attrs = dict.fromkeys(["cmap", "interpolation", "cmap_interpolation"]) + for k in attrs: + attrs[k] = getattr(old_graphic, k) + + # delete the old graphic + self._subplot.delete_graphic(old_graphic) + + # set any attributes that we're carrying over like cmap + for attr, val in attrs.items(): + setattr(new_graphic, attr, val) + + self._graphic = new_graphic + + self._subplot.add_graphic(self._graphic) + + self._reset_camera() + self._reset_histogram() + + def _reset_histogram(self): + # reset histogram + if self.graphic is None: + return + + if not self.processor.compute_histogram: + # hide right dock if histogram not desired + self._subplot.docks["right"].size = 0 + return + + if self.processor.histogram: + if self._histogram_widget: + # histogram widget exists, update it + self._histogram_widget.histogram = self.processor.histogram + self._histogram_widget.images = self.graphic + if self._subplot.docks["right"].size < 1: + self._subplot.docks["right"].size = 80 + else: + # make hist tool + self._histogram_widget = HistogramLUTTool( + histogram=self.processor.histogram, + images=self.graphic, + name=f"hist-{hex(id(self.graphic))}", + ) + self._subplot.docks["right"].add_graphic(self._histogram_widget) + self._subplot.docks["right"].size = 80 + + self.graphic.reset_vmin_vmax() + + def _reset_camera(self): + # set camera to a nice position based on whether it's a 2D ImageGraphic or 3D ImageVolumeGraphic + if isinstance(self._graphic, ImageGraphic): + # set camera orthogonal to the xy plane, flip y axis + self._subplot.camera.set_state( + { + "position": [0, 0, -1], + "rotation": [0, 0, 0, 1], + "scale": [1, -1, 1], + "reference_up": [0, 1, 0], + "fov": 0, # orthographic projection + "depth_range": None, + } + ) + + self._subplot.controller = "panzoom" + self._subplot.axes.intersection = None + self._subplot.auto_scale() + + else: + # It's not an ImageGraphic, set perspective projection + self._subplot.camera.fov = 50 + self._subplot.controller = "orbit" + + # set all 3D dimension camera scales to positive since positive scales + # are typically used for looking at volumes + for dim in ["x", "y", "z"]: + if getattr(self._subplot.camera.local, f"scale_{dim}") < 0: + setattr(self._subplot.camera.local, f"scale_{dim}", 1) + + self._subplot.auto_scale() + + @property + def spatial_dims(self) -> tuple[str, str] | tuple[str, str, str]: + """get or set the spatial dims, see docstring for details""" + return self.processor.spatial_dims + + @spatial_dims.setter + def spatial_dims(self, dims: tuple[str, str] | tuple[str, str, str]): + self.processor.spatial_dims = dims + + # shape has probably changed, recreate graphic + self._create_graphic() + + @property + def indices(self) -> dict[Hashable, Any]: + """get or set the indices, managed by the ReferenceIndex, users usually don't want to set this manually""" + return {d: self._ref_index[d] for d in self.processor.slider_dims} + + @indices.setter + def indices(self, indices): + data_slice = self.processor.get(indices) + + self.graphic.data = data_slice + + @property + def compute_histogram(self) -> bool: + """whether or not to compute the histogram and display the HistogramLUTTool""" + return self.processor.compute_histogram + + @compute_histogram.setter + def compute_histogram(self, v: bool): + self.processor.compute_histogram = v + self._reset_histogram() + + @property + def histogram_widget(self) -> HistogramLUTTool: + """The histogram lut tool associated with this NDGraphic""" + return self._histogram_widget + + @property + def spatial_func(self) -> Callable[[xr.DataArray], xr.DataArray] | None: + """get or set the spatial_func, see docstring for details""" + return self.processor.spatial_func + + @spatial_func.setter + def spatial_func( + self, func: Callable[[xr.DataArray], xr.DataArray] + ) -> Callable | None: + self.processor.spatial_func = func + self.processor._recompute_histogram() + self._reset_histogram() + + def _tooltip_handler(self, graphic, pick_info): + # TODO: need to do this better + # get graphic within the collection + n_index = np.argwhere(self.graphic.graphics == graphic).item() + p_index = pick_info["vertex_index"] + return self.processor.tooltip_format(n_index, p_index) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/__init__.py b/fastplotlib/widgets/nd_widget/_nd_positions/__init__.py new file mode 100644 index 000000000..60703f8c2 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_positions/__init__.py @@ -0,0 +1,23 @@ +import importlib + +from ._nd_positions import NDPositions, NDPositionsProcessor + +class Extras: + pass + +ndp_extras = Extras() + + +for optional in ["pandas", "zarr"]: + try: + importlib.import_module(optional) + except ImportError: + pass + else: + module = importlib.import_module(f"._{optional}", "fastplotlib.widgets.nd_widget._nd_positions") + cls = getattr(module, f"NDPP_{optional.capitalize()}") + setattr( + ndp_extras, + f"NDPP_{optional.capitalize()}", + cls + ) diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py new file mode 100644 index 000000000..2d3ff2b9a --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_nd_positions.py @@ -0,0 +1,1144 @@ +from collections.abc import Callable, Hashable, Sequence +from functools import partial +from typing import Literal, Any, Type +from warnings import warn + +import numpy as np +from numpy.lib.stride_tricks import sliding_window_view +from numpy.typing import ArrayLike +import xarray as xr + +from ....layouts import Subplot +from ....graphics import ( + Graphic, + ImageGraphic, + LineGraphic, + LineStack, + LineCollection, + ScatterGraphic, + ScatterCollection, + ScatterStack, +) +from ....graphics.features.utils import parse_colors +from ....graphics.utils import pause_events +from ....graphics.selectors import LinearSelector +from .._base import ( + NDProcessor, + NDGraphic, + WindowFuncCallable, + block_reentrance, + block_indices_ctx, +) +from .._index import ReferenceIndex + +# types for the other features +FeatureCallable = Callable[[np.ndarray, slice], np.ndarray] +ColorsType = np.ndarray | FeatureCallable | None +MarkersType = Sequence[str] | np.ndarray | FeatureCallable | None +SizesType = Sequence[float] | np.ndarray | FeatureCallable | None + + +def default_cmap_transform_each(p: int, data_slice: np.ndarray, s: slice): + # create a cmap transform based on the `p` dim size + n_displayed = data_slice.shape[1] + + # linspace that's just normalized 0 - 1 within `p` dim size + return np.linspace( + start=s.start / p, + stop=s.stop / p, + num=n_displayed, + endpoint=False, # since we use a slice object for the displayed data, the last point isn't included + ) + + +class NDPositionsProcessor(NDProcessor): + _other_features = ["colors", "markers", "cmap_transform_each", "sizes"] + + def __init__( + self, + data: Any, + dims: Sequence[Hashable], + # TODO: allow stack_dim to be None and auto-add new dim of size 1 in get logic + spatial_dims: tuple[ + Hashable | None, Hashable, Hashable + ], # [stack_dim, n_datapoints, spatial_dim], IN ORDER!! + slider_dim_transforms: dict[str, Callable[[Any], int] | ArrayLike] = None, + display_window: int | float | None = 100, # window for n_datapoints dim only + max_display_datapoints: int = 1_000, + datapoints_window_func: tuple[Callable, str, int | float] | None = None, + colors: ColorsType = None, + markers: MarkersType = None, + cmap_transform_each: np.ndarray = None, + sizes: SizesType = None, + **kwargs, + ): + """ + ``NDProcessor`` subclass for n-dimensional positional and timeseries data. + + + The *datapoints* dimension is + simultaneously a slider dim and a spatial dim and is handled by a dedicated + :attr:`datapoints_window_func` rather than the general ``window_funcs`` + mechanism. + + + Parameters + ---------- + data + dims + spatial_dims + slider_dim_transforms + display_window + max_display_datapoints: int, default 1_000 + this is approximate since floor division is used to determine the step size of the current display window slice + datapoints_window_func: + Important note: if used, display_window is approximate and not exact due to padding from the window size + kwargs + """ + self._display_window = display_window + self._max_display_datapoints = max_display_datapoints + + super().__init__( + data=data, + dims=dims, + spatial_dims=spatial_dims, + slider_dim_transforms=slider_dim_transforms, + **kwargs, + ) + + self._datapoints_window_func = datapoints_window_func + + self.colors = colors + self.markers = markers + self.cmap_transform_each = cmap_transform_each + self.sizes = sizes + + def _check_shape_feature( + self, prop: str, check_shape: tuple[int, int] + ) -> tuple[int, int]: + # this function exists because it's used repeatedly for colors, markers, etc. + # shape for [l, p] dims must match, or l must be 1 + shape = tuple([self.shape[dim] for dim in self.spatial_dims[:2]]) + + if check_shape[1] != shape[1]: + raise IndexError( + f"shape of first two dims of {prop} must must be [l, p] or [1, p].\n" + f"required `p` dim shape is: {shape[1]}, {check_shape[1]} was provided" + ) + + if check_shape[0] != 1 and check_shape[0] != shape[0]: + raise IndexError( + f"shape of first two dims of {prop} must must be [l, p] or [1, p]\n" + f"required `l` dim shape is {shape[0]} | 1, {check_shape[0]} was provided" + ) + + return shape + + @property + def colors(self) -> ColorsType: + """ + A callable that dynamically creates colors for the current display window, or array of colors per-datapoint. + + Array must be of shape [l, p, 4] for unique colors per line/scatter, or [1, p, 4] for identical colors per + line/scatter. + + Callable must return an array of shape [l, pw, 4] or [1, pw, 4], where pw is the number of currently displayed + datapoints given the current display window. The callable receives the current data slice array, as well as the + slice object that corresponds to the current display window. + """ + return self._colors + + @colors.setter + def colors(self, new): + if callable(new): + # custom callable that creates the colors + self._colors = new + return + + if new is None: + self._colors = None + return + + # as array so we can check shape + new = np.asarray(new) + if new.ndim == 2: + # only [p, 4] provided, broadcast to [1, p, 4] + new = new[None] + + shape = self._check_shape_feature("colors", new.shape[:2]) + + if new.shape[0] == 1: + # same colors across all graphical elements + self._colors = parse_colors(new[0], n_colors=shape[1])[None] + + else: + # colors specified for each individual line/scatter + new_ = np.zeros(shape=(*self.data.shape[:2], 4), dtype=np.float32) + for i in range(shape[0]): + new_[i] = parse_colors(new[i], n_colors=shape[1]) + + self._colors = new_ + + @property + def markers(self) -> MarkersType: + """ + A callable that dynamically creates markers for the current display window, or array of markers per-datapoint. + + Array must be of shape [l, p] for unique markers per line/scatter, or [p,] or [1, p] for identical markers per + line/scatter. + + Callable must return an array of shape [l, pw], [1, pw], or [pw,] where pw is the number of currently displayed + datapoints given the current display window. The callable receives the current data slice array, as well as the + slice object that corresponds to the current display window. + """ + return self._markers + + @markers.setter + def markers(self, new: MarkersType): + if callable(new): + # custom callable that creates the markers dynamically + self._markers = new + return + + if new is None: + self._markers = None + return + + # as array so we can check shape + new = np.asarray(new) + + # if 1-dim, assume it's specifying markers over `p` dim, so set `l` dim to 1 + if new.ndim == 1: + new = new[None] + + self._check_shape_feature("markers", new.shape[:2]) + + self._markers = np.asarray(new) + + @property + def cmap_transform_each(self) -> np.ndarray | FeatureCallable | None: + return self._cmap_transform_each + + @cmap_transform_each.setter + def cmap_transform_each(self, new: np.ndarray | FeatureCallable | None): + """ + A callable that dynamically creates cmap transforms for the current display window, or array + of transforms per-datapoint. + + Array must be of shape [l, p] for unique transforms per line/scatter, or [p,] or [1, p] for identical markers + per line/scatter. + + Callable must return an array of shape [l, pw], [1, pw], or [pw,] where pw is the number of currently displayed + datapoints given the current display window. The callable receives the current data slice array, as well as the + slice object that corresponds to the current display window. + """ + if callable(new): + self._cmap_transform_each = new + return + + if new is None: + self._cmap_transform_each = None + return + + new = np.asarray(new) + + # if 1-dim, assume it's specifying sizes over `p` dim, set `l` dim to 1 + if new.ndim == 1: + new = new[None] + + self._check_shape_feature("cmap_transform_each", new.shape) + + self._cmap_transform_each = new + + @property + def sizes(self) -> SizesType: + return self._sizes + + @sizes.setter + def sizes(self, new: SizesType): + """ + A callable that dynamically creates sizes for the current display window, or array of sizes per-datapoint. + + Array must be of shape [l, p] for unique sizes per line/scatter, or [p,] or [1, p] for identical markers per + line/scatter. + + Callable must return an array of shape [l, pw], [1, pw], or [pw,] where pw is the number of currently displayed + datapoints given the current display window. The callable receives the current data slice array, as well as the + slice object that corresponds to the current display window. + """ + if callable(new): + # custom callable + self._sizes = new + return + + if new is None: + self._sizes = None + return + + new = np.array(new) + # if 1-dim, assume it's specifying sizes over `p` dim, set `l` dim to 1 + if new.ndim == 1: + new = new[None] + + self._check_shape_feature("sizes", new.shape) + + self._sizes = new + + @property + def spatial_dims(self) -> tuple[str, str, str]: + return self._spatial_dims + + @spatial_dims.setter + def spatial_dims(self, sdims: tuple[str, str, str]): + if len(sdims) != 3: + raise IndexError + + if not all([d in self.dims for d in sdims]): + raise KeyError + + self._spatial_dims = tuple(sdims) + + @property + def slider_dims(self) -> set[Hashable]: + # append `p` dim to slider dims + return tuple([*super().slider_dims, self.spatial_dims[1]]) + + @property + def display_window(self) -> int | float | None: + """display window in the reference units for the n_datapoints dim""" + return self._display_window + + @display_window.setter + def display_window(self, dw: int | float | None): + if dw is None: + self._display_window = None + + elif not isinstance(dw, (int, float)): + raise TypeError + + self._display_window = dw + + @property + def max_display_datapoints(self) -> int: + return self._max_display_datapoints + + @max_display_datapoints.setter + def max_display_datapoints(self, n: int): + if not isinstance(n, (int, np.integer)): + raise TypeError + if n < 2: + raise ValueError + + self._max_display_datapoints = n + + # TODO: validation for datapoints_window_func and size + @property + def datapoints_window_func(self) -> tuple[Callable, str, int | float] | None: + """ + Callable, str indicating which dims to apply window function along, window_size in reference space: + 'all', 'x', 'y', 'z', 'xyz', 'xy', 'xz', 'yz' + '""" + return self._datapoints_window_func + + @datapoints_window_func.setter + def datapoints_window_func(self, funcs: tuple[Callable, str, int | float]): + if len(funcs) != 3: + raise TypeError + + self._datapoints_window_func = tuple(funcs) + + def _get_dw_slice(self, indices: dict[str, Any]) -> slice: + # given indices, return slice required to obtain display window + + # n_datapoints dim name + # display_window acts on this dim + p_dim = self.spatial_dims[1] + + if self.display_window is None: + # just return everything + return slice(0, self.shape[p_dim]) + + if self.display_window == 0: + # just map p dimension at this index and return + index = self._ref_index_to_array_index(p_dim, indices[p_dim]) + return slice(index, index + 1) + + # half window size, in reference units + hw = self.display_window / 2 + + if self.datapoints_window_func is not None: + # add half datapoints_window_func size here, assumes the reference space is somewhat continuous + # and the display_window and datapoints window size map to their actual size values + hw += self.datapoints_window_func[2] / 2 + + # display window is in reference units, apply display window and then map to array indices + # start in reference units + start_ref = indices[p_dim] - hw + # stop in reference units + stop_ref = indices[p_dim] + hw + + # map to array indices + start = self._ref_index_to_array_index(p_dim, start_ref) + stop = self._ref_index_to_array_index(p_dim, stop_ref) + + if start >= stop: + stop = start + 1 + + w = stop - start + + # get step size + step = max(1, w // self.max_display_datapoints) + + return slice(start, stop, step) + + def _apply_dw_window_func( + self, array: xr.DataArray | np.ndarray + ) -> xr.DataArray | np.ndarray: + """ + Takes array where display window has already been applied and applies window functions on the `p` dim. + + Parameters + ---------- + array: np.ndarray + array of shape: [l, display_window, 2 | 3] + + Returns + ------- + np.ndarray + array with window functions applied along `p` dim + """ + if self.display_window == 0: + # can't apply window func when there is only 1 datapoint + return array + + p_dim = self.spatial_dims[1] + + # display window in array index space + if self.display_window is not None: + dw = self.slider_dim_transforms[p_dim](self.display_window) + + # step size based on max number of datapoints to render + step = max(1, dw // self.max_display_datapoints) + + # apply window function on the `p` n_datapoints dim + if ( + self.datapoints_window_func is not None + # if there are too many points to efficiently compute the window func, skip + # applying a window func also requires making a copy so that's a further performance hit + and (dw < self.max_display_datapoints * 2) + ): + # get windows + + # graphic_data will be of shape: [n, p, 2 | 3] + # where: + # n - number of lines, scatters, heatmap rows + # p - number of datapoints/samples + + # ws is in ref units + wf, apply_dims, ws = self.datapoints_window_func + + # map ws in ref units to array index + # min window size is 3 + ws = max(self._ref_index_to_array_index(p_dim, ws), 3) + + if ws % 2 == 0: + # odd size windows are easier to handle + ws += 1 + + hw = ws // 2 + start, stop = hw, array.shape[1] - hw + + # apply user's window func + # result will be of shape [n, p, 2 | 3] + if apply_dims == "all": + # windows will be of shape [n, p, 1 | 2 | 3, ws] + windows = sliding_window_view(array, ws, axis=-2) + return wf(windows, axis=-1)[:, ::step] + + # map user dims str to tuple of numerical dims + dims = tuple(map({"x": 0, "y": 1, "z": 2}.get, apply_dims)) + + # windows will be of shape [n, (p - ws + 1), 1 | 2 | 3, ws] + windows = sliding_window_view(array[..., dims], ws, axis=-2).squeeze() + + # make a copy because we need to modify it + array = array[:, start:stop].copy() + + # this reshape is required to reshape wf outputs of shape [n, p] -> [n, p, 1] only when necessary + array[..., dims] = wf(windows, axis=-1).reshape( + *array.shape[:-1], len(dims) + ) + + return array[:, ::step] + + step = max(1, array.shape[1] // self.max_display_datapoints) + + return array[:, ::step] + + def _apply_spatial_func( + self, array: xr.DataArray | np.ndarray + ) -> xr.DataArray | np.ndarray: + if self.spatial_func is not None: + return self.spatial_func(array) + + return array + + def _finalize_(self, array: xr.DataArray | np.ndarray) -> xr.DataArray | np.ndarray: + return self._apply_spatial_func(self._apply_dw_window_func(array)) + + def _get_other_features( + self, data_slice: np.ndarray, dw_slice: slice + ) -> dict[str, np.ndarray]: + other = dict.fromkeys(self._other_features) + for attr in self._other_features: + val = getattr(self, attr) + + if val is None: + continue + + if callable(val): + # if it's a callable, give it the data and display window slice, it must return the appropriate + # type of array for that graphic feature + val_sliced = val(data_slice, dw_slice) + + else: + # if no l dim, broadcast to [1, p] + if val.ndim == 1: + val = val[None] + + # apply current display window slice + val_sliced = val[:, dw_slice] + + # check if l dim size is 1 + if val_sliced.shape[0] == 1: + # broadcast across all graphical elements + n_graphics = self.shape[self.spatial_dims[0]] + val_sliced = np.broadcast_to( + val_sliced, shape=(n_graphics, *val_sliced.shape[1:]) + ) + + other[attr] = val_sliced + + return other + + def get(self, indices: dict[str, Any]) -> dict[str, np.ndarray]: + """ + slices through all slider dims and outputs an array that can be used to set graphic data + + Note that we do not use __getitem__ here since the index is a tuple specifying a single integer + index for each dimension. Slices are not allowed, therefore __getitem__ is not suitable here. + """ + + if len(self.slider_dims) > 1: + # there are slider dims in addition to the datapoints_dim + window_output = self._apply_window_functions(indices).squeeze() + else: + # no slider dims, use all the data + window_output = self.data + + # verify window output only has the spatial dims + if not set(window_output.dims) == set(self.spatial_dims): + raise IndexError + + # get slice obj for display window + dw_slice = self._get_dw_slice(indices) + + # data that will be used for the graphical representation + # a copy is made, if there were no window functions then this is a view of the original data + p_dim = self.spatial_dims[1] + + # slice the datapoints to be displayed in the graphic using the display window slice + # transpose to match spatial dims order, get numpy array, this is a view + graphic_data = window_output.isel({p_dim: dw_slice}).transpose( + *self.spatial_dims + ) + + data = self._finalize_(graphic_data).values + other = self._get_other_features(data, dw_slice) + + return { + "data": data, + **other, + } + + +class NDPositions(NDGraphic): + def __init__( + self, + ref_index: ReferenceIndex, + subplot: Subplot, + data: Any, + dims: Sequence[str], + spatial_dims: tuple[str, str, str], + *args, + graphic_type: Type[ + LineGraphic + | LineCollection + | LineStack + | ScatterGraphic + | ScatterCollection + | ScatterStack + | ImageGraphic + ], + processor: type[NDPositionsProcessor] = NDPositionsProcessor, + display_window: int = 10, + window_funcs: tuple[WindowFuncCallable | None] | None = None, + slider_dim_transforms: tuple[Callable[[Any], int] | None] | None = None, + max_display_datapoints: int = 1_000, + linear_selector: bool = False, + x_range_mode: Literal["fixed", "auto"] | None = None, + colors: ( + Sequence[str] | np.ndarray | Callable[[slice, np.ndarray], np.ndarray] + ) = None, + # TODO: cleanup how this cmap stuff works, require a cmap to be set per-graphic + # before allowing cmaps_transform, validate that stuff makes sense etc. + cmap: str = None, # across the line/scatter collection + cmap_each: Sequence[str] = None, # for each individual line/scatter + cmap_transform_each: np.ndarray = None, # for each individual line/scatter + markers: np.ndarray = None, # across the scatter collection, shape [l,] + markers_each: Sequence[str] = None, # for each individual scatter, shape [l, p] + sizes: np.ndarray = None, # across the scatter collection, shape [l,] + sizes_each: Sequence[float] = None, # for each individual scatter, shape [l, p] + thickness: np.ndarray = None, # for each line, shape [l,] + name: str = None, + timeseries: bool = False, + graphic_kwargs: dict = None, + processor_kwargs: dict = None, + ): + """ + Wraps an :class:`NDPositionsProcessor` and supports four interchangeable + graphical representations: ``LineStack``, ``LineCollection``, ``ScatterStack``, + and ``ScatterCollection``, as well as a heatmap view. For timeseries use-cases + it also manages a linear selector and automatically adjusts the view according + to the current x-range of the displayed data. + + Parameters + ---------- + ref_index + subplot + data + dims + spatial_dims + args + graphic_type + processor + display_window + window_funcs + slider_dim_transforms + max_display_datapoints + linear_selector + x_range_mode + colors + cmap + cmap_each + cmap_transform_each + markers + markers_each + sizes + sizes_each + thickness + name + graphic_kwargs + processor_kwargs + """ + + super().__init__(subplot, name) + + self._ref_index = ref_index + + if processor_kwargs is None: + processor_kwargs = dict() + + if graphic_kwargs is None: + self._graphic_kwargs = dict() + else: + self._graphic_kwargs = graphic_kwargs + + self._processor = processor( + data, + dims, + spatial_dims, + *args, + display_window=display_window, + max_display_datapoints=max_display_datapoints, + window_funcs=window_funcs, + slider_dim_transforms=slider_dim_transforms, + colors=colors, + markers=markers_each, + cmap_transform_each=cmap_transform_each, + sizes=sizes_each, + **processor_kwargs, + ) + + self._cmap = cmap + self._sizes = sizes + self._markers = markers + self._thickness = thickness + + self.cmap_each = cmap_each + self.cmap_transform_each = cmap_transform_each + + self._graphic_type = graphic_type + self._create_graphic() + + self._x_range_mode = None + self.x_range_mode = x_range_mode + self._last_x_range = np.array([0.0, 0.0], dtype=np.float32) + + self._timeseries = timeseries + # TODO: I think this is messy af, NDTimeseriesSubclass??? + if self._timeseries: + # makes some assumptions about positional data that apply only to timeseries representations + # probably don't want to maintain aspect + self._subplot.camera.maintain_aspect = False + + # auto x range modes make no sense for non-timeseries data + self.x_range_mode = x_range_mode + + if linear_selector: + self._linear_selector = LinearSelector( + 0, limits=(-np.inf, np.inf), edge_color="cyan" + ) + self._linear_selector.add_event_handler( + self._linear_selector_handler, "selection" + ) + self._subplot.add_graphic(self._linear_selector) + else: + self._linear_selector = None + else: + self._linear_selector = None + + @property + def processor(self) -> NDPositionsProcessor: + return self._processor + + @property + def graphic( + self, + ) -> ( + LineGraphic + | LineCollection + | LineStack + | ScatterGraphic + | ScatterCollection + | ScatterStack + | ImageGraphic + | None + ): + """LineStack or ImageGraphic for heatmaps""" + return self._graphic + + @property + def graphic_type( + self, + ) -> Type[ + LineGraphic + | LineCollection + | LineStack + | ScatterGraphic + | ScatterCollection + | ScatterStack + | ImageGraphic + ]: + return self._graphic_type + + @graphic_type.setter + def graphic_type(self, graphic_type): + if type(self.graphic) is graphic_type: + return + + self._subplot.delete_graphic(self._graphic) + self._graphic_type = graphic_type + self._create_graphic() + + @property + def spatial_dims(self) -> tuple[str, str, str]: + return self.processor.spatial_dims + + @spatial_dims.setter + def spatial_dims(self, dims: tuple[str, str, str]): + self.processor.spatial_dims = dims + # force re-render + self.indices = self.indices + + @property + def indices(self) -> dict[Hashable, Any]: + return {d: self._ref_index[d] for d in self.processor.slider_dims} + + @indices.setter + @block_reentrance + def indices(self, indices): + if self.data is None: + return + + new_features = self.processor.get(indices) + data_slice = new_features["data"] + + # TODO: set other graphic features, colors, sizes, markers, etc. + + if isinstance(self.graphic, (LineGraphic, ScatterGraphic)): + self.graphic.data[:, : data_slice.shape[-1]] = data_slice + + elif isinstance(self.graphic, (LineCollection, ScatterCollection)): + for l, g in enumerate(self.graphic.graphics): + new_data = data_slice[l] + if g.data.value.shape[0] != new_data.shape[0]: + # will replace buffer internally + g.data = new_data + else: + # if data are only xy, set only xy + g.data[:, : new_data.shape[1]] = new_data + + for feature in ["colors", "sizes", "markers"]: + value = new_features[feature] + + match value: + case None: + pass + case _: + if feature == "colors": + g.color_mode = "vertex" + + setattr(g, feature, value[l]) + + if self.cmap_each is not None: + match new_features["cmap_transform_each"]: + case None: + pass + case _: + setattr( + getattr(g, "cmap"), # ind_graphic.cmap + "transform", + new_features["cmap_transform_each"], + ) + + elif isinstance(self.graphic, ImageGraphic): + image_data, x0, x_scale = self._create_heatmap_data(data_slice) + self.graphic.data = image_data + self.graphic.offset = (x0, *self.graphic.offset[1:]) + self.graphic.scale = (x_scale, *self.graphic.scale[1:]) + + # TODO: I think this is messy af, NDTimeseriesSubclass??? + # x range of the data + xr = data_slice[0, 0, 0], data_slice[0, -1, 0] + if self.x_range_mode is not None: + self.graphic._plot_area.x_range = xr + + # if the update_from_view is polling, this prevents it from being called by setting the new last xrange + # in theory, but this doesn't seem to fully work yet, not a big deal right now can check later + self._last_x_range[:] = self.graphic._plot_area.x_range + + if self._linear_selector is not None: + with pause_events(self._linear_selector): # we don't want the linear selector change to update the indices + self._linear_selector.limits = xr + # linear selector acts on `p` dim + self._linear_selector.selection = indices[ + self.processor.spatial_dims[1] + ] + + def _linear_selector_handler(self, ev): + with block_indices_ctx(self): + # linear selector always acts on the `p` dim + self._ref_index[self.processor.spatial_dims[1]] = ev.info["value"] + + def _tooltip_handler(self, graphic, pick_info): + if isinstance(self.graphic, (LineCollection, ScatterCollection)): + # get graphic within the collection + n_index = np.argwhere(self.graphic.graphics == graphic).item() + p_index = pick_info["vertex_index"] + return self.processor.tooltip_format(n_index, p_index) + + def _create_graphic(self): + if self.data is None: + return + + new_features = self.processor.get(self.indices) + data_slice = new_features["data"] + + # store any cmap, sizes, thickness, etc. to assign to new graphic + graphic_attrs = dict() + for attr in ["cmap", "markers", "sizes", "thickness"]: + if attr in new_features.keys(): + if new_features[attr] is not None: + # markers and sizes defined for each line via processor takes priority + continue + + val = getattr(self, attr) + if val is not None: + graphic_attrs[attr] = val + + if issubclass(self._graphic_type, ImageGraphic): + # `d` dim must only have xy data to be interpreted as a heatmap, xyz can't become a timeseries heatmap + if self.processor.shape[self.processor.spatial_dims[-1]] != 2: + raise ValueError + + image_data, x0, x_scale = self._create_heatmap_data(data_slice) + self._graphic = self._graphic_type( + image_data, offset=(x0, 0, -1), scale=(x_scale, 1, 1) + ) + + else: + if issubclass(self._graphic_type, (LineStack, ScatterStack)): + kwargs = {"separation": 0.0, **self._graphic_kwargs} + else: + kwargs = self._graphic_kwargs + self._graphic = self._graphic_type(data_slice, **kwargs) + + for attr in graphic_attrs.keys(): + if hasattr(self._graphic, attr): + setattr(self._graphic, attr, graphic_attrs[attr]) + + if isinstance(self._graphic, (LineCollection, ScatterCollection)): + for l, g in enumerate(self.graphic.graphics): + for feature in ["colors", "sizes", "markers"]: + value = new_features[feature] + + match value: + case None: + pass + case _: + if feature == "colors": + g.color_mode = "vertex" + + setattr(g, feature, value[l]) + + if self.cmap_each is not None: + g.color_mode = "vertex" + g.cmap = self.cmap_each[l] + match new_features["cmap_transform_each"]: + case None: + pass + case _: + setattr( + getattr(g, "cmap"), # indv_graphic.cmap + "transform", + new_features["cmap_transform_each"], + ) + + if self.processor.tooltip: + if isinstance(self._graphic, (LineCollection, ScatterCollection)): + for g in self._graphic.graphics: + g.tooltip_format = partial(self._tooltip_handler, g) + + self._subplot.add_graphic(self._graphic) + + def _create_heatmap_data(self, data_slice) -> tuple[np.ndarray, float, float]: + """return [n_rows, n_cols] shape data from [n_timeseries, n_timepoints, xy] data""" + # assumes x vals in every row is the same, otherwise a heatmap representation makes no sense + # data slice is of shape [n_timeseries, n_timepoints, xy], where xy is x-y coordinates of each timeseries + x = data_slice[0, :, 0] # get x from just the first row + + # check if we need to interpolate + norm = np.linalg.norm(np.diff(np.diff(x))) / x.size + + if norm > 1e-6: + # x is not uniform upto float32 precision, must interpolate + x_uniform = np.linspace(x[0], x[-1], num=x.size) + y_interp = np.empty(shape=data_slice[..., 1].shape, dtype=np.float32) + + # this for loop is actually slightly faster than numpy.apply_along_axis() + for i in range(data_slice.shape[0]): + y_interp[i] = np.interp(x_uniform, x, data_slice[i, :, 1]) + + else: + # x is sufficiently uniform + y_interp = data_slice[..., 1] + + x0 = data_slice[0, 0, 0] + + # assume all x values are the same across all lines + # otherwise a heatmap representation makes no sense anyways + x_stop = x[-1] + x_scale = (x_stop - x0) / data_slice.shape[1] + + return y_interp, x0, x_scale + + @property + def display_window(self) -> int | float | None: + """display window in the reference units for the n_datapoints dim""" + return self.processor.display_window + + @display_window.setter + def display_window(self, dw: int | float | None): + self.processor.display_window = dw + + # force re-render + self.indices = self.indices + + @property + def datapoints_window_func(self) -> tuple[Callable, str, int | float] | None: + """ + Callable, str indicating which dims to apply window function along, window_size in reference space: + 'all', 'x', 'y', 'z', 'xyz', 'xy', 'xz', 'yz' + '""" + return self.processor.datapoints_window_func + + @datapoints_window_func.setter + def datapoints_window_func(self, funcs: tuple[Callable, str, int | float]): + self.processor.datapoints_window_func = funcs + + @property + def x_range_mode(self) -> Literal["fixed", "auto"] | None: + """x-range using a fixed window from the display window, or by polling the camera (auto)""" + return self._x_range_mode + + @x_range_mode.setter + def x_range_mode(self, mode: Literal[None, "fixed", "auto"]): + if self._x_range_mode == "auto": + # old mode was auto + self._subplot.remove_animation(self._update_from_view_range) + + if mode == "auto": + self._subplot.add_animations(self._update_from_view_range) + + self._x_range_mode = mode + + def _update_from_view_range(self): + if self._graphic is None: + return + + xr = self._subplot.x_range + + # the floating point error near zero gets nasty here + if np.allclose(xr, self._last_x_range, atol=1e-14): + return + + last_width = abs(self._last_x_range[1] - self._last_x_range[0]) + self._last_x_range[:] = xr + + new_width = abs(xr[1] - xr[0]) + new_index = (xr[0] + xr[1]) / 2 + + if (new_index == self._ref_index[self.processor.spatial_dims[1]]) and ( + last_width == new_width + ): + return + + self.processor.display_window = new_width + # set the `p` dim on the global index vector + self._ref_index[self.processor.spatial_dims[1]] = new_index + + @property + def cmap(self) -> str | None: + return self._cmap + + @cmap.setter + def cmap(self, new: str | None): + if new is None: + # just set a default + if isinstance(self.graphic, (LineCollection, ScatterCollection)): + self.graphic.colors = "w" + else: + self.graphic.cmap = "plasma" + + self._cmap = None + return + + self._graphic.cmap = new + self._cmap = new + # force a re-render + self.indices = self.indices + + @property + def cmap_each(self) -> np.ndarray[str] | None: + # per-line/scatter + return self._cmap_each + + @cmap_each.setter + def cmap_each(self, new: Sequence[str] | None): + if new is None: + self._cmap_each = None + return + + if isinstance(new, str): + new = [new] + + new = np.asarray(new) + + if new.ndim != 1: + raise ValueError + + l_dim_size = self.processor.shape[self.processor.spatial_dims[0]] + # same cmap for all if size == 1, or specific cmap for each in `l` dim + if new.size != 1 and new.size != l_dim_size: + raise ValueError + + self._cmap_each = np.broadcast_to(new, shape=(l_dim_size,)) + + @property + def cmap_transform_each(self) -> np.ndarray | None: + # PER line/scatter, only allowed after `cmaps` is set. + return self.processor.cmap_transform_each + + @cmap_transform_each.setter + def cmap_transform_each(self, new: np.ndarray | FeatureCallable | None): + if new is None: + self.processor.cmap_transform_each = None + + if self.cmap_each is None: + self.processor.cmap_transform_each = None + warn("must set `cmap_each` before `cmap_transform_each`") + return + + if new is None and self.cmap_each is not None: + # default transform is just a transform based on the `p` dim size + new = partial(default_cmap_transform_each, self.shape[self.spatial_dims[1]]) + + self.processor.cmap_transform_each = new + + @property + def markers(self) -> str | Sequence[str] | None: + return self._markers + + @markers.setter + def markers(self, new: str | None): + if not isinstance(self.graphic, ScatterCollection): + self._markers = None + return + + if new is None: + # just set a default + new = "circle" + + self.graphic.markers = new + self._markers = new + # force a re-render + self.indices = self.indices + + @property + def sizes(self) -> float | Sequence[float] | None: + return self._sizes + + @sizes.setter + def sizes(self, new: float | Sequence[float] | None): + if not isinstance(self.graphic, ScatterCollection): + self._sizes = None + return + + if new is None: + # just set a default + new = 5.0 + + self.graphic.sizes = new + self._sizes = new + # force a re-render + self.indices = self.indices + + @property + def thickness(self) -> float | Sequence[float] | None: + return self._thickness + + @thickness.setter + def thickness(self, new: float | Sequence[float] | None): + if not isinstance(self.graphic, LineCollection): + self._thickness = None + return + + if new is None: + # just set a default + new = 2.0 + + self.graphic.thickness = new + self._thickness = new + # force a re-render + self.indices = self.indices diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py new file mode 100644 index 000000000..1b94e1cbc --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_pandas.py @@ -0,0 +1,98 @@ +from typing import Any + +import numpy as np +import pandas as pd + +from ._nd_positions import NDPositionsProcessor + + +class NDPP_Pandas(NDPositionsProcessor): + def __init__( + self, + data: pd.DataFrame, + spatial_dims: tuple[str, str, str], # [l, p, d] dims in order + columns: list[tuple[str, str] | tuple[str, str, str]], + tooltip_columns: list[str] = None, + **kwargs, + ): + self._columns = columns + + if tooltip_columns is not None: + if len(tooltip_columns) != len(self.columns): + raise ValueError + self._tooltip_columns = tooltip_columns + self._tooltip = True + else: + self._tooltip_columns = None + self._tooltip = False + + self._dims = spatial_dims + + super().__init__( + data=data, + dims=spatial_dims, + spatial_dims=spatial_dims, + **kwargs, + ) + + self._dw_slice = None + + @property + def data(self) -> pd.DataFrame: + return self._data + + def _validate_data(self, data: pd.DataFrame): + if not isinstance(data, pd.DataFrame): + raise TypeError + + return data + + @property + def columns(self) -> list[tuple[str, str] | tuple[str, str, str]]: + return self._columns + + @property + def dims(self) -> tuple[str, str, str]: + return self._dims + + @property + def shape(self) -> dict[str, int]: + # n_graphical_elements, n_timepoints, 2 + return {self.dims[0]: len(self.columns), self.dims[1]: self.data.index.size, self.dims[2]: 2} + + @property + def ndim(self) -> int: + return len(self.shape) + + @property + def tooltip(self) -> bool: + return self._tooltip + + def tooltip_format(self, n: int, p: int): + # datapoint index w.r.t. full data + p += self._dw_slice.start + return str(self.data[self._tooltip_columns[n]][p]) + + def get(self, indices: dict[str, Any]) -> dict[str, np.ndarray]: + # TODO: LOD by using a step size according to max_p + # TODO: Also what to do if display_window is None and data + # hasn't changed when indices keeps getting set, cache? + + # assume no additional slider dims + self._dw_slice = self._get_dw_slice(indices) + gdata_shape = len(self.columns), self._dw_slice.stop - self._dw_slice.start, 3 + + graphic_data = np.zeros(shape=gdata_shape, dtype=np.float32) + + for i, col in enumerate(self.columns): + graphic_data[i, :, :len(col)] = np.column_stack( + [self.data[c][self._dw_slice] for c in col] + ) + + data = self._finalize_(graphic_data) + other = self._get_other_features(data, self._dw_slice) + + return { + "data": data, + **other, + } diff --git a/fastplotlib/widgets/nd_widget/_nd_positions/_zarr.py b/fastplotlib/widgets/nd_widget/_nd_positions/_zarr.py new file mode 100644 index 000000000..fb3bb7015 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_positions/_zarr.py @@ -0,0 +1,4 @@ +# placeholder + +class NDPP_Zarr: + pass diff --git a/fastplotlib/widgets/nd_widget/_ndw_subplot.py b/fastplotlib/widgets/nd_widget/_ndw_subplot.py new file mode 100644 index 000000000..6666b3fc1 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_ndw_subplot.py @@ -0,0 +1,122 @@ +from collections.abc import Callable +from typing import Literal, Sequence, Hashable + +import numpy as np + +from ... import ScatterCollection, ScatterStack, LineCollection, LineStack, ImageGraphic +from ...layouts import Subplot +from ...utils import ArrayProtocol +from . import NDImage, NDPositions +from ._base import NDGraphic, WindowFuncCallable + + +class NDWSubplot: + """ + Entry point for adding ``NDGraphic`` objects to a subplot of an ``NDWidget``. + + Accessed via ``ndw[row, col]`` or ``ndw["subplot_name"]``. + Each ``add_nd_<...>`` method constructs the appropriate ``NDGraphic``, registers it with the parent + ``ReferenceIndex``, appends it to this subplot and returns the ``NDGraphic`` instance to the user. + + Note: ``NDWSubplot`` is not meant to be constructed directly, it only exists as part of an ``NDWidget`` + """ + def __init__(self, ndw, subplot: Subplot): + self.ndw = ndw + self._subplot = subplot + + self._nd_graphics = list() + + @property + def nd_graphics(self) -> tuple[NDGraphic]: + """all the NDGraphic instance in this subplot""" + return tuple(self._nd_graphics) + + def __getitem__(self, key): + # get a specific NDGraphic by index or name + if isinstance(key, (int, np.integer)): + return self.nd_graphics[key] + + for g in self.nd_graphics: + if g.name == key: + return g + + else: + raise KeyError(f"NDGraphc with given key not found: {key}") + + def add_nd_image( + self, + data: ArrayProtocol | None, + dims: Sequence[Hashable], + spatial_dims: ( + tuple[str, str] | tuple[str, str, str] + ), # must be in order! [rows, cols] | [z, rows, cols] + rgb_dim: str | None = None, + window_funcs: tuple[WindowFuncCallable | None, ...] | WindowFuncCallable = None, + window_order: tuple[int, ...] = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] = None, + compute_histogram: bool = True, + slider_dim_transforms=None, + name: str = None, + ): + nd = NDImage(self.ndw.indices, self._subplot, data=data, + dims=dims, + spatial_dims=spatial_dims, + rgb_dim=rgb_dim, + window_funcs=window_funcs, + window_order=window_order, + spatial_func=spatial_func, + compute_histogram=compute_histogram, + slider_dim_transforms=slider_dim_transforms, + name=name, + ) + + self._nd_graphics.append(nd) + return nd + + def add_nd_scatter(self, *args, **kwargs): + # TODO: better func signature here, send all kwargs to processor_kwargs + nd = NDPositions( + self.ndw.indices, + self._subplot, + *args, + graphic_type=ScatterCollection, + **kwargs, + ) + + self._nd_graphics.append(nd) + return nd + + def add_nd_timeseries( + self, + *args, + graphic_type: type[ + LineCollection | LineStack | ScatterStack | ImageGraphic + ] = LineStack, + x_range_mode: Literal["fixed", "auto"] | None = "auto", + **kwargs, + ): + nd = NDPositions( + self.ndw.indices, + self._subplot, + *args, + graphic_type=graphic_type, + linear_selector=True, + x_range_mode=x_range_mode, + timeseries=True, + **kwargs, + ) + + self._nd_graphics.append(nd) + return nd + + def add_nd_lines(self, *args, **kwargs): + nd = NDPositions( + self.ndw.indices, + self._subplot, + *args, + graphic_type=LineCollection, + **kwargs, + ) + + self._nd_graphics.append(nd) + return nd diff --git a/fastplotlib/widgets/nd_widget/_ndwidget.py b/fastplotlib/widgets/nd_widget/_ndwidget.py new file mode 100644 index 000000000..9ddfa8986 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_ndwidget.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import Any, Optional + +from ._index import RangeContinuous, RangeDiscrete, ReferenceIndex +from ._ndw_subplot import NDWSubplot +from ._ui import NDWidgetUI, RightClickMenu +from ...layouts import ImguiFigure, Subplot + + +class NDWidget: + def __init__(self, ref_ranges: dict[str, tuple], ref_index: Optional[ReferenceIndex] = None, **kwargs): + if ref_index is None: + self._indices = ReferenceIndex(ref_ranges) + else: + self._indices = ref_index + + self._indices._add_ndwidget_(self) + + self._figure = ImguiFigure(std_right_click_menu=RightClickMenu, **kwargs) + self._figure.std_right_click_menu.set_nd_widget(self) + + self._subplots_nd: dict[Subplot, NDWSubplot] = dict() + for subplot in self.figure: + self._subplots_nd[subplot] = NDWSubplot(self, subplot) + + # hard code the expected height so that the first render looks right in tests, docs etc. + ui_size = 57 + (50 * len(self.indices)) + + self._sliders_ui = NDWidgetUI(self.figure, ui_size, self) + self.figure.add_gui(self._sliders_ui) + + @property + def figure(self) -> ImguiFigure: + return self._figure + + @property + def indices(self) -> ReferenceIndex: + return self._indices + + @indices.setter + def indices(self, new_indices: dict[str, int | float | Any]): + self._indices.set(new_indices) + + @property + def ranges(self) -> dict[str, RangeContinuous | RangeDiscrete]: + return self._indices.ref_ranges + + @property + def ndgraphics(self): + gs = list() + for subplot in self._subplots_nd.values(): + gs.extend(subplot.nd_graphics) + + return tuple(gs) + + def __getitem__(self, key: str | tuple[int, int] | Subplot): + if not isinstance(key, Subplot): + key = self.figure[key] + return self._subplots_nd[key] + + def show(self, **kwargs): + return self.figure.show(**kwargs) + + def close(self): + self.figure.close() diff --git a/fastplotlib/widgets/nd_widget/_repr_formatter.py b/fastplotlib/widgets/nd_widget/_repr_formatter.py new file mode 100644 index 000000000..0569f1004 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_repr_formatter.py @@ -0,0 +1,599 @@ +from __future__ import annotations + +import html +from collections.abc import Callable +from typing import Any + + +_RESET = "\033[0m" +_BOLD = "\033[1m" +_DIM = "\033[2m" + +_C = { + "title": "\033[38;5;75m", # sky-blue + "spatial": "\033[38;5;114m", # sage-green + "slider": "\033[38;5;215m", # soft-orange + "label": "\033[38;5;246m", # mid-grey + "value": "\033[38;5;252m", # near-white + "section": "\033[38;5;68m", # steel-blue + "muted": "\033[38;5;240m", # dark-grey + "warn": "\033[38;5;222m", # amber +} + + +def _c(key: str, text: str) -> str: + return f"{_C[key]}{text}{_RESET}" + + +def _callable_name(f: Callable | None) -> str: + if f is None: + return "—" + module = getattr(f, "__module__", "") or "" + qname = getattr(f, "__qualname__", None) or getattr(f, "__name__", repr(f)) + if module and not module.startswith("__"): + short = module.split(".")[-1] + return f"{short}.{qname}" + return qname + + +def ndprocessor_fmt_txt(processor) -> str: + """ + Returns a colored, ascii box + """ + lines: list[str] = [] + + cls = type(processor).__name__ + lines.append(_c("title", _BOLD + cls) + _RESET) + lines.append(_c("muted", "─" * 72)) + + lines.append(_c("section", " Dimensions")) + + header = ( + f" {'dim':<14}{'size':>6} {'role':<10} {'window_func size':<26} index_mapping" + ) + lines.append(_c("label", header)) + lines.append(_c("muted", " " + "─" * 70)) + + for dim in processor.dims: + size = processor.shape[dim] + is_sp = dim in processor.spatial_dims + role_s = (_c("spatial", f"{'spatial':<10}") if is_sp + else _c("slider", f"{'slider':<10}")) + + # window_func - size column + if not is_sp: + wf, ws = processor.window_funcs.get(dim, (None, None)) + if wf is not None and ws is not None: + win_s = _c("value", f"{_callable_name(wf)}") + _c("muted", f" - {ws}") + else: + win_s = _c("muted", "—") + else: + win_s = "" + + # index_mapping column (slider dims only; skip identity) + if not is_sp: + imap = processor.index_mappings.get(dim) + iname = getattr(imap, "__name__", "") if imap is not None else "" + if iname != "identity" and imap is not None: + idx_s = _c("value", _callable_name(imap)) + else: + idx_s = _c("muted", "—") + else: + idx_s = "" + + # pad win_s to fixed visible width (strip ANSI for measuring) + import re + _ansi_re = re.compile(r"\033\[[^m]*m") + win_visible = len(_ansi_re.sub("", win_s)) + win_pad = win_s + " " * max(0, 26 - win_visible) + + line = ( + f" {_c('value', f'{str(dim):<14}')}" + f"{_c('label', f'{size:>6}')} " + f"{role_s} {win_pad} {idx_s}" + ) + lines.append(line) + + # window order + if processor.window_order: + lines.append("") + order_s = " → ".join(str(d) for d in processor.window_order) + lines.append(f" {_c('section', 'Window order')} {_c('value', order_s)}") + + # spatial func + if processor.spatial_func is not None: + lines.append("") + lines.append( + f" {_c('section', 'Spatial func')} " + f"{_c('value', _callable_name(processor.spatial_func))}" + ) + + lines.append(_c("muted", "─" * 72)) + return "\n".join(lines) + + +def ndgraphic_fmt_txt(ndg) -> str: + """Text repr for NDGraphic.""" + cls = type(ndg).__name__ + gcls = type(ndg.graphic).__name__ if ndg.graphic is not None else "—" + name = ndg.name or "—" + + header = ( + f"{_c('title', _BOLD + cls)}{_RESET} " + f"{_c('muted', '·')} " + f"{_c('section', 'graphic')} {_c('value', gcls)} " + f"{_c('muted', '·')} " + f"{_c('section', 'name')} {_c('value', name)}\n" + ) + + proc_block = ndprocessor_fmt_txt(ndg.processor) + # indent processor block + indented = "\n".join(" " + l for l in proc_block.splitlines()) + return header + indented + +_CSS = """ + +""" + + +def _h(s: Any) -> str: + """html-escape a stringified value""" + return html.escape(str(s)) + + +def _badge(role: str) -> str: + cls = "fpl-badge-spatial" if role == "spatial" else "fpl-badge-slider" + return f'{role}' + + +def _code(s: str) -> str: + return f"{_h(s)}" + + +def _section(title: str, content_html: str, count: str = "", open_: bool = True) -> str: + open_attr = " open" if open_ else "" + count_badge = ( + f'{_h(count)}' if count else "" + ) + return ( + f'
' + f'' + f'{_h(title)}' + f'{count_badge}' + f'' + f'{content_html}' + f'
' + ) + + +def _dim_rows_html(proc) -> str: + rows = [] + for dim in proc.dims: + size = proc.shape[dim] + is_sp = dim in proc.spatial_dims + badge = _badge("spatial" if is_sp else "slider") + + # window_func - size column + if not is_sp: + wf, ws = proc.window_funcs.get(dim, (None, None)) + if wf is not None and ws is not None: + win_td = ( + f'' + f'{_code(_callable_name(wf))}' + f'-' + f'{_code(str(ws))}' + f'' + ) + else: + win_td = '—' + else: + win_td = '' + + # index_mapping column (slider dims only; hide identity) + if not is_sp: + imap = proc.index_mappings.get(dim) + if imap is not None: + idx_td = f'{_code(_callable_name(imap))}' + else: + idx_td = '—' + else: + idx_td = '' + + rows.append( + f'' + f'{_h(str(dim))}' + f'{size:,}' + f'{badge}' + f'{win_td}' + f'{idx_td}' + f'' + ) + + # column header row + header = ( + f'' + f'dim' + f'size' + f'role' + f'window_func - size' + f'index_mapping' + f'' + ) + + table = ( + '' + '' + '' + '' + '' + + header + + "".join(rows) + + "
" + ) + return table + + +def _footer_kv(pairs: list[tuple[str, str]]) -> str: + """Always-visible key/value rows rendered below the dim table.""" + inner = "" + for k, v in pairs: + inner += ( + f'' + f'' + ) + return f'' + + +def _kv_list_html(pairs: list[tuple[str, str]]) -> str: + inner = "" + for k, v in pairs: + inner += ( + f'
{_h(k)}
' + f'
{v}
' + ) + return f'
{inner}
' + + +def _html_processor(proc) -> str: + cls = _h(type(proc).__name__) + + # header + ndim_pill = ( + f'' + f'{proc.ndim}D' + ) + header = ( + f'
' + f'{cls}' + f'{ndim_pill}' + f'
' + ) + + # dims section (always open) + dim_content = _dim_rows_html(proc) + sections = _section("Dimensions", dim_content, + count=str(proc.ndim), open_=True) + + # always-visible footer rows + footer_pairs: list[tuple[str, str]] = [] + + if proc.window_order: + chain = " → ".join( + f'{_h(str(d))}' + if i > 0 else _h(str(d)) + for i, d in enumerate(proc.window_order) + ) + footer_pairs.append(("window order", f'{chain}')) + + if proc.spatial_func is not None: + footer_pairs.append(("spatial func", _code(_callable_name(proc.spatial_func)))) + + if footer_pairs: + sections += _footer_kv(footer_pairs) + + body = f'
{sections}
' + return f'{_CSS}
{header}{body}
' + + +def ndgraphic_fmt_html(ndg) -> str: + cls = _h(type(ndg).__name__) + gcls = _h(type(ndg.graphic).__name__) if ndg.graphic is not None else "—" + name = _h(ndg.name or "—") + + graphic_pill = f'graphic: {gcls}' + name_pill = f'name: {name}' + + header = ( + f'
' + f'{cls}' + f'·' + f'{graphic_pill}{name_pill}' + f'
' + ) + + # embed processor repr (without its own outer box) inside a section + proc_inner = _dim_rows_html(ndg.processor) + sections = _section("Processor · Dimensions", proc_inner, open_=True) + + footer_pairs: list[tuple[str, str]] = [] + + if ndg.processor.window_order: + chain = " → ".join( + f'{_h(str(d))}' + if i > 0 else _h(str(d)) + for i, d in enumerate(ndg.processor.window_order) + ) + footer_pairs.append(("window order", f'{chain}')) + + if ndg.processor.spatial_func is not None: + footer_pairs.append(("spatial func", _code(_callable_name(ndg.processor.spatial_func)))) + + if footer_pairs: + sections += _footer_kv(footer_pairs) + + body = f'
{sections}
' + return f'{_CSS}
{header}{body}
' + +class ReprMixin: + """ + Mixin that provides: + • __repr__ → coloured ANSI text (terminal / plain REPL) + • _repr_html_ → rich HTML (Jupyter) + • _repr_mimebundle_ → both, so Jupyter picks the richest format + + Subclasses must implement _repr_text_() and _repr_html_() themselves OR + rely on the dispatch below which checks the concrete type. + """ + + def _repr_text_(self) -> str: + # lazy import avoids circular; swap for a direct call in your module + if _is_ndgraphic(self): + return ndgraphic_fmt_txt(self) + return ndprocessor_fmt_txt(self) + + def _repr_html_(self) -> str: + return ndgraphic_fmt_html(self) + return _html_processor(self) + + def __repr__(self) -> str: + return self._repr_text_() + + def _repr_mimebundle_(self, **kwargs) -> dict: + return { + "text/plain": self._repr_text_(), + "text/html": self._repr_html_(), + } + + +def _is_ndgraphic(obj) -> bool: + """duck-type check: does this object have a .graphic and .processor?""" + return hasattr(obj, "graphic") and hasattr(obj, "processor") \ No newline at end of file diff --git a/fastplotlib/widgets/nd_widget/_ui.py b/fastplotlib/widgets/nd_widget/_ui.py new file mode 100644 index 000000000..2855c7063 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_ui.py @@ -0,0 +1,287 @@ +import os +from time import perf_counter + +import numpy as np +from imgui_bundle import imgui, imgui_ctx, icons_fontawesome_6 as fa + +from ...graphics import ( + ScatterCollection, + ScatterStack, + LineCollection, + LineStack, + ImageGraphic, + ImageVolumeGraphic, +) +from ...utils import quick_min_max +from ...layouts import Subplot +from ...ui import EdgeWindow, StandardRightClickMenu +from ._index import RangeContinuous +from ._base import NDGraphic +from ._nd_positions import NDPositions +from ._nd_image import NDImage + +position_graphic_types = [ScatterCollection, ScatterStack, LineCollection, LineStack, ImageGraphic] + + +class NDWidgetUI(EdgeWindow): + def __init__(self, figure, size, ndwidget): + super().__init__( + figure=figure, + size=size, + title="NDWidget controls", + location="bottom", + window_flags=imgui.WindowFlags_.no_collapse + | imgui.WindowFlags_.no_resize + | imgui.WindowFlags_.no_title_bar, + ) + self._ndwidget = ndwidget + + ref_ranges = self._ndwidget.ranges + + # whether or not a dimension is in play mode + self._playing = {dim: False for dim in ref_ranges.keys()} + + # approximate framerate for playing + self._fps = {dim: 20 for dim in ref_ranges.keys()} + + # framerate converted to frame time + self._frame_time = {dim: 1 / 20 for dim in ref_ranges.keys()} + + # last timepoint that a frame was displayed from a given dimension + self._last_frame_time = {dim: perf_counter() for dim in ref_ranges.keys()} + + # loop playback + self._loop = {dim: False for dim in ref_ranges.keys()} + + # auto-plays the ImageWidget's left-most dimension in docs galleries + if "DOCS_BUILD" in os.environ.keys(): + if os.environ["DOCS_BUILD"] == "1": + self._playing[0] = True + self._loop = True + + self._max_display_windows: dict[NDGraphic, float | int] = dict() + + def _set_index(self, dim, index): + if index >= self._ndwidget.ranges[dim].stop: + if self._loop[dim]: + index = self._ndwidget.ranges[dim].start + else: + index = self._ndwidget.ranges[dim].stop + self._playing[dim] = False + + self._ndwidget.indices[dim] = index + + def update(self): + now = perf_counter() + + for dim, current_index in self._ndwidget.indices: + # push id since we have the same buttons for each dim + imgui.push_id(f"{self._id_counter}_{dim}") + + rr = self._ndwidget.ranges[dim] + + if self._playing[dim]: + # show pause button if playing + if imgui.button(label=fa.ICON_FA_PAUSE): + # if pause button clicked, then set playing to false + self._playing[dim] = False + + # if in play mode and enough time has elapsed w.r.t. the desired framerate, increment the index + if now - self._last_frame_time[dim] >= self._frame_time[dim]: + self._set_index(dim, current_index + rr.step) + self._last_frame_time[dim] = now + + else: + # we are not playing, so display play button + if imgui.button(label=fa.ICON_FA_PLAY): + # if play button is clicked, set last frame time to 0 so that index increments on next render + self._last_frame_time[dim] = 0 + # set playing to True since play button was clicked + self._playing[dim] = True + + imgui.same_line() + # step back one frame button + if imgui.button(label=fa.ICON_FA_BACKWARD_STEP) and not self._playing[dim]: + self._set_index(dim, current_index - rr.step) + + imgui.same_line() + # step forward one frame button + if imgui.button(label=fa.ICON_FA_FORWARD_STEP) and not self._playing[dim]: + self._set_index(dim, current_index + rr.step) + + imgui.same_line() + # stop button + if imgui.button(label=fa.ICON_FA_STOP): + self._playing[dim] = False + self._last_frame_time[dim] = 0 + self._ndwidget.indices[dim] = rr.start + + imgui.same_line() + # loop checkbox + _, self._loop[dim] = imgui.checkbox( + label=fa.ICON_FA_ROTATE, v=self._loop[dim] + ) + if imgui.is_item_hovered(0): + imgui.set_tooltip("loop playback") + + imgui.same_line() + imgui.text("framerate :") + imgui.same_line() + imgui.set_next_item_width(100) + # framerate int entry + fps_changed, value = imgui.input_int( + label="fps", v=self._fps[dim], step_fast=5 + ) + if imgui.is_item_hovered(0): + imgui.set_tooltip( + "framerate is approximate and less reliable as it approaches your monitor refresh rate" + ) + if fps_changed: + if value < 1: + value = 1 + if value > 50: + value = 50 + self._fps[dim] = value + self._frame_time[dim] = 1 / value + + imgui.text(str(dim)) + imgui.same_line() + # so that slider occupies full width + imgui.set_next_item_width(self.width * 0.85) + + if isinstance(rr, RangeContinuous): + changed, new_index = imgui.slider_float( + v=current_index, + v_min=rr.start, + v_max=rr.stop - rr.step, + label=f"##{dim}", + ) + + # TODO: refactor all this stuff, make fully fledged UI + if changed: + self._ndwidget.indices[dim] = new_index + + elif imgui.is_item_hovered(): + if imgui.is_key_pressed(imgui.Key.right_arrow): + self._set_index(dim, current_index + rr.step) + + elif imgui.is_key_pressed(imgui.Key.left_arrow): + self._set_index(dim, current_index - rr.step) + + imgui.pop_id() + + +class RightClickMenu(StandardRightClickMenu): + def __init__(self, figure): + self._ndwidget = None + self._ndgraphic_windows = set() + + super().__init__(figure=figure) + + def set_nd_widget(self, ndw): + self._ndwidget = ndw + + def _extra_menu(self): + if self._ndwidget is None: + return + + if imgui.begin_menu("ND Graphics"): + subplot = self.get_subplot() + for ndg in self._ndwidget[subplot].nd_graphics: + name = ndg.name if ndg.name is not None else hex(id(ndg)) + if imgui.menu_item( + f"{name}", "", False + )[0]: + self._ndgraphic_windows.add(ndg) + + imgui.end_menu() + + def update(self): + super().update() + + for ndg in list(self._ndgraphic_windows): # set -> list so we can change size during iteration + name = ndg.name if ndg.name is not None else hex(id(ndg)) + subplot = ndg.graphic._plot_area + imgui.set_next_window_size((0, 0)) + _, open = imgui.begin(f"subplot: {subplot.name}, {name}", True) + + if isinstance(ndg, NDPositions): + self._draw_nd_pos_ui(subplot, ndg) + + elif isinstance(ndg, NDImage): + self._draw_nd_image_ui(subplot, ndg) + + _, ndg.pause = imgui.checkbox("pause", ndg.pause) + + if not open: + self._ndgraphic_windows.remove(ndg) + + imgui.end() + + def _draw_nd_image_ui(self, subplot, nd_image: NDImage): + _min, _max = quick_min_max(nd_image.graphic.data.value) + changed, vmin = imgui.slider_float( + "vmin", nd_image.graphic.vmin, v_min=_min, v_max=_max + ) + if changed: + nd_image.graphic.vmin = vmin + + changed, vmax = imgui.slider_float( + "vmax", nd_image.graphic.vmax, v_min=_min, v_max=_max + ) + if changed: + nd_image.graphic.vmax = vmax + + changed, new_gamma = imgui.slider_float( + "gamma", nd_image.graphic._material.gamma, 0.01, 5 + ) + if changed: + nd_image.graphic._material.gamma = new_gamma + + def _draw_nd_pos_ui(self, subplot: Subplot, nd_graphic: NDPositions): + for i, cls in enumerate(position_graphic_types): + if imgui.radio_button(cls.__name__, type(nd_graphic.graphic) is cls): + nd_graphic.graphic_type = cls + subplot.auto_scale() + + changed, val = imgui.checkbox( + "use display window", nd_graphic.display_window is not None + ) + + p_dim = nd_graphic.processor.spatial_dims[1] + + if changed: + if not val: + nd_graphic.display_window = None + else: + # pick a value 10% of the reference range + nd_graphic.display_window = self._ndwidget.ranges[p_dim].range * 0.1 + + if nd_graphic.display_window is not None: + if isinstance(nd_graphic.display_window, (int, np.integer)): + slider = imgui.slider_int + input_ = imgui.input_int + type_ = int + else: + slider = imgui.slider_float + input_ = imgui.input_float + type_ = float + + changed, new = slider( + "display window", + v=nd_graphic.display_window, + v_min=type_(0), + v_max=type_(self._ndwidget.ranges[p_dim].stop * 0.1), + ) + + if changed: + nd_graphic.display_window = new + + options = [None, "fixed", "auto"] + changed, option = imgui.combo( + "x-range mode", + options.index(nd_graphic.x_range_mode), + [str(o) for o in options], + ) + if changed: + nd_graphic.x_range_mode = options[option] diff --git a/pyproject.toml b/pyproject.toml index 73dfd7ee3..b91b168c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ keywords = [ requires-python = ">= 3.10" dependencies = [ "numpy>=1.23.0", - "pygfx==0.15.3", + "pygfx==0.16.0", "wgpu", # Let pygfx constrain the wgpu version "cmap>=0.1.3", # (this comment keeps this list multiline in VSCode) @@ -46,6 +46,7 @@ notebook = [ "jupyter-rfb>=0.5.1", "ipywidgets>=8.0.0,<9", "sidecar", + "simplejpeg", ] tests = [ "pytest", @@ -58,7 +59,8 @@ tests = [ "ome-zarr", ] imgui = ["wgpu[imgui]"] -dev = ["fastplotlib[docs,notebook,tests,imgui]"] +ndwidget = ["wgpu[imgui]", "xarray"] +dev = ["fastplotlib[docs,notebook,tests,imgui,ndwidget]"] [project.urls] Homepage = "https://www.fastplotlib.org/" diff --git a/tests/test_colors_buffer_manager.py b/tests/test_colors_buffer_manager.py index 7b1aef16a..f9d56189e 100644 --- a/tests/test_colors_buffer_manager.py +++ b/tests/test_colors_buffer_manager.py @@ -48,10 +48,10 @@ def test_int(test_graphic): data = generate_positions_spiral_data("xyz") if test_graphic == "line": - graphic = fig[0, 0].add_line(data=data) + graphic = fig[0, 0].add_line(data=data, color_mode="vertex") elif test_graphic == "scatter": - graphic = fig[0, 0].add_scatter(data=data) + graphic = fig[0, 0].add_scatter(data=data, color_mode="vertex") colors = graphic.colors global EVENT_RETURN_VALUE @@ -98,10 +98,10 @@ def test_tuple(test_graphic, slice_method): data = generate_positions_spiral_data("xyz") if test_graphic == "line": - graphic = fig[0, 0].add_line(data=data) + graphic = fig[0, 0].add_line(data=data, color_mode="vertex") elif test_graphic == "scatter": - graphic = fig[0, 0].add_scatter(data=data) + graphic = fig[0, 0].add_scatter(data=data, color_mode="vertex") colors = graphic.colors global EVENT_RETURN_VALUE @@ -190,10 +190,10 @@ def test_slice(color_input, slice_method: dict, test_graphic: bool): data = generate_positions_spiral_data("xyz") if test_graphic == "line": - graphic = fig[0, 0].add_line(data=data) + graphic = fig[0, 0].add_line(data=data, color_mode="vertex") elif test_graphic == "scatter": - graphic = fig[0, 0].add_scatter(data=data) + graphic = fig[0, 0].add_scatter(data=data, color_mode="vertex") colors = graphic.colors diff --git a/tests/test_markers_buffer_manager.py b/tests/test_markers_buffer_manager.py index 65ead392e..488bed194 100644 --- a/tests/test_markers_buffer_manager.py +++ b/tests/test_markers_buffer_manager.py @@ -46,10 +46,10 @@ def test_create_buffer(test_graphic): if test_graphic: fig = fpl.Figure() - scatter = fig[0, 0].add_scatter(data, markers=MARKERS1) + scatter = fig[0, 0].add_scatter(data, markers=MARKERS1, uniform_marker=False) vertex_markers = scatter.markers assert isinstance(vertex_markers, VertexMarkers) - assert vertex_markers.buffer is scatter.world_object.geometry.markers + assert vertex_markers._fpl_buffer is scatter.world_object.geometry.markers else: vertex_markers = VertexMarkers(MARKERS1, len(data)) @@ -68,7 +68,7 @@ def test_int(test_graphic, index: int): if test_graphic: fig = fpl.Figure() - scatter = fig[0, 0].add_scatter(data, markers=MARKERS1) + scatter = fig[0, 0].add_scatter(data, markers=MARKERS1, uniform_marker=False) scatter.add_event_handler(event_handler, "markers") vertex_markers = scatter.markers else: @@ -108,7 +108,7 @@ def test_slice(test_graphic, slice_method): if test_graphic: fig = fpl.Figure() - scatter = fig[0, 0].add_scatter(data, markers=MARKERS1) + scatter = fig[0, 0].add_scatter(data, markers=MARKERS1, uniform_marker=False) scatter.add_event_handler(event_handler, "markers") vertex_markers = scatter.markers diff --git a/tests/test_point_rotations_buffer_manager.py b/tests/test_point_rotations_buffer_manager.py index ec5fdbe0f..50ee88984 100644 --- a/tests/test_point_rotations_buffer_manager.py +++ b/tests/test_point_rotations_buffer_manager.py @@ -35,7 +35,7 @@ def test_create_buffer(test_graphic): scatter = fig[0, 0].add_scatter(data, point_rotation_mode="vertex", point_rotations=ROTATIONS1) vertex_rotations = scatter.point_rotations assert isinstance(vertex_rotations, VertexRotations) - assert vertex_rotations.buffer is scatter.world_object.geometry.rotations + assert vertex_rotations._fpl_buffer is scatter.world_object.geometry.rotations else: vertex_rotations = VertexRotations(ROTATIONS1, len(data)) diff --git a/tests/test_positions_data_buffer_manager.py b/tests/test_positions_data_buffer_manager.py index e2582d4ba..cc550abf0 100644 --- a/tests/test_positions_data_buffer_manager.py +++ b/tests/test_positions_data_buffer_manager.py @@ -57,7 +57,7 @@ def test_int(test_graphic): graphic = fig[0, 0].add_scatter(data=data) points = graphic.data - assert graphic.data.buffer is graphic.world_object.geometry.positions + assert graphic.data._fpl_buffer is graphic.world_object.geometry.positions global EVENT_RETURN_VALUE graphic.add_event_handler(event_handler, "data") else: diff --git a/tests/test_positions_graphics.py b/tests/test_positions_graphics.py index 31c001888..4bc93b626 100644 --- a/tests/test_positions_graphics.py +++ b/tests/test_positions_graphics.py @@ -37,12 +37,12 @@ def test_sizes_slice(): @pytest.mark.parametrize("graphic_type", ["line", "scatter"]) @pytest.mark.parametrize("colors", [None, *generate_color_inputs("b")]) -@pytest.mark.parametrize("uniform_color", [True, False]) -def test_uniform_color(graphic_type, colors, uniform_color): +@pytest.mark.parametrize("color_mode", ["uniform", "vertex"]) +def test_color_mode(graphic_type, colors, color_mode): fig = fpl.Figure() kwargs = dict() - for kwarg in ["colors", "uniform_color"]: + for kwarg in ["colors", "color_mode"]: if locals()[kwarg] is not None: # add to dict of arguments that will be passed kwargs[kwarg] = locals()[kwarg] @@ -54,7 +54,7 @@ def test_uniform_color(graphic_type, colors, uniform_color): elif graphic_type == "scatter": graphic = fig[0, 0].add_scatter(data=data, **kwargs) - if uniform_color: + if color_mode == "uniform": assert isinstance(graphic._colors, UniformColor) assert isinstance(graphic.colors, pygfx.Color) if colors is None: @@ -130,17 +130,17 @@ def test_positions_graphics_data( @pytest.mark.parametrize("graphic_type", ["line", "scatter"]) @pytest.mark.parametrize("colors", [None, *generate_color_inputs("r")]) -@pytest.mark.parametrize("uniform_color", [None, False]) +@pytest.mark.parametrize("color_mode", ["vertex"]) def test_positions_graphic_vertex_colors( graphic_type, colors, - uniform_color, + color_mode, ): # test different ways of passing vertex colors fig = fpl.Figure() kwargs = dict() - for kwarg in ["colors", "uniform_color"]: + for kwarg in ["colors", "color_mode"]: if locals()[kwarg] is not None: # add to dict of arguments that will be passed kwargs[kwarg] = locals()[kwarg] @@ -153,10 +153,9 @@ def test_positions_graphic_vertex_colors( graphic = fig[0, 0].add_scatter(data=data, **kwargs) # color per vertex - # uniform colors is default False, or set to False - assert isinstance(graphic._colors, VertexColors) - assert isinstance(graphic.colors, VertexColors) - assert len(graphic.colors) == len(graphic.data) + assert isinstance(graphic._colors, VertexColors) + assert isinstance(graphic.colors, VertexColors) + assert len(graphic.colors) == len(graphic.data) if colors is None: # default @@ -179,7 +178,7 @@ def test_positions_graphic_vertex_colors( @pytest.mark.parametrize("graphic_type", ["line", "scatter"]) @pytest.mark.parametrize("colors", [None, *generate_color_inputs("r")]) -@pytest.mark.parametrize("uniform_color", [None, False]) +@pytest.mark.parametrize("color_mode", ["auto", "vertex"]) @pytest.mark.parametrize("cmap", ["jet"]) @pytest.mark.parametrize( "cmap_transform", [None, [3, 5, 2, 1, 0, 6, 9, 7, 4, 8], np.arange(9, -1, -1)] @@ -187,7 +186,7 @@ def test_positions_graphic_vertex_colors( def test_cmap( graphic_type, colors, - uniform_color, + color_mode, cmap, cmap_transform, ): @@ -195,7 +194,7 @@ def test_cmap( fig = fpl.Figure() kwargs = dict() - for kwarg in ["cmap", "cmap_transform", "colors", "uniform_color"]: + for kwarg in ["cmap", "cmap_transform", "colors", "color_mode"]: if locals()[kwarg] is not None: # add to dict of arguments that will be passed kwargs[kwarg] = locals()[kwarg] @@ -220,7 +219,8 @@ def test_cmap( # make sure buffer is identical # cmap overrides colors argument - assert graphic.colors.buffer is graphic.cmap.buffer + # use __repr__.__self__ to get the real reference from the cmap feature instead of the weakref proxy + assert graphic.colors._fpl_buffer is graphic.cmap.buffer.__repr__.__self__ npt.assert_almost_equal(graphic.cmap.value, truth) npt.assert_almost_equal(graphic.colors.value, truth) @@ -261,14 +261,14 @@ def test_cmap( "colors", [None, *generate_color_inputs("multi")] ) # cmap arg overrides colors @pytest.mark.parametrize( - "uniform_color", [True] # none of these will work with a uniform buffer + "color_mode", ["uniform"] # none of these will work with a uniform buffer ) -def test_incompatible_cmap_color_args(graphic_type, cmap, colors, uniform_color): +def test_incompatible_cmap_color_args(graphic_type, cmap, colors, color_mode): # test incompatible cmap args fig = fpl.Figure() kwargs = dict() - for kwarg in ["cmap", "colors", "uniform_color"]: + for kwarg in ["cmap", "colors", "color_mode"]: if locals()[kwarg] is not None: # add to dict of arguments that will be passed kwargs[kwarg] = locals()[kwarg] @@ -276,24 +276,24 @@ def test_incompatible_cmap_color_args(graphic_type, cmap, colors, uniform_color) data = generate_positions_spiral_data("xy") if graphic_type == "line": - with pytest.raises(TypeError): + with pytest.raises(ValueError): graphic = fig[0, 0].add_line(data=data, **kwargs) elif graphic_type == "scatter": - with pytest.raises(TypeError): + with pytest.raises(ValueError): graphic = fig[0, 0].add_scatter(data=data, **kwargs) @pytest.mark.parametrize("graphic_type", ["line", "scatter"]) @pytest.mark.parametrize("colors", [*generate_color_inputs("multi")]) @pytest.mark.parametrize( - "uniform_color", [True] # none of these will work with a uniform buffer + "color_mode", ["uniform"] # none of these will work with a uniform buffer ) -def test_incompatible_color_args(graphic_type, colors, uniform_color): +def test_incompatible_color_args(graphic_type, colors, color_mode): # test incompatible color args fig = fpl.Figure() kwargs = dict() - for kwarg in ["colors", "uniform_color"]: + for kwarg in ["colors", "color_mode"]: if locals()[kwarg] is not None: # add to dict of arguments that will be passed kwargs[kwarg] = locals()[kwarg] @@ -301,16 +301,15 @@ def test_incompatible_color_args(graphic_type, colors, uniform_color): data = generate_positions_spiral_data("xy") if graphic_type == "line": - with pytest.raises(TypeError): + with pytest.raises(ValueError): graphic = fig[0, 0].add_line(data=data, **kwargs) elif graphic_type == "scatter": - with pytest.raises(TypeError): + with pytest.raises(ValueError): graphic = fig[0, 0].add_scatter(data=data, **kwargs) @pytest.mark.parametrize("sizes", [None, 5.0, np.linspace(3, 8, 10, dtype=np.float32)]) -@pytest.mark.parametrize("uniform_size", [None, False]) -def test_sizes(sizes, uniform_size): +def test_sizes(sizes): # test scatter sizes fig = fpl.Figure() @@ -322,7 +321,7 @@ def test_sizes(sizes, uniform_size): data = generate_positions_spiral_data("xy") - graphic = fig[0, 0].add_scatter(data=data, **kwargs) + graphic = fig[0, 0].add_scatter(data=data, uniform_size=False, **kwargs) assert isinstance(graphic.sizes, VertexPointSizes) assert isinstance(graphic._sizes, VertexPointSizes) diff --git a/tests/test_replace_buffer.py b/tests/test_replace_buffer.py new file mode 100644 index 000000000..a9d0ffe41 --- /dev/null +++ b/tests/test_replace_buffer.py @@ -0,0 +1,155 @@ +import gc +import weakref + +import pytest +import numpy as np +from itertools import product + +import fastplotlib as fpl +from .utils_textures import MAX_TEXTURE_SIZE, check_texture_array, check_image_graphic + +# These are only de-referencing tests for positions graphics, and ImageGraphic +# they do not test that VRAM gets free, for now this can only be checked manually +# with the tests in examples/misc/buffer_replace_gc.py + + +@pytest.mark.parametrize("graphic_type", ["line", "scatter"]) +@pytest.mark.parametrize("new_buffer_size", [50, 150]) +def test_replace_positions_buffer(graphic_type, new_buffer_size): + fig = fpl.Figure() + + # create some data with an initial shape + orig_datapoints = 100 + + xs = np.linspace(0, 2 * np.pi, orig_datapoints) + ys = np.sin(xs) + zs = np.cos(xs) + + data = np.column_stack([xs, ys, zs]) + + # add add_line or add_scatter method + adder = getattr(fig[0, 0], f"add_{graphic_type}") + + if graphic_type == "scatter": + kwargs = { + "markers": np.random.choice(list("osD+x^v<>*"), size=orig_datapoints), + "uniform_marker": False, + "sizes": np.abs(ys), + "uniform_size": False, + # TODO: skipping edge_colors for now since that causes a WGPU bind group error that we will figure out later + # anyways I think changing buffer sizes in combination with per-vertex edge colors is a literal edge-case + "point_rotations": zs * 180, + "point_rotation_mode": "vertex", + } + else: + kwargs = dict() + + # add a line or scatter graphic + graphic = adder(data=data, colors=np.random.rand(orig_datapoints, 4), **kwargs) + + fig.show() + + # weakrefs to the original buffers + # these should raise a ReferenceError when the corresponding feature is replaced with data of a different shape + orig_data_buffer = weakref.proxy(graphic.data._fpl_buffer) + orig_colors_buffer = weakref.proxy(graphic.colors._fpl_buffer) + + buffers = [orig_data_buffer, orig_colors_buffer] + + # extra buffers for the scatters + if graphic_type == "scatter": + for attr in ["markers", "sizes", "point_rotations"]: + buffers.append(weakref.proxy(getattr(graphic, attr)._fpl_buffer)) + + # create some new data that requires a different buffer shape + xs = np.linspace(0, 15 * np.pi, new_buffer_size) + ys = np.sin(xs) + zs = np.cos(xs) + + new_data = np.column_stack([xs, ys, zs]) + + # set data that requires a larger buffer and check that old buffer is no longer referenced + graphic.data = new_data + graphic.colors = np.random.rand(new_buffer_size, 4) + + if graphic_type == "scatter": + # changes values so that new larger buffers must be allocated + graphic.markers = np.random.choice(list("osD+x^v<>*"), size=new_buffer_size) + graphic.sizes = np.abs(zs) + graphic.point_rotations = ys * 180 + + # make sure old original buffers are de-referenced + for i in range(len(buffers)): + with pytest.raises(ReferenceError) as fail: + buffers[i] + pytest.fail( + f"GC failed for buffer: {buffers[i]}, " + f"with referrers: {gc.get_referrers(buffers[i].__repr__.__self__)}" + ) + + +# test all combination of dims that require TextureArrays of shapes 1x1, 1x2, 1x3, 2x3, 3x3 etc. +@pytest.mark.parametrize( + "new_buffer_size", list(product(*[[(500, 1), (1200, 2), (2200, 3)]] * 2)) +) +def test_replace_image_buffer(new_buffer_size): + # make an image with some starting shape + orig_size = (1_500, 1_500) + + data = np.random.rand(*orig_size) + + fig = fpl.Figure() + image = fig[0, 0].add_image(data) + + # the original Texture buffers that represent the individual image tiles + orig_buffers = [ + weakref.proxy(image.data.buffer.ravel()[i]) + for i in range(image.data.buffer.size) + ] + orig_shape = image.data.buffer.shape + + fig.show() + + # dimensions for a new image + new_dims = [v[0] for v in new_buffer_size] + + # the number of tiles required in each dim/shape of the TextureArray + new_shape = tuple(v[1] for v in new_buffer_size) + + # make the new data and set the image + new_data = np.random.rand(*new_dims) + image.data = new_data + + # test that old Texture buffers are de-referenced + for i in range(len(orig_buffers)): + with pytest.raises(ReferenceError) as fail: + orig_buffers[i] + pytest.fail( + f"GC failed for buffer: {orig_buffers[i]}, of shape: {orig_shape}" + f"with referrers: {gc.get_referrers(orig_buffers[i].__repr__.__self__)}" + ) + + # check new texture array + check_texture_array( + data=new_data, + ta=image.data, + buffer_size=np.prod(new_shape), + buffer_shape=new_shape, + row_indices_size=new_shape[0], + col_indices_size=new_shape[1], + row_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (new_data.shape[0] - 1) // MAX_TEXTURE_SIZE) + ] + ), + col_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (new_data.shape[1] - 1) // MAX_TEXTURE_SIZE) + ] + ), + ) + + # check that new image tiles are arranged correctly + check_image_graphic(image.data, image) diff --git a/tests/test_scatter_graphic.py b/tests/test_scatter_graphic.py index a61681f24..930d8c495 100644 --- a/tests/test_scatter_graphic.py +++ b/tests/test_scatter_graphic.py @@ -133,7 +133,7 @@ def test_edge_colors(edge_colors): npt.assert_almost_equal(scatter.edge_colors.value, MULTI_COLORS_TRUTH) assert ( - scatter.edge_colors.buffer is scatter.world_object.geometry.edge_colors + scatter.edge_colors._fpl_buffer is scatter.world_object.geometry.edge_colors ) # test changes, don't need to test extensively here since it's tested in the main VertexColors test diff --git a/tests/test_texture_array.py b/tests/test_texture_array.py index 6220f2fe5..01abb9a97 100644 --- a/tests/test_texture_array.py +++ b/tests/test_texture_array.py @@ -2,14 +2,9 @@ from numpy import testing as npt import pytest -import pygfx - import fastplotlib as fpl from fastplotlib.graphics.features import TextureArray -from fastplotlib.graphics.image import _ImageTile - - -MAX_TEXTURE_SIZE = 1024 +from .utils_textures import MAX_TEXTURE_SIZE, check_texture_array, check_image_graphic def make_data(n_rows: int, n_cols: int) -> np.ndarray: @@ -25,50 +20,6 @@ def make_data(n_rows: int, n_cols: int) -> np.ndarray: return np.vstack([sine * i for i in range(n_rows)]).astype(np.float32) -def check_texture_array( - data: np.ndarray, - ta: TextureArray, - buffer_size: int, - buffer_shape: tuple[int, int], - row_indices_size: int, - col_indices_size: int, - row_indices_values: np.ndarray, - col_indices_values: np.ndarray, -): - - npt.assert_almost_equal(ta.value, data) - - assert ta.buffer.size == buffer_size - assert ta.buffer.shape == buffer_shape - - assert all([isinstance(texture, pygfx.Texture) for texture in ta.buffer.ravel()]) - - assert ta.row_indices.size == row_indices_size - assert ta.col_indices.size == col_indices_size - npt.assert_array_equal(ta.row_indices, row_indices_values) - npt.assert_array_equal(ta.col_indices, col_indices_values) - - # make sure chunking is correct - for texture, chunk_index, data_slice in ta: - assert ta.buffer[chunk_index] is texture - chunk_row, chunk_col = chunk_index - - data_row_start_index = chunk_row * MAX_TEXTURE_SIZE - data_col_start_index = chunk_col * MAX_TEXTURE_SIZE - - data_row_stop_index = min( - data.shape[0], data_row_start_index + MAX_TEXTURE_SIZE - ) - data_col_stop_index = min( - data.shape[1], data_col_start_index + MAX_TEXTURE_SIZE - ) - - row_slice = slice(data_row_start_index, data_row_stop_index) - col_slice = slice(data_col_start_index, data_col_stop_index) - - assert data_slice == (row_slice, col_slice) - - def check_set_slice(data, ta, row_slice, col_slice): ta[row_slice, col_slice] = 1 npt.assert_almost_equal(ta[row_slice, col_slice], 1) @@ -85,17 +36,6 @@ def make_image_graphic(data) -> fpl.ImageGraphic: return fig[0, 0].add_image(data) -def check_image_graphic(texture_array, graphic): - # make sure each ImageTile has the right texture - for (texture, chunk_index, data_slice), img in zip( - texture_array, graphic.world_object.children - ): - assert isinstance(img, _ImageTile) - assert img.geometry.grid is texture - assert img.world.x == data_slice[1].start - assert img.world.y == data_slice[0].start - - @pytest.mark.parametrize("test_graphic", [False, True]) def test_small_texture(test_graphic): # tests TextureArray with dims that requires only 1 texture @@ -162,15 +102,27 @@ def test_wide(test_graphic): else: ta = TextureArray(data) + ta_shape = (2, 3) + check_texture_array( data, ta=ta, - buffer_size=6, - buffer_shape=(2, 3), - row_indices_size=2, - col_indices_size=3, - row_indices_values=np.array([0, MAX_TEXTURE_SIZE]), - col_indices_values=np.array([0, MAX_TEXTURE_SIZE, 2 * MAX_TEXTURE_SIZE]), + buffer_size=np.prod(ta_shape), + buffer_shape=ta_shape, + row_indices_size=ta_shape[0], + col_indices_size=ta_shape[1], + row_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (data.shape[0] - 1) // MAX_TEXTURE_SIZE) + ] + ), + col_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (data.shape[1] - 1) // MAX_TEXTURE_SIZE) + ] + ), ) if test_graphic: @@ -189,15 +141,27 @@ def test_tall(test_graphic): else: ta = TextureArray(data) + ta_shape = (3, 2) + check_texture_array( data, ta=ta, - buffer_size=6, - buffer_shape=(3, 2), - row_indices_size=3, - col_indices_size=2, - row_indices_values=np.array([0, MAX_TEXTURE_SIZE, 2 * MAX_TEXTURE_SIZE]), - col_indices_values=np.array([0, MAX_TEXTURE_SIZE]), + buffer_size=np.prod(ta_shape), + buffer_shape=ta_shape, + row_indices_size=ta_shape[0], + col_indices_size=ta_shape[1], + row_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (data.shape[0] - 1) // MAX_TEXTURE_SIZE) + ] + ), + col_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (data.shape[1] - 1) // MAX_TEXTURE_SIZE) + ] + ), ) if test_graphic: @@ -216,15 +180,27 @@ def test_square(test_graphic): else: ta = TextureArray(data) + ta_shape = (3, 3) + check_texture_array( data, ta=ta, - buffer_size=9, - buffer_shape=(3, 3), - row_indices_size=3, - col_indices_size=3, - row_indices_values=np.array([0, MAX_TEXTURE_SIZE, 2 * MAX_TEXTURE_SIZE]), - col_indices_values=np.array([0, MAX_TEXTURE_SIZE, 2 * MAX_TEXTURE_SIZE]), + buffer_size=np.prod(ta_shape), + buffer_shape=ta_shape, + row_indices_size=ta_shape[0], + col_indices_size=ta_shape[1], + row_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (data.shape[0] - 1) // MAX_TEXTURE_SIZE) + ] + ), + col_indices_values=np.array( + [ + i * MAX_TEXTURE_SIZE + for i in range(0, 1 + (data.shape[1] - 1) // MAX_TEXTURE_SIZE) + ] + ), ) if test_graphic: diff --git a/tests/utils_textures.py b/tests/utils_textures.py new file mode 100644 index 000000000..f40a7371c --- /dev/null +++ b/tests/utils_textures.py @@ -0,0 +1,64 @@ +import numpy as np +import pygfx +from numpy import testing as npt + +from fastplotlib.graphics.features import TextureArray +from fastplotlib.graphics.image import _ImageTile + + +MAX_TEXTURE_SIZE = 1024 + + +def check_texture_array( + data: np.ndarray, + ta: TextureArray, + buffer_size: int, + buffer_shape: tuple[int, int], + row_indices_size: int, + col_indices_size: int, + row_indices_values: np.ndarray, + col_indices_values: np.ndarray, +): + + npt.assert_almost_equal(ta.value, data) + + assert ta.buffer.size == buffer_size + assert ta.buffer.shape == buffer_shape + + assert all([isinstance(texture, pygfx.Texture) for texture in ta.buffer.ravel()]) + + assert ta.row_indices.size == row_indices_size + assert ta.col_indices.size == col_indices_size + npt.assert_array_equal(ta.row_indices, row_indices_values) + npt.assert_array_equal(ta.col_indices, col_indices_values) + + # make sure chunking is correct + for texture, chunk_index, data_slice in ta: + assert ta.buffer[chunk_index] is texture + chunk_row, chunk_col = chunk_index + + data_row_start_index = chunk_row * MAX_TEXTURE_SIZE + data_col_start_index = chunk_col * MAX_TEXTURE_SIZE + + data_row_stop_index = min( + data.shape[0], data_row_start_index + MAX_TEXTURE_SIZE + ) + data_col_stop_index = min( + data.shape[1], data_col_start_index + MAX_TEXTURE_SIZE + ) + + row_slice = slice(data_row_start_index, data_row_stop_index) + col_slice = slice(data_col_start_index, data_col_stop_index) + + assert data_slice == (row_slice, col_slice) + + +def check_image_graphic(texture_array, graphic): + # make sure each ImageTile has the right texture + for (texture, chunk_index, data_slice), img in zip( + texture_array, graphic.world_object.children + ): + assert isinstance(img, _ImageTile) + assert img.geometry.grid is texture + assert img.world.x == data_slice[1].start + assert img.world.y == data_slice[0].start