LC 6103. 从树中删除边的最小分数

https://leetcode.cn/contest/weekly-contest-299/problems/minimum-score-after-removals-on-a-tree/

这题目想了比较长的时间,觉得下面这个解法应该是正确的,大致思路就是枚举删除边。

问题就是,如果我们枚举删除边之后在去计算的话,那么时间复杂度肯定是不行的。边的数量在1000左右,枚举2个删除边就是1M的数据量级别,那么计算复杂度就应该是O(1),否则肯定会超时。

假设我们已经创建好了这个联通树,并且计算好了所有的XOR值,记为XOR(t). 我们准备删除 `b->a` 和 `d->c` 两条边,大约有下面几种情况

有了两个组件,最后的组件就可以通过全值进行异或了。为了快速判断父亲节点,还需要保存每个节点的父节点情况。

#!/usr/bin/env python
# coding:utf-8
# Copyright (C) dirlt

from typing import List


class TreeNode:
    def __init__(self, index, value, depth):
        self.index = index
        self.value = value
        self.total = value
        self.depth = depth
        self.parent = set()


class Solution:
    def minimumScore(self, nums: List[int], edges: List[List[int]]) -> int:
        n = len(nums)

        adj = [[] for _ in range(n)]
        for x, y in edges:
            adj[x].append(y)
            adj[y].append(x)

        tt = 0
        for x in nums:
            tt ^= x

        trees = [None] * n

        def build(x, d, pt):
            t = TreeNode(x, nums[x], d)
            pt.append(x)
            t.parent = set(pt)
            trees[x] = t
            for y in adj[x]:
                if trees[y]: continue
                t2 = build(y, d + 1, pt)
                t.total ^= t2.total
            pt.pop()

            return t

        build(0, 0, [])
        for xx in edges:
            xx.sort(key=lambda x: trees[x].depth)

        inf = 1 << 30
        ans = inf
        for i in range(len(edges)):
            for j in range(i + 1, len(edges)):
                a, b = edges[i]
                c, d = edges[j]
                ta, tb, tc, td = [trees[x] for x in [a, b, c, d]]
                # tb -> ta
                # td -> tc
                if d in ta.parent:
                    # tb -> ta -> ...  td -> tc
                    x = tb.total
                    y = td.total ^ x
                elif b in tc.parent:
                    # td -> tc -> tb -> ta
                    x = td.total
                    y = tb.total ^ x
                else:
                    x = tb.total
                    y = td.total
                z = tt ^ x ^ y
                minv = min(x, y, z)
                maxv = max(x, y, z)
                res = maxv - minv
                ans = min(res, ans)
        return ans