Skip to content

Commit 3947427

Browse files
committed
Added K-Means Clustering in Scala
1 parent 8394f57 commit 3947427

2 files changed

Lines changed: 169 additions & 0 deletions

File tree

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import scala.annotation.tailrec
2+
3+
object kmeans_clustering {
4+
5+
/**
6+
* Simple 3d Vector class
7+
* @param a first component
8+
* @param b second component
9+
* @param c third component
10+
*/
11+
case class Vector(a: Double, b: Double, c: Double) {
12+
/**
13+
* Adds two vectors.
14+
* @param that the other vector
15+
* @return the sum
16+
*/
17+
def +(that: Vector): Vector = Vector(a + that.a, b + that.b, c + that.c)
18+
19+
/**
20+
* Subtracts two vectors
21+
* @param that the other vector
22+
* @return the difference
23+
*/
24+
def -(that: Vector): Vector = Vector(a - that.a, b - that.b, c - that.c)
25+
26+
/**
27+
* The cross product of two vectors
28+
* @param that the other vecto
29+
* @return the cross product
30+
*/
31+
def x(that: Vector): Vector = Vector(b * that.c - c * that.b, a * that.c - c * that.a, a * that.b - b * that.a)
32+
33+
/**
34+
* The dot product of two vectors.
35+
* @param that the other vector
36+
* @return the dot product
37+
*/
38+
def *(that: Vector): Double = this.a * that.a + this.b * that.b + this.c * that.c
39+
40+
/**
41+
* Multiplies the vector with a scalar.
42+
* @param s the scalar
43+
* @return the resulting vector
44+
*/
45+
def *(s: Double): Vector = Vector(this.a * s, this.b * s, this.c * s)
46+
47+
/**
48+
* The length of the vector
49+
* @return the length
50+
*/
51+
def length: Double = math.sqrt(this * this)
52+
53+
/**
54+
* Normalises the vector.
55+
* @return the normalised vector
56+
*/
57+
def normalised: Vector = Vector(a / length, b / length, c / length)
58+
59+
override def toString = s"Vector($a, $b, $c)"
60+
61+
override def equals(obj: scala.Any): Boolean = obj match {
62+
case v: Vector
63+
val vn = v.normalised
64+
val tn = this.normalised
65+
66+
vn.a == tn.a && vn.b == tn.b && vn.c == tn.c
67+
case _ false
68+
}
69+
}
70+
71+
/**
72+
* A Cluster, represented by its centre and the members belonging to this cluster.
73+
* @param centre the centre of the cluster
74+
* @param members the members of this cluster
75+
*/
76+
class Cluster(var centre: Vector, var members: List[Vector]) {
77+
override def toString = s"Cluster(Centre: $centre, members: $members)"
78+
}
79+
80+
/**
81+
* A similarity measure.
82+
* @tparam T the type
83+
*/
84+
trait SimilarityMeasure[T] {
85+
/**
86+
* Calculates the similarity between two entities.
87+
* @param x 1st entity
88+
* @param y 2nd entity
89+
* @return the similarity
90+
*/
91+
def s(x: T, y: T): Double
92+
}
93+
94+
/**
95+
* The cosine similarity uses the cosine of the angle between two vectors to measure the similarity.
96+
*/
97+
implicit object CosineSimilarity extends SimilarityMeasure[Vector] {
98+
def s(x: Vector, y: Vector): Double = (x * y) / (x.length * y.length)
99+
}
100+
101+
/**
102+
* Clusters a list of vectors and groups them into c clusters.
103+
* @param values the vectors to be clustered
104+
* @param c the number of clusters
105+
* @param sim the similarity measure that is used for the similarity calculations
106+
* @return the list of found clusters
107+
*/
108+
def kmeans(values: List[Vector], c: Int)(implicit sim: SimilarityMeasure[Vector]): List[Cluster] = {
109+
assert(values.size >= c)
110+
recalc(values.take(c), values)
111+
}
112+
113+
@tailrec
114+
private def recalc(centres: List[Vector], values: List[Vector])(implicit sim: SimilarityMeasure[Vector]): List[Cluster] = {
115+
// Setting up new clusters
116+
val clusters = for {
117+
centre centres
118+
} yield new Cluster(centre, Nil)
119+
120+
// Assigning vector to cluster with highest similarity.
121+
for (v values) {
122+
var d = (clusters(0), sim.s(clusters(0).centre, v))
123+
for (cluster clusters) {
124+
val dist = sim.s(cluster.centre, v)
125+
if (d._2 < dist) {
126+
d = (cluster, dist)
127+
}
128+
}
129+
d._1.members = v :: d._1.members
130+
}
131+
132+
// Updating the centres of every cluster
133+
for (cluster clusters) {
134+
val mp = if (cluster.members.size > 0) midpoint(cluster.members) else cluster.centre
135+
cluster.centre = mp
136+
}
137+
138+
val newCentres = clusters.collect {
139+
case c c.centre
140+
}
141+
142+
// Checking how many cluster centres were changed in this run. If nothing changed we return the list of clusters.
143+
if (centres.diff(newCentres).size < 1) {
144+
clusters
145+
} else {
146+
recalc(newCentres, values)
147+
}
148+
}
149+
150+
private def midpoint(members: List[Vector]): Vector = members.foldRight(Vector(0, 0, 0))(_ + _) * (1.0 / members.size)
151+
152+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import org.scalatest.{ Matchers, FlatSpec }
2+
import kmeans_clustering._
3+
4+
class kmeans_clustering_test extends FlatSpec with Matchers {
5+
"A list of vectors" should "be clustered into 2 clusters" in {
6+
val vectors =
7+
Vector(1, 0, 0) :: Vector(0, 0, 1) ::
8+
Vector(4, 0.1, 0) :: Vector(0, 0.1, 6.3) :: Nil
9+
10+
val clusters = kmeans(vectors, 2)
11+
12+
clusters.size shouldBe 2
13+
clusters.collect {
14+
case c c.centre
15+
} should contain allOf (Vector(2.5, 0.05, 0.0), Vector(0.0, 0.05, 3.65))
16+
}
17+
}

0 commit comments

Comments
 (0)