@@ -78,6 +78,8 @@ fig.show()
7878
7979## Model generalization on unseen data
8080
81+ Easily color your plot based on a predefined data split.
82+
8183``` python
8284import numpy as np
8385import plotly.express as px
@@ -106,6 +108,8 @@ fig.show()
106108
107109## Comparing different kNN models parameters
108110
111+ Compare the performance of two different models on the same dataset. This can be easily combined with discrete color legends from ` px ` .
112+
109113``` python
110114import numpy as np
111115import plotly.express as px
@@ -114,14 +118,16 @@ from sklearn.neighbors import KNeighborsRegressor
114118
115119df = px.data.tips()
116120X = df.total_bill.values.reshape(- 1 , 1 )
121+ x_range = np.linspace(X.min(), X.max(), 100 )
117122
123+ # Model #1
118124knn_dist = KNeighborsRegressor(10 , weights = ' distance' )
119- knn_uni = KNeighborsRegressor(10 , weights = ' uniform' )
120125knn_dist.fit(X, df.tip)
121- knn_uni.fit(X, df.tip)
122-
123- x_range = np.linspace(X.min(), X.max(), 100 )
124126y_dist = knn_dist.predict(x_range.reshape(- 1 , 1 ))
127+
128+ # Model #2
129+ knn_uni = KNeighborsRegressor(10 , weights = ' uniform' )
130+ knn_uni.fit(X, df.tip)
125131y_uni = knn_uni.predict(x_range.reshape(- 1 , 1 ))
126132
127133fig = px.scatter(df, x = ' total_bill' , y = ' tip' , color = ' sex' , opacity = 0.65 )
@@ -132,6 +138,8 @@ fig.show()
132138
133139## 3D regression surface with ` px.scatter_3d ` and ` go.Surface `
134140
141+ Visualize the decision plane of your model whenever you have more than one variable in your ` X ` .
142+
135143``` python
136144import numpy as np
137145import plotly.express as px
@@ -229,7 +237,7 @@ model = LinearRegression()
229237model.fit(X, y)
230238y_pred = model.predict(X)
231239
232- fig = px.scatter(x = y , y = y_pred , labels = {' x' : ' y true ' , ' y' : ' y pred ' })
240+ fig = px.scatter(x = y_pred , y = y , labels = {' x' : ' prediction ' , ' y' : ' actual ' })
233241fig.add_shape(
234242 type = " line" , line = dict (dash = ' dash' ),
235243 x0 = y.min(), y0 = y.min(),
@@ -238,7 +246,9 @@ fig.add_shape(
238246fig.show()
239247```
240248
241- ### Augmented prediction error analysis using ` plotly.express `
249+ ### Enhanced prediction error analysis using ` plotly.express `
250+
251+ Add marginal histograms to quickly diagnoses any prediction bias your model might have. The built-in ` OLS ` functionality let you visualize how well your model generalizes by comparing it with the theoretical optimal fit (black dotted line).
242252
243253``` python
244254import plotly.express as px
@@ -254,6 +264,7 @@ df['split'] = 'train'
254264df.loc[test_idx, ' split' ] = ' test'
255265
256266X = df[[' sepal_width' , ' sepal_length' ]]
267+ y = df[' petal_width' ]
257268X_train = df.loc[train_idx, [' sepal_width' , ' sepal_length' ]]
258269y_train = df.loc[train_idx, ' petal_width' ]
259270
@@ -263,7 +274,7 @@ model.fit(X_train, y_train)
263274df[' prediction' ] = model.predict(X)
264275
265276fig = px.scatter(
266- df, x = ' petal_width ' , y = ' prediction ' ,
277+ df, x = ' prediction ' , y = ' petal_width ' ,
267278 marginal_x = ' histogram' , marginal_y = ' histogram' ,
268279 color = ' split' , trendline = ' ols'
269280)
0 commit comments