AOC2023-Day25 求解图的最小割
aoc2023 day25 这题 https://adventofcode.com/2023/day/25
最开始没有想出来,看了reddit上面分享 https://www.reddit.com/r/adventofcode/comments/18qbsxs/2023_day_25_solutions/ 然后看了一下wiki, https://en.wikipedia.org/wiki/Karger%27s_algorithm 算法的确是比较简洁
大致思路就是就是不断地根据edge来融合点,知道最后存在两个点,然后看看这两个点之间存在多少边。但是这个是一个概率算法,并不保证得到的是最小割,所以需要多运行几次。
数据结构上用到了find-union, 我们只需要顺序遍历edge, 如果两个点没有融合的话那么就进行融合,并且节点数量就可以减少1了。
#!/usr/bin/env python
# coding:utf-8
# Copyright (C) dirlt
import random
import sys
class UnionFind:
def __init__(self, values):
# r, c, = {}, {}
n = len(values)
r, c = [0] * n, [0] * n
for v in values:
r[v], c[v] = v, 1
self.r, self.c = r, c
def size(self, a):
ra = self.find(a)
return self.c[ra]
def find(self, a):
# find root.
x = a
while True:
ra = self.r[x]
if ra == x:
break
x = ra
# compress path.
x = a
while x != ra:
rx = self.r[x]
self.r[x] = ra
x = rx
return ra
def merge(self, a, b):
ra, rb = self.find(a), self.find(b)
if ra == rb:
return rb
ca, cb = self.c[ra], self.c[rb]
if ca > cb:
ca, cb, ra, rb = cb, ca, rb, ra
self.r[ra] = rb
self.c[rb] += ca
return rb
# https://en.wikipedia.org/wiki/Karger%27s_algorithm
def karger(n, seed, edges):
fu = UnionFind(list(range(n)))
rnd = random.Random(seed)
from collections import deque
edges = edges.copy()
rnd.shuffle(edges)
Q = deque()
for e in edges:
Q.append(e)
while n != 2:
a, b = Q.popleft()
a, b = fu.find(a), fu.find(b)
if a == b: continue
fu.merge(a, b)
n -= 1
left = 0
root = set()
while Q:
a, b = Q.popleft()
a, b = fu.find(a), fu.find(b)
if a != b:
left += 1
root.add(a)
root.add(b)
if left == 3:
assert len(root) == 2
a, b = root
a = fu.size(a)
b = fu.size(b)
return a * b
return 0
def solve(graph):
n = len(graph)
edges = []
for i in range(n):
for j in graph[i]:
if i < j:
edges.append((i, j))
print(n, len(edges))
for seed in range(100000):
print(f'running {seed}')
ans = karger(n, seed, edges)
if ans != 0:
print(ans, seed)
return ans
def main():
# test = True
test = False
input_file = 'tmp.in' if test else 'input.txt'
from collections import defaultdict
adj = defaultdict(list)
numbers = {}
with open(input_file) as fh:
for s in fh:
s = s.strip()
nodes = s.split()
a = nodes[0][:-1]
adj[a] = nodes[1:]
# number it.
if a not in numbers:
numbers[a] = len(numbers)
for b in adj[a]:
if b not in numbers:
numbers[b] = len(numbers)
graph = [set() for _ in range(len(numbers))]
for f, xs in adj.items():
fi = numbers[f]
for x in xs:
ti = numbers[x]
graph[fi].add(ti)
graph[ti].add(fi)
ans = solve(graph)
print(ans)
if __name__ == '__main__':
main()