Union Find

class QuickFind:

    def __init__(self, n):
        """O(N)"""
        self.n = n
        self.id = dict((i, i) for i in range(0, n))

    def union(self, a, b):
        """O(N)"""
        aid = self.id[a]
        bid = self.id[b]

        for key, val in self.id.items():
            if val == aid:
                self.id[key] = bid

    def find(self, a, b):
        """O(1)"""
        return self.is_connected(a, b)

    def is_connected(self, a, b):
        return self.id[a] == self.id[b]

if __name__ == "__main__":
    
    uf = QuickFind(10)
    
    uf.id[0] = 1
    uf.id[1] = 1
    uf.id[2] = 1
    uf.id[5] = 1
    uf.id[6] = 1
    uf.id[7] = 1
    
    uf.id[3] = 8
    uf.id[4] = 8
    uf.id[8] = 8
    uf.id[9] = 8
    
    print(uf.find(1,8))
    qf.union(1,8)
    print(uf.find(1,8))
    print(uf.find(2,8))
    print(uf.id[2])
    print(uf.root(2))

class QuickUnion:

    def __init__(self, n):
        """O(N)"""
        self.n = n
        self.id = dict((i, i) for i in range(0, n))

    def root(self, a):
        if self.id[a] == a:
            return a
        else:
            while self.id[a] != a:
                a = self.id[a]
            return a

    def is_connected(self, a, b):
        return self.root(a) == self.root(b)

    def find(self, a, b):
        """O(N)"""
        return self.is_connected(a, b)

    def union(self, a, b):
        """O(N)"""
        root_a = self.root(a)
        root_b = self.root(b)
        if root_a != root_b:
            self.id[root_a] = root_b

if __name__ == "__main__":
    
    uf = QuickUnion(10)
    
    uf.id[0] = 1
    uf.id[1] = 1
    uf.id[2] = 1
    uf.id[5] = 1
    uf.id[6] = 1
    uf.id[7] = 1
    
    uf.id[3] = 8
    uf.id[4] = 8
    uf.id[8] = 8
    uf.id[9] = 8
    
    print(uf.find(1,8))
    qf.union(1,8)
    print(uf.find(1,8))
    print(uf.find(2,8))
    print(uf.id[2])
    print(uf.root(2))
class WeightedUnionFind:
    def __init__(self, n):
        self.n = n
        self.id = dict((i, i) for i in range(0, n))
        self.sz = dict((i, 1) for i in range(0, n))

    def root(self, a):
        """O(lg(N))"""
        if self.id[a] == a:
            return a
        else:
            while self.id[a] != a:
                self.id[a] = self.id[self.id[a]]
                a = self.id[a]
            return a

    def is_connected(self, a, b):
        """O(lg(N))"""
        return self.root(a) == self.root(b)

    def find(self, a, b):
        """O(lg(N))"""
        return self.is_connected(a, b)

    def union(self, a, b):
        """O(lg(N))"""
        root_a = self.root(a)
        root_b = self.root(b)

        if root_a != root_b:
            if self.sz[root_a] < self.sz[root_b]:
                self.id[root_a] = root_b
                self.sz[root_b] = self.sz[root_b] + self.sz[root_a]
            else:
                self.id[root_b] = root_a
                self.sz[root_a] = self.sz[root_a] + self.sz[root_b]

if __name__ == "__main__":
    
    uf = WeightedUnionFind(10)
    
    uf.id[0] = 1
    uf.id[1] = 1
    uf.id[2] = 1
    uf.id[5] = 1
    uf.id[6] = 1
    uf.id[7] = 1
    
    uf.id[3] = 8
    uf.id[4] = 8
    uf.id[8] = 8
    uf.id[9] = 8
    
    print(uf.find(1,8))
    qf.union(1,8)
    print(uf.find(1,8))
    print(uf.find(2,8))
    print(uf.id[2])
    print(uf.root(2))

Continue reading

Advertisements