66import org .jfree .chart .plot .XYPlot ;
77import org .jfree .chart .renderer .xy .XYLineAndShapeRenderer ;
88import org .jfree .data .xy .AbstractXYDataset ;
9- import org .jfree .data .xy .XYDataset ;
109
1110import javax .swing .*;
1211import java .awt .*;
@@ -30,12 +29,10 @@ public static void main(String[] args) {
3029
3130 private static class ContourDataset extends AbstractXYDataset {
3231 private final List <List <Point2D >> contourLines ;
33- private final List <Boolean > isFirstCluster ; // Track which cluster each contour belongs to
32+ private final List <Boolean > isFirstCluster ;
3433
3534 public ContourDataset (double [][] data , int gridSize , int numContours ) {
36- // Calculate KDE
3735 double [][] density = calculateKDE (data , gridSize );
38- // Generate contour lines with cluster information
3936 var result = generateContours (density , numContours , data );
4037 this .contourLines = result .contourLines ;
4138 this .isFirstCluster = result .isFirstCluster ;
@@ -49,28 +46,28 @@ public boolean isFirstCluster(int series) {
4946 public int getSeriesCount () {
5047 return contourLines .size ();
5148 }
52-
49+
5350 @ Override
5451 public Comparable getSeriesKey (int series ) {
5552 return "Contour " + series ;
5653 }
57-
54+
5855 @ Override
5956 public int getItemCount (int series ) {
6057 return contourLines .get (series ).size ();
6158 }
62-
59+
6360 @ Override
6461 public Number getX (int series , int item ) {
6562 return contourLines .get (series ).get (item ).getX ();
6663 }
67-
64+
6865 @ Override
6966 public Number getY (int series , int item ) {
7067 return contourLines .get (series ).get (item ).getY ();
7168 }
7269 }
73-
70+
7471 private static double [][] calculateKDE (double [][] data , int gridSize ) {
7572 // Find data bounds
7673 double minX = Double .POSITIVE_INFINITY , maxX = Double .NEGATIVE_INFINITY ;
@@ -126,8 +123,21 @@ private static class ContourResult {
126123 }
127124 }
128125
126+ private static Point2D interpolate (double x1 , double y1 , double v1 ,
127+ double x2 , double y2 , double v2 ,
128+ double level ) {
129+ if (Math .abs (v1 - v2 ) < 1e-10 ) {
130+ return new Point2D .Double (x1 , y1 );
131+ }
132+ double t = (level - v1 ) / (v2 - v1 );
133+ return new Point2D .Double (
134+ x1 + t * (x2 - x1 ),
135+ y1 + t * (y2 - y1 )
136+ );
137+ }
138+
129139 private static ContourResult generateContours (double [][] density , int numContours , double [][] originalData ) {
130- List <List <Point2D >> contourLines = new ArrayList <>();
140+ List <List <Point2D >> allContourLines = new ArrayList <>();
131141 List <Boolean > isFirstCluster = new ArrayList <>();
132142 double maxDensity = Arrays .stream (density )
133143 .flatMapToDouble (Arrays ::stream )
@@ -137,62 +147,122 @@ private static ContourResult generateContours(double[][] density, int numContour
137147 // For each contour level
138148 for (int i = 1 ; i <= numContours ; i ++) {
139149 double level = maxDensity * i / (numContours + 1 );
140- List <Point2D > points = new ArrayList <>();
150+ List <List <Point2D >> contoursAtLevel = new ArrayList <>();
151+
152+ boolean [][] visited = new boolean [density .length ][density [0 ].length ];
141153
142- // Collect all points for this contour
154+ // Scan for contour starting points
143155 for (int x = 0 ; x < density .length - 1 ; x ++) {
144156 for (int y = 0 ; y < density [0 ].length - 1 ; y ++) {
145- boolean bl = density [x ][y ] >= level ;
146- boolean br = density [x +1 ][y ] >= level ;
147- boolean tr = density [x +1 ][y +1 ] >= level ;
148- boolean tl = density [x ][y +1 ] >= level ;
157+ if (!visited [x ][y ]) {
158+ List <Point2D > contour = traceContour (density , level , x , y , visited );
159+ if (!contour .isEmpty ()) {
160+ contoursAtLevel .add (contour );
161+ }
162+ }
163+ }
164+ }
149165
150- int caseNum = (bl ? 1 : 0 ) + (br ? 2 : 0 ) +
151- (tr ? 4 : 0 ) + (tl ? 8 : 0 );
166+ // Process each separate contour at this level
167+ for (List <Point2D > contour : contoursAtLevel ) {
168+ if (contour .size () >= 4 ) { // Filter out tiny contours
169+ Point2D centerPoint = findContourCenter (contour );
170+ boolean belongsToFirstCluster = determineCluster (centerPoint , originalData );
152171
153- if (caseNum != 0 && caseNum != 15 ) {
154- points .add (new Point2D .Double (x + 0.5 , y + 0.5 ));
155- }
172+ allContourLines .add (contour );
173+ isFirstCluster .add (belongsToFirstCluster );
156174 }
157175 }
176+ }
177+
178+ return new ContourResult (allContourLines , isFirstCluster );
179+ }
180+ private static List <Point2D > traceContour (double [][] density , double level ,
181+ int startX , int startY ,
182+ boolean [][] visited ) {
183+ List <Point2D > contour = new ArrayList <>();
184+ Queue <Point2D > queue = new LinkedList <>();
185+ Set <String > visitedEdges = new HashSet <>();
186+
187+ // Initialize with start point
188+ addContourSegments (density , level , startX , startY , queue , visitedEdges );
189+
190+ while (!queue .isEmpty ()) {
191+ Point2D point = queue .poll ();
192+ contour .add (point );
158193
159- if (!points .isEmpty ()) {
160- // Order points to form a continuous contour
161- List <Point2D > orderedPoints = orderContourPoints (points );
162- contourLines .add (orderedPoints );
194+ // Find grid cell containing this point
195+ int x = (int ) Math .floor (point .getX ());
196+ int y = (int ) Math .floor (point .getY ());
163197
164- // Determine which cluster this contour belongs to
165- Point2D centerPoint = findContourCenter (orderedPoints );
166- boolean belongsToFirstCluster = determineCluster (centerPoint , originalData );
167- isFirstCluster .add (belongsToFirstCluster );
198+ // Mark as visited
199+ if (x >= 0 && x < visited .length - 1 &&
200+ y >= 0 && y < visited [0 ].length - 1 ) {
201+ visited [x ][y ] = true ;
202+
203+ // Add adjacent segments
204+ addContourSegments (density , level , x , y , queue , visitedEdges );
168205 }
169206 }
170207
171- return new ContourResult ( contourLines , isFirstCluster ) ;
208+ return contour ;
172209 }
173- private static List <Point2D > orderContourPoints (List <Point2D > points ) {
174- List <Point2D > ordered = new ArrayList <>();
175- Set <Point2D > remaining = new HashSet <>(points );
176210
177- // Start with the leftmost point
178- Point2D current = points .stream ()
179- .min (Comparator .comparingDouble (Point2D ::getX ))
180- .orElseThrow ();
181- ordered .add (current );
182- remaining .remove (current );
211+ private static void addContourSegments (double [][] density , double level ,
212+ int x , int y ,
213+ Queue <Point2D > queue ,
214+ Set <String > visitedEdges ) {
215+ if (x < 0 || x >= density .length - 1 ||
216+ y < 0 || y >= density [0 ].length - 1 ) {
217+ return ;
218+ }
183219
184- while (!remaining .isEmpty ()) {
185- Point2D finalCurrent = current ;
186- current = remaining .stream ()
187- .min (Comparator .comparingDouble (p ->
188- finalCurrent .distance (p )))
189- .orElseThrow ();
220+ double v00 = density [x ][y ];
221+ double v10 = density [x +1 ][y ];
222+ double v11 = density [x +1 ][y +1 ];
223+ double v01 = density [x ][y +1 ];
190224
191- ordered .add (current );
192- remaining .remove (current );
225+ // For each edge of the cell
226+ List <Point2D > intersections = new ArrayList <>();
227+
228+ // Bottom edge
229+ if ((v00 < level && v10 >= level ) || (v00 >= level && v10 < level )) {
230+ String edge = String .format ("%d,%d,B" , x , y );
231+ if (!visitedEdges .contains (edge )) {
232+ intersections .add (interpolate (x , y , v00 , x +1 , y , v10 , level ));
233+ visitedEdges .add (edge );
234+ }
193235 }
194236
195- return ordered ;
237+ // Right edge
238+ if ((v10 < level && v11 >= level ) || (v10 >= level && v11 < level )) {
239+ String edge = String .format ("%d,%d,R" , x +1 , y );
240+ if (!visitedEdges .contains (edge )) {
241+ intersections .add (interpolate (x +1 , y , v10 , x +1 , y +1 , v11 , level ));
242+ visitedEdges .add (edge );
243+ }
244+ }
245+
246+ // Top edge
247+ if ((v01 < level && v11 >= level ) || (v01 >= level && v11 < level )) {
248+ String edge = String .format ("%d,%d,T" , x , y +1 );
249+ if (!visitedEdges .contains (edge )) {
250+ intersections .add (interpolate (x , y +1 , v01 , x +1 , y +1 , v11 , level ));
251+ visitedEdges .add (edge );
252+ }
253+ }
254+
255+ // Left edge
256+ if ((v00 < level && v01 >= level ) || (v00 >= level && v01 < level )) {
257+ String edge = String .format ("%d,%d,L" , x , y );
258+ if (!visitedEdges .contains (edge )) {
259+ intersections .add (interpolate (x , y , v00 , x , y +1 , v01 , level ));
260+ visitedEdges .add (edge );
261+ }
262+ }
263+
264+ // Add all found intersections to the queue
265+ queue .addAll (intersections );
196266 }
197267
198268 private static Point2D findContourCenter (List <Point2D > points ) {
@@ -217,6 +287,31 @@ private static boolean determineCluster(Point2D point, double[][] originalData)
217287 return dist1 < dist2 ;
218288 }
219289
290+ private static List <Point2D > orderContourPoints (List <Point2D > points ) {
291+ List <Point2D > ordered = new ArrayList <>();
292+ Set <Point2D > remaining = new HashSet <>(points );
293+
294+ // Start with the leftmost point
295+ Point2D current = points .stream ()
296+ .min (Comparator .comparingDouble (Point2D ::getX ))
297+ .orElseThrow ();
298+ ordered .add (current );
299+ remaining .remove (current );
300+
301+ while (!remaining .isEmpty ()) {
302+ Point2D finalCurrent = current ;
303+ current = remaining .stream ()
304+ .min (Comparator .comparingDouble (p ->
305+ finalCurrent .distance (p )))
306+ .orElseThrow ();
307+
308+ ordered .add (current );
309+ remaining .remove (current );
310+ }
311+
312+ return ordered ;
313+ }
314+
220315 private static double calculateSD (double [][] data , int dimension ) {
221316 double mean = 0 ;
222317 for (double [] point : data ) {
@@ -253,31 +348,34 @@ private static double[][] generateSampleData() {
253348
254349 return data ;
255350 }
256-
351+
257352 private static ChartPanel createChartPanel (double [][] data ) {
258- XYDataset dataset = new ContourDataset (data , 50 , 8 );
259-
353+ ContourDataset dataset = new ContourDataset (data , 50 , 8 );
354+
260355 JFreeChart chart = ChartFactory .createXYLineChart (
261- "2D KDE Contour Plot" ,
262- "X" ,
263- "Y" ,
264- dataset
356+ "2D KDE Contour Plot" ,
357+ "X" ,
358+ "Y" ,
359+ dataset
265360 );
266-
361+
267362 XYPlot plot = chart .getXYPlot ();
268363 XYLineAndShapeRenderer renderer = new XYLineAndShapeRenderer ();
269-
270- // Style each contour line differently
364+
365+ // Style contours based on cluster
366+ Color cluster1Color = Color .BLUE ;
367+ Color cluster2Color = Color .ORANGE ;
368+
271369 for (int i = 0 ; i < dataset .getSeriesCount (); i ++) {
272370 renderer .setSeriesLinesVisible (i , true );
273371 renderer .setSeriesShapesVisible (i , false );
274- float hue = ( float ) i / dataset .getSeriesCount ();
275- renderer . setSeriesPaint ( i , Color . getHSBColor ( hue , 0.8f , 0.8f ) );
372+ renderer . setSeriesPaint ( i , dataset .isFirstCluster ( i ) ?
373+ cluster1Color : cluster2Color );
276374 renderer .setSeriesStroke (i , new BasicStroke (2.0f ));
277375 }
278-
376+
279377 plot .setRenderer (renderer );
280-
378+
281379 return new ChartPanel (chart );
282380 }
283- }
381+ }
0 commit comments