LC 2916. 子数组不同元素数目的平方和 II

https://leetcode.cn/problems/subarrays-distinct-element-sum-of-squares-ii/

这题有个量级小一点的版本可以作为参考 https://leetcode.cn/problems/subarrays-distinct-element-sum-of-squares-i/

关于这题里面的数学思路,可以看 题解,这里面说的非常清楚,大致思路就是

里面需要一个数据结构的辅助满足

这个数据结构就是线段树,之前实现过一个简单的,可以更新其中某个元素,但是没有实现过更新区间的功能。实现更新区间的功能,需要有个 `lazy` 的结构。 这个 `lazy` 结构表示:将用于孩子节点的更新。这样一旦我们更新区间的时候,不用立刻更新下面的孩子,除非我们需要去计算孩子的区间和,这个时候再去将这个lazy结构应用上去。 在实现上,我们可以用一个naive实现来做交叉验证,这个非常有效。

这题另外一个优化,就是我们更新的区间,和计算区间的和,是同一个区间。如果我们将两者作为两个操作的话,那么时间就会double,这题就会超时。 我们 必须 将查询和更新放在一个方法里面才能过。

#!/usr/bin/env python
# coding:utf-8
# Copyright (C) dirlt

from typing import List


class RangeSumer:
    class Base:
        def __init__(self, n):
            self.values = [0] * n

        def update(self, i, j, delta):
            for k in range(i, j + 1):
                self.values[k] += delta

        def query(self, i, j):
            acc = 0
            for k in range(i, j + 1):
                acc += self.values[k]
            return acc

    def __init__(self, n):
        self.n = n
        sz = 1
        while sz < n:
            sz <<= 1
        self.sum = [0] * (sz << 1)
        self.lazy = [0] * (sz << 1)
        self.sz = sz
        self.base = RangeSumer.Base(n)
        self.debug = False

    def dump(self):
        sz = 1
        off = 1
        while sz <= self.sz:
            print(self.sum[off:off + sz], self.lazy[off:off + sz])
            off += sz
            sz = sz << 1

    def query_and_update(self, i, j, delta):
        def do(i, j, k, s, sz):

            if i <= s <= (s + sz - 1) <= j:
                res = self.sum[k]
                self.apply_lazy(k, sz, delta)
                return res

            self.push_down(k, sz)
            mid = s + sz // 2
            res = 0
            if i < mid:
                res += do(i, j, 2 * k, s, sz // 2)
            if j >= mid:
                res += do(i, j, 2 * k + 1, mid, sz // 2)

            self.sum[k] = self.sum[2 * k] + self.sum[2 * k + 1]
            return res

        ans = do(i, j, 1, 0, self.sz)
        if self.debug:
            exp = self.base.query(i, j)
            self.base.update(i, j, delta)
            print('query_and_update(%d, %d) = %d' % (i, j, ans))
            self.dump()

            if ans != exp:
                assert (ans == exp)
        return ans

    def push_down(self, k, sz):
        if self.lazy[k] and sz != 1:
            v = self.lazy[k]
            self.apply_lazy(2 * k, sz // 2, v)
            self.apply_lazy(2 * k + 1, sz // 2, v)
            self.lazy[k] = 0

    def apply_lazy(self, k, sz, delta):
        self.sum[k] += delta * sz
        self.lazy[k] += delta

    def query(self, i, j):
        return self.query_and_update(i, j, 0)

    def update(self, i, j, delta):
        self.query_and_update(i, j, delta)


class Solution:
    def sumCounts(self, nums: List[int]) -> int:
        n = len(nums)
        prev = {}
        ans, acc = 0, 0
        MOD = 10 ** 9 + 7
        sumer = RangeSumer(n)
        # sumer.debug = True
        for i in range(n):
            p = prev.get(nums[i], -1)
            prev[nums[i]] = i
            delta = 2 * sumer.query_and_update(p + 1, i, 1) + (i - p)
            acc = (acc + delta) % MOD
            ans = (ans + acc) % MOD
        return ans % MOD