LC 2916. 子数组不同元素数目的平方和 II
https://leetcode.cn/problems/subarrays-distinct-element-sum-of-squares-ii/
这题有个量级小一点的版本可以作为参考 https://leetcode.cn/problems/subarrays-distinct-element-sum-of-squares-i/
关于这题里面的数学思路,可以看 题解,这里面说的非常清楚,大致思路就是
- 假设我们已经有A[i], 表示 X[0..i], X[1..i], X[2..i] … X[i-1..i], X[i, i] 这些序列的去重数平方
- 那么我们可以考虑A[i+1] 如果进行更新。假设A[i+1] =x, 并且A[j]=x的话,那么X[0..i],..X[j,i] = X[0..i+1]..X[j,i+1],这些都不用变
- 变化的就是 X[j+1,..i], X[j+2..i] … X[i..i] 每个元素上都需要+1, 然后平方,然后我们就计算增量就行。
里面需要一个数据结构的辅助满足
- 我们需要更新一段区间,让这段区间里面每个元素+1
- 然后我们还需要计算一段区间的和
这个数据结构就是线段树,之前实现过一个简单的,可以更新其中某个元素,但是没有实现过更新区间的功能。实现更新区间的功能,需要有个 `lazy` 的结构。 这个 `lazy` 结构表示:将用于孩子节点的更新。这样一旦我们更新区间的时候,不用立刻更新下面的孩子,除非我们需要去计算孩子的区间和,这个时候再去将这个lazy结构应用上去。 在实现上,我们可以用一个naive实现来做交叉验证,这个非常有效。
这题另外一个优化,就是我们更新的区间,和计算区间的和,是同一个区间。如果我们将两者作为两个操作的话,那么时间就会double,这题就会超时。 我们 必须 将查询和更新放在一个方法里面才能过。
#!/usr/bin/env python # coding:utf-8 # Copyright (C) dirlt from typing import List class RangeSumer: class Base: def __init__(self, n): self.values = [0] * n def update(self, i, j, delta): for k in range(i, j + 1): self.values[k] += delta def query(self, i, j): acc = 0 for k in range(i, j + 1): acc += self.values[k] return acc def __init__(self, n): self.n = n sz = 1 while sz < n: sz <<= 1 self.sum = [0] * (sz << 1) self.lazy = [0] * (sz << 1) self.sz = sz self.base = RangeSumer.Base(n) self.debug = False def dump(self): sz = 1 off = 1 while sz <= self.sz: print(self.sum[off:off + sz], self.lazy[off:off + sz]) off += sz sz = sz << 1 def query_and_update(self, i, j, delta): def do(i, j, k, s, sz): if i <= s <= (s + sz - 1) <= j: res = self.sum[k] self.apply_lazy(k, sz, delta) return res self.push_down(k, sz) mid = s + sz // 2 res = 0 if i < mid: res += do(i, j, 2 * k, s, sz // 2) if j >= mid: res += do(i, j, 2 * k + 1, mid, sz // 2) self.sum[k] = self.sum[2 * k] + self.sum[2 * k + 1] return res ans = do(i, j, 1, 0, self.sz) if self.debug: exp = self.base.query(i, j) self.base.update(i, j, delta) print('query_and_update(%d, %d) = %d' % (i, j, ans)) self.dump() if ans != exp: assert (ans == exp) return ans def push_down(self, k, sz): if self.lazy[k] and sz != 1: v = self.lazy[k] self.apply_lazy(2 * k, sz // 2, v) self.apply_lazy(2 * k + 1, sz // 2, v) self.lazy[k] = 0 def apply_lazy(self, k, sz, delta): self.sum[k] += delta * sz self.lazy[k] += delta def query(self, i, j): return self.query_and_update(i, j, 0) def update(self, i, j, delta): self.query_and_update(i, j, delta) class Solution: def sumCounts(self, nums: List[int]) -> int: n = len(nums) prev = {} ans, acc = 0, 0 MOD = 10 ** 9 + 7 sumer = RangeSumer(n) # sumer.debug = True for i in range(n): p = prev.get(nums[i], -1) prev[nums[i]] = i delta = 2 * sumer.query_and_update(p + 1, i, 1) + (i - p) acc = (acc + delta) % MOD ans = (ans + acc) % MOD return ans % MOD