LC 1562. 查找大小为 M 的最新分组
https://leetcode-cn.com/problems/find-latest-group-of-size-m/
这题肯定需要使用Find/Union数据结构,不过如何使用是个问题。
我的第一个实现使用了两个Find/Union数据结构
- `first[i]` 表示 `arr[i]` 这个节点所处连续区间[s,e]的s
- `last[i]` 则表示 `arr[i]` 这个节点所处连续区间[s,e]的e
但是关联和查询两个数据结构的时候必须非常小心
- 当我们查询到 `p0=queyFirst(x)` 之后,我们不仅仅需要 `first[x]=p0`, 并且需要设置 `last[p0]=x`. 这样p0下面所有的节点的last才能够更新到最远
- 然后我们在 `queryLast` 的实现里面,我们需要先找到特征点,然后通过特征点找到连续区间的终点。
class Solution:
def findLatestStep(self, arr: List[int], m: int) -> int:
n = len(arr)
first = list(range(n))
last = list(range(n))
mark = [0] * n
def queryFirst(x):
p = x
while first[p] != p:
p = first[p]
while first[x] != p:
x2 = first[x]
first[x] = p
x = x2
return p
def queryLast(x):
x = queryFirst(x)
p = x
while last[p] != p:
p = last[p]
while last[x] != p:
x2 = last[x]
last[x] = p
x = x2
return p
cnt = 0
ans = -1
for step, x in enumerate(arr):
x = x - 1
mark[x] = 1
if x > 0 and mark[x - 1]:
p0 = queryFirst(x - 1)
if (x - p0) == m:
cnt -= 1
first[x] = p0
last[p0] = x
if x < (n - 1) and mark[x + 1]:
p1 = queryLast(x + 1)
if (p1 - x) == m:
cnt -= 1
first[p1] = x
last[x] = p1
p0 = queryFirst(x)
p1 = queryLast(x)
if (p1 - p0 + 1) == m:
cnt += 1
if cnt > 0:
ans = step + 1
return ans
我看到另外一份实现是这样的,只维护一个Find/Union数据结构,它只表示特征点。但是在特征点上附带了长度信息。当两个区间合并的时候,只要把长度相加就行,然后在新的特征点上附带上长度信息。这种实现类似课本里面的写法,只不过课本里面的实现方法,特征点上附带的不是区间长度而树的高度。在合并两个特征点的时候,根据高度选择谁作为新的特征点更合适。
简单来说,维护两个Find/Union的数据结构好像会比较麻烦,维护一个Find/Union数据结构然后在merge的时候将附带信息做合并,是更加简单的方法。
class Solution:
def findLatestStep(self, arr: List[int], m: int) -> int:
n = len(arr)
first = list(range(n))
size = [1] * n
mark = [0] * n
from collections import Counter
cnt = Counter()
def queryFirst(x):
p = x
while first[p] != p:
p = first[p]
# compress.
while first[x] != p:
x2 = first[x]
first[x] = p
x = x2
return p
def merge(a, b):
pa = queryFirst(a)
pb = queryFirst(b)
if pa != pb:
if pa < pb:
pa, pb = pb, pa
cnt[size[pa]] -= 1
cnt[size[pb]] -= 1
size[pb] = size[pa] + size[pb]
first[pa] = pb
cnt[size[pb]] += 1
ans = -1
for step, x in enumerate(arr):
x = x - 1
mark[x] = 1
cnt[1] += 1
if x > 0 and mark[x - 1]:
merge(x, x - 1)
if x < (n - 1) and mark[x + 1]:
merge(x, x + 1)
if cnt[m] > 0:
ans = step + 1
return ans