Skip to content

Commit 5fa80a7

Browse files
committed
basic heatmap implementation with row or column selection
1 parent f0b5a36 commit 5fa80a7

3 files changed

Lines changed: 125 additions & 1 deletion

File tree

fastplotlib/graphics/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from .line import Line
33
from .scatter import Scatter
44
from .image import Image
5+
from .heatmap import Heatmap
56

6-
__all__ = ["Image", "Scatter", "Line", "Histogram"]
7+
__all__ = ["Image", "Scatter", "Line", "Histogram", "Heatmap"]

fastplotlib/graphics/heatmap.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import numpy as np
2+
import pygfx
3+
from typing import *
4+
from .image import Image
5+
6+
from ..utils import quick_min_max, get_cmap_texture
7+
8+
9+
default_selection_options = {
10+
"mode": "single",
11+
"orientation": "row",
12+
"callbacks": None,
13+
}
14+
15+
16+
class SelectionOptions:
17+
def __init__(
18+
self,
19+
event: str = "double_click", # click or double_click
20+
event_button: Union[int, str] = 1,
21+
mode: str = "single",
22+
axis: str = "row",
23+
color: Tuple[int, int, int, int] = None,
24+
callbacks: List[callable] = None,
25+
):
26+
self.event = event
27+
self.event_button = event_button
28+
self.mode = mode
29+
self.axis = axis
30+
31+
if color is not None:
32+
self.color = color
33+
34+
else:
35+
self.color = (1, 1, 1, 0.4)
36+
37+
if callbacks is None:
38+
self.callbacks = list()
39+
else:
40+
self.callbacks = callbacks
41+
42+
43+
class Heatmap(Image):
44+
def __init__(
45+
self,
46+
data: np.ndarray,
47+
vmin: int = None,
48+
vmax: int = None,
49+
cmap: str = 'plasma',
50+
selection_options: dict = None,
51+
*args,
52+
**kwargs
53+
):
54+
super().__init__(data, vmin, vmax, cmap)
55+
56+
self.selection_options = SelectionOptions()
57+
self.selection_options.callbacks = list()
58+
59+
if selection_options is not None:
60+
for k in selection_options.keys():
61+
setattr(self.selection_options, k, selection_options[k])
62+
63+
self.world_object.add_event_handler(
64+
self.handle_selection_event, self.selection_options.event
65+
)
66+
67+
self._highlights = list()
68+
69+
def handle_selection_event(self, event):
70+
if not event.button == self.selection_options.event_button:
71+
return
72+
73+
if self.selection_options.mode == "single":
74+
for h in self._highlights:
75+
self.remove_highlight(h)
76+
77+
rval = self.add_highlight(event)
78+
79+
for f in self.selection_options.callbacks:
80+
f(rval)
81+
82+
def remove_highlight(self, h):
83+
self._highlights.remove(h)
84+
self.world_object.remove(h)
85+
86+
def add_highlight(self, event):
87+
index = event.pick_info["index"]
88+
89+
if self.selection_options.axis == "row":
90+
index = index[1]
91+
w = self.data.shape[1]
92+
h = 1
93+
94+
pos = ((self.data.shape[1] / 2) - 0.5, index, 1)
95+
rval = self.data[index, :] # returned to selection.callbacks functions
96+
97+
elif self.selection_options.axis == "column":
98+
index = index[0]
99+
w = 1
100+
h = self.data.shape[0]
101+
102+
pos = (index, (self.data.shape[0] / 2) - 0.5, 1)
103+
rval = self.data[:, index]
104+
105+
geometry = pygfx.plane_geometry(
106+
width=w,
107+
height=h
108+
)
109+
110+
material = pygfx.MeshBasicMaterial(color=self.selection_options.color)
111+
112+
self.selection_graphic = pygfx.Mesh(geometry, material)
113+
self.selection_graphic.position.set(*pos)
114+
115+
self.world_object.add(self.selection_graphic)
116+
self._highlights.append(self.selection_graphic)
117+
118+
return rval

fastplotlib/subplot.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pygfx
22
from pygfx import Scene, OrthographicCamera, PerspectiveCamera, PanZoomController, Viewport, AxesHelper, GridHelper
33
from .defaults import camera_types, controller_types
4+
from .graphics import Heatmap
45
from typing import *
56
from wgpu.gui.auto import WgpuCanvas
67
from warnings import warn
@@ -96,6 +97,7 @@ def add_animations(self, funcs: List[callable]):
9697

9798
def add_graphic(self, graphic, center: bool = True):
9899
graphic_names = list()
100+
99101
for g in self._graphics:
100102
graphic_names.append(g.name)
101103

@@ -105,6 +107,9 @@ def add_graphic(self, graphic, center: bool = True):
105107
self._graphics.append(graphic)
106108
self.scene.add(graphic.world_object)
107109

110+
if isinstance(graphic, Heatmap):
111+
self.controller.scale.y = copysign(self.controller.scale.y, -1)
112+
108113
if center:
109114
self.center_graphic(graphic)
110115

0 commit comments

Comments
 (0)