LC 2867. 统计树中的合法路径数目

https://leetcode.cn/problems/count-valid-paths-in-a-tree/description/

这题感觉自己是魔怔了,回过头来看其实是挺简单的题目的。

我一度有过这样的想法,可能是之前有过类似的题目

最后提交的程序完全不是那么回事。对于这类枚举量比较小的情况,应该是只需要考虑每个子树然后拼凑起来就行

def get_primes(N):
    ps = []
    mask = [0] * (N + 1)
    for i in range(2, N + 1):
        if mask[i] == 1: continue
        for j in range(2, N + 1):
            if i * j > N: break
            mask[i * j] = 1
    for i in range(2, N + 1):
        if mask[i] == 0:
            ps.append(i)
    return ps


class Solution:
    def countPaths(self, n: int, edges: List[List[int]]) -> int:
        primes = set(get_primes(n))
        adj = [[] for _ in range(n + 1)]
        for x, y in edges:
            adj[x].append(y)
            adj[y].append(x)

        ans = 0

        def dfs(x, p):
            nonlocal ans
            isPrime = x in primes

            c0, c1 = 0, 0
            for y in adj[x]:
                if y == p: continue
                a, b = dfs(y, x)
                if isPrime:
                    ans += c0 * a
                else:
                    ans += c0 * b + c1 * a
                c0 += a
                c1 += b

            if isPrime:
                ans += c0
                res = (0, c0 + 1)
            else:
                ans += c1
                res = (c0 + 1, c1)

            # print(x, ans)
            return res

        dfs(1, -1)
        return ans