AOC2023-Day25 求解图的最小割
aoc2023 day25 这题 https://adventofcode.com/2023/day/25
最开始没有想出来,看了reddit上面分享 https://www.reddit.com/r/adventofcode/comments/18qbsxs/2023_day_25_solutions/ 然后看了一下wiki, https://en.wikipedia.org/wiki/Karger%27s_algorithm 算法的确是比较简洁
大致思路就是就是不断地根据edge来融合点,知道最后存在两个点,然后看看这两个点之间存在多少边。但是这个是一个概率算法,并不保证得到的是最小割,所以需要多运行几次。
数据结构上用到了find-union, 我们只需要顺序遍历edge, 如果两个点没有融合的话那么就进行融合,并且节点数量就可以减少1了。
#!/usr/bin/env python # coding:utf-8 # Copyright (C) dirlt import random import sys class UnionFind: def __init__(self, values): # r, c, = {}, {} n = len(values) r, c = [0] * n, [0] * n for v in values: r[v], c[v] = v, 1 self.r, self.c = r, c def size(self, a): ra = self.find(a) return self.c[ra] def find(self, a): # find root. x = a while True: ra = self.r[x] if ra == x: break x = ra # compress path. x = a while x != ra: rx = self.r[x] self.r[x] = ra x = rx return ra def merge(self, a, b): ra, rb = self.find(a), self.find(b) if ra == rb: return rb ca, cb = self.c[ra], self.c[rb] if ca > cb: ca, cb, ra, rb = cb, ca, rb, ra self.r[ra] = rb self.c[rb] += ca return rb # https://en.wikipedia.org/wiki/Karger%27s_algorithm def karger(n, seed, edges): fu = UnionFind(list(range(n))) rnd = random.Random(seed) from collections import deque edges = edges.copy() rnd.shuffle(edges) Q = deque() for e in edges: Q.append(e) while n != 2: a, b = Q.popleft() a, b = fu.find(a), fu.find(b) if a == b: continue fu.merge(a, b) n -= 1 left = 0 root = set() while Q: a, b = Q.popleft() a, b = fu.find(a), fu.find(b) if a != b: left += 1 root.add(a) root.add(b) if left == 3: assert len(root) == 2 a, b = root a = fu.size(a) b = fu.size(b) return a * b return 0 def solve(graph): n = len(graph) edges = [] for i in range(n): for j in graph[i]: if i < j: edges.append((i, j)) print(n, len(edges)) for seed in range(100000): print(f'running {seed}') ans = karger(n, seed, edges) if ans != 0: print(ans, seed) return ans def main(): # test = True test = False input_file = 'tmp.in' if test else 'input.txt' from collections import defaultdict adj = defaultdict(list) numbers = {} with open(input_file) as fh: for s in fh: s = s.strip() nodes = s.split() a = nodes[0][:-1] adj[a] = nodes[1:] # number it. if a not in numbers: numbers[a] = len(numbers) for b in adj[a]: if b not in numbers: numbers[b] = len(numbers) graph = [set() for _ in range(len(numbers))] for f, xs in adj.items(): fi = numbers[f] for x in xs: ti = numbers[x] graph[fi].add(ti) graph[ti].add(fi) ans = solve(graph) print(ans) if __name__ == '__main__': main()