Skip to content

Commit 0e02ff0

Browse files
committed
Add an almost-working multiple bivariate KDE plot
1 parent 2e1a105 commit 0e02ff0

File tree

1 file changed

+214
-0
lines changed

1 file changed

+214
-0
lines changed
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
package org.scijava.ui.swing.plot;
2+
3+
import org.jfree.chart.ChartFactory;
4+
import org.jfree.chart.ChartPanel;
5+
import org.jfree.chart.JFreeChart;
6+
import org.jfree.chart.plot.XYPlot;
7+
import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
8+
import org.jfree.data.xy.AbstractXYDataset;
9+
import org.jfree.data.xy.XYDataset;
10+
11+
import javax.swing.*;
12+
import java.awt.*;
13+
import java.awt.geom.Point2D;
14+
import java.util.*;
15+
import java.util.List;
16+
17+
public class KDEContourPlot {
18+
public static void main(String[] args) {
19+
// Generate sample data
20+
double[][] data = generateSampleData();
21+
22+
// Create and display the plot
23+
JFrame frame = new JFrame("2D KDE Contour Plot");
24+
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
25+
frame.getContentPane().add(createChartPanel(data));
26+
frame.pack();
27+
frame.setSize(800, 600);
28+
frame.setVisible(true);
29+
}
30+
31+
private static class ContourDataset extends AbstractXYDataset {
32+
private final List<List<Point2D>> contourLines;
33+
34+
public ContourDataset(double[][] data, int gridSize, int numContours) {
35+
// Calculate KDE
36+
double[][] density = calculateKDE(data, gridSize);
37+
// Generate contour lines
38+
this.contourLines = generateContours(density, numContours);
39+
}
40+
41+
@Override
42+
public int getSeriesCount() {
43+
return contourLines.size();
44+
}
45+
46+
@Override
47+
public Comparable getSeriesKey(int series) {
48+
return "Contour " + series;
49+
}
50+
51+
@Override
52+
public int getItemCount(int series) {
53+
return contourLines.get(series).size();
54+
}
55+
56+
@Override
57+
public Number getX(int series, int item) {
58+
return contourLines.get(series).get(item).getX();
59+
}
60+
61+
@Override
62+
public Number getY(int series, int item) {
63+
return contourLines.get(series).get(item).getY();
64+
}
65+
}
66+
67+
private static double[][] calculateKDE(double[][] data, int gridSize) {
68+
// Find data bounds
69+
double minX = Double.POSITIVE_INFINITY, maxX = Double.NEGATIVE_INFINITY;
70+
double minY = Double.POSITIVE_INFINITY, maxY = Double.NEGATIVE_INFINITY;
71+
for (double[] point : data) {
72+
minX = Math.min(minX, point[0]);
73+
maxX = Math.max(maxX, point[0]);
74+
minY = Math.min(minY, point[1]);
75+
maxY = Math.max(maxY, point[1]);
76+
}
77+
78+
// Add padding
79+
double padX = (maxX - minX) * 0.1;
80+
double padY = (maxY - minY) * 0.1;
81+
minX -= padX; maxX += padX;
82+
minY -= padY; maxY += padY;
83+
84+
// Calculate bandwidth using Silverman's rule
85+
double sdX = calculateSD(data, 0);
86+
double sdY = calculateSD(data, 1);
87+
double n = data.length;
88+
double bandwidthX = 1.06 * sdX * Math.pow(n, -0.2);
89+
double bandwidthY = 1.06 * sdY * Math.pow(n, -0.2);
90+
91+
// Calculate KDE on grid
92+
double[][] density = new double[gridSize][gridSize];
93+
for (int i = 0; i < gridSize; i++) {
94+
for (int j = 0; j < gridSize; j++) {
95+
double x = minX + (maxX - minX) * i / (gridSize - 1);
96+
double y = minY + (maxY - minY) * j / (gridSize - 1);
97+
98+
double sum = 0;
99+
for (double[] point : data) {
100+
double zx = (x - point[0]) / bandwidthX;
101+
double zy = (y - point[1]) / bandwidthY;
102+
sum += Math.exp(-0.5 * (zx * zx + zy * zy)) /
103+
(2 * Math.PI * bandwidthX * bandwidthY);
104+
}
105+
density[i][j] = sum / n;
106+
}
107+
}
108+
109+
return density;
110+
}
111+
112+
private static List<List<Point2D>> generateContours(double[][] density, int numContours) {
113+
List<List<Point2D>> contourLines = new ArrayList<>();
114+
double maxDensity = Arrays.stream(density)
115+
.flatMapToDouble(Arrays::stream)
116+
.max()
117+
.orElse(1.0);
118+
119+
// For each contour level
120+
for (int i = 1; i <= numContours; i++) {
121+
double level = maxDensity * i / (numContours + 1);
122+
List<Point2D> contourLine = new ArrayList<>();
123+
124+
// Simple marching squares implementation
125+
for (int x = 0; x < density.length - 1; x++) {
126+
for (int y = 0; y < density[0].length - 1; y++) {
127+
// Check if contour passes through this cell
128+
boolean bl = density[x][y] >= level;
129+
boolean br = density[x+1][y] >= level;
130+
boolean tr = density[x+1][y+1] >= level;
131+
boolean tl = density[x][y+1] >= level;
132+
133+
int caseNum = (bl ? 1 : 0) + (br ? 2 : 0) +
134+
(tr ? 4 : 0) + (tl ? 8 : 0);
135+
136+
if (caseNum != 0 && caseNum != 15) {
137+
// Add interpolated points for this cell
138+
contourLine.add(new Point2D.Double(x + 0.5, y + 0.5));
139+
}
140+
}
141+
}
142+
143+
if (!contourLine.isEmpty()) {
144+
contourLines.add(contourLine);
145+
}
146+
}
147+
148+
return contourLines;
149+
}
150+
151+
private static double calculateSD(double[][] data, int dimension) {
152+
double mean = 0;
153+
for (double[] point : data) {
154+
mean += point[dimension];
155+
}
156+
mean /= data.length;
157+
158+
double variance = 0;
159+
for (double[] point : data) {
160+
double diff = point[dimension] - mean;
161+
variance += diff * diff;
162+
}
163+
variance /= (data.length - 1);
164+
165+
return Math.sqrt(variance);
166+
}
167+
168+
private static double[][] generateSampleData() {
169+
Random rand = new Random(42);
170+
int n = 1000;
171+
double[][] data = new double[n][2];
172+
173+
for (int i = 0; i < n; i++) {
174+
if (rand.nextDouble() < 0.6) {
175+
// First cluster
176+
data[i][0] = rand.nextGaussian() * 0.5 + 2;
177+
data[i][1] = rand.nextGaussian() * 0.5 + 2;
178+
} else {
179+
// Second cluster
180+
data[i][0] = rand.nextGaussian() * 0.3 + 4;
181+
data[i][1] = rand.nextGaussian() * 0.3 + 4;
182+
}
183+
}
184+
185+
return data;
186+
}
187+
188+
private static ChartPanel createChartPanel(double[][] data) {
189+
XYDataset dataset = new ContourDataset(data, 50, 8);
190+
191+
JFreeChart chart = ChartFactory.createXYLineChart(
192+
"2D KDE Contour Plot",
193+
"X",
194+
"Y",
195+
dataset
196+
);
197+
198+
XYPlot plot = chart.getXYPlot();
199+
XYLineAndShapeRenderer renderer = new XYLineAndShapeRenderer();
200+
201+
// Style each contour line differently
202+
for (int i = 0; i < dataset.getSeriesCount(); i++) {
203+
renderer.setSeriesLinesVisible(i, true);
204+
renderer.setSeriesShapesVisible(i, false);
205+
float hue = (float)i / dataset.getSeriesCount();
206+
renderer.setSeriesPaint(i, Color.getHSBColor(hue, 0.8f, 0.8f));
207+
renderer.setSeriesStroke(i, new BasicStroke(2.0f));
208+
}
209+
210+
plot.setRenderer(renderer);
211+
212+
return new ChartPanel(chart);
213+
}
214+
}

0 commit comments

Comments
 (0)