LC 1938. 查询最大基因差

https://leetcode-cn.com/problems/maximum-genetic-difference-query/

思路不是很复杂:

  1. 存储在每个节点上的查询值
  2. DFS遍历整棵树,每次经过一个节点,就将该节点放入trie里面
  3. 根据trie找到这个节点上所有需要查询的值的最大xor.
  4. DFS退出这个节点的时候,从trie里面删除这个节点

想过应该来怎么实现删除操作,发现题解里面给的计数方案好像是最简单的,很容易就把程序写正确。

class Tree:
    def __init__(self):
        self.child = [None, None]
        self.cnt = 0


def insert(root, x, bits):
    for i in reversed(range(bits)):
        side = (x >> i) & 0x1
        if root.child[side] is None:
            t = Tree()
            root.child[side] = t
        root = root.child[side]
        root.cnt += 1


def query(root, x, bits):
    ans = 0
    for i in reversed(range(bits)):
        side = (x >> i) & 0x1
        ans = ans * 2
        if root.child[1 - side] is not None and root.child[1 - side].cnt != 0:
            ans += 1
            root = root.child[1 - side]
        else:
            root = root.child[side]
    return ans


def remove(root, x, bits):
    for i in reversed(range(bits)):
        side = (x >> i) & 0x1
        root = root.child[side]
        root.cnt -= 1


class Solution:
    def maxGeneticDifference(self, parents: List[int], queries: List[List[int]]) -> List[int]:
        n = len(parents)
        child = [[] for _ in range(n)]
        root = None
        for i in range(n):
            p = parents[i]
            child[p].append(i)
            if p == -1:
                root = i

        maxValue = n
        flatQueries = [[] for _ in range(n)]
        for i in range(len(queries)):
            node, v = queries[i]
            flatQueries[node].append((i, v))
            maxValue = max(maxValue, v)

        bits = 1
        while (1 << bits) <= maxValue:
            bits += 1

        vis = [0] * n
        tree = Tree()
        ans = [0] * len(queries)

        def dfs(x):
            vis[x] = 1

            insert(tree, x, bits)
            for idx, v in flatQueries[x]:
                res = query(tree, v, bits)
                ans[idx] = res

            for y in child[x]:
                if vis[y]: continue
                dfs(y)

            remove(tree, x, bits)

        dfs(root)
        return ans