LC 2569. 更新数组后处理求和查询

https://leetcode.cn/problems/handling-sum-queries-after-update/

这题看题解中还有直接模拟超大数的bit count. 按照直觉来说这种办法其实并不容易work, 但是运行时间却比线段树要好很多。

class Solution:
    def handleQuery(self, nums1: List[int], nums2: List[int], queries: List[List[int]]) -> List[int]:
        s = sum(nums2)
        x = int(''.join(map(str, nums1[::-1])), 2)

        ans = []
        for op, l, r in queries:
            if op == 1:
                y = 1 << (r - l + 1) - 1
                y <<= l
                x = x ^ y
            elif op == 2:
                s += l * x.bit_count()
            else:
                ans.append(s)

        return ans

题解 里面还给出了延迟线段树的解法,我觉得这个是值得学习学习的。

延迟线段树有几个要点:

  1. 树状数组的要点都是从1开始计算的,分隔点可以是 (l+r)//2
  2. 递归的时候,想要更新的区间可以不进行分隔,但是作用区间需要分隔。
  3. 延迟线段树需要设置 `lazy[i]`, 表示它的子树是否已经处理过。如果子树没有处理过的话,那么先处理两个子树。
class Solution:
    def handleQuery(self, nums1: List[int], nums2: List[int], queries: List[List[int]]) -> List[int]:
        n = len(nums1)
        cnt = [0] * (4 * n)
        lazy = [False] * (4 * n)

        def maintain(i):
            cnt[i] = cnt[2 * i] + cnt[2 * i + 1]

        def build(i, l, r):
            if l == r:
                cnt[i] = nums1[l - 1]
                return

            m = (l + r) // 2
            build(2 * i, l, m)
            build(2 * i + 1, m + 1, r)
            maintain(i)
            return

        def flip(i, l, r, L, R):
            def fix(i, l, r):
                cnt[i] = (r - l + 1) - cnt[i]
                lazy[i] = not lazy[i]
                return

            if l <= L and R <= r:
                fix(i, L, R)
                return

            M = (L + R) // 2
            if lazy[i]: # 如果子树没有处理的话,那么需要处理子树先
                lazy[i] = False
                fix(2 * i, L, M)
                fix(2 * i + 1, M + 1, R)
                maintain(i)

            if l <= M: flip(2 * i, l, r, L, M) # 处理区间(l,r)不用拆分
            if (M + 1) <= r: flip(2 * i + 1, l, r, M + 1, R) # 但是作用区间(L, R)需要拆分
            maintain(i)

        build(1, 1, n) # 下标从1开始很关键,否则处理起来很麻烦
        ans, base = [], sum(nums2)
        for op, l, r in queries:
            if op == 1:
                flip(1, l + 1, r + 1, 1, n)
            elif op == 2:
                base += l * cnt[1]
            else:
                ans.append(base)
        return ans