LC 2846. 边权重均等查询
https://leetcode.cn/problems/minimum-edge-weight-equilibrium-queries-in-a-tree/
这个问题就是,A->B这个路径上,排除most common weight之外的其他weight的和,所以我们要记录A->B路径上权重的分布。
有个关键的提示是 w<=26. 所以这个权重分布可以非常简单,最后直接 `sum(w) - max(w)` 就好了。
难点估计在于怎么表示A->B的权重分布,直接表示肯定是不行的,所以需要使用类似树拆分的方式:
- 我们计算root->A和root->B的权重分布,假设分别是Wa, Wb. 这个直接一个dfs就可以计算出来。
- 然后我们找到A, B的最小公共祖先假设是X, 并且假设root->X也是 Wx
- 那么从A->B的权重分布应该是 (Wa - Wx) + (Wb - Wx).
所以一个新问题就是如何计算LCA. 这个有个tarjan-lca算法可以参考,可以根据queries进行离线计算,时间复杂度大约就是在O(n+q)上。
这个tarjan-lca的算法大致思路就是,某个点的parent会随着dfs不断地修改,这个点的parent始终是在最近一次访问的root上。 为了更好更快递修改点的parent, 就需要使用union find set结构,具体说应该是find结构就行。
def tarjan_lca(graph, root, queries): class UnionFindSet: def __init__(self, n): self.ps = [0] * n for i in range(n): self.ps[i] = i def find(self, x): p = x while self.ps[p] != p: p = self.ps[p] while self.ps[x] != x: up = self.ps[x] self.ps[x] = p x = up return p def set(self, x, p): self.ps[x] = p from collections import defaultdict query_index = defaultdict(list) ans = [-1] * len(queries) for idx, (u, v) in enumerate(queries): query_index[u].append((v, idx)) query_index[v].append((u, idx)) n = len(graph) ufs = UnionFindSet(n) visited = [0] * n def dfs(root, parent): # answer queries. visited[root] = 1 query = query_index[root] for v, idx in query: # 如果这个节点之前没有被访问过,那么是不知道LCA的 if not visited[v]: continue # 如果有对应的查询节点v, 并且这个节点之前访问过 # 那么使用这个节点的parent. # 如果v是root的祖先节点的话,那么就是v # 如果v在另外一个树上的话,那么就是最早交汇的节点 p = ufs.find(v) ans[idx] = p # continue to dfs. for v, _ in graph[root]: if v != parent: dfs(v, root) # 遍历子节点之后,将子节点的父节点设置为自己 ufs.set(v, root) dfs(root, -1) return ans class Solution: def minOperationsQueries(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]: adj = [[] for _ in range(n)] for (u, v, w) in edges: adj[u].append((v, w)) adj[v].append((u, w)) W = {} def dfs(root, parent, weight): W[root] = tuple(weight) for v, w in adj[root]: if v == parent: continue weight[w] += 1 dfs(v, root, weight) weight[w] -= 1 dfs(0, -1, [0] * 27) lca = tarjan_lca(adj, 0, queries) ans = [] for (u, v), r in zip(queries, lca): w1 = list(W[u]) w2 = list(W[v]) w3 = list(W[r]) for i in range(27): w1[i] -= w3[i] w2[i] -= w3[i] w1[i] += w2[i] c = sum(w1) - max(w1) ans.append(c) return ans