LC 1562. 查找大小为 M 的最新分组
https://leetcode-cn.com/problems/find-latest-group-of-size-m/
这题肯定需要使用Find/Union数据结构,不过如何使用是个问题。
我的第一个实现使用了两个Find/Union数据结构
- `first[i]` 表示 `arr[i]` 这个节点所处连续区间[s,e]的s
- `last[i]` 则表示 `arr[i]` 这个节点所处连续区间[s,e]的e
但是关联和查询两个数据结构的时候必须非常小心
- 当我们查询到 `p0=queyFirst(x)` 之后,我们不仅仅需要 `first[x]=p0`, 并且需要设置 `last[p0]=x`. 这样p0下面所有的节点的last才能够更新到最远
- 然后我们在 `queryLast` 的实现里面,我们需要先找到特征点,然后通过特征点找到连续区间的终点。
class Solution: def findLatestStep(self, arr: List[int], m: int) -> int: n = len(arr) first = list(range(n)) last = list(range(n)) mark = [0] * n def queryFirst(x): p = x while first[p] != p: p = first[p] while first[x] != p: x2 = first[x] first[x] = p x = x2 return p def queryLast(x): x = queryFirst(x) p = x while last[p] != p: p = last[p] while last[x] != p: x2 = last[x] last[x] = p x = x2 return p cnt = 0 ans = -1 for step, x in enumerate(arr): x = x - 1 mark[x] = 1 if x > 0 and mark[x - 1]: p0 = queryFirst(x - 1) if (x - p0) == m: cnt -= 1 first[x] = p0 last[p0] = x if x < (n - 1) and mark[x + 1]: p1 = queryLast(x + 1) if (p1 - x) == m: cnt -= 1 first[p1] = x last[x] = p1 p0 = queryFirst(x) p1 = queryLast(x) if (p1 - p0 + 1) == m: cnt += 1 if cnt > 0: ans = step + 1 return ans
我看到另外一份实现是这样的,只维护一个Find/Union数据结构,它只表示特征点。但是在特征点上附带了长度信息。当两个区间合并的时候,只要把长度相加就行,然后在新的特征点上附带上长度信息。这种实现类似课本里面的写法,只不过课本里面的实现方法,特征点上附带的不是区间长度而树的高度。在合并两个特征点的时候,根据高度选择谁作为新的特征点更合适。
简单来说,维护两个Find/Union的数据结构好像会比较麻烦,维护一个Find/Union数据结构然后在merge的时候将附带信息做合并,是更加简单的方法。
class Solution: def findLatestStep(self, arr: List[int], m: int) -> int: n = len(arr) first = list(range(n)) size = [1] * n mark = [0] * n from collections import Counter cnt = Counter() def queryFirst(x): p = x while first[p] != p: p = first[p] # compress. while first[x] != p: x2 = first[x] first[x] = p x = x2 return p def merge(a, b): pa = queryFirst(a) pb = queryFirst(b) if pa != pb: if pa < pb: pa, pb = pb, pa cnt[size[pa]] -= 1 cnt[size[pb]] -= 1 size[pb] = size[pa] + size[pb] first[pa] = pb cnt[size[pb]] += 1 ans = -1 for step, x in enumerate(arr): x = x - 1 mark[x] = 1 cnt[1] += 1 if x > 0 and mark[x - 1]: merge(x, x - 1) if x < (n - 1) and mark[x + 1]: merge(x, x + 1) if cnt[m] > 0: ans = step + 1 return ans