数独程序求解

code on github


@2019-07-02T10:48:46

增加了剪枝策略:选择下一个放置点的时候,尽可能考虑这个点上可选项最少。

# 限制策略:考虑可选数值最少的点
class RestrictedNextStrategy:
    def __init__(self, ps):
        self.possible_idxs = set(range(len(ps)))

    def select_next(self, idx):
        min_choices = 10
        min_choice_idx = None
        for idx in self.possible_idxs:
            values = choices(idx)
            count = len(values)
            if count < min_choices:
                min_choices = count
                min_choice_idx = idx
        return min_choice_idx

    def should_stop(self, idx):
        return idx is None

    def use(self, idx):
        if idx is not None:
            self.possible_idxs.remove(idx)

    def unuse(self, idx):
        if idx is not None:
            self.possible_idxs.add(idx)

在《算法设计指南》第7章里面给了一个非常复杂的case. 我的原始算法对下面这个case 需要运行30s才能得到结果,而使用剪枝策略之后可以在1s内返回结果。

grid = [
    [0, 0, 0, 0, 0, 0, 0, 1, 2],
    [0, 0, 0, 0, 3, 5, 0, 0, 0],
    [0, 0, 0, 6, 0, 0, 0, 7, 0],
    [7, 0, 0, 0, 0, 0, 3, 0, 0],
    [0, 0, 0, 4, 0, 0, 8, 0, 0],
    [1, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 1, 2, 0, 0, 0, 0],
    [0, 8, 0, 0, 0, 0, 0, 4, 0],
    [0, 5, 0, 0, 0, 0, 6, 0, 0],
]
➜  misc git:(master) ✗ time python sudoku.py
===== answer(27.35, 1) =====
>>>>>>>>>>>>>>>>>>>>
[6, 7, 3, 8, 9, 4, 5, 1, 2]
[9, 1, 2, 7, 3, 5, 4, 8, 6]
[8, 4, 5, 6, 1, 2, 9, 7, 3]
[7, 9, 8, 2, 6, 1, 3, 5, 4]
[5, 2, 6, 4, 7, 3, 8, 9, 1]
[1, 3, 4, 5, 8, 9, 2, 6, 7]
[4, 6, 9, 1, 2, 8, 7, 3, 5]
[2, 8, 7, 3, 5, 6, 1, 4, 9]
[3, 5, 1, 9, 4, 7, 6, 2, 8]
<<<<<<<<<<<<<<<<<<<<
python sudoku.py  26.89s user 0.13s system 98% cpu 27.416 total // 这个是使用SimpleNextStrategy的时间
python sudoku.py  0.97s user 0.04s system 95% cpu 1.059 total // 这个是使用RestrictedNextStrategy的时间
"""

对网格每个空白点(x,y)进行预处理,求解这个点上可以放置哪些数字。因为数独范围在1-9之间,所以可以用bits表示。每个空白点(x,y)可以放置的数字满足下面几个条件:

  1. 行x不能和已有的数字重复。对应代码里面的 `xs`.
  2. 列y不能和已有的数字重复。对应代码里面的 `ys`.
  3. (x,y)所在的3x3子矩阵内不能重复。对应代码里面的 `rs`.
N = 9
xs, ys, rs, ps = [], [], [], []
for i in range(N):
    mark = [0] * (N + 1)
    for j in range(N):
        mark[grid[i][j]] = 1
    cs = 0
    for k in range(1, N + 1):
        if mark[k] == 0:
            cs |= (1 << k)
    xs.append(cs)

for j in range(N):
    mark = [0] * (N + 1)
    for i in range(N):
        mark[grid[i][j]] = 1
    cs = 0
    for k in range(1, N + 1):
        if mark[k] == 0:
            cs |= (1 << k)
    ys.append(cs)

for i in range(3):
    for j in range(3):
        mark = [0] * (N + 1)
        for k in range(3):
            for l in range(3):
                v = grid[3 * i + k][3 * j + l]
                mark[v] = 1
        cs = 0
        for k in range(1, N + 1):
            if mark[k] == 0:
                cs |= (1 << k)
        rs.append(cs)

for i in range(N):
    for j in range(N):
        if grid[i][j] == 0:
            ps.append((i, j))

然后进行递归求解,每个点(x,y)所满足的数字是 `xs[x] & ys[y] & rs[r]` 的交集。相比使用set来表示可选择的数字,如果使用bits表示的话,那么and操作效率会更高。一个数独可能会存在多个解,在递归求解的时候可以对解的数量进行控制,一旦找到一定数量的解就可以退出了。

ans = []

def dfs(idx):
    if idx == len(ps):
        ans.append(copy.deepcopy(grid))
        return

    x, y = ps[idx]
    r = (x // 3) * 3 + (y // 3)
    cs = xs[x] & ys[y] & rs[r]
    for v in range(1, 10):
        if (cs >> v) & 0x1:
            unmask = ~(1 << v)
            mask = (1 << v)
            grid[x][y] = v
            xs[x] &= unmask
            ys[y] &= unmask
            rs[r] &= unmask
            dfs(idx + 1)
            grid[x][y] = 0
            xs[x] |= mask
            ys[y] |= mask
            rs[r] |= mask
            if ans and len(ans) == number:
                return

dfs(0)
return ans

程序运行时间和求解数量和空白位置数量相关。如果只是求解几个解,那么速度还是蛮快的。

def main():
    grid = [
        [7, 0, 0, 8, 3, 0, 0, 0, 5],
        [0, 2, 5, 0, 6, 0, 3, 0, 0],
        [0, 1, 0, 0, 7, 0, 9, 0, 2],
        [1, 0, 2, 5, 0, 3, 0, 7, 0],
        [5, 0, 8, 0, 0, 6, 4, 0, 0],
        [0, 3, 0, 9, 0, 0, 5, 0, 6],
        [9, 0, 6, 0, 1, 0, 0, 5, 0],
        [0, 0, 4, 0, 9, 0, 6, 1, 0],
        [3, 0, 0, 0, 5, 8, 0, 0, 4]
    ]
    start = time.time()
    ans = sudoku_solve(grid, number=0)
    stop = time.time()
    print('===== answer(%.2f, %s) =====' % (stop - start, len(ans)))
    for arr in ans:
        print('>' * 20)
        for i in range(len(arr)):
            print(arr[i])
        print('<' * 20)

输出如下,基本上是秒出

===== answer(0.00, 1) =====
>>>>>>>>>>>>>>>>>>>>
[7, 4, 9, 8, 3, 2, 1, 6, 5]
[8, 2, 5, 1, 6, 9, 3, 4, 7]
[6, 1, 3, 4, 7, 5, 9, 8, 2]
[1, 6, 2, 5, 4, 3, 8, 7, 9]
[5, 9, 8, 7, 2, 6, 4, 3, 1]
[4, 3, 7, 9, 8, 1, 5, 2, 6]
[9, 8, 6, 2, 1, 4, 7, 5, 3]
[2, 5, 4, 3, 9, 7, 6, 1, 8]
[3, 7, 1, 6, 5, 8, 2, 9, 4]
<<<<<<<<<<<<<<<<<<<<