LC 2846. 边权重均等查询
https://leetcode.cn/problems/minimum-edge-weight-equilibrium-queries-in-a-tree/
这个问题就是,A->B这个路径上,排除most common weight之外的其他weight的和,所以我们要记录A->B路径上权重的分布。
有个关键的提示是 w<=26. 所以这个权重分布可以非常简单,最后直接 `sum(w) - max(w)` 就好了。
难点估计在于怎么表示A->B的权重分布,直接表示肯定是不行的,所以需要使用类似树拆分的方式:
- 我们计算root->A和root->B的权重分布,假设分别是Wa, Wb. 这个直接一个dfs就可以计算出来。
- 然后我们找到A, B的最小公共祖先假设是X, 并且假设root->X也是 Wx
- 那么从A->B的权重分布应该是 (Wa - Wx) + (Wb - Wx).
所以一个新问题就是如何计算LCA. 这个有个tarjan-lca算法可以参考,可以根据queries进行离线计算,时间复杂度大约就是在O(n+q)上。
这个tarjan-lca的算法大致思路就是,某个点的parent会随着dfs不断地修改,这个点的parent始终是在最近一次访问的root上。 为了更好更快递修改点的parent, 就需要使用union find set结构,具体说应该是find结构就行。
def tarjan_lca(graph, root, queries):
class UnionFindSet:
def __init__(self, n):
self.ps = [0] * n
for i in range(n):
self.ps[i] = i
def find(self, x):
p = x
while self.ps[p] != p:
p = self.ps[p]
while self.ps[x] != x:
up = self.ps[x]
self.ps[x] = p
x = up
return p
def set(self, x, p):
self.ps[x] = p
from collections import defaultdict
query_index = defaultdict(list)
ans = [-1] * len(queries)
for idx, (u, v) in enumerate(queries):
query_index[u].append((v, idx))
query_index[v].append((u, idx))
n = len(graph)
ufs = UnionFindSet(n)
visited = [0] * n
def dfs(root, parent):
# answer queries.
visited[root] = 1
query = query_index[root]
for v, idx in query:
# 如果这个节点之前没有被访问过,那么是不知道LCA的
if not visited[v]: continue
# 如果有对应的查询节点v, 并且这个节点之前访问过
# 那么使用这个节点的parent.
# 如果v是root的祖先节点的话,那么就是v
# 如果v在另外一个树上的话,那么就是最早交汇的节点
p = ufs.find(v)
ans[idx] = p
# continue to dfs.
for v, _ in graph[root]:
if v != parent:
dfs(v, root)
# 遍历子节点之后,将子节点的父节点设置为自己
ufs.set(v, root)
dfs(root, -1)
return ans
class Solution:
def minOperationsQueries(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
adj = [[] for _ in range(n)]
for (u, v, w) in edges:
adj[u].append((v, w))
adj[v].append((u, w))
W = {}
def dfs(root, parent, weight):
W[root] = tuple(weight)
for v, w in adj[root]:
if v == parent: continue
weight[w] += 1
dfs(v, root, weight)
weight[w] -= 1
dfs(0, -1, [0] * 27)
lca = tarjan_lca(adj, 0, queries)
ans = []
for (u, v), r in zip(queries, lca):
w1 = list(W[u])
w2 = list(W[v])
w3 = list(W[r])
for i in range(27):
w1[i] -= w3[i]
w2[i] -= w3[i]
w1[i] += w2[i]
c = sum(w1) - max(w1)
ans.append(c)
return ans