LC 2916. 子数组不同元素数目的平方和 II
https://leetcode.cn/problems/subarrays-distinct-element-sum-of-squares-ii/
这题有个量级小一点的版本可以作为参考 https://leetcode.cn/problems/subarrays-distinct-element-sum-of-squares-i/
关于这题里面的数学思路,可以看 题解,这里面说的非常清楚,大致思路就是
- 假设我们已经有A[i], 表示 X[0..i], X[1..i], X[2..i] … X[i-1..i], X[i, i] 这些序列的去重数平方
- 那么我们可以考虑A[i+1] 如果进行更新。假设A[i+1] =x, 并且A[j]=x的话,那么X[0..i],..X[j,i] = X[0..i+1]..X[j,i+1],这些都不用变
- 变化的就是 X[j+1,..i], X[j+2..i] … X[i..i] 每个元素上都需要+1, 然后平方,然后我们就计算增量就行。
里面需要一个数据结构的辅助满足
- 我们需要更新一段区间,让这段区间里面每个元素+1
- 然后我们还需要计算一段区间的和
这个数据结构就是线段树,之前实现过一个简单的,可以更新其中某个元素,但是没有实现过更新区间的功能。实现更新区间的功能,需要有个 `lazy` 的结构。 这个 `lazy` 结构表示:将用于孩子节点的更新。这样一旦我们更新区间的时候,不用立刻更新下面的孩子,除非我们需要去计算孩子的区间和,这个时候再去将这个lazy结构应用上去。 在实现上,我们可以用一个naive实现来做交叉验证,这个非常有效。
这题另外一个优化,就是我们更新的区间,和计算区间的和,是同一个区间。如果我们将两者作为两个操作的话,那么时间就会double,这题就会超时。 我们 必须 将查询和更新放在一个方法里面才能过。
#!/usr/bin/env python
# coding:utf-8
# Copyright (C) dirlt
from typing import List
class RangeSumer:
class Base:
def __init__(self, n):
self.values = [0] * n
def update(self, i, j, delta):
for k in range(i, j + 1):
self.values[k] += delta
def query(self, i, j):
acc = 0
for k in range(i, j + 1):
acc += self.values[k]
return acc
def __init__(self, n):
self.n = n
sz = 1
while sz < n:
sz <<= 1
self.sum = [0] * (sz << 1)
self.lazy = [0] * (sz << 1)
self.sz = sz
self.base = RangeSumer.Base(n)
self.debug = False
def dump(self):
sz = 1
off = 1
while sz <= self.sz:
print(self.sum[off:off + sz], self.lazy[off:off + sz])
off += sz
sz = sz << 1
def query_and_update(self, i, j, delta):
def do(i, j, k, s, sz):
if i <= s <= (s + sz - 1) <= j:
res = self.sum[k]
self.apply_lazy(k, sz, delta)
return res
self.push_down(k, sz)
mid = s + sz // 2
res = 0
if i < mid:
res += do(i, j, 2 * k, s, sz // 2)
if j >= mid:
res += do(i, j, 2 * k + 1, mid, sz // 2)
self.sum[k] = self.sum[2 * k] + self.sum[2 * k + 1]
return res
ans = do(i, j, 1, 0, self.sz)
if self.debug:
exp = self.base.query(i, j)
self.base.update(i, j, delta)
print('query_and_update(%d, %d) = %d' % (i, j, ans))
self.dump()
if ans != exp:
assert (ans == exp)
return ans
def push_down(self, k, sz):
if self.lazy[k] and sz != 1:
v = self.lazy[k]
self.apply_lazy(2 * k, sz // 2, v)
self.apply_lazy(2 * k + 1, sz // 2, v)
self.lazy[k] = 0
def apply_lazy(self, k, sz, delta):
self.sum[k] += delta * sz
self.lazy[k] += delta
def query(self, i, j):
return self.query_and_update(i, j, 0)
def update(self, i, j, delta):
self.query_and_update(i, j, delta)
class Solution:
def sumCounts(self, nums: List[int]) -> int:
n = len(nums)
prev = {}
ans, acc = 0, 0
MOD = 10 ** 9 + 7
sumer = RangeSumer(n)
# sumer.debug = True
for i in range(n):
p = prev.get(nums[i], -1)
prev[nums[i]] = i
delta = 2 * sumer.query_and_update(p + 1, i, 1) + (i - p)
acc = (acc + delta) % MOD
ans = (ans + acc) % MOD
return ans % MOD