Skip to content

Commit 58dc945

Browse files
authored
Add CART-based decision tree with Gini impurity (#1)
1 parent 7ed610b commit 58dc945

9 files changed

Lines changed: 536 additions & 47 deletions

File tree

src/lambda_ml/data/binary_tree.clj

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
(ns lambda-ml.data.binary-tree)
2+
3+
(defn make-tree
4+
([val]
5+
(make-tree val nil nil))
6+
([val left right]
7+
(vector val left right)))
8+
9+
(defn get-value
10+
[tree]
11+
(nth tree 0))
12+
13+
(defn get-left
14+
[tree]
15+
(nth tree 1))
16+
17+
(defn get-right
18+
[tree]
19+
(nth tree 2))
20+
21+
(defn get-path
22+
[tree paths]
23+
(->> paths
24+
(map (fn [path]
25+
(cond (= path :left) 1
26+
(= path :right) 2
27+
:else (throw (IllegalArgumentException. "Invalid tree path")))))
28+
(get-in tree)))
29+
30+
(defn leaf?
31+
[tree]
32+
(and (nil? (get-left tree)) (nil? (get-right tree))))
33+
34+
(defn print-tree
35+
([tree]
36+
(print-tree tree 0))
37+
([tree level]
38+
(println (apply str (repeat level "\t"))
39+
(let [val (get-value tree)]
40+
(or (meta val) val)))
41+
(when (not (nil? (get-left tree)))
42+
(print-tree (get-left tree) (inc level)))
43+
(when (not (nil? (get-right tree)))
44+
(print-tree (get-right tree) (inc level)))))

src/lambda_ml/data/kd_tree.clj

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
(ns lambda-ml.data.kd-tree)
1+
(ns lambda-ml.data.kd-tree
2+
(:require [lambda-ml.data.binary-tree :as bt]))
23

34
;; K-d tree
45

@@ -17,27 +18,6 @@
1718
(let [dim (fn [node] (nth (f node) (mod depth dims)))
1819
sorted (sort-by dim nodes)
1920
median (quot (count sorted) 2)]
20-
(vector (nth sorted median)
21-
(make-tree dims (take median sorted) f (inc depth))
22-
(make-tree dims (drop (+ median 1) sorted) f (inc depth)))))))
23-
24-
(defn get-value
25-
[tree]
26-
(nth tree 0))
27-
28-
(defn get-left
29-
[tree]
30-
(nth tree 1))
31-
32-
(defn get-right
33-
[tree]
34-
(nth tree 2))
35-
36-
(defn get-path
37-
[tree paths]
38-
(->> paths
39-
(map (fn [path]
40-
(cond (= path :left) 1
41-
(= path :right) 2
42-
:else (throw (IllegalArgumentException. "Invalid tree path")))))
43-
(get-in tree)))
21+
(bt/make-tree (nth sorted median)
22+
(make-tree dims (take median sorted) f (inc depth))
23+
(make-tree dims (drop (+ median 1) sorted) f (inc depth)))))))

