LC 1483. 树节点的第 K 个祖先

https://leetcode-cn.com/problems/kth-ancestor-of-a-tree-node

说明这个倍增法解决LCA之前,我先写写我最开始想到的办法。这是一个基于记忆化搜索的方法,但是似乎不太好计算时间复杂度,思路大致是:如果节点node之前查询过k'的祖先并且结果是X的话,那么从节点node查询k的祖先就可以变为"从X查询k-k'的祖先",也就是 kth(node, k) = kth(X, k - k') if kth(node, k') == X. 这种算法最坏的情况,每次查询的路径只有在根节点有重叠。虽然这个代码可以通过,但是只能说明leetcode后台的测试用例不够强,无法处理最坏情况是硬伤。

class TreeAncestor {
    private TreeMap<Integer, TreeMap<Integer, Integer>> dp;
    private int[] parent;
    private int n;

    public TreeAncestor(int n, int[] parent) {
        this.n = n;
        this.parent = parent;
        this.dp = new TreeMap();
    }

    public int getKthAncestor(int node, int k) {
        if (node == -1) {
            return -1;
        }
        if (k == 0) {
            return node;
        }
        Integer ans = -2;
        TreeMap<Integer, Integer> tm = dp.get(node);
        if (tm != null) {
            Integer t = tm.lowerKey(k);
            if (t != null) {
                ans = getKthAncestor(tm.get(t), k - t);
            }
        } else {
            tm = new TreeMap();
            dp.put(node, tm);
        }
        if (ans == -2) {
            ans = getKthAncestor(parent[node], k - 1);
        }
        tm.put(k, ans);
        return ans;
    }
}

然后写写这个倍增法。如果我们使用DFS,记录每个节点的所有可能的祖先的话,那么查询起来速度会很快。但是DFS这个过程时间复杂度非常高, 这个时间复杂度是和空间相关的(因为要记录所有的祖先),时空复杂度都是O(n^2). 为了平衡查询和构建的时间复杂度,我们不记录每个节点所有可能的祖先, 而只记录这个节点向上1,2,4,8..2^k个祖先。这样做的好处是在时间和空间上比较平衡。考虑到一共有n个节点,那么空间复杂度就是O(n*lgn), 时间复杂度上则是O(lgn).

class TreeAncestor:

    def __init__(self, n: int, parent: List[int]):
        dp = [[-1] * 20 for _ in range(n)]
        for i in range(n):
            dp[i][0] = parent[i]

        for k in range(1, 20):
            for i in range(n):
                p = dp[i][k - 1]
                if p != -1:
                    dp[i][k] = dp[dp[i][k - 1]][k - 1]
        self.dp = dp

    def getKthAncestor(self, node: int, k: int) -> int:
        dp = self.dp
        while k > 0:
            d = 0
            while (1 << (d + 1)) <= k:
                d += 1
            node = dp[node][d]
            if node == -1:
                break
            k = k - (1 << d)
        return node