求解RMQ的线段树实现

很早做题的时候就听到过线段树,但是因为课本上没有,所以觉得这个东西用的应该比较少,没有太关注。最近这段时间接触到很多题目,都用到了线段树,并且发现其实线段树可以用来有效地解决许多和区间相关的问题,比如查询区间最大最小值(RMQ)以及查询区间和。

区间最大最小值(RMQ)有好几种办法:

  1. DP. O(n^2)空间复杂度,O(1)查询复杂度,预处理时间O(n^2). 不支持动态更新
  2. Square Root Decomposition. n2 = n ** 0.5. 时空复杂度都是O(n2), 预处理时间O(n), 动态更新时间O(n2).
  3. Sparse Table. 空间O(nlgn), 预处理O(nlgn), 查询O(1) 不支持动态更新
  4. 线段树. 空间O(n), 预处理O(n), 查询O(lgn), 支持动态更新O(lgn)

区间和有:

  1. 线段树. 空间O(n), 预处理O(n), 查询O(lgn), 支持动态更新O(lgn)
  2. Fenwick Tree. 空间O(n), 预处理O(nlgn), 查询(lgn), 支持动态更新O(lgn)

线段树并不需要真的构造树,可以使用树状数组来有效表示,有点类似堆(heap)的实现。IX最下面一层(i>=IN)是所有A的下标,上面则是每个区间段最小值的下标。

# NOTE(yan): IX是构造出来的线段树

class RMQSegmentTree:
    def __init__(self, A):
        self.INF = float('inf')
        self.A = A
        n = 1
        while n < len(A):
            n = n * 2
        IX = [None] * (2 * n)
        self.IX = IX
        self.IN = n
        self._init_index()

    def _value(self, i):
        if i is None:
            return self.INF
        return self.A[i]

    def _init_index(self):
        for i in range(len(self.A)):
            self.IX[i + self.IN] = i

        for i in range(self.IN - 1, 0, -1):
            i0 = self.IX[2 * i]
            i1 = self.IX[2 * i + 1]
            v0 = self._value(i0)
            v1 = self._value(i1)
            self.IX[i] = i0 if v0 <= v1 else i1

    def _update_index(self, i):
        p = i // 2
        while p:
            i0 = self.IX[2 * p]
            i1 = self.IX[2 * p + 1]
            v0 = self._value(i0)
            v1 = self._value(i1)
            self.IX[p] = i0 if v0 <= v1 else i1
            p = p // 2

    def update(self, i, x):
        self.A[i] = x
        self._update_index(i + self.IN)

    def _query(self, i, start, span, left, right):
        if (start + span) <= left or start >= right:
            return None
        if start >= left and (start + span) <= right:
            return self.IX[i]
        i0 = self._query(i * 2, start, span // 2, left, right)
        i1 = self._query(i * 2 + 1, start + span // 2, span // 2, left, right)
        v0 = self._value(i0)
        v1 = self._value(i1)
        return i0 if v0 <= v1 else i1

    def query(self, left, right):
        # [left, right]
        ans = self._query(1, 0, self.IN, left, right + 1)
        return ans

为了验证有效性,可以构造随机数组,枚举所有的查询范围,并且随机变动其他的内容。通过和naive实现,来验证线段树的实现是否正确。

import numpy as np

def naive_query(A, left, right):
    return np.argmin(A[left:right + 1]) + left


def main():
    for size in (10, 16, 20, 32):
        A = np.random.rand(size)
        rmq = RMQSegmentTree(A)
        n = len(A)
        for left in range(n):
            for right in range(left + 1, n):
                for _ in range(10):
                    p = np.random.randint(left, right)
                    v = np.random.rand()
                    rmq.update(p, v)
                    x = naive_query(A, left, right)
                    y = rmq.query(left, right)
                    if x != y:
                        print('F')
        print('PASS ON SIZE = {}'.format(size))