1+ """
2+ K-Means Clustering of MNIST Dataset
3+ ===================================
4+
5+ Example showing how you can perform K-Means clustering on the MNIST dataset.
6+ """
7+
8+ # test_example = false
9+ # sphinx_gallery_pygfx_docs = 'screenshot'
10+
11+ import fastplotlib as fpl
12+ import numpy as np
13+ from sklearn .datasets import load_digits
14+ from sklearn .cluster import KMeans
15+ from sklearn .decomposition import PCA
16+
17+ # load the data
18+ mnist = load_digits ()
19+
20+ # get the data and labels
21+ data = mnist ['data' ] # (1797, 64)
22+ labels = mnist ['target' ] # (1797,)
23+
24+ # visualize the first 5 digits
25+ # NOTE: this is just to give a sense of the dataset if you are unfamiliar,
26+ # the more interesting visualization is below :D
27+ fig_data = fpl .Figure (shape = (1 , 5 ), size = (900 , 300 ))
28+
29+ # iterate through each subplot
30+ for i , subplot in enumerate (fig_data ):
31+ # reshape each image to (8, 8)
32+ subplot .add_image (data [i ].reshape (8 ,8 ), cmap = "gray" , interpolation = "linear" )
33+ # add the label as a title
34+ subplot .set_title (f"Label: { labels [i ]} " )
35+ # turn off the axes and toolbar
36+ subplot .axes .visible = False
37+ subplot .toolbar = False
38+
39+ fig_data .show ()
40+
41+ # project the data from 64 dimensions down to the number of unique digits
42+ n_digits = len (np .unique (labels )) # 10
43+
44+ reduced_data = PCA (n_components = n_digits ).fit_transform (data ) # (1797, 10)
45+
46+ # performs K-Means clustering, take the best of 4 runs
47+ kmeans = KMeans (n_clusters = n_digits , n_init = 4 )
48+ # fit the lower-dimension data
49+ kmeans .fit (reduced_data )
50+
51+ # get the centroids (center of the clusters)
52+ centroids = kmeans .cluster_centers_
53+
54+ # plot the kmeans result and corresponding original image
55+ figure = fpl .Figure (
56+ shape = (1 ,2 ),
57+ size = (700 , 400 ),
58+ cameras = ["3d" , "2d" ],
59+ controller_types = [["fly" , "panzoom" ]]
60+ )
61+
62+ # set the axes to False
63+ figure [0 , 0 ].axes .visible = False
64+ figure [0 , 1 ].axes .visible = False
65+
66+ figure [0 , 0 ].set_title (f"K-means clustering of PCA-reduced data" )
67+
68+ # plot the centroids
69+ figure [0 , 0 ].add_scatter (
70+ data = np .vstack ([centroids [:, 0 ], centroids [:, 1 ], centroids [:, 2 ]]).T ,
71+ colors = "white" ,
72+ sizes = 15
73+ )
74+ # plot the down-projected data
75+ digit_scatter = figure [0 ,0 ].add_scatter (
76+ data = np .vstack ([reduced_data [:, 0 ], reduced_data [:, 1 ], reduced_data [:, 2 ]]).T ,
77+ sizes = 5 ,
78+ cmap = "tab10" , # use a qualitative cmap
79+ cmap_transform = kmeans .labels_ , # color by the predicted cluster
80+ )
81+
82+ # initial index
83+ ix = 0
84+
85+ # plot the initial image
86+ digit_img = figure [0 , 1 ].add_image (
87+ data = data [ix ].reshape (8 ,8 ),
88+ cmap = "gray" ,
89+ name = "digit" ,
90+ interpolation = "linear"
91+ )
92+
93+ # change the color and size of the initial selected data point
94+ digit_scatter .colors [ix ] = "magenta"
95+ digit_scatter .sizes [ix ] = 10
96+
97+ # define event handler to update the selected data point
98+ @digit_scatter .add_event_handler ("pointer_enter" )
99+ def update (ev ):
100+ # reset colors and sizes
101+ digit_scatter .cmap = "tab10"
102+ digit_scatter .sizes = 5
103+
104+ # update with new seleciton
105+ ix = ev .pick_info ["vertex_index" ]
106+
107+ digit_scatter .colors [ix ] = "magenta"
108+ digit_scatter .sizes [ix ] = 10
109+
110+ # update digit fig
111+ figure [0 , 1 ]["digit" ].data = data [ix ].reshape (8 , 8 )
112+
113+ figure .show ()
114+
115+ # NOTE: `if __name__ == "__main__"` is NOT how to use fastplotlib interactively
116+ # please see our docs for using fastplotlib interactively in ipython and jupyter
117+ if __name__ == "__main__" :
118+ print (__doc__ )
119+ fpl .loop .run ()
0 commit comments