@@ -34,7 +34,20 @@ jupyter:
3434 thumbnail : thumbnail/knn-classification.png
3535---
3636
37- ## Basic Binary Classification with ` plotly.express `
37+ ## Basic binary classification with kNN
38+
39+
40+ ### Display training and test splits
41+
42+ ``` python
43+
44+ ```
45+
46+ ### Visualize predictions on test split
47+
48+ ``` python
49+
50+ ```
3851
3952``` python
4053import numpy as np
@@ -113,7 +126,7 @@ fig.add_trace(
113126 showscale = False ,
114127 colorscale = [' Blue' , ' Red' ],
115128 opacity = 0.4 ,
116- name = ' Confidence '
129+ name = ' Score '
117130 )
118131)
119132fig.show()
@@ -150,7 +163,7 @@ Z = Z.reshape(ll.shape)
150163proba = clf.predict_proba(np.c_[ll.ravel(), ww.ravel()])
151164proba = proba.reshape(ll.shape + (3 ,))
152165
153- fig = px.scatter(df, x = ' sepal_length' , y = ' sepal_width' , color = ' species' , width = 1000 , height = 1000 )
166+ fig = px.scatter(df, x = ' sepal_length' , y = ' sepal_width' , color = ' species' )
154167fig.update_traces(marker_size = 10 , marker_line_width = 1 )
155168fig.add_trace(
156169 go.Heatmap(
@@ -173,77 +186,12 @@ fig.add_trace(
173186fig.show()
174187```
175188
176- ## 3D Classification with ` px.scatter_3d `
177-
178- ``` python
179- import numpy as np
180- import plotly.express as px
181- import plotly.graph_objects as go
182- from sklearn.neighbors import KNeighborsClassifier
183- from sklearn.model_selection import train_test_split
184-
185- df = px.data.iris()
186- features = [" sepal_width" , " sepal_length" , " petal_width" ]
187-
188- X = df[features]
189- y = df.species
190- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3 , random_state = 0 )
191-
192- # Create classifier, run predictions on grid
193- clf = KNeighborsClassifier(15 , weights = ' distance' )
194- clf.fit(X_train, y_train)
195- y_pred = clf.predict(X_test)
196- y_score = clf.predict_proba(X_test)
197- y_score = np.around(y_score.max(axis = 1 ), 4 )
198-
199- fig = px.scatter_3d(
200- X_test,
201- x = ' sepal_length' ,
202- y = ' sepal_width' ,
203- z = ' petal_width' ,
204- symbol = y_pred,
205- color = y_score,
206- labels = {' symbol' : ' prediction' , ' color' : ' score' }
207- )
208- fig.update_layout(legend = dict (x = 0 , y = 0 ))
209- fig.show()
210- ```
211-
212- ## High Dimension Visualization with ` px.scatter_matrix `
213-
214- If you need to visualize classifications that go beyond 3D, you can use the [ scatter plot matrix] ( https://plot.ly/python/splom/ ) .
215-
216- ``` python
217- import numpy as np
218- import plotly.express as px
219- import plotly.graph_objects as go
220- from sklearn.neighbors import KNeighborsClassifier
221- from sklearn.model_selection import train_test_split
222-
223- df = px.data.iris()
224- features = [" sepal_width" , " sepal_length" , " petal_width" , " petal_length" ]
225-
226- X = df[features]
227- y = df.species
228- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3 , random_state = 0 )
229-
230- # Create classifier, run predictions on grid
231- clf = KNeighborsClassifier(15 , weights = ' distance' )
232- clf.fit(X_train, y_train)
233- y_pred = clf.predict(X_test)
234-
235- fig = px.scatter_matrix(X_test, dimensions = features, color = y_pred, labels = {' color' : ' prediction' })
236- fig.show()
237- ```
238-
239189### Reference
240190
241191Learn more about ` px ` , ` go.Contour ` , and ` go.Heatmap ` here:
242192* https://plot.ly/python/plotly-express/
243193* https://plot.ly/python/heatmaps/
244194* https://plot.ly/python/contour-plots/
245- * https://plot.ly/python/3d-scatter-plots/
246- * https://plot.ly/python/splom/
247195
248196This tutorial was inspired by amazing examples from the official scikit-learn docs:
249197* https://scikit-learn.org/stable/auto_examples/neighbors/plot_classification.html
0 commit comments