-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathdecision_tree.clj
More file actions
202 lines (185 loc) · 7.8 KB
/
decision_tree.clj
File metadata and controls
202 lines (185 loc) · 7.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
(ns lambda-ml.decision-tree
"Decision tree learning using the Classification and Regression Trees (CART)
algorithm.
Example usage;
```
(def data [[0 0 0] [0 1 1] [1 0 1] [1 1 0]])
(def fit
(let [min-split 2
min-leaf 1
max-features 2]
(-> (make-classification-tree gini-impurity min-split min-leaf max-features)
(decision-tree-fit data))))
(decision-tree-predict fit (map butlast data))
;;=> (0 1 1 0)
```"
(:require [lambda-ml.core :as c]
[lambda-ml.data.binary-tree :as bt]))
;; Cost functions
(defn gini-impurity
"Returns the Gini impurity of a seq of labels."
[labels]
(let [total (count labels)]
(->> (vals (frequencies labels))
(map #(/ % total))
(map #(* % (- 1 %)))
(reduce +))))
(defn mean-squared-error
"Returns the mean squared error for a seq of predictions."
[labels predictions]
(->> (map - labels predictions)
(map #(* % %))
(reduce +)
(* (/ 1 (count predictions)))))
(defn classification-weighted-cost
[y1 y2 f g]
(let [n1 (count y1)
n2 (count y2)]
;; Classification cost doesn't take the prediction value into account
(cond-> 0
(> n1 0) (+ (* (/ n1 (+ n1 n2)) (f y1)))
(> n2 0) (+ (* (/ n2 (+ n1 n2)) (f y2))))))
(defn regression-weighted-cost
[y1 y2 f g]
(let [n1 (count y1)
n2 (count y2)]
(cond-> 0
(> n1 0) (+ (* (/ n1 (+ n1 n2))
(f y1 (repeat n1 (g y1)))))
(> n2 0) (+ (* (/ n2 (+ n1 n2))
(f y2 (repeat n2 (g y2))))))))
;; Tree splitting
(defn categorical-partitions
"Given a seq of k distinct values, returns the 2^{k-1}-1 possible binary
partitions of the values into sets."
[vals]
(if (<= (count vals) 1)
[]
(reduce (fn [p [s1 s2]]
(conj p
[(conj s1 (first vals)) s2]
[(conj s2 (first vals)) s1]))
(vector [(hash-set (first vals)) (set (rest vals))])
(categorical-partitions (rest vals)))))
(defn numeric-partitions
"Given a seq of k distinct numeric values, returns k-1 possible binary
partitions of the values by taking the average of consecutive elements in the
sorted seq of values."
[vals]
(loop [partitions []
v (sort vals)]
(if (<= (count v) 1)
partitions
(recur (conj partitions (/ (+ (first v) (second v)) 2))
(rest v)))))
(defn splitters
"Returns a seq of all possible splitters for feature i. A splitter is a
predicate function that evaluates to true if an example belongs in the left
subtree, or false if an example belongs in the right subtree, based on the
splitting criterion."
[x i]
(let [domain (distinct (map #(nth % i) x))
val (first domain)]
(cond (number? val) (->> (numeric-partitions domain)
(map (fn [s]
(with-meta
(fn [x] (<= (nth x i) s))
{:decision (float s)}))))
(or (keyword? val)
(string? val)) (->> (categorical-partitions domain)
(map (fn [[s1 s2]]
(with-meta
(fn [x] (contains? s1 (nth x i)))
{:decision [s1 s2]}))))
:else (throw (IllegalArgumentException. "Invalid feature type")))))
(defn best-splitter
"Returns the splitter for the given data that minimizes a weighted cost
function, or returns nil if no splitter exists."
[model x y]
(let [{cost :cost prediction :prediction weighted :weighted
min-leaf :min-leaf max-features :max-features} model
;; Feature bagging - sample a subset of features to split on
features (-> (range (count (first x)))
(c/sample-without-replacement max-features))
data (map #(conj (vec %1) %2) x y)]
(->> (for [i features]
(let [no-splitter [nil Double/MAX_VALUE i]]
;; Find best splitter for feature i
(->> (splitters x i)
(map (fn [splitter]
(let [[left right] (vals (group-by splitter data))]
;; Either split would have fewer observations than required
(cond (< (count left) min-leaf) no-splitter
(< (count right) min-leaf) no-splitter
:else (let [cost (weighted (map last left) (map last right) cost prediction)
;; Add metadata to splitter
splitter (vary-meta splitter merge {:cost (float cost) :feature i})]
[splitter cost i])))))
(#(if (empty? %) (list no-splitter) %))
(apply min-key second))))
;; Find best splitter amongst all features
(reduce (fn [a b]
(let [[_ c1 i1] a [_ c2 i2] b]
(cond (< c1 c2) a
;; To match the CART algorithm, break ties in cost by
;; choosing splitter for feature with lower index
(= c1 c2) (if (< i1 i2) a b)
:else b))))
(first))))
;; API
(defn decision-tree-fit
"Fits a decision tree to the given training data."
([model data]
(decision-tree-fit model (map butlast data) (map last data)))
([model x y]
(let [{cost :cost prediction :prediction weighted :weighted
min-split :min-split min-leaf :min-leaf max-features :max-features} model
weighted (fn [left right] (weighted left right cost prediction))]
(->> (cond
;; Fewer observations than required to split a node
(< (count y) min-split) (bt/make-tree (prediction y))
;; All observed labels are equivalent
(apply = y) (bt/make-tree (prediction y))
:else
(let [splitter (best-splitter model x y)]
(if (nil? splitter)
(bt/make-tree (prediction y))
(let [data (map #(conj (vec %1) %2) x y)
split (group-by splitter data)
left (get split true)
right (get split false)]
(bt/make-tree splitter
(:parameters (decision-tree-fit model left))
(:parameters (decision-tree-fit model right)))))))
(assoc model :parameters)))))
(defn decision-tree-predict
"Predicts the values of example data using a decision tree."
[model x]
(let [{tree :parameters} model]
(when (not (nil? tree))
(letfn [(predict [t xi]
(let [val (bt/get-value t)]
(cond (bt/leaf? t) val
(val xi) (predict (bt/get-left t) xi)
:else (predict (bt/get-right t) xi))))]
(map #(predict tree %) x)))))
(defn print-decision-tree
"Prints information about a given decision tree."
[model]
(println (dissoc model :parameters))
(when (contains? model :parameters)
(bt/print-tree (:parameters model))))
(defn make-classification-tree
"Returns a classification decision tree model using the given cost function."
[cost min-split min-leaf max-features]
{:cost cost :prediction c/mode :weighted classification-weighted-cost
:min-split min-split
:min-leaf min-leaf
:max-features max-features})
(defn make-regression-tree
"Returns a regression decision tree model using the given cost function."
[cost min-split min-leaf max-features]
{:cost cost :prediction c/mean :weighted regression-weighted-cost
:min-split min-split
:min-leaf min-leaf
:max-features max-features})