把树装进数组里面

昨天看到 https://www.youtube.com/watch?v=rF13507mRp8 这个视频,这个作者写的区间树太优雅了。 把树放进数组里面高效且实现简单,遍历起来也特别容易。这个模式总结下来有:

  1. 为叶节点开辟的空间是>=sz的最近的2^n.
  2. 如果叶节点开辟的空间是n的话,那么树节点开辟空间也是n.
  3. 其实树节点只需要n-1个,但是开辟n个的话就可以从下标1作为root.
  4. 下标1作为root好处是,直接使用 2*i, 2*i+1 就可以访问到子树。
class SegmentTree:
    def __init__(self, sz):
        n = 1
        while n < sz:
            n = n * 2
        # 此时的n可以容纳叶子节点,但是我们还需要开辟树节点
        # 并且树节点从1开始标记
        self.arr = [0] * (2 * n)
        self.n = n

    def get(self, idx):
        return self.arr[idx + self.n]

    def update(self, idx, value):
        p = (idx + self.n)
        self.arr[p] = value
        p = p // 2
        while p >= 1:
            i, j = p * 2, p * 2 + 1
            self.arr[p] = max(self.arr[i], self.arr[j])
            p = p // 2

    def total_max(self):
        return self.arr[1]