LC 1998. 数组的最大公因数排序

https://leetcode-cn.com/contest/weekly-contest-257/problems/gcd-sort-of-an-array/

这题的主要思路就是,通过针对数组里面的元素做质数分解,找到联通分量,不同联通分量之间肯定是隔离的。然后针对每个联通分量进行排序,检查最后的结果是否完全排序。

在实现之前最好做一个估算:

对于某个元素 nums[i] 进行质数p分解时,需要

这样如果i和j关联之后,那么可以通过backs[j]找到更多关联的质数。

寻找联通分量就是一个dfs的过程,但是这个过程不仅仅需要判断某个下标是否访问过,还需要判断某个质数是否访问过。因为如果某个质数群访问过的话,那么它下面所有的下标肯定也都访问过了。没有这个优化就会出现超时。

maski = [0] * len(nums)
maskp = set()

def dfs(i, idxs):
    for p in backs[i]:
        if p in maskp: continue
        maskp.add(p)
        for j in groups[p]:
            if maski[j]:
                continue
            maski[j] = 1
            idxs.append(j)
            dfs(j, idxs)

下面是完整的代码,包括计算质数,判断是否有序,按照下标进行排序等。

def make_primes(K):
    mask = [1] * (K + 1)

    for i in range(2, K + 1):
        if mask[i] == 0:
            continue
        for j in range(2, K + 1):
            if i * j > K:
                break
            mask[i * j] = 0

    primes = []
    for i in range(2, K + 1):
        if mask[i] == 1:
            primes.append(i)
    return primes


def issorted(nums):
    for i in range(1, len(nums)):
        if nums[i] < nums[i - 1]:
            return False
    return True


def sort_by_idxs(nums, idxs, tmps):
    for j in idxs:
        tmps.append(nums[j])
    idxs.sort()
    tmps.sort()
    for j in range(len(idxs)):
        nums[idxs[j]] = tmps[j]


class Solution:
    def gcdSort(self, nums: List[int]) -> bool:
        primes = make_primes(round(max(nums) ** 0.5) + 1)

        if issorted(nums):
            return True

        from collections import defaultdict
        groups = defaultdict(list)
        backs = [[] for _ in range(len(nums))]
        for i in range(len(nums)):
            x = nums[i]
            for p in primes:
                if x < p: break
                if x % p == 0:
                    while x % p == 0:
                        x = x // p
                    groups[p].append(i)
                    backs[i].append(p)
            if x != 1:
                groups[x].append(i)
                backs[i].append(x)

        maski = [0] * len(nums)
        maskp = set()

        def dfs(i, idxs):
            for p in backs[i]:
                if p in maskp: continue
                maskp.add(p)
                for j in groups[p]:
                    if maski[j]:
                        continue
                    maski[j] = 1
                    idxs.append(j)
                    dfs(j, idxs)

        idxs = []
        tmps = []
        for i in range(len(nums)):
            if maski[i] == 0:
                idxs.clear()
                tmps.clear()

                maski[i] = 1
                idxs.append(i)
                dfs(i, idxs)

                sort_by_idxs(nums, idxs, tmps)

        return issorted(nums)