AOC2023-Day25 求解图的最小割

aoc2023 day25 这题 https://adventofcode.com/2023/day/25

最开始没有想出来,看了reddit上面分享 https://www.reddit.com/r/adventofcode/comments/18qbsxs/2023_day_25_solutions/ 然后看了一下wiki, https://en.wikipedia.org/wiki/Karger%27s_algorithm 算法的确是比较简洁

大致思路就是就是不断地根据edge来融合点,知道最后存在两个点,然后看看这两个点之间存在多少边。但是这个是一个概率算法,并不保证得到的是最小割,所以需要多运行几次。

数据结构上用到了find-union, 我们只需要顺序遍历edge, 如果两个点没有融合的话那么就进行融合,并且节点数量就可以减少1了。

#!/usr/bin/env python
# coding:utf-8
# Copyright (C) dirlt
import random
import sys

class UnionFind:
    def __init__(self, values):
        # r, c, = {}, {}
        n = len(values)
        r, c = [0] * n, [0] * n
        for v in values:
            r[v], c[v] = v, 1
        self.r, self.c = r, c

    def size(self, a):
        ra = self.find(a)
        return self.c[ra]

    def find(self, a):
        # find root.
        x = a
        while True:
            ra = self.r[x]
            if ra == x:
                break
            x = ra

        # compress path.
        x = a
        while x != ra:
            rx = self.r[x]
            self.r[x] = ra
            x = rx
        return ra

    def merge(self, a, b):
        ra, rb = self.find(a), self.find(b)
        if ra == rb:
            return rb
        ca, cb = self.c[ra], self.c[rb]
        if ca > cb:
            ca, cb, ra, rb = cb, ca, rb, ra
        self.r[ra] = rb
        self.c[rb] += ca
        return rb


# https://en.wikipedia.org/wiki/Karger%27s_algorithm
def karger(n, seed, edges):
    fu = UnionFind(list(range(n)))
    rnd = random.Random(seed)

    from collections import deque
    edges = edges.copy()
    rnd.shuffle(edges)
    Q = deque()
    for e in edges:
        Q.append(e)

    while n != 2:
        a, b = Q.popleft()
        a, b = fu.find(a), fu.find(b)
        if a == b: continue
        fu.merge(a, b)
        n -= 1

    left = 0
    root = set()
    while Q:
        a, b = Q.popleft()
        a, b = fu.find(a), fu.find(b)
        if a != b:
            left += 1
            root.add(a)
            root.add(b)

    if left == 3:
        assert len(root) == 2
        a, b = root
        a = fu.size(a)
        b = fu.size(b)
        return a * b
    return 0


def solve(graph):
    n = len(graph)
    edges = []
    for i in range(n):
        for j in graph[i]:
            if i < j:
                edges.append((i, j))
    print(n, len(edges))

    for seed in range(100000):
        print(f'running {seed}')
        ans = karger(n, seed, edges)
        if ans != 0:
            print(ans, seed)
            return ans


def main():
    # test = True
    test = False
    input_file = 'tmp.in' if test else 'input.txt'

    from collections import defaultdict
    adj = defaultdict(list)
    numbers = {}
    with open(input_file) as fh:
        for s in fh:
            s = s.strip()
            nodes = s.split()
            a = nodes[0][:-1]
            adj[a] = nodes[1:]

            # number it.
            if a not in numbers:
                numbers[a] = len(numbers)
            for b in adj[a]:
                if b not in numbers:
                    numbers[b] = len(numbers)

    graph = [set() for _ in range(len(numbers))]
    for f, xs in adj.items():
        fi = numbers[f]
        for x in xs:
            ti = numbers[x]
            graph[fi].add(ti)
            graph[ti].add(fi)

    ans = solve(graph)
    print(ans)


if __name__ == '__main__':
    main()