Skip to content

Commit 16787eb

Browse files
committed
Iterate on the KDE plot some more
Still wacky! And lots of wild code from Claude.ai. But it runs!
1 parent 07bdc30 commit 16787eb

File tree

1 file changed

+161
-63
lines changed

1 file changed

+161
-63
lines changed

src/test/java/org/scijava/ui/swing/plot/KDEContourPlot.java

Lines changed: 161 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import org.jfree.chart.plot.XYPlot;
77
import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
88
import org.jfree.data.xy.AbstractXYDataset;
9-
import org.jfree.data.xy.XYDataset;
109

1110
import javax.swing.*;
1211
import java.awt.*;
@@ -30,12 +29,10 @@ public static void main(String[] args) {
3029

3130
private static class ContourDataset extends AbstractXYDataset {
3231
private final List<List<Point2D>> contourLines;
33-
private final List<Boolean> isFirstCluster; // Track which cluster each contour belongs to
32+
private final List<Boolean> isFirstCluster;
3433

3534
public ContourDataset(double[][] data, int gridSize, int numContours) {
36-
// Calculate KDE
3735
double[][] density = calculateKDE(data, gridSize);
38-
// Generate contour lines with cluster information
3936
var result = generateContours(density, numContours, data);
4037
this.contourLines = result.contourLines;
4138
this.isFirstCluster = result.isFirstCluster;
@@ -49,28 +46,28 @@ public boolean isFirstCluster(int series) {
4946
public int getSeriesCount() {
5047
return contourLines.size();
5148
}
52-
49+
5350
@Override
5451
public Comparable getSeriesKey(int series) {
5552
return "Contour " + series;
5653
}
57-
54+
5855
@Override
5956
public int getItemCount(int series) {
6057
return contourLines.get(series).size();
6158
}
62-
59+
6360
@Override
6461
public Number getX(int series, int item) {
6562
return contourLines.get(series).get(item).getX();
6663
}
67-
64+
6865
@Override
6966
public Number getY(int series, int item) {
7067
return contourLines.get(series).get(item).getY();
7168
}
7269
}
73-
70+
7471
private static double[][] calculateKDE(double[][] data, int gridSize) {
7572
// Find data bounds
7673
double minX = Double.POSITIVE_INFINITY, maxX = Double.NEGATIVE_INFINITY;
@@ -126,8 +123,21 @@ private static class ContourResult {
126123
}
127124
}
128125

126+
private static Point2D interpolate(double x1, double y1, double v1,
127+
double x2, double y2, double v2,
128+
double level) {
129+
if (Math.abs(v1 - v2) < 1e-10) {
130+
return new Point2D.Double(x1, y1);
131+
}
132+
double t = (level - v1) / (v2 - v1);
133+
return new Point2D.Double(
134+
x1 + t * (x2 - x1),
135+
y1 + t * (y2 - y1)
136+
);
137+
}
138+
129139
private static ContourResult generateContours(double[][] density, int numContours, double[][] originalData) {
130-
List<List<Point2D>> contourLines = new ArrayList<>();
140+
List<List<Point2D>> allContourLines = new ArrayList<>();
131141
List<Boolean> isFirstCluster = new ArrayList<>();
132142
double maxDensity = Arrays.stream(density)
133143
.flatMapToDouble(Arrays::stream)
@@ -137,62 +147,122 @@ private static ContourResult generateContours(double[][] density, int numContour
137147
// For each contour level
138148
for (int i = 1; i <= numContours; i++) {
139149
double level = maxDensity * i / (numContours + 1);
140-
List<Point2D> points = new ArrayList<>();
150+
List<List<Point2D>> contoursAtLevel = new ArrayList<>();
151+
152+
boolean[][] visited = new boolean[density.length][density[0].length];
141153

142-
// Collect all points for this contour
154+
// Scan for contour starting points
143155
for (int x = 0; x < density.length - 1; x++) {
144156
for (int y = 0; y < density[0].length - 1; y++) {
145-
boolean bl = density[x][y] >= level;
146-
boolean br = density[x+1][y] >= level;
147-
boolean tr = density[x+1][y+1] >= level;
148-
boolean tl = density[x][y+1] >= level;
157+
if (!visited[x][y]) {
158+
List<Point2D> contour = traceContour(density, level, x, y, visited);
159+
if (!contour.isEmpty()) {
160+
contoursAtLevel.add(contour);
161+
}
162+
}
163+
}
164+
}
149165

150-
int caseNum = (bl ? 1 : 0) + (br ? 2 : 0) +
151-
(tr ? 4 : 0) + (tl ? 8 : 0);
166+
// Process each separate contour at this level
167+
for (List<Point2D> contour : contoursAtLevel) {
168+
if (contour.size() >= 4) { // Filter out tiny contours
169+
Point2D centerPoint = findContourCenter(contour);
170+
boolean belongsToFirstCluster = determineCluster(centerPoint, originalData);
152171

153-
if (caseNum != 0 && caseNum != 15) {
154-
points.add(new Point2D.Double(x + 0.5, y + 0.5));
155-
}
172+
allContourLines.add(contour);
173+
isFirstCluster.add(belongsToFirstCluster);
156174
}
157175
}
176+
}
177+
178+
return new ContourResult(allContourLines, isFirstCluster);
179+
}
180+
private static List<Point2D> traceContour(double[][] density, double level,
181+
int startX, int startY,
182+
boolean[][] visited) {
183+
List<Point2D> contour = new ArrayList<>();
184+
Queue<Point2D> queue = new LinkedList<>();
185+
Set<String> visitedEdges = new HashSet<>();
186+
187+
// Initialize with start point
188+
addContourSegments(density, level, startX, startY, queue, visitedEdges);
189+
190+
while (!queue.isEmpty()) {
191+
Point2D point = queue.poll();
192+
contour.add(point);
158193

159-
if (!points.isEmpty()) {
160-
// Order points to form a continuous contour
161-
List<Point2D> orderedPoints = orderContourPoints(points);
162-
contourLines.add(orderedPoints);
194+
// Find grid cell containing this point
195+
int x = (int) Math.floor(point.getX());
196+
int y = (int) Math.floor(point.getY());
163197

164-
// Determine which cluster this contour belongs to
165-
Point2D centerPoint = findContourCenter(orderedPoints);
166-
boolean belongsToFirstCluster = determineCluster(centerPoint, originalData);
167-
isFirstCluster.add(belongsToFirstCluster);
198+
// Mark as visited
199+
if (x >= 0 && x < visited.length - 1 &&
200+
y >= 0 && y < visited[0].length - 1) {
201+
visited[x][y] = true;
202+
203+
// Add adjacent segments
204+
addContourSegments(density, level, x, y, queue, visitedEdges);
168205
}
169206
}
170207

171-
return new ContourResult(contourLines, isFirstCluster);
208+
return contour;
172209
}
173-
private static List<Point2D> orderContourPoints(List<Point2D> points) {
174-
List<Point2D> ordered = new ArrayList<>();
175-
Set<Point2D> remaining = new HashSet<>(points);
176210

177-
// Start with the leftmost point
178-
Point2D current = points.stream()
179-
.min(Comparator.comparingDouble(Point2D::getX))
180-
.orElseThrow();
181-
ordered.add(current);
182-
remaining.remove(current);
211+
private static void addContourSegments(double[][] density, double level,
212+
int x, int y,
213+
Queue<Point2D> queue,
214+
Set<String> visitedEdges) {
215+
if (x < 0 || x >= density.length - 1 ||
216+
y < 0 || y >= density[0].length - 1) {
217+
return;
218+
}
183219

184-
while (!remaining.isEmpty()) {
185-
Point2D finalCurrent = current;
186-
current = remaining.stream()
187-
.min(Comparator.comparingDouble(p ->
188-
finalCurrent.distance(p)))
189-
.orElseThrow();
220+
double v00 = density[x][y];
221+
double v10 = density[x+1][y];
222+
double v11 = density[x+1][y+1];
223+
double v01 = density[x][y+1];
190224

191-
ordered.add(current);
192-
remaining.remove(current);
225+
// For each edge of the cell
226+
List<Point2D> intersections = new ArrayList<>();
227+
228+
// Bottom edge
229+
if ((v00 < level && v10 >= level) || (v00 >= level && v10 < level)) {
230+
String edge = String.format("%d,%d,B", x, y);
231+
if (!visitedEdges.contains(edge)) {
232+
intersections.add(interpolate(x, y, v00, x+1, y, v10, level));
233+
visitedEdges.add(edge);
234+
}
193235
}
194236

195-
return ordered;
237+
// Right edge
238+
if ((v10 < level && v11 >= level) || (v10 >= level && v11 < level)) {
239+
String edge = String.format("%d,%d,R", x+1, y);
240+
if (!visitedEdges.contains(edge)) {
241+
intersections.add(interpolate(x+1, y, v10, x+1, y+1, v11, level));
242+
visitedEdges.add(edge);
243+
}
244+
}
245+
246+
// Top edge
247+
if ((v01 < level && v11 >= level) || (v01 >= level && v11 < level)) {
248+
String edge = String.format("%d,%d,T", x, y+1);
249+
if (!visitedEdges.contains(edge)) {
250+
intersections.add(interpolate(x, y+1, v01, x+1, y+1, v11, level));
251+
visitedEdges.add(edge);
252+
}
253+
}
254+
255+
// Left edge
256+
if ((v00 < level && v01 >= level) || (v00 >= level && v01 < level)) {
257+
String edge = String.format("%d,%d,L", x, y);
258+
if (!visitedEdges.contains(edge)) {
259+
intersections.add(interpolate(x, y, v00, x, y+1, v01, level));
260+
visitedEdges.add(edge);
261+
}
262+
}
263+
264+
// Add all found intersections to the queue
265+
queue.addAll(intersections);
196266
}
197267

198268
private static Point2D findContourCenter(List<Point2D> points) {
@@ -217,6 +287,31 @@ private static boolean determineCluster(Point2D point, double[][] originalData)
217287
return dist1 < dist2;
218288
}
219289

290+
private static List<Point2D> orderContourPoints(List<Point2D> points) {
291+
List<Point2D> ordered = new ArrayList<>();
292+
Set<Point2D> remaining = new HashSet<>(points);
293+
294+
// Start with the leftmost point
295+
Point2D current = points.stream()
296+
.min(Comparator.comparingDouble(Point2D::getX))
297+
.orElseThrow();
298+
ordered.add(current);
299+
remaining.remove(current);
300+
301+
while (!remaining.isEmpty()) {
302+
Point2D finalCurrent = current;
303+
current = remaining.stream()
304+
.min(Comparator.comparingDouble(p ->
305+
finalCurrent.distance(p)))
306+
.orElseThrow();
307+
308+
ordered.add(current);
309+
remaining.remove(current);
310+
}
311+
312+
return ordered;
313+
}
314+
220315
private static double calculateSD(double[][] data, int dimension) {
221316
double mean = 0;
222317
for (double[] point : data) {
@@ -253,31 +348,34 @@ private static double[][] generateSampleData() {
253348

254349
return data;
255350
}
256-
351+
257352
private static ChartPanel createChartPanel(double[][] data) {
258-
XYDataset dataset = new ContourDataset(data, 50, 8);
259-
353+
ContourDataset dataset = new ContourDataset(data, 50, 8);
354+
260355
JFreeChart chart = ChartFactory.createXYLineChart(
261-
"2D KDE Contour Plot",
262-
"X",
263-
"Y",
264-
dataset
356+
"2D KDE Contour Plot",
357+
"X",
358+
"Y",
359+
dataset
265360
);
266-
361+
267362
XYPlot plot = chart.getXYPlot();
268363
XYLineAndShapeRenderer renderer = new XYLineAndShapeRenderer();
269-
270-
// Style each contour line differently
364+
365+
// Style contours based on cluster
366+
Color cluster1Color = Color.BLUE;
367+
Color cluster2Color = Color.ORANGE;
368+
271369
for (int i = 0; i < dataset.getSeriesCount(); i++) {
272370
renderer.setSeriesLinesVisible(i, true);
273371
renderer.setSeriesShapesVisible(i, false);
274-
float hue = (float)i / dataset.getSeriesCount();
275-
renderer.setSeriesPaint(i, Color.getHSBColor(hue, 0.8f, 0.8f));
372+
renderer.setSeriesPaint(i, dataset.isFirstCluster(i) ?
373+
cluster1Color : cluster2Color);
276374
renderer.setSeriesStroke(i, new BasicStroke(2.0f));
277375
}
278-
376+
279377
plot.setRenderer(renderer);
280-
378+
281379
return new ChartPanel(chart);
282380
}
283-
}
381+
}

0 commit comments

Comments
 (0)