LC 1938. 查询最大基因差
https://leetcode-cn.com/problems/maximum-genetic-difference-query/
思路不是很复杂:
- 存储在每个节点上的查询值
- DFS遍历整棵树,每次经过一个节点,就将该节点放入trie里面
- 根据trie找到这个节点上所有需要查询的值的最大xor.
- DFS退出这个节点的时候,从trie里面删除这个节点
想过应该来怎么实现删除操作,发现题解里面给的计数方案好像是最简单的,很容易就把程序写正确。
class Tree:
def __init__(self):
self.child = [None, None]
self.cnt = 0
def insert(root, x, bits):
for i in reversed(range(bits)):
side = (x >> i) & 0x1
if root.child[side] is None:
t = Tree()
root.child[side] = t
root = root.child[side]
root.cnt += 1
def query(root, x, bits):
ans = 0
for i in reversed(range(bits)):
side = (x >> i) & 0x1
ans = ans * 2
if root.child[1 - side] is not None and root.child[1 - side].cnt != 0:
ans += 1
root = root.child[1 - side]
else:
root = root.child[side]
return ans
def remove(root, x, bits):
for i in reversed(range(bits)):
side = (x >> i) & 0x1
root = root.child[side]
root.cnt -= 1
class Solution:
def maxGeneticDifference(self, parents: List[int], queries: List[List[int]]) -> List[int]:
n = len(parents)
child = [[] for _ in range(n)]
root = None
for i in range(n):
p = parents[i]
child[p].append(i)
if p == -1:
root = i
maxValue = n
flatQueries = [[] for _ in range(n)]
for i in range(len(queries)):
node, v = queries[i]
flatQueries[node].append((i, v))
maxValue = max(maxValue, v)
bits = 1
while (1 << bits) <= maxValue:
bits += 1
vis = [0] * n
tree = Tree()
ans = [0] * len(queries)
def dfs(x):
vis[x] = 1
insert(tree, x, bits)
for idx, v in flatQueries[x]:
res = query(tree, v, bits)
ans[idx] = res
for y in child[x]:
if vis[y]: continue
dfs(y)
remove(tree, x, bits)
dfs(root)
return ans