5497. 查找大小为 M 的最新分组

https://leetcode-cn.com/problems/find-latest-group-of-size-m/

这题肯定需要使用Find/Union数据结构,不过如何使用是个问题。

我的第一个实现使用了两个Find/Union数据结构

但是关联和查询两个数据结构的时候必须非常小心

class Solution:
    def findLatestStep(self, arr: List[int], m: int) -> int:
        n = len(arr)
        first = list(range(n))
        last = list(range(n))
        mark = [0] * n

        def queryFirst(x):
            p = x
            while first[p] != p:
                p = first[p]

            while first[x] != p:
                x2 = first[x]
                first[x] = p
                x = x2

            return p

        def queryLast(x):
            x = queryFirst(x)
            p = x
            while last[p] != p:
                p = last[p]

            while last[x] != p:
                x2 = last[x]
                last[x] = p
                x = x2
            return p

        cnt = 0
        ans = -1

        for step, x in enumerate(arr):
            x = x - 1
            mark[x] = 1

            if x > 0 and mark[x - 1]:
                p0 = queryFirst(x - 1)
                if (x - p0) == m:
                    cnt -= 1
                first[x] = p0
                last[p0] = x

            if x < (n - 1) and mark[x + 1]:
                p1 = queryLast(x + 1)
                if (p1 - x) == m:
                    cnt -= 1
                first[p1] = x
                last[x] = p1

            p0 = queryFirst(x)
            p1 = queryLast(x)
            if (p1 - p0 + 1) == m:
                cnt += 1
            if cnt > 0:
                ans = step + 1

        return ans

我看到另外一份实现是这样的,只维护一个Find/Union数据结构,它只表示特征点。但是在特征点上附带了长度信息。当两个区间合并的时候,只要把长度相加就行,然后在新的特征点上附带上长度信息。这种实现类似课本里面的写法,只不过课本里面的实现方法,特征点上附带的不是区间长度而树的高度。在合并两个特征点的时候,根据高度选择谁作为新的特征点更合适。

简单来说,维护两个Find/Union的数据结构好像会比较麻烦,维护一个Find/Union数据结构然后在merge的时候将附带信息做合并,是更加简单的方法。

class Solution:
    def findLatestStep(self, arr: List[int], m: int) -> int:
        n = len(arr)
        first = list(range(n))
        size = [1] * n
        mark = [0] * n
        from collections import Counter
        cnt = Counter()

        def queryFirst(x):
            p = x
            while first[p] != p:
                p = first[p]

            # compress.
            while first[x] != p:
                x2 = first[x]
                first[x] = p
                x = x2

            return p

        def merge(a, b):
            pa = queryFirst(a)
            pb = queryFirst(b)
            if pa != pb:
                if pa < pb:
                    pa, pb = pb, pa
                cnt[size[pa]] -= 1
                cnt[size[pb]] -= 1
                size[pb] = size[pa] + size[pb]
                first[pa] = pb
                cnt[size[pb]] += 1

        ans = -1

        for step, x in enumerate(arr):
            x = x - 1
            mark[x] = 1
            cnt[1] += 1

            if x > 0 and mark[x - 1]:
                merge(x, x - 1)

            if x < (n - 1) and mark[x + 1]:
                merge(x, x + 1)

            if cnt[m] > 0:
                ans = step + 1

        return ans