Skip to content

Commit eb7a386

Browse files
author
Michael Galarnyk
committed
Added new decision tree for titantic
1 parent e68efe1 commit eb7a386

4 files changed

Lines changed: 335 additions & 2 deletions

File tree

Kaggle/.DS_Store

0 Bytes
Binary file not shown.

Kaggle/Titanic/Titanic_DecisionTree.ipynb

Lines changed: 319 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,326 @@
3030
"outputs": [],
3131
"source": [
3232
"import pandas as pd\n",
33-
"url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data\"\n",
33+
"from sklearn import tree \n",
34+
"import numpy as np\n",
35+
"\n",
36+
"url = \"https://raw.githubusercontent.com/mGalarnyk/Python_Tutorials/master/Kaggle/Titanic/train.csv\"\n",
3437
"# load dataset into Pandas DataFrame\n",
35-
"df = pd.read_csv(url, names=['sepal length','sepal width','petal length','petal width','target'])"
38+
"df = pd.read_csv(url)"
39+
]
40+
},
41+
{
42+
"cell_type": "code",
43+
"execution_count": 2,
44+
"metadata": {
45+
"collapsed": false
46+
},
47+
"outputs": [],
48+
"source": [
49+
"# Change sex to binary\n",
50+
"df['Sex'] = df['Sex'].map( {'female': 0, 'male': 1} ).astype(int)"
51+
]
52+
},
53+
{
54+
"cell_type": "code",
55+
"execution_count": 3,
56+
"metadata": {
57+
"collapsed": false
58+
},
59+
"outputs": [],
60+
"source": [
61+
"# Take subset of the dataset\n",
62+
"\n",
63+
"df = df[['Sex', 'Age', 'Fare', 'Survived']]"
64+
]
65+
},
66+
{
67+
"cell_type": "code",
68+
"execution_count": 4,
69+
"metadata": {
70+
"collapsed": false
71+
},
72+
"outputs": [
73+
{
74+
"data": {
75+
"text/plain": [
76+
"1 577\n",
77+
"0 314\n",
78+
"Name: Sex, dtype: int64"
79+
]
80+
},
81+
"execution_count": 4,
82+
"metadata": {},
83+
"output_type": "execute_result"
84+
}
85+
],
86+
"source": [
87+
"df.Sex.value_counts(dropna = False)"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": 5,
93+
"metadata": {
94+
"collapsed": false
95+
},
96+
"outputs": [
97+
{
98+
"data": {
99+
"text/plain": [
100+
"8.0500 43\n",
101+
"13.0000 42\n",
102+
"7.8958 38\n",
103+
"7.7500 34\n",
104+
"26.0000 31\n",
105+
"10.5000 24\n",
106+
"7.9250 18\n",
107+
"7.7750 16\n",
108+
"26.5500 15\n",
109+
"0.0000 15\n",
110+
"7.2292 15\n",
111+
"7.8542 13\n",
112+
"8.6625 13\n",
113+
"7.2500 13\n",
114+
"7.2250 12\n",
115+
"16.1000 9\n",
116+
"9.5000 9\n",
117+
"24.1500 8\n",
118+
"15.5000 8\n",
119+
"56.4958 7\n",
120+
"52.0000 7\n",
121+
"14.5000 7\n",
122+
"14.4542 7\n",
123+
"69.5500 7\n",
124+
"7.0500 7\n",
125+
"31.2750 7\n",
126+
"46.9000 6\n",
127+
"30.0000 6\n",
128+
"7.7958 6\n",
129+
"39.6875 6\n",
130+
" ..\n",
131+
"7.1417 1\n",
132+
"42.4000 1\n",
133+
"211.5000 1\n",
134+
"12.2750 1\n",
135+
"61.1750 1\n",
136+
"8.4333 1\n",
137+
"51.4792 1\n",
138+
"7.8875 1\n",
139+
"8.6833 1\n",
140+
"7.5208 1\n",
141+
"34.6542 1\n",
142+
"28.7125 1\n",
143+
"25.5875 1\n",
144+
"7.7292 1\n",
145+
"12.2875 1\n",
146+
"8.6542 1\n",
147+
"8.7125 1\n",
148+
"61.3792 1\n",
149+
"6.9500 1\n",
150+
"9.8417 1\n",
151+
"8.3000 1\n",
152+
"13.7917 1\n",
153+
"9.4750 1\n",
154+
"13.4167 1\n",
155+
"26.3875 1\n",
156+
"8.4583 1\n",
157+
"9.8375 1\n",
158+
"8.3625 1\n",
159+
"14.1083 1\n",
160+
"17.4000 1\n",
161+
"Name: Fare, Length: 248, dtype: int64"
162+
]
163+
},
164+
"execution_count": 5,
165+
"metadata": {},
166+
"output_type": "execute_result"
167+
}
168+
],
169+
"source": [
170+
"df.Fare.value_counts(dropna = False)"
171+
]
172+
},
173+
{
174+
"cell_type": "code",
175+
"execution_count": 6,
176+
"metadata": {
177+
"collapsed": false
178+
},
179+
"outputs": [
180+
{
181+
"data": {
182+
"text/plain": [
183+
"8.0500 43\n",
184+
"13.0000 42\n",
185+
"7.8958 38\n",
186+
"7.7500 34\n",
187+
"26.0000 31\n",
188+
"10.5000 24\n",
189+
"7.9250 18\n",
190+
"7.7750 16\n",
191+
"26.5500 15\n",
192+
"0.0000 15\n",
193+
"7.2292 15\n",
194+
"7.8542 13\n",
195+
"8.6625 13\n",
196+
"7.2500 13\n",
197+
"7.2250 12\n",
198+
"16.1000 9\n",
199+
"9.5000 9\n",
200+
"24.1500 8\n",
201+
"15.5000 8\n",
202+
"56.4958 7\n",
203+
"52.0000 7\n",
204+
"14.5000 7\n",
205+
"14.4542 7\n",
206+
"69.5500 7\n",
207+
"7.0500 7\n",
208+
"31.2750 7\n",
209+
"46.9000 6\n",
210+
"30.0000 6\n",
211+
"7.7958 6\n",
212+
"39.6875 6\n",
213+
" ..\n",
214+
"7.1417 1\n",
215+
"42.4000 1\n",
216+
"211.5000 1\n",
217+
"12.2750 1\n",
218+
"61.1750 1\n",
219+
"8.4333 1\n",
220+
"51.4792 1\n",
221+
"7.8875 1\n",
222+
"8.6833 1\n",
223+
"7.5208 1\n",
224+
"34.6542 1\n",
225+
"28.7125 1\n",
226+
"25.5875 1\n",
227+
"7.7292 1\n",
228+
"12.2875 1\n",
229+
"8.6542 1\n",
230+
"8.7125 1\n",
231+
"61.3792 1\n",
232+
"6.9500 1\n",
233+
"9.8417 1\n",
234+
"8.3000 1\n",
235+
"13.7917 1\n",
236+
"9.4750 1\n",
237+
"13.4167 1\n",
238+
"26.3875 1\n",
239+
"8.4583 1\n",
240+
"9.8375 1\n",
241+
"8.3625 1\n",
242+
"14.1083 1\n",
243+
"17.4000 1\n",
244+
"Name: Fare, Length: 248, dtype: int64"
245+
]
246+
},
247+
"execution_count": 6,
248+
"metadata": {},
249+
"output_type": "execute_result"
250+
}
251+
],
252+
"source": [
253+
"df.Fare.value_counts(dropna = False)"
254+
]
255+
},
256+
{
257+
"cell_type": "code",
258+
"execution_count": 7,
259+
"metadata": {
260+
"collapsed": false
261+
},
262+
"outputs": [],
263+
"source": [
264+
"# Impute age with mean \n",
265+
"df.loc[df.Age.isna(), 'Age'] = np.ceil(df.Age.mean())"
266+
]
267+
},
268+
{
269+
"cell_type": "code",
270+
"execution_count": 24,
271+
"metadata": {
272+
"collapsed": false
273+
},
274+
"outputs": [],
275+
"source": [
276+
"clf = tree.DecisionTreeClassifier(max_depth=2) \n",
277+
"clf = clf.fit(df[['Sex', 'Age', 'Fare']], df[['Survived']]) \n",
278+
"tree.export_graphviz(clf,\n",
279+
" out_file=\"decisionTreeTitantic.dot\",\n",
280+
" feature_names=['Sex', 'Age', 'Fare'],\n",
281+
" class_names=['Dead', 'Alive'],\n",
282+
" filled = True)"
283+
]
284+
},
285+
{
286+
"cell_type": "code",
287+
"execution_count": 25,
288+
"metadata": {
289+
"collapsed": true
290+
},
291+
"outputs": [],
292+
"source": [
293+
"!dot -Tpng decisionTreeTitantic.dot -o decisionTreeTitantic.png"
294+
]
295+
},
296+
{
297+
"cell_type": "code",
298+
"execution_count": 19,
299+
"metadata": {
300+
"collapsed": true
301+
},
302+
"outputs": [],
303+
"source": []
304+
},
305+
{
306+
"cell_type": "code",
307+
"execution_count": 20,
308+
"metadata": {
309+
"collapsed": false
310+
},
311+
"outputs": [
312+
{
313+
"data": {
314+
"text/plain": [
315+
"array(['setosa', 'versicolor', 'virginica'], dtype='|S10')"
316+
]
317+
},
318+
"execution_count": 20,
319+
"metadata": {},
320+
"output_type": "execute_result"
321+
}
322+
],
323+
"source": [
324+
"iris.target_names"
325+
]
326+
},
327+
{
328+
"cell_type": "code",
329+
"execution_count": 21,
330+
"metadata": {
331+
"collapsed": false
332+
},
333+
"outputs": [
334+
{
335+
"data": {
336+
"text/plain": [
337+
"array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
338+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
339+
" 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
340+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
341+
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
342+
" 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
343+
" 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])"
344+
]
345+
},
346+
"execution_count": 21,
347+
"metadata": {},
348+
"output_type": "execute_result"
349+
}
350+
],
351+
"source": [
352+
"iris.target"
36353
]
37354
},
38355
{
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
digraph Tree {
2+
node [shape=box, style="filled", color="black"] ;
3+
0 [label="Sex <= 0.5\ngini = 0.473\nsamples = 891\nvalue = [549, 342]\nclass = Dead", fillcolor="#e5813960"] ;
4+
1 [label="Fare <= 48.2\ngini = 0.383\nsamples = 314\nvalue = [81, 233]\nclass = Alive", fillcolor="#399de5a6"] ;
5+
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
6+
2 [label="gini = 0.447\nsamples = 225\nvalue = [76, 149]\nclass = Alive", fillcolor="#399de57d"] ;
7+
1 -> 2 ;
8+
3 [label="gini = 0.106\nsamples = 89\nvalue = [5, 84]\nclass = Alive", fillcolor="#399de5f0"] ;
9+
1 -> 3 ;
10+
4 [label="Age <= 6.5\ngini = 0.306\nsamples = 577\nvalue = [468, 109]\nclass = Dead", fillcolor="#e58139c4"] ;
11+
0 -> 4 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
12+
5 [label="gini = 0.444\nsamples = 24\nvalue = [8, 16]\nclass = Alive", fillcolor="#399de57f"] ;
13+
4 -> 5 ;
14+
6 [label="gini = 0.28\nsamples = 553\nvalue = [460, 93]\nclass = Dead", fillcolor="#e58139cb"] ;
15+
4 -> 6 ;
16+
}
75.1 KB
Loading

0 commit comments

Comments
 (0)