diff --git a/examples/garbage_collection.ipynb b/examples/garbage_collection.ipynb new file mode 100644 index 000000000..85744e6e0 --- /dev/null +++ b/examples/garbage_collection.ipynb @@ -0,0 +1,419 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "1ef0578e-09e1-45ff-bd34-84472db3885e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from fastplotlib import Plot\n", + "import numpy as np\n", + "import sys\n", + "\n", + "import weakref\n", + "import gc\n", + "import os, psutil\n", + "process = psutil.Process(os.getpid())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1bb6bc6f-7786-4d23-9eb1-e30bbc66c798", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def print_process_ram_mb():\n", + " print(process.memory_info().rss / 1024 / 1024)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b376676e-a7fe-4424-9ba6-fde5be03b649", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "print_process_ram_mb()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b23ba640-88ec-40d9-b53c-c8cbb3e39b0b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "plot = Plot()\n", + "plot.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17f46f21-b29d-4dd3-9496-989bbb240f50", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "print_process_ram_mb()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27627cd4-c363-4eab-a121-f6c8abbbe5ae", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "graphic = \"scatter\"" + ] + }, + { + "cell_type": "markdown", + "id": "d9c10edc-169a-4dd2-bd5b-8a1b67baf3a9", + "metadata": {}, + "source": [ + "### Run the following cells repeatedly to add and remove the graphic" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e26d392f-6afd-4e89-a685-d618065d3caf", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "if graphic == \"line\":\n", + " a = np.random.rand(10_000_000)\n", + " g = plot.add_line(a)\n", + " \n", + "elif graphic == \"heatmap\":\n", + " a = np.random.rand(20_000, 20_000)\n", + " g = plot.add_heatmap(a)\n", + "\n", + "elif graphic == \"line_collection\":\n", + " a = np.random.rand(500, 50_000)\n", + " g = plot.add_line_collection(a)\n", + " \n", + "elif graphic == \"image\":\n", + " a = np.random.rand(7_000, 7_000)\n", + " g = plot.add_image(a)\n", + "\n", + "elif graphic == \"scatter\":\n", + " a = np.random.rand(10_000_000, 3)\n", + " g = plot.add_scatter(a)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31e3027a-56cf-4f7b-ba78-aed4f78eef47", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "print_process_ram_mb()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6518795c-98cf-405d-94ab-786ac3b2e1d6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "g" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee2481c9-82e3-4043-85fd-21a0cdf21187", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "plot.auto_scale()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53170858-ae72-4451-8647-7d5b1f9da75e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "print(process.memory_info().rss / 1024 / 1024)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4e0f73b-c58a-40e7-acf5-07a1f70d2821", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "plot.delete_graphic(plot.graphics[0])" + ] + }, + { + "cell_type": "markdown", + "id": "47baa487-c66b-4c40-aa11-d819902870e3", + "metadata": {}, + "source": [ + "If there is no serious system memory leak, this value shouldn't really increase after repeated cycles" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56c64498-229e-48b7-9fb1-f7c327fff2ae", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "print(process.memory_info().rss / 1024 / 1024)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd6a26c1-ea81-469d-ae7a-95839b1f9d5a", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from wgpu.gui.auto import WgpuCanvas, run\n", + "import pygfx as gfx\n", + "import subprocess\n", + "\n", + "canvas = WgpuCanvas()\n", + "renderer = gfx.WgpuRenderer(canvas)\n", + "scene = gfx.Scene()\n", + "camera = gfx.OrthographicCamera(5000, 5000)\n", + "camera.position.x = 2048\n", + "camera.position.y = 2048\n", + "\n", + "\n", + "def make_image():\n", + " data = np.random.rand(4096, 4096).astype(np.float32)\n", + "\n", + " return gfx.Image(\n", + " gfx.Geometry(grid=gfx.Texture(data, dim=2)),\n", + " gfx.ImageBasicMaterial(clim=(0, 1)),\n", + " )\n", + "\n", + "\n", + "class Graphic:\n", + " def __init__(self):\n", + " data = np.random.rand(4096, 4096).astype(np.float32)\n", + " self.wo = gfx.Image(\n", + " gfx.Geometry(grid=gfx.Texture(data, dim=2)),\n", + " gfx.ImageBasicMaterial(clim=(0, 1)),\n", + " )\n", + "\n", + "\n", + "def draw():\n", + " renderer.render(scene, camera)\n", + " canvas.request_draw()\n", + "\n", + "\n", + "def print_nvidia(msg):\n", + " print(msg)\n", + " print(\n", + " subprocess.check_output([\"nvidia-smi\", \"--format=csv\", \"--query-gpu=memory.used\"]).decode().split(\"\\n\")[1]\n", + " )\n", + " print()\n", + "\n", + "\n", + "def add_img(*args):\n", + " print_nvidia(\"Before creating image\")\n", + " img = make_image()\n", + " print_nvidia(\"After creating image\")\n", + " scene.add(img)\n", + " img.add_event_handler(remove_img, \"click\")\n", + " draw()\n", + " print_nvidia(\"After add image to scene\")\n", + "\n", + "\n", + "def remove_img(*args):\n", + " img = scene.children[0]\n", + " scene.remove(img)\n", + " draw()\n", + " print_nvidia(\"After remove image from scene\")\n", + " del img\n", + " draw()\n", + " print_nvidia(\"After del image\")\n", + "\n", + "\n", + "renderer.add_event_handler(add_img, \"double_click\")\n", + "canvas" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2599f430-8b00-4490-9e11-774897be6e77", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from wgpu.gui.auto import WgpuCanvas, run\n", + "import pygfx as gfx\n", + "import subprocess\n", + "\n", + "canvas = WgpuCanvas()\n", + "renderer = gfx.WgpuRenderer(canvas)\n", + "scene = gfx.Scene()\n", + "camera = gfx.OrthographicCamera(5000, 5000)\n", + "camera.position.x = 2048\n", + "camera.position.y = 2048\n", + "\n", + "\n", + "def make_image():\n", + " data = np.random.rand(4096, 4096).astype(np.float32)\n", + "\n", + " return gfx.Image(\n", + " gfx.Geometry(grid=gfx.Texture(data, dim=2)),\n", + " gfx.ImageBasicMaterial(clim=(0, 1)),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ec10f26-6544-4ad3-80c1-aa34617dc826", + "metadata": {}, + "outputs": [], + "source": [ + "import weakref" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "acc819a3-cd50-4fdd-a0b5-c442d80847e2", + "metadata": {}, + "outputs": [], + "source": [ + "img = make_image()\n", + "img_ref = weakref.ref(img)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f89da335-3372-486b-b773-9f103d6a9bbd", + "metadata": {}, + "outputs": [], + "source": [ + "img_ref()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c22904ad-d674-43e6-83bb-7a2f7b277c06", + "metadata": {}, + "outputs": [], + "source": [ + "del img" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "573566d7-eb91-4690-958c-d00dd495b3e4", + "metadata": {}, + "outputs": [], + "source": [ + "import gc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aaef3e89-2bfd-43af-9b8f-824a3f89b85f", + "metadata": {}, + "outputs": [], + "source": [ + "img_ref()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3380f35e-fcc9-43f6-80d2-7e9348cd13b4", + "metadata": {}, + "outputs": [], + "source": [ + "draw()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a27bf7c7-f3ef-4ae8-8ecf-31507f8c0449", + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " subprocess.check_output([\"nvidia-smi\", \"--format=csv\", \"--query-gpu=memory.used\"]).decode().split(\"\\n\")[1]\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4bf2711-8a83-4d9c-a4f7-f50de7ae1715", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/fastplotlib/graphics/_base.py b/fastplotlib/graphics/_base.py index 255a2fec7..309b68d9f 100644 --- a/fastplotlib/graphics/_base.py +++ b/fastplotlib/graphics/_base.py @@ -1,4 +1,5 @@ from typing import * +import weakref from warnings import warn import numpy as np @@ -14,6 +15,11 @@ from dataclasses import dataclass +# dict that holds all world objects for a given python kernel/session +# Graphic objects only use proxies to WorldObjects +WORLD_OBJECTS: Dict[str, WorldObject] = dict() #: {hex id str: WorldObject} + + PYGFX_EVENTS = [ "key_down", "key_up", @@ -44,7 +50,8 @@ def __init_subclass__(cls, **kwargs): class Graphic(BaseGraphic): def __init__( - self, name: str = None): + self, name: str = None + ): """ Parameters @@ -58,10 +65,16 @@ def __init__( self.registered_callbacks = dict() self.present = PresentFeature(parent=self) + # store hex id str of Graphic instance mem location + self.loc: str = hex(id(self)) + @property def world_object(self) -> WorldObject: - """Associated pygfx WorldObject.""" - return self._world_object + """Associated pygfx WorldObject. Always returns a proxy, real object cannot be accessed directly.""" + return weakref.proxy(WORLD_OBJECTS[hex(id(self))]) + + def _set_world_object(self, wo: WorldObject): + WORLD_OBJECTS[hex(id(self))] = wo @property def position(self) -> Vector3: @@ -75,7 +88,7 @@ def visible(self) -> bool: return self.world_object.visible @visible.setter - def visible(self, v) -> bool: + def visible(self, v: bool): """Access or change the visibility.""" self.world_object.visible = v @@ -100,6 +113,9 @@ def __repr__(self): else: return rval + def __del__(self): + del WORLD_OBJECTS[self.loc] + class Interaction(ABC): """Mixin class that makes graphics interactive""" @@ -216,8 +232,13 @@ def _event_handler(self, event): # for now we only have line collections so this works else: - for i, item in enumerate(self._graphics): - if item.world_object is event.pick_info["world_object"]: + # get index of world object that made this event + for i, item in enumerate(self.graphics): + wo = WORLD_OBJECTS[item.loc] + # we only store hex id of worldobject, but worldobject `pick_info` is always the real object + # so if pygfx worldobject triggers an event by itself, such as `click`, etc., this will be + # the real world object in the pick_info and not the proxy + if wo is event.pick_info["world_object"]: indices = i target_info.target._set_feature(feature=target_info.feature, new_data=target_info.new_data, indices=indices) else: @@ -264,22 +285,21 @@ class PreviouslyModifiedData: indices: Any +COLLECTION_GRAPHICS: dict[str, Graphic] = dict() + + class GraphicCollection(Graphic): """Graphic Collection base class""" def __init__(self, name: str = None): super(GraphicCollection, self).__init__(name) - self._graphics: List[Graphic] = list() - - @property - def world_object(self) -> Group: - """Returns the underling pygfx WorldObject.""" - return self._world_object + self._graphics: List[str] = list() @property def graphics(self) -> Tuple[Graphic]: - """returns the Graphics within this collection""" - return tuple(self._graphics) + """The Graphics within this collection. Always returns a proxy to the Graphics.""" + proxies = [weakref.proxy(COLLECTION_GRAPHICS[loc]) for loc in self._graphics] + return tuple(proxies) def add_graphic(self, graphic: Graphic, reset_index: True): """Add a graphic to the collection""" @@ -289,17 +309,31 @@ def add_graphic(self, graphic: Graphic, reset_index: True): f"You can only add {self.child_type} to a {self.__class__.__name__}, " f"you are trying to add a {graphic.__class__.__name__}." ) - self._graphics.append(graphic) + + loc = hex(id(graphic)) + COLLECTION_GRAPHICS[loc] = graphic + + self._graphics.append(loc) if reset_index: self._reset_index() self.world_object.add(graphic.world_object) def remove_graphic(self, graphic: Graphic, reset_index: True): """Remove a graphic from the collection""" - self._graphics.remove(graphic) + self._graphics.remove(graphic.loc) + if reset_index: self._reset_index() - self.world_object.remove(graphic) + + self.world_object.remove(graphic.world_object) + + def __del__(self): + self.world_object.clear() + + for loc in self._graphics: + del COLLECTION_GRAPHICS[loc] + + super().__del__() def _reset_index(self): for new_index, graphic in enumerate(self._graphics): @@ -312,7 +346,7 @@ def __getitem__(self, key): if isinstance(key, slice): key = cleanup_slice(key, upper_bound=len(self)) selection_indices = range(key.start, key.stop, key.step) - selection = self._graphics[key] + selection = self.graphics[key] # fancy-ish indexing elif isinstance(key, (tuple, list, np.ndarray)): @@ -324,7 +358,7 @@ def __getitem__(self, key): selection = list() for ix in key: - selection.append(self._graphics[ix]) + selection.append(self.graphics[ix]) selection_indices = key else: @@ -365,7 +399,7 @@ def __init__( selection_indices: Union[list, range] the corresponding indices from the parent GraphicCollection that were selected """ - self._parent = parent + self._parent = weakref.proxy(parent) self._selection = selection self._selection_indices = selection_indices diff --git a/fastplotlib/graphics/features/_base.py b/fastplotlib/graphics/features/_base.py index 80029180e..da6a177a0 100644 --- a/fastplotlib/graphics/features/_base.py +++ b/fastplotlib/graphics/features/_base.py @@ -2,6 +2,7 @@ from inspect import getfullargspec from warnings import warn from typing import * +import weakref import numpy as np from pygfx import Buffer, Texture @@ -71,7 +72,7 @@ def __init__(self, parent, data: Any, collection_index: int = None): if part of a collection, index of this graphic within the collection """ - self._parent = parent + self._parent = weakref.proxy(parent) self._data = to_gpu_supported_dtype(data) diff --git a/fastplotlib/graphics/image.py b/fastplotlib/graphics/image.py index 83cae3de8..cb4cf1587 100644 --- a/fastplotlib/graphics/image.py +++ b/fastplotlib/graphics/image.py @@ -100,11 +100,13 @@ def __init__( self.cmap = ImageCmapFeature(self, cmap) material = pygfx.ImageBasicMaterial(clim=(vmin, vmax), map=self.cmap()) - self._world_object: pygfx.Image = pygfx.Image( + world_object = pygfx.Image( geometry, material ) + self._set_world_object(world_object) + self.data = ImageDataFeature(self, data) # TODO: we need to organize and do this better if isolated_buffer: @@ -272,7 +274,8 @@ def __init__( start_ixs = [list(map(lambda c: c * chunk_size, chunk)) for chunk in chunks] stop_ixs = [list(map(lambda c: c + chunk_size, chunk)) for chunk in start_ixs] - self._world_object = pygfx.Group() + world_object = pygfx.Group() + self._set_world_object(world_object) if (vmin is None) or (vmax is None): vmin, vmax = quick_min_max(data) diff --git a/fastplotlib/graphics/line.py b/fastplotlib/graphics/line.py index 926f5729c..0b1e579bc 100644 --- a/fastplotlib/graphics/line.py +++ b/fastplotlib/graphics/line.py @@ -85,12 +85,14 @@ def __init__( self.thickness = ThicknessFeature(self, thickness) - self._world_object: pygfx.Line = pygfx.Line( + world_object: pygfx.Line = pygfx.Line( # self.data.feature_data because data is a Buffer geometry=pygfx.Geometry(positions=self.data(), colors=self.colors()), material=material(thickness=self.thickness(), vertex_colors=True) ) + self._set_world_object(world_object) + if z_position is not None: self.world_object.position.z = z_position diff --git a/fastplotlib/graphics/line_collection.py b/fastplotlib/graphics/line_collection.py index 07fc9cad7..3bff6f7c5 100644 --- a/fastplotlib/graphics/line_collection.py +++ b/fastplotlib/graphics/line_collection.py @@ -157,7 +157,7 @@ def __init__( "or must be a str of tuple/list with the same length as the data" ) - self._world_object = pygfx.Group() + self._set_world_object(pygfx.Group()) for i, d in enumerate(data): if isinstance(z_position, list): @@ -343,6 +343,6 @@ def __init__( ) axis_zero = 0 - for i, line in enumerate(self._graphics): + for i, line in enumerate(self.graphics): getattr(line.position, f"set_{separation_axis}")(axis_zero) axis_zero = axis_zero + line.data()[:, axes[separation_axis]].max() + separation diff --git a/fastplotlib/graphics/line_slider.py b/fastplotlib/graphics/line_slider.py index 8755af51a..f19db9cda 100644 --- a/fastplotlib/graphics/line_slider.py +++ b/fastplotlib/graphics/line_slider.py @@ -74,7 +74,7 @@ def __init__( else: material = pygfx.LineMaterial - colors_inner = np.repeat([Color("w")], 2, axis=0).astype(np.float32) + colors_inner = np.repeat([Color(color)], 2, axis=0).astype(np.float32) colors_outer = np.repeat([Color([1., 1., 1., 0.25])], 2, axis=0).astype(np.float32) line_inner = pygfx.Line( @@ -88,17 +88,19 @@ def __init__( material=material(thickness=thickness + 4, vertex_colors=True) ) - self._world_object = pygfx.Group() + world_object = pygfx.Group() - self._world_object.add(line_outer) - self._world_object.add(line_inner) + world_object.add(line_outer) + world_object.add(line_inner) + + self._set_world_object(world_object) self.position.x = x_pos self.slider = slider self.slider.observe(self.set_position, "value") - self.name = name + super().__init__(name=name) def set_position(self, change): self.position.x = change["new"] diff --git a/fastplotlib/graphics/scatter.py b/fastplotlib/graphics/scatter.py index 016d1cac9..b53985de0 100644 --- a/fastplotlib/graphics/scatter.py +++ b/fastplotlib/graphics/scatter.py @@ -72,9 +72,11 @@ def __init__( super(ScatterGraphic, self).__init__(*args, **kwargs) - self._world_object: pygfx.Points = pygfx.Points( + world_object = pygfx.Points( pygfx.Geometry(positions=self.data(), sizes=sizes, colors=self.colors()), material=pygfx.PointsMaterial(vertex_colors=True, vertex_sizes=True) ) + self._set_world_object(world_object) + self.world_object.position.z = z_position diff --git a/fastplotlib/graphics/text.py b/fastplotlib/graphics/text.py index 665c53606..8225bb300 100644 --- a/fastplotlib/graphics/text.py +++ b/fastplotlib/graphics/text.py @@ -38,11 +38,13 @@ def __init__( """ super(TextGraphic, self).__init__(name=name) - self._world_object = pygfx.Text( + world_object = pygfx.Text( pygfx.TextGeometry(text=text, font_size=size, screen_space=False), pygfx.TextMaterial(color=face_color, outline_color=outline_color, outline_thickness=outline_thickness) ) + self._set_world_object(world_object) + self.world_object.position.set(*position) self.name = None diff --git a/fastplotlib/layouts/_base.py b/fastplotlib/layouts/_base.py index ce35135c7..c98c010ea 100644 --- a/fastplotlib/layouts/_base.py +++ b/fastplotlib/layouts/_base.py @@ -1,11 +1,21 @@ +from warnings import warn +from typing import * +import weakref + import numpy as np + from pygfx import Scene, OrthographicCamera, PerspectiveCamera, PanZoomController, OrbitController, \ Viewport, WgpuRenderer from wgpu.gui.auto import WgpuCanvas -from warnings import warn -from ..graphics._base import Graphic + +from ..graphics._base import Graphic, GraphicCollection from ..graphics.line_slider import LineSlider -from typing import * + + +# dict to store Graphic instances +# this is the only place where the real references to Graphics are stored in a Python session +# {hex id str: Graphic} +GRAPHICS: Dict[str, Graphic] = dict() class PlotArea: @@ -74,7 +84,9 @@ def __init__( self.renderer.add_event_handler(self.set_viewport_rect, "resize") - self._graphics: List[Graphic] = list() + # list of hex id strings for all graphics managed by this PlotArea + # the real Graphic instances are stored in the ``GRAPHICS`` dict + self._graphics: List[str] = list() # hacky workaround for now to exclude from bbox calculations self._sliders: List[LineSlider] = list() @@ -129,8 +141,13 @@ def controller(self) -> Union[PanZoomController, OrbitController]: @property def graphics(self) -> Tuple[Graphic]: - """returns the Graphics in the plot area""" - return tuple(self._graphics) + """Graphics in the plot area. Always returns a proxy to the Graphic instances.""" + proxies = list() + for loc in self._graphics: + p = weakref.proxy(GRAPHICS[loc]) + proxies.append(p) + + return tuple(proxies) def get_rect(self) -> Tuple[float, float, float, float]: """allows setting the region occupied by the viewport w.r.t. the parent""" @@ -154,7 +171,8 @@ def add_graphic(self, graphic: Graphic, center: bool = True): Parameters ---------- graphic: Graphic or GraphicCollection - Add a Graphic or a GraphicCollection to the plot area + Add a Graphic or a GraphicCollection to the plot area. + Note: this must be a real Graphic instance and not a proxy center: bool, default True Center the camera on the newly added Graphic @@ -170,10 +188,14 @@ def add_graphic(self, graphic: Graphic, center: bool = True): # TODO: need to refactor LineSlider entirely if isinstance(graphic, LineSlider): - self._sliders.append(graphic) + self._sliders.append(graphic) # don't manage garbage collection of LineSliders for now else: - self._graphics.append(graphic) + # store in GRAPHICS dict + loc = graphic.loc + GRAPHICS[loc] = graphic + self._graphics.append(loc) # add hex id string for referencing this graphic instance + # add world object to scene self.scene.add(graphic.world_object) if center: @@ -185,7 +207,7 @@ def add_graphic(self, graphic: Graphic, center: bool = True): def _check_graphic_name_exists(self, name): graphic_names = list() - for g in self._graphics: + for g in self.graphics: graphic_names.append(g.name) if name in graphic_names: @@ -287,8 +309,9 @@ def auto_scale(self, maintain_aspect: bool = False, zoom: float = 0.8): def remove_graphic(self, graphic: Graphic): """ - Remove a graphic from the scene. Note: This does not garbage collect the graphic, - you can add it back to the scene after removing it. + Remove a ``Graphic`` from the scene. Note: This does not garbage collect the graphic, + you can add it back to the scene after removing it. Use ``delete_graphic()`` to + delete and garbage collect a ``Graphic``. Parameters ---------- @@ -296,15 +319,67 @@ def remove_graphic(self, graphic: Graphic): The graphic to remove from the scene """ + self.scene.remove(graphic.world_object) + def delete_graphic(self, graphic: Graphic): + """ + Delete the graphic, garbage collects and frees GPU VRAM. + + Parameters + ---------- + graphic: Graphic or GraphicCollection + The graphic to delete + + """ + + # graphic_loc = hex(id(graphic.__repr__.__self__)) + + # get location + graphic_loc = graphic.loc + + if graphic_loc not in self._graphics: + raise KeyError(f"Graphic with following address not found in plot area: {graphic_loc}") + + # remove from scene if necessary + if graphic.world_object in self.scene.children: + self.scene.remove(graphic.world_object) + + # remove from list of addresses + self._graphics.remove(graphic_loc) + + # for GraphicCollection objects + # if isinstance(graphic, GraphicCollection): + # # clear Group + # graphic.world_object.clear() + # graphic.clear() + # delete all child world objects in the collection + # for g in graphic.graphics: + # subloc = hex(id(g)) + # del WORLD_OBJECTS[subloc] + + # get mem location of graphic + # loc = hex(id(graphic)) + # delete world object + #del WORLD_OBJECTS[graphic_loc] + + del GRAPHICS[graphic_loc] + + def clear(self): + """ + Clear the Plot or Subplot. Also performs garbage collection, i.e. runs ``delete_graphic`` on all graphics. + """ + + for g in self.graphics: + self.delete_graphic(g) + def __getitem__(self, name: str): - for graphic in self._graphics: + for graphic in self.graphics: if graphic.name == name: return graphic graphic_names = list() - for g in self._graphics: + for g in self.graphics: graphic_names.append(g.name) raise IndexError(f"no graphic of given name, the current graphics are:\n {graphic_names}") @@ -322,5 +397,5 @@ def __repr__(self): return f"{self}\n" \ f" parent: {self.parent}\n" \ f" Graphics:\n" \ - f"\t{newline.join(graphic.__repr__() for graphic in self._graphics)}" \ + f"\t{newline.join(graphic.__repr__() for graphic in self.graphics)}" \ f"\n" diff --git a/fastplotlib/layouts/_subplot.py b/fastplotlib/layouts/_subplot.py index 41d065648..7bb1f0540 100644 --- a/fastplotlib/layouts/_subplot.py +++ b/fastplotlib/layouts/_subplot.py @@ -2,6 +2,7 @@ import numpy as np from math import copysign from functools import partial +import weakref from inspect import signature, getfullargspec from warnings import warn @@ -112,7 +113,7 @@ def __init__( if self.name is not None: self.set_title(self.name) - def _create_graphic(self, graphic_class, *args, **kwargs): + def _create_graphic(self, graphic_class, *args, **kwargs) -> weakref.proxy: if "center" in kwargs.keys(): center = kwargs.pop("center") else: @@ -124,7 +125,8 @@ def _create_graphic(self, graphic_class, *args, **kwargs): graphic = graphic_class(*args, **kwargs) self.add_graphic(graphic, center=center) - return graphic + # only return a proxy to the real graphic + return weakref.proxy(graphic) def set_title(self, text: Any): """Sets the name of a subplot to 'top' viewport if defined."""