数独程序求解
@2019-07-02T10:48:46
增加了剪枝策略:选择下一个放置点的时候,尽可能考虑这个点上可选项最少。
# 限制策略:考虑可选数值最少的点 class RestrictedNextStrategy: def __init__(self, ps): self.possible_idxs = set(range(len(ps))) def select_next(self, idx): min_choices = 10 min_choice_idx = None for idx in self.possible_idxs: values = choices(idx) count = len(values) if count < min_choices: min_choices = count min_choice_idx = idx return min_choice_idx def should_stop(self, idx): return idx is None def use(self, idx): if idx is not None: self.possible_idxs.remove(idx) def unuse(self, idx): if idx is not None: self.possible_idxs.add(idx)
在《算法设计指南》第7章里面给了一个非常复杂的case. 我的原始算法对下面这个case 需要运行30s才能得到结果,而使用剪枝策略之后可以在1s内返回结果。
grid = [ [0, 0, 0, 0, 0, 0, 0, 1, 2], [0, 0, 0, 0, 3, 5, 0, 0, 0], [0, 0, 0, 6, 0, 0, 0, 7, 0], [7, 0, 0, 0, 0, 0, 3, 0, 0], [0, 0, 0, 4, 0, 0, 8, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 2, 0, 0, 0, 0], [0, 8, 0, 0, 0, 0, 0, 4, 0], [0, 5, 0, 0, 0, 0, 6, 0, 0], ]
➜ misc git:(master) ✗ time python sudoku.py ===== answer(27.35, 1) ===== >>>>>>>>>>>>>>>>>>>> [6, 7, 3, 8, 9, 4, 5, 1, 2] [9, 1, 2, 7, 3, 5, 4, 8, 6] [8, 4, 5, 6, 1, 2, 9, 7, 3] [7, 9, 8, 2, 6, 1, 3, 5, 4] [5, 2, 6, 4, 7, 3, 8, 9, 1] [1, 3, 4, 5, 8, 9, 2, 6, 7] [4, 6, 9, 1, 2, 8, 7, 3, 5] [2, 8, 7, 3, 5, 6, 1, 4, 9] [3, 5, 1, 9, 4, 7, 6, 2, 8] <<<<<<<<<<<<<<<<<<<< python sudoku.py 26.89s user 0.13s system 98% cpu 27.416 total // 这个是使用SimpleNextStrategy的时间 python sudoku.py 0.97s user 0.04s system 95% cpu 1.059 total // 这个是使用RestrictedNextStrategy的时间 """
对网格每个空白点(x,y)进行预处理,求解这个点上可以放置哪些数字。因为数独范围在1-9之间,所以可以用bits表示。每个空白点(x,y)可以放置的数字满足下面几个条件:
- 行x不能和已有的数字重复。对应代码里面的 `xs`.
- 列y不能和已有的数字重复。对应代码里面的 `ys`.
- (x,y)所在的3x3子矩阵内不能重复。对应代码里面的 `rs`.
N = 9 xs, ys, rs, ps = [], [], [], [] for i in range(N): mark = [0] * (N + 1) for j in range(N): mark[grid[i][j]] = 1 cs = 0 for k in range(1, N + 1): if mark[k] == 0: cs |= (1 << k) xs.append(cs) for j in range(N): mark = [0] * (N + 1) for i in range(N): mark[grid[i][j]] = 1 cs = 0 for k in range(1, N + 1): if mark[k] == 0: cs |= (1 << k) ys.append(cs) for i in range(3): for j in range(3): mark = [0] * (N + 1) for k in range(3): for l in range(3): v = grid[3 * i + k][3 * j + l] mark[v] = 1 cs = 0 for k in range(1, N + 1): if mark[k] == 0: cs |= (1 << k) rs.append(cs) for i in range(N): for j in range(N): if grid[i][j] == 0: ps.append((i, j))
然后进行递归求解,每个点(x,y)所满足的数字是 `xs[x] & ys[y] & rs[r]` 的交集。相比使用set来表示可选择的数字,如果使用bits表示的话,那么and操作效率会更高。一个数独可能会存在多个解,在递归求解的时候可以对解的数量进行控制,一旦找到一定数量的解就可以退出了。
ans = [] def dfs(idx): if idx == len(ps): ans.append(copy.deepcopy(grid)) return x, y = ps[idx] r = (x // 3) * 3 + (y // 3) cs = xs[x] & ys[y] & rs[r] for v in range(1, 10): if (cs >> v) & 0x1: unmask = ~(1 << v) mask = (1 << v) grid[x][y] = v xs[x] &= unmask ys[y] &= unmask rs[r] &= unmask dfs(idx + 1) grid[x][y] = 0 xs[x] |= mask ys[y] |= mask rs[r] |= mask if ans and len(ans) == number: return dfs(0) return ans
程序运行时间和求解数量和空白位置数量相关。如果只是求解几个解,那么速度还是蛮快的。
def main(): grid = [ [7, 0, 0, 8, 3, 0, 0, 0, 5], [0, 2, 5, 0, 6, 0, 3, 0, 0], [0, 1, 0, 0, 7, 0, 9, 0, 2], [1, 0, 2, 5, 0, 3, 0, 7, 0], [5, 0, 8, 0, 0, 6, 4, 0, 0], [0, 3, 0, 9, 0, 0, 5, 0, 6], [9, 0, 6, 0, 1, 0, 0, 5, 0], [0, 0, 4, 0, 9, 0, 6, 1, 0], [3, 0, 0, 0, 5, 8, 0, 0, 4] ] start = time.time() ans = sudoku_solve(grid, number=0) stop = time.time() print('===== answer(%.2f, %s) =====' % (stop - start, len(ans))) for arr in ans: print('>' * 20) for i in range(len(arr)): print(arr[i]) print('<' * 20)
输出如下,基本上是秒出
===== answer(0.00, 1) ===== >>>>>>>>>>>>>>>>>>>> [7, 4, 9, 8, 3, 2, 1, 6, 5] [8, 2, 5, 1, 6, 9, 3, 4, 7] [6, 1, 3, 4, 7, 5, 9, 8, 2] [1, 6, 2, 5, 4, 3, 8, 7, 9] [5, 9, 8, 7, 2, 6, 4, 3, 1] [4, 3, 7, 9, 8, 1, 5, 2, 6] [9, 8, 6, 2, 1, 4, 7, 5, 3] [2, 5, 4, 3, 9, 7, 6, 1, 8] [3, 7, 1, 6, 5, 8, 2, 9, 4] <<<<<<<<<<<<<<<<<<<<