LC 2846. 边权重均等查询

https://leetcode.cn/problems/minimum-edge-weight-equilibrium-queries-in-a-tree/

这个问题就是,A->B这个路径上,排除most common weight之外的其他weight的和,所以我们要记录A->B路径上权重的分布。

有个关键的提示是 w<=26. 所以这个权重分布可以非常简单,最后直接 `sum(w) - max(w)` 就好了。

难点估计在于怎么表示A->B的权重分布,直接表示肯定是不行的,所以需要使用类似树拆分的方式:

所以一个新问题就是如何计算LCA. 这个有个tarjan-lca算法可以参考,可以根据queries进行离线计算,时间复杂度大约就是在O(n+q)上。

这个tarjan-lca的算法大致思路就是,某个点的parent会随着dfs不断地修改,这个点的parent始终是在最近一次访问的root上。 为了更好更快递修改点的parent, 就需要使用union find set结构,具体说应该是find结构就行。

def tarjan_lca(graph, root, queries):
    class UnionFindSet:
        def __init__(self, n):
            self.ps = [0] * n
            for i in range(n):
                self.ps[i] = i

        def find(self, x):
            p = x
            while self.ps[p] != p:
                p = self.ps[p]

            while self.ps[x] != x:
                up = self.ps[x]
                self.ps[x] = p
                x = up
            return p

        def set(self, x, p):
            self.ps[x] = p

    from collections import defaultdict
    query_index = defaultdict(list)
    ans = [-1] * len(queries)
    for idx, (u, v) in enumerate(queries):
        query_index[u].append((v, idx))
        query_index[v].append((u, idx))

    n = len(graph)
    ufs = UnionFindSet(n)
    visited = [0] * n

    def dfs(root, parent):
        # answer queries.
        visited[root] = 1
        query = query_index[root]
        for v, idx in query:
            # 如果这个节点之前没有被访问过,那么是不知道LCA的
            if not visited[v]: continue
            # 如果有对应的查询节点v, 并且这个节点之前访问过
            # 那么使用这个节点的parent.
            # 如果v是root的祖先节点的话,那么就是v
            # 如果v在另外一个树上的话,那么就是最早交汇的节点
            p = ufs.find(v)
            ans[idx] = p

        # continue to dfs.
        for v, _ in graph[root]:
            if v != parent:
                dfs(v, root)
                # 遍历子节点之后,将子节点的父节点设置为自己
                ufs.set(v, root)

    dfs(root, -1)
    return ans

class Solution:
    def minOperationsQueries(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
        adj = [[] for _ in range(n)]
        for (u, v, w) in edges:
            adj[u].append((v, w))
            adj[v].append((u, w))
        W = {}

        def dfs(root, parent, weight):
            W[root] = tuple(weight)
            for v, w in adj[root]:
                if v == parent: continue
                weight[w] += 1
                dfs(v, root, weight)
                weight[w] -= 1

        dfs(0, -1, [0] * 27)
        lca = tarjan_lca(adj, 0, queries)
        ans = []
        for (u, v), r in zip(queries, lca):
            w1 = list(W[u])
            w2 = list(W[v])
            w3 = list(W[r])
            for i in range(27):
                w1[i] -= w3[i]
                w2[i] -= w3[i]
                w1[i] += w2[i]
            c = sum(w1) - max(w1)
            ans.append(c)
        return ans