Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixed a few bugs
  • Loading branch information
amandelpie committed Jul 18, 2022
commit 00157a907cf8d0c553cf25197100e0f02507c5b0
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import org.utbot.summary.clustering.dbscan.neighbor.RangeQuery
* Keeps the information about clusters produced by [DBSCANTrainer].
*
* @property [k] Number of clusters.
* @property [clusterLabels] Labels of clusters in the range [0; k).
* @property [clusterLabels] It contains labels of clusters in the range ```[0; k)```
* or [Int.MIN_VALUE] if point could not be assigned to any cluster.
*/
data class DBSCANModel(
val k: Int = 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import org.utbot.summary.clustering.dbscan.neighbor.LinearRangeQuery
import org.utbot.summary.clustering.dbscan.neighbor.Neighbor
import org.utbot.summary.clustering.dbscan.neighbor.RangeQuery

private const val NOISE = -3
private const val NOISE = Int.MIN_VALUE
private const val CLUSTER_PART = -2
private const val UNDEFINED = -1

Expand All @@ -29,68 +29,71 @@ class DBSCANTrainer<T>(val eps: Float, val minSamples: Int, val metric: Metric<T

/** Builds a clustering model based on the given data. */
fun fit(data: Array<T>): DBSCANModel {
require(data.isNotEmpty()) { "Nothing to learn, data is empty."}

if (rangeQuery is LinearRangeQuery) {
rangeQuery.data = data
rangeQuery.metric = metric
} // TODO: could be refactored if we add some new variants of RangeQuery

val numberOfClusters = 0
val labels = IntArray(data.size) { _ -> UNDEFINED }

var k = 0
Comment thread
amandelpie marked this conversation as resolved.
Outdated

for (i in data.indices) {
if (labels[i] == UNDEFINED) {
val neigbors = rangeQuery.findNeighbors(data[i], eps).toMutableList()
if (neigbors.size < minSamples) {
val neighbors = rangeQuery.findNeighbors(data[i], eps).toMutableList()
if (neighbors.size < minSamples) {
labels[i] = NOISE
} else {
k++
labels[i] = k
expandCluster(neigbors, labels, k)
expandCluster(neighbors, labels, k)
k++
}
}
}

return DBSCANModel(k = numberOfClusters, clusterLabels = labels)
return DBSCANModel(k = k, clusterLabels = labels)
}

private fun expandCluster(
neigbors: MutableList<Neighbor<T>>,
neighbors: MutableList<Neighbor<T>>,
labels: IntArray,
k: Int
) {
neigbors.forEach { // Neighbors to expand.
neighbors.forEach { // Neighbors to expand.
if (labels[it.index] == UNDEFINED) {
labels[it.index] = CLUSTER_PART // All neighbors of a cluster point became cluster points.
}
}

for (j in neigbors.indices) { // Process every seed point Q.
val q = neigbors[j]
// NOTE: the size of neighbors could grow from iteration to iteration and the classical for-loop in Kotlin could not be used
var j = 0
while (j < neighbors.count()) // Process every seed point Q.
{
val q = neighbors[j]
val idx = q.index


if (labels[idx] == NOISE) { // Change Noise to border point.
labels[idx] = k
}

if (labels[idx] == UNDEFINED || labels[idx] == CLUSTER_PART) {
labels[idx] = k


val qNeighbors = rangeQuery.findNeighbors(q.key, eps)

if (qNeighbors.size >= minSamples) { // Density check (if Q is a core point).
mergeTwoGroupsInCluster(qNeighbors, labels, neigbors)
mergeTwoGroupsInCluster(qNeighbors, labels, neighbors)
}
}
j++
}
}

private fun mergeTwoGroupsInCluster(
qNeighbors: List<Neighbor<T>>,
labels: IntArray,
neigbors: MutableList<Neighbor<T>>
neighbors: MutableList<Neighbor<T>>
) {
for (qNeighbor in qNeighbors) {
val label = labels[qNeighbor.index]
Expand All @@ -99,7 +102,7 @@ class DBSCANTrainer<T>(val eps: Float, val minSamples: Int, val metric: Metric<T
}

if (label == UNDEFINED || label == NOISE) {
neigbors.add(qNeighbor)
neighbors.add(qNeighbor)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,31 @@ import org.junit.jupiter.api.Test

import org.junit.jupiter.api.Assertions.*
import org.utbot.summary.clustering.dbscan.neighbor.LinearRangeQuery
import java.lang.IllegalArgumentException
import kotlin.math.sqrt

internal class DBSCANTrainerTest {
@Test
fun emptyData() {
val testData = arrayOf<Point>()

val dbscan = DBSCANTrainer(
eps = 0.3f,
minSamples = 10,
metric = TestEuclideanMetric(),
rangeQuery = LinearRangeQuery()
)

val exception = assertThrows(IllegalArgumentException::class.java) {
dbscan.fit(testData)
}

assertEquals(
"Nothing to learn, data is empty.",
exception.message
)
}


@Test
fun fit() {
Expand Down Expand Up @@ -180,7 +202,7 @@ internal class DBSCANTrainerTest {


val dbscan = DBSCANTrainer(
eps = 0.5f,
eps = 0.3f,
minSamples = 10,
metric = TestEuclideanMetric(),
rangeQuery = LinearRangeQuery()
Expand All @@ -190,10 +212,10 @@ internal class DBSCANTrainerTest {
val clusterLabels = dbscanModel.clusterLabels

assertEquals(150, clusterLabels.size)
assertEquals(50, clusterLabels.count { it == 1 })
assertEquals(50, clusterLabels.count { it == 2 })
assertEquals(50, clusterLabels.count { it == 3 })

assertEquals(27, clusterLabels.count { it == 0 })
assertEquals(35, clusterLabels.count { it == 1 })
assertEquals(18, clusterLabels.count { it == 2 })
assertEquals(70, clusterLabels.count { it == Int.MIN_VALUE })
}


Expand Down