求解RMQ的线段树实现
很早做题的时候就听到过线段树,但是因为课本上没有,所以觉得这个东西用的应该比较少,没有太关注。最近这段时间接触到很多题目,都用到了线段树,并且发现其实线段树可以用来有效地解决许多和区间相关的问题,比如查询区间最大最小值(RMQ)以及查询区间和。
区间最大最小值(RMQ)有好几种办法:
- DP. O(n^2)空间复杂度,O(1)查询复杂度,预处理时间O(n^2). 不支持动态更新
- Square Root Decomposition. n2 = n ** 0.5. 时空复杂度都是O(n2), 预处理时间O(n), 动态更新时间O(n2).
- Sparse Table. 空间O(nlgn), 预处理O(nlgn), 查询O(1) 不支持动态更新
- 线段树. 空间O(n), 预处理O(n), 查询O(lgn), 支持动态更新O(lgn)
区间和有:
- 线段树. 空间O(n), 预处理O(n), 查询O(lgn), 支持动态更新O(lgn)
- Fenwick Tree. 空间O(n), 预处理O(nlgn), 查询(lgn), 支持动态更新O(lgn)
线段树并不需要真的构造树,可以使用树状数组来有效表示,有点类似堆(heap)的实现。IX最下面一层(i>=IN)是所有A的下标,上面则是每个区间段最小值的下标。
# NOTE(yan): IX是构造出来的线段树
class RMQSegmentTree:
def __init__(self, A):
self.INF = float('inf')
self.A = A
n = 1
while n < len(A):
n = n * 2
IX = [None] * (2 * n)
self.IX = IX
self.IN = n
self._init_index()
def _value(self, i):
if i is None:
return self.INF
return self.A[i]
def _init_index(self):
for i in range(len(self.A)):
self.IX[i + self.IN] = i
for i in range(self.IN - 1, 0, -1):
i0 = self.IX[2 * i]
i1 = self.IX[2 * i + 1]
v0 = self._value(i0)
v1 = self._value(i1)
self.IX[i] = i0 if v0 <= v1 else i1
def _update_index(self, i):
p = i // 2
while p:
i0 = self.IX[2 * p]
i1 = self.IX[2 * p + 1]
v0 = self._value(i0)
v1 = self._value(i1)
self.IX[p] = i0 if v0 <= v1 else i1
p = p // 2
def update(self, i, x):
self.A[i] = x
self._update_index(i + self.IN)
def _query(self, i, start, span, left, right):
if (start + span) <= left or start >= right:
return None
if start >= left and (start + span) <= right:
return self.IX[i]
i0 = self._query(i * 2, start, span // 2, left, right)
i1 = self._query(i * 2 + 1, start + span // 2, span // 2, left, right)
v0 = self._value(i0)
v1 = self._value(i1)
return i0 if v0 <= v1 else i1
def query(self, left, right):
# [left, right]
ans = self._query(1, 0, self.IN, left, right + 1)
return ans
为了验证有效性,可以构造随机数组,枚举所有的查询范围,并且随机变动其他的内容。通过和naive实现,来验证线段树的实现是否正确。
import numpy as np
def naive_query(A, left, right):
return np.argmin(A[left:right + 1]) + left
def main():
for size in (10, 16, 20, 32):
A = np.random.rand(size)
rmq = RMQSegmentTree(A)
n = len(A)
for left in range(n):
for right in range(left + 1, n):
for _ in range(10):
p = np.random.randint(left, right)
v = np.random.rand()
rmq.update(p, v)
x = naive_query(A, left, right)
y = rmq.query(left, right)
if x != y:
print('F')
print('PASS ON SIZE = {}'.format(size))