求解RMQ的线段树实现
很早做题的时候就听到过线段树,但是因为课本上没有,所以觉得这个东西用的应该比较少,没有太关注。最近这段时间接触到很多题目,都用到了线段树,并且发现其实线段树可以用来有效地解决许多和区间相关的问题,比如查询区间最大最小值(RMQ)以及查询区间和。
区间最大最小值(RMQ)有好几种办法:
- DP. O(n^2)空间复杂度,O(1)查询复杂度,预处理时间O(n^2). 不支持动态更新
- Square Root Decomposition. n2 = n ** 0.5. 时空复杂度都是O(n2), 预处理时间O(n), 动态更新时间O(n2).
- Sparse Table. 空间O(nlgn), 预处理O(nlgn), 查询O(1) 不支持动态更新
- 线段树. 空间O(n), 预处理O(n), 查询O(lgn), 支持动态更新O(lgn)
区间和有:
- 线段树. 空间O(n), 预处理O(n), 查询O(lgn), 支持动态更新O(lgn)
- Fenwick Tree. 空间O(n), 预处理O(nlgn), 查询(lgn), 支持动态更新O(lgn)
线段树并不需要真的构造树,可以使用树状数组来有效表示,有点类似堆(heap)的实现。IX最下面一层(i>=IN)是所有A的下标,上面则是每个区间段最小值的下标。
# NOTE(yan): IX是构造出来的线段树 class RMQSegmentTree: def __init__(self, A): self.INF = float('inf') self.A = A n = 1 while n < len(A): n = n * 2 IX = [None] * (2 * n) self.IX = IX self.IN = n self._init_index() def _value(self, i): if i is None: return self.INF return self.A[i] def _init_index(self): for i in range(len(self.A)): self.IX[i + self.IN] = i for i in range(self.IN - 1, 0, -1): i0 = self.IX[2 * i] i1 = self.IX[2 * i + 1] v0 = self._value(i0) v1 = self._value(i1) self.IX[i] = i0 if v0 <= v1 else i1 def _update_index(self, i): p = i // 2 while p: i0 = self.IX[2 * p] i1 = self.IX[2 * p + 1] v0 = self._value(i0) v1 = self._value(i1) self.IX[p] = i0 if v0 <= v1 else i1 p = p // 2 def update(self, i, x): self.A[i] = x self._update_index(i + self.IN) def _query(self, i, start, span, left, right): if (start + span) <= left or start >= right: return None if start >= left and (start + span) <= right: return self.IX[i] i0 = self._query(i * 2, start, span // 2, left, right) i1 = self._query(i * 2 + 1, start + span // 2, span // 2, left, right) v0 = self._value(i0) v1 = self._value(i1) return i0 if v0 <= v1 else i1 def query(self, left, right): # [left, right] ans = self._query(1, 0, self.IN, left, right + 1) return ans
为了验证有效性,可以构造随机数组,枚举所有的查询范围,并且随机变动其他的内容。通过和naive实现,来验证线段树的实现是否正确。
import numpy as np def naive_query(A, left, right): return np.argmin(A[left:right + 1]) + left def main(): for size in (10, 16, 20, 32): A = np.random.rand(size) rmq = RMQSegmentTree(A) n = len(A) for left in range(n): for right in range(left + 1, n): for _ in range(10): p = np.random.randint(left, right) v = np.random.rand() rmq.update(p, v) x = naive_query(A, left, right) y = rmq.query(left, right) if x != y: print('F') print('PASS ON SIZE = {}'.format(size))