public class DisjointSet extends TreeNode {

    /** create a node in the set */
    protected DisjointSet(Comparable x) {
	super(x, null, null, null);
	rank = 0;
    }
    /** the rank of this tree node in the set */
    protected int rank; 

    /** return the rank */
    final int getRank() { return rank; }

    /** set the rank */
    final void setRank(int r) {	rank = r; }

    /** @see TreeNode#toString */
    public String toString() { 
	return super.toString() + " (" + rank + ")";
    }

    /** create a singleton set containing x as representative */
    public static DisjointSet makeSet(Comparable x) {
	return new DisjointSet(x);
    }

    /** union two sets */
    public static DisjointSet union(DisjointSet x, DisjointSet y) {
	return link(findSet(x), findSet(y));
    }

    /** link two set nodes using their ranks */
    public static DisjointSet link(DisjointSet x, DisjointSet y) {
	if (x.getRank() > y.getRank()) {
	    y.setParent(x);
	    return x;
	} else {
	    x.setParent(y);
	    if (x.getRank() == y.getRank()) {
		y.setRank(y.getRank() + 1);
	    }
	    return y;
	}
    }

    /** find the representative of a set node; do path compression */
    public static DisjointSet findSet(DisjointSet x) {
	if (x != x.getParent()) {
	    x.setParent(DisjointSet.findSet((DisjointSet) x.getParent()));
	}
	return (DisjointSet)x.getParent();
    }

    /** testing ... */
    public static void main(String[] args) {
	DisjointSet s1 = DisjointSet.makeSet(new Integer(1));
	DisjointSet s2 = DisjointSet.makeSet(new Integer(2));
	DisjointSet s3 = DisjointSet.makeSet(new Integer(3));
	DisjointSet s4 = DisjointSet.union(DisjointSet.union(s1, s2), s3);
	System.out.println(DisjointSet.findSet(s4));
    }
}
