LC 1916. 统计为蚁群构筑房间的不同顺序

https://leetcode-cn.com/contest/weekly-contest-247/problems/count-ways-to-build-rooms-in-an-ant-colony/

因为只有N个节点,并且只有N-1条边,同时保证最后是完全联通的,所以结构上只能是一个树。

假设一个节点有3个孩子ABC,每个孩子分别有Na, Nb, Nc种访问顺序,那么这棵树有:

  1. 首先不考虑A, B, C各自节点穿插,那么就有Na * Nb * Nc种访问顺序
  2. 如果考虑之间的穿插,不考虑ABC内部顺序的话,那么有(A+B+C)!种访问顺序
  3. 如果考虑内部顺序的话,那么还要除去 (A! * B! * C!), 因为对A内来说顺序不能打乱,已经有Na种了。

这个计算过程涉及到除法取模,所以要用到 费马小定理. 阶乘以及阶乘的倒数必须预先计算出来。整个过程分为两步:构建树和遍历树。

#!/usr/bin/env python
# coding:utf-8
# Copyright (C) dirlt

from typing import List
from collections import Counter, defaultdict, deque
from functools import lru_cache
import heapq


class Tree:
    def __init__(self, idx):
        self.idx = idx
        self.child = []
        self.n = 1
        self.val = 1


class Solution:
    def waysToBuildRooms(self, prevRoom: List[int]) -> int:
        # n个节点并且n-1条边,并且0可以访问到所有房间,说明是个树结构
        n = len(prevRoom)
        ind = [0] * n
        for x in prevRoom:
            if x == -1: continue
            ind[x] += 1

        # build tree by top sort.
        from collections import deque
        dq = deque()
        for i in range(n):
            if ind[i] == 0:
                dq.append(i)
        trees = []
        for i in range(n):
            trees.append(Tree(i))

        while dq:
            x = dq.popleft()
            p = prevRoom[x]
            if p == -1: continue
            t = trees[x]
            pt = trees[p]
            pt.child.append(t)
            ind[p] -= 1
            if ind[p] == 0:
                dq.append(p)

        MOD = 10 ** 9 + 7
        fac = [0] * (n + 1)
        rev = [0] * (n + 1)

        t = 1
        for i in range(1, n + 1):
            t = (t * i) % MOD
            fac[i] = t

        def pow(x, y):
            t = 1
            while y:
                if y & 0x1:
                    t = (t * x) % MOD
                x = (x * x) % MOD
                y = y >> 1
            return t

        for i in range(1, n + 1):
            rev[i] = pow(fac[i], MOD - 2)

        # print(fac, rev)

        def compute(t):
            if len(t.child) == 0:
                return

            # 假设有3个子节点,A, B, C
            # 每个节点下面排列分别是Na, Nb, Nc
            # 那么总排列是 (A+B+C)! / (A! * B! * C!) * (Na * Nb * Nc)
            n = 0
            b = 1
            for c in t.child:
                compute(c)
                n += c.n
                b = (b * rev[c.n] * c.val) % MOD
            t.n = n + 1
            t.val = (fac[n] * b) % MOD
            return

        compute(trees[0])
        ans = trees[0].val
        return ans