Skip to content

Commit b775936

Browse files
committed
feat: Add Disjoint Set(Union Find)
1 parent 9c766a6 commit b775936

File tree

3 files changed

+287
-0
lines changed

3 files changed

+287
-0
lines changed

README.md

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
- [BloomFilter](#bloom-filter)
2929
- [RingBuffer(Circular Buffer)](#ring-buffer)
3030
- [SegmentTree](#segment-tree)
31+
- [DisjointSet(UnionFind)](#disjoint-set)
3132
4. [License](#license)
3233

3334
## [Installation](#installation)
@@ -578,6 +579,85 @@ minSt := collections.NewSegmentTree(arr, math.Inf(1), func(a, b float64) float64
578579
- Statistical range queries
579580
- Competitive programming
580581
- Database query optimization
582+
---
583+
### [Disjoint Set](#disjoint-set)
584+
585+
A Disjoint Set (also known as Union-Find) is a data structure that keeps track of elements partitioned into non-overlapping subsets. It provides near-constant-time operations to merge sets and determine if two elements belong to the same set.
586+
587+
#### Type `DisjointSet[T comparable]`
588+
589+
- **Constructor:**
590+
591+
```go
592+
func New[T comparable]() *DisjointSet[T]
593+
```
594+
595+
- **Methods:**
596+
597+
- `MakeSet(x T)`: Creates a new set containing a single element.
598+
- `Find(x T) T`: Returns the representative element of the set containing x.
599+
- `Union(x, y T)`: Merges the sets containing elements x and y.
600+
- `Connected(x, y T) bool`: Returns true if elements x and y are in the same set.
601+
- `Clear()`: Removes all elements from the disjoint set.
602+
- `Len() int`: Returns the number of elements in the disjoint set.
603+
- `IsEmpty() bool`: Returns true if the disjoint set contains no elements.
604+
- `GetSets() map[T][]T`: Returns a map of representatives to their set members.
605+
606+
#### Example Usage:
607+
608+
```go
609+
package main
581610

611+
import (
612+
"fmt"
613+
"github.com/idsulik/go-collections/v2/disjointset"
614+
)
615+
616+
func main() {
617+
// Create a new disjoint set
618+
ds := disjointset.New[string]()
619+
620+
// Create individual sets
621+
ds.MakeSet("A")
622+
ds.MakeSet("B")
623+
ds.MakeSet("C")
624+
ds.MakeSet("D")
625+
626+
// Merge sets
627+
ds.Union("A", "B")
628+
ds.Union("C", "D")
629+
630+
// Check if elements are in the same set
631+
fmt.Println(ds.Connected("A", "B")) // true
632+
fmt.Println(ds.Connected("A", "C")) // false
633+
634+
// Get all sets
635+
sets := ds.GetSets()
636+
for root, elements := range sets {
637+
fmt.Printf("Set with root %v: %v\n", root, elements)
638+
}
639+
}
640+
```
641+
642+
#### Performance Characteristics:
643+
644+
- MakeSet: O(1)
645+
- Find: O(α(n)) amortized (nearly constant)
646+
- Union: O(α(n)) amortized (nearly constant)
647+
- Connected: O(α(n)) amortized (nearly constant)
648+
649+
Where α(n) is the inverse Ackermann function, which grows extremely slowly and is effectively constant for all practical values of n.
650+
651+
#### Use Cases:
652+
653+
- Detecting cycles in graphs
654+
- Finding connected components
655+
- Network connectivity
656+
- Image processing (connected component labeling)
657+
- Kruskal's minimum spanning tree algorithm
658+
- Dynamic connectivity problems
659+
- Online dynamic connectivity
660+
- Percolation analysis
661+
---
582662
## [License](#license)
583663
This project is licensed under the [MIT License](LICENSE) - see the [LICENSE](LICENSE) file for details.

disjointset/disjointset.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package disjointset
2+
3+
// DisjointSet represents a disjoint set data structure
4+
type DisjointSet[T comparable] struct {
5+
parent map[T]T
6+
rank map[T]int
7+
}
8+
9+
// New creates a new DisjointSet instance
10+
func New[T comparable]() *DisjointSet[T] {
11+
return &DisjointSet[T]{
12+
parent: make(map[T]T),
13+
rank: make(map[T]int),
14+
}
15+
}
16+
17+
// MakeSet creates a new set containing a single element
18+
func (ds *DisjointSet[T]) MakeSet(x T) {
19+
if _, exists := ds.parent[x]; !exists {
20+
ds.parent[x] = x
21+
ds.rank[x] = 0
22+
}
23+
}
24+
25+
// Find returns the representative element of the set containing x
26+
// Uses path compression for optimization
27+
func (ds *DisjointSet[T]) Find(x T) T {
28+
if _, exists := ds.parent[x]; !exists {
29+
return x
30+
}
31+
32+
if ds.parent[x] != x {
33+
ds.parent[x] = ds.Find(ds.parent[x]) // Path compression
34+
}
35+
return ds.parent[x]
36+
}
37+
38+
// Union merges the sets containing elements x and y
39+
// Uses union by rank for optimization
40+
func (ds *DisjointSet[T]) Union(x, y T) {
41+
rootX := ds.Find(x)
42+
rootY := ds.Find(y)
43+
44+
if rootX == rootY {
45+
return
46+
}
47+
48+
// Union by rank
49+
if ds.rank[rootX] < ds.rank[rootY] {
50+
ds.parent[rootX] = rootY
51+
} else if ds.rank[rootX] > ds.rank[rootY] {
52+
ds.parent[rootY] = rootX
53+
} else {
54+
ds.parent[rootY] = rootX
55+
ds.rank[rootX]++
56+
}
57+
}
58+
59+
// Connected returns true if elements x and y are in the same set
60+
func (ds *DisjointSet[T]) Connected(x, y T) bool {
61+
return ds.Find(x) == ds.Find(y)
62+
}
63+
64+
// Clear removes all elements from the disjoint set
65+
func (ds *DisjointSet[T]) Clear() {
66+
ds.parent = make(map[T]T)
67+
ds.rank = make(map[T]int)
68+
}
69+
70+
// Len returns the number of elements in the disjoint set
71+
func (ds *DisjointSet[T]) Len() int {
72+
return len(ds.parent)
73+
}
74+
75+
// IsEmpty returns true if the disjoint set contains no elements
76+
func (ds *DisjointSet[T]) IsEmpty() bool {
77+
return len(ds.parent) == 0
78+
}
79+
80+
// GetSets returns a map of representatives to their set members
81+
func (ds *DisjointSet[T]) GetSets() map[T][]T {
82+
sets := make(map[T][]T)
83+
for element := range ds.parent {
84+
root := ds.Find(element)
85+
sets[root] = append(sets[root], element)
86+
}
87+
return sets
88+
}

disjointset/disjointset_test.go

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
package disjointset
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestDisjointSet(t *testing.T) {
8+
t.Run(
9+
"New DisjointSet", func(t *testing.T) {
10+
ds := New[int]()
11+
if !ds.IsEmpty() {
12+
t.Error("New DisjointSet should be empty")
13+
}
14+
},
15+
)
16+
17+
t.Run(
18+
"MakeSet", func(t *testing.T) {
19+
ds := New[int]()
20+
ds.MakeSet(1)
21+
if ds.Find(1) != 1 {
22+
t.Error("MakeSet should create a set with the element as its own representative")
23+
}
24+
},
25+
)
26+
27+
t.Run(
28+
"Union and Find", func(t *testing.T) {
29+
ds := New[int]()
30+
ds.MakeSet(1)
31+
ds.MakeSet(2)
32+
ds.MakeSet(3)
33+
34+
ds.Union(1, 2)
35+
if !ds.Connected(1, 2) {
36+
t.Error("Elements 1 and 2 should be connected after Union")
37+
}
38+
39+
ds.Union(2, 3)
40+
if !ds.Connected(1, 3) {
41+
t.Error("Elements 1 and 3 should be connected after Union")
42+
}
43+
},
44+
)
45+
46+
t.Run(
47+
"Connected", func(t *testing.T) {
48+
ds := New[string]()
49+
ds.MakeSet("A")
50+
ds.MakeSet("B")
51+
ds.MakeSet("C")
52+
53+
if ds.Connected("A", "B") {
54+
t.Error("Elements should not be connected before Union")
55+
}
56+
57+
ds.Union("A", "B")
58+
if !ds.Connected("A", "B") {
59+
t.Error("Elements should be connected after Union")
60+
}
61+
},
62+
)
63+
64+
t.Run(
65+
"Clear", func(t *testing.T) {
66+
ds := New[int]()
67+
ds.MakeSet(1)
68+
ds.MakeSet(2)
69+
ds.Union(1, 2)
70+
71+
ds.Clear()
72+
if !ds.IsEmpty() {
73+
t.Error("DisjointSet should be empty after Clear")
74+
}
75+
},
76+
)
77+
78+
t.Run(
79+
"GetSets", func(t *testing.T) {
80+
ds := New[int]()
81+
ds.MakeSet(1)
82+
ds.MakeSet(2)
83+
ds.MakeSet(3)
84+
ds.MakeSet(4)
85+
86+
ds.Union(1, 2)
87+
ds.Union(3, 4)
88+
89+
sets := ds.GetSets()
90+
if len(sets) != 2 {
91+
t.Error("Should have exactly 2 distinct sets")
92+
}
93+
94+
for _, set := range sets {
95+
if len(set) != 2 {
96+
t.Error("Each set should contain exactly 2 elements")
97+
}
98+
}
99+
},
100+
)
101+
102+
t.Run(
103+
"Path Compression", func(t *testing.T) {
104+
ds := New[int]()
105+
ds.MakeSet(1)
106+
ds.MakeSet(2)
107+
ds.MakeSet(3)
108+
109+
ds.Union(1, 2)
110+
ds.Union(2, 3)
111+
112+
// After finding 3, the path should be compressed
113+
root := ds.Find(3)
114+
if ds.parent[3] != root {
115+
t.Error("Path compression should make 3 point directly to the root")
116+
}
117+
},
118+
)
119+
}

0 commit comments

Comments
 (0)