src/lambda_ml/decision_tree.clj

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
(ns lambda-ml.decision-tree
2+
(:require [lambda-ml.data.binary-tree :as bt]))
3+
4+
(defn gini-impurity
5+
[y]
6+
(let [total (count y)]
7+
(->> (vals (frequencies y))
8+
(map #(/ % total))
9+
(map #(* % (- 1 %)))
10+
(reduce +))))
11+
12+
(defn weighted-cost
13+
[y1 y2 f]
14+
(let [n1 (count y1)
15+
n2 (count y2)
16+
c1 (f y1)
17+
c2 (f y2)]
18+
(float (+ (* (/ n1 (+ n1 n2)) c1)
19+
(* (/ n2 (+ n1 n2)) c2)))))
20+
21+
(defn categorical-partitions
22+
"Given a seq of k distinct values, returns the 2^{k-1}-1 possible binary
23+
partitions of the values into sets. Returns a trivial partition when k = 1."
24+
[vals]
25+
(let [partition [(hash-set (first vals))
26+
(set (rest vals))]]
27+
(if (<= (count vals) 2)
28+
[partition]
29+
(reduce (fn [p [s1 s2]]
30+
(conj p
31+
[(conj s1 (first vals)) s2]
32+
[(conj s2 (first vals)) s1]))
33+
[partition]
34+
(categorical-partitions (rest vals))))))
35+
36+
(defn numeric-partitions
37+
"Given a seq of k distinct numeric values, returns k-1 possible binary
38+
partitions of the values by taking the average of consecutive elements in the
39+
sorted seq of values. Returns the same seq when k = 1."
40+
[vals]
41+
(if (= (count vals) 1)
42+
vals
43+
(loop [partitions []
44+
v (sort vals)]
45+
(if (= (count v) 1)
46+
partitions
47+
(recur (conj partitions (/ (+ (first v) (second v)) 2))
48+
(rest v))))))
49+
50+
(defn splitters
51+
"Returns a seq of all possible splitters for feature i. A splitter is a
52+
predicate function that evaluates to true if an example belongs in the left
53+
subtree, or false if an example belongs in the right subtree, based on the
54+
splitting criterion."
55+
[x i]
56+
(let [domain (distinct (map #(nth % i) x))]
57+
(cond (number? (first domain)) (->> (numeric-partitions domain)
58+
(map (fn [s]
59+
(with-meta
60+
(fn [x] (<= (nth x i) s))
61+
{:decision (float s)}))))
62+
(string? (first domain)) (->> (categorical-partitions domain)
63+
(map (fn [[s1 s2]]
64+
(with-meta
65+
(fn [x] (contains? s1 (nth x i)))
66+
{:decision [s1 s2]}))))
67+
:else (throw (IllegalStateException. "Invalid feature type")))))
68+
69+
(defn best-splitter
70+
"Returns the splitter for the given data that minimizes cost function f."
71+
[f x y]
72+
(->> (for [i (range (count (first x)))]
73+
;; Find best splitter for feature i
74+
(->> (splitters x i)
75+
(map (fn [splitter]
76+
(let [data (map #(conj (vec %1) %2) x y)
77+
[left right] (vals (group-by splitter data))
78+
cost (weighted-cost (map last left) (map last right) f)]
79+
;; Add metadata to splitter
80+
[(vary-meta splitter merge {:cost cost :feature i}) cost i])))
81+
(apply min-key second)))
82+
;; Find best splitter amongst all features
83+
(reduce (fn [a b]
84+
(let [[_ c1 i1] a [_ c2 i2] b]
85+
(cond (< c1 c2) a
86+
;; To match the CART algorithm, break ties in cost by
87+
;; choosing splitter for feature with lower index
88+
(= c1 c2) (if (< i1 i2) a b)
89+
:else b))))
90+
(first)))
91+
92+
(defn decision-tree-fit
93+
"Fits a decision tree to the given training data."
94+
([model data]
95+
(decision-tree-fit model (map butlast data) (map last data)))
96+
([model x y]
97+
(->> (if (apply = y)
98+
(bt/make-tree (first y))
99+
(let [{cost :cost} model
100+
splitter (best-splitter cost x y)
101+
data (map #(conj (vec %1) %2) x y)
102+
split (group-by splitter data)
103+
left (get split true)
104+
right (get split false)]
105+
(bt/make-tree splitter
106+
(:parameters (decision-tree-fit model left))
107+
(:parameters (decision-tree-fit model right)))))
108+
(assoc model :parameters))))
109+
110+
(defn decision-tree-predict
111+
"Predicts the values of example data using a decision tree."
112+
[model x]
113+
(let [{tree :parameters} model]
114+
(when (not (nil? tree))
115+
(letfn [(predict [t xi]
116+
(let [val (bt/get-value t)]
117+
(cond (bt/leaf? t) val
118+
(val xi) (predict (bt/get-left t) xi)
119+
:else (predict (bt/get-right t) xi))))]
120+
(map #(predict tree %) x)))))
121+
122+
(defn print-decision-tree
123+
"Prints information about a given decision tree."
124+
[model]
125+
(println (dissoc model :parameters))
126+
(when (contains? model :parameters)
127+
(bt/print-tree (:parameters model))))
128+
129+
(defn make-decision-tree
130+
"Returns a decision tree model using the given cost function."
131+
[cost]
132+
{:cost cost})

0 commit comments

Comments
 (0)