import java.util.*; public class UnionFind{ public static class Node{ public V value; public Node(V v){ value= v; } } private HashMap,Node> parents; private HashMap> nodes; private HashMap,Integer> sizeMap; public UnionFind(List arr){ parents = new HashMap,Node>(); nodes = new HashMap>(); sizeMap = new HashMap,Integer>(); for (V cur : arr ) { Node n = new Node(cur); nodes.put(cur,n); parents.put(n,n); sizeMap.put(n,1); } } private Node Find(V node){ Stack> stack = new Stack>(); Node cur = nodes.get(node); while(cur != parents.get(cur)){ stack.push(cur); cur = parents.get(cur); } while(!stack.isEmpty()){ parents.put(stack.pop(),cur); } return cur; } public void Union(V a,V b){ Node fa = Find(a); Node fb = Find(b); if(fa != fb){ int sa = sizeMap.get(fa); int sb = sizeMap.get(fb); Node big = sa>=sb ? fa : fb; Node small = big==fa ? fb : fa; parents.put(small,big); sizeMap.remove(small); } } public int size(){ return sizeMap.size(); } public static void main(String[] args){ ArrayList arr = new ArrayList(); arr.add(1); arr.add(11); arr.add(111); arr.add(1111); arr.add(21); arr.add(31); arr.add(41); UnionFind unionFind = new UnionFind(arr); System.out.println("size ="+ unionFind.size()); unionFind.Union(1,11); System.out.println("size ="+ unionFind.size()); unionFind.Union(1111,11); System.out.println("size ="+ unionFind.size()); unionFind.Union(1,1111); System.out.println("size ="+ unionFind.size()); } }