数独程序求解
@2019-07-02T10:48:46
增加了剪枝策略:选择下一个放置点的时候,尽可能考虑这个点上可选项最少。
# 限制策略:考虑可选数值最少的点
class RestrictedNextStrategy:
def __init__(self, ps):
self.possible_idxs = set(range(len(ps)))
def select_next(self, idx):
min_choices = 10
min_choice_idx = None
for idx in self.possible_idxs:
values = choices(idx)
count = len(values)
if count < min_choices:
min_choices = count
min_choice_idx = idx
return min_choice_idx
def should_stop(self, idx):
return idx is None
def use(self, idx):
if idx is not None:
self.possible_idxs.remove(idx)
def unuse(self, idx):
if idx is not None:
self.possible_idxs.add(idx)
在《算法设计指南》第7章里面给了一个非常复杂的case. 我的原始算法对下面这个case 需要运行30s才能得到结果,而使用剪枝策略之后可以在1s内返回结果。
grid = [
[0, 0, 0, 0, 0, 0, 0, 1, 2],
[0, 0, 0, 0, 3, 5, 0, 0, 0],
[0, 0, 0, 6, 0, 0, 0, 7, 0],
[7, 0, 0, 0, 0, 0, 3, 0, 0],
[0, 0, 0, 4, 0, 0, 8, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 2, 0, 0, 0, 0],
[0, 8, 0, 0, 0, 0, 0, 4, 0],
[0, 5, 0, 0, 0, 0, 6, 0, 0],
]
➜ misc git:(master) ✗ time python sudoku.py ===== answer(27.35, 1) ===== >>>>>>>>>>>>>>>>>>>> [6, 7, 3, 8, 9, 4, 5, 1, 2] [9, 1, 2, 7, 3, 5, 4, 8, 6] [8, 4, 5, 6, 1, 2, 9, 7, 3] [7, 9, 8, 2, 6, 1, 3, 5, 4] [5, 2, 6, 4, 7, 3, 8, 9, 1] [1, 3, 4, 5, 8, 9, 2, 6, 7] [4, 6, 9, 1, 2, 8, 7, 3, 5] [2, 8, 7, 3, 5, 6, 1, 4, 9] [3, 5, 1, 9, 4, 7, 6, 2, 8] <<<<<<<<<<<<<<<<<<<< python sudoku.py 26.89s user 0.13s system 98% cpu 27.416 total // 这个是使用SimpleNextStrategy的时间 python sudoku.py 0.97s user 0.04s system 95% cpu 1.059 total // 这个是使用RestrictedNextStrategy的时间 """
对网格每个空白点(x,y)进行预处理,求解这个点上可以放置哪些数字。因为数独范围在1-9之间,所以可以用bits表示。每个空白点(x,y)可以放置的数字满足下面几个条件:
- 行x不能和已有的数字重复。对应代码里面的 `xs`.
- 列y不能和已有的数字重复。对应代码里面的 `ys`.
- (x,y)所在的3x3子矩阵内不能重复。对应代码里面的 `rs`.
N = 9
xs, ys, rs, ps = [], [], [], []
for i in range(N):
mark = [0] * (N + 1)
for j in range(N):
mark[grid[i][j]] = 1
cs = 0
for k in range(1, N + 1):
if mark[k] == 0:
cs |= (1 << k)
xs.append(cs)
for j in range(N):
mark = [0] * (N + 1)
for i in range(N):
mark[grid[i][j]] = 1
cs = 0
for k in range(1, N + 1):
if mark[k] == 0:
cs |= (1 << k)
ys.append(cs)
for i in range(3):
for j in range(3):
mark = [0] * (N + 1)
for k in range(3):
for l in range(3):
v = grid[3 * i + k][3 * j + l]
mark[v] = 1
cs = 0
for k in range(1, N + 1):
if mark[k] == 0:
cs |= (1 << k)
rs.append(cs)
for i in range(N):
for j in range(N):
if grid[i][j] == 0:
ps.append((i, j))
然后进行递归求解,每个点(x,y)所满足的数字是 `xs[x] & ys[y] & rs[r]` 的交集。相比使用set来表示可选择的数字,如果使用bits表示的话,那么and操作效率会更高。一个数独可能会存在多个解,在递归求解的时候可以对解的数量进行控制,一旦找到一定数量的解就可以退出了。
ans = []
def dfs(idx):
if idx == len(ps):
ans.append(copy.deepcopy(grid))
return
x, y = ps[idx]
r = (x // 3) * 3 + (y // 3)
cs = xs[x] & ys[y] & rs[r]
for v in range(1, 10):
if (cs >> v) & 0x1:
unmask = ~(1 << v)
mask = (1 << v)
grid[x][y] = v
xs[x] &= unmask
ys[y] &= unmask
rs[r] &= unmask
dfs(idx + 1)
grid[x][y] = 0
xs[x] |= mask
ys[y] |= mask
rs[r] |= mask
if ans and len(ans) == number:
return
dfs(0)
return ans
程序运行时间和求解数量和空白位置数量相关。如果只是求解几个解,那么速度还是蛮快的。
def main():
grid = [
[7, 0, 0, 8, 3, 0, 0, 0, 5],
[0, 2, 5, 0, 6, 0, 3, 0, 0],
[0, 1, 0, 0, 7, 0, 9, 0, 2],
[1, 0, 2, 5, 0, 3, 0, 7, 0],
[5, 0, 8, 0, 0, 6, 4, 0, 0],
[0, 3, 0, 9, 0, 0, 5, 0, 6],
[9, 0, 6, 0, 1, 0, 0, 5, 0],
[0, 0, 4, 0, 9, 0, 6, 1, 0],
[3, 0, 0, 0, 5, 8, 0, 0, 4]
]
start = time.time()
ans = sudoku_solve(grid, number=0)
stop = time.time()
print('===== answer(%.2f, %s) =====' % (stop - start, len(ans)))
for arr in ans:
print('>' * 20)
for i in range(len(arr)):
print(arr[i])
print('<' * 20)
输出如下,基本上是秒出
===== answer(0.00, 1) ===== >>>>>>>>>>>>>>>>>>>> [7, 4, 9, 8, 3, 2, 1, 6, 5] [8, 2, 5, 1, 6, 9, 3, 4, 7] [6, 1, 3, 4, 7, 5, 9, 8, 2] [1, 6, 2, 5, 4, 3, 8, 7, 9] [5, 9, 8, 7, 2, 6, 4, 3, 1] [4, 3, 7, 9, 8, 1, 5, 2, 6] [9, 8, 6, 2, 1, 4, 7, 5, 3] [2, 5, 4, 3, 9, 7, 6, 1, 8] [3, 7, 1, 6, 5, 8, 2, 9, 4] <<<<<<<<<<<<<<<<<<<<