数独程序求解

code on github

对网格每个空白点(x,y)进行预处理,求解这个点上可以放置哪些数字。因为数独范围在1-9之间,所以可以用bits表示。每个空白点(x,y)可以放置的数字满足下面几个条件:

  1. 行x不能和已有的数字重复。对应代码里面的 `xs`.
  2. 列y不能和已有的数字重复。对应代码里面的 `ys`.
  3. (x,y)所在的3x3子矩阵内不能重复。对应代码里面的 `rs`.
N = 9
xs, ys, rs, ps = [], [], [], []
for i in range(N):
    mark = [0] * (N + 1)
    for j in range(N):
        mark[grid[i][j]] = 1
    cs = 0
    for k in range(1, N + 1):
        if mark[k] == 0:
            cs |= (1 << k)
    xs.append(cs)

for j in range(N):
    mark = [0] * (N + 1)
    for i in range(N):
        mark[grid[i][j]] = 1
    cs = 0
    for k in range(1, N + 1):
        if mark[k] == 0:
            cs |= (1 << k)
    ys.append(cs)

for i in range(3):
    for j in range(3):
        mark = [0] * (N + 1)
        for k in range(3):
            for l in range(3):
                v = grid[3 * i + k][3 * j + l]
                mark[v] = 1
        cs = 0
        for k in range(1, N + 1):
            if mark[k] == 0:
                cs |= (1 << k)
        rs.append(cs)

for i in range(N):
    for j in range(N):
        if grid[i][j] == 0:
            ps.append((i, j))

然后进行递归求解,每个点(x,y)所满足的数字是 `xs[x] & ys[y] & rs[r]` 的交集。相比使用set来表示可选择的数字,如果使用bits表示的话,那么and操作效率会更高。一个数独可能会存在多个解,在递归求解的时候可以对解的数量进行控制,一旦找到一定数量的解就可以退出了。

ans = []

def dfs(idx):
    if idx == len(ps):
        ans.append(copy.deepcopy(grid))
        return

    x, y = ps[idx]
    r = (x // 3) * 3 + (y // 3)
    cs = xs[x] & ys[y] & rs[r]
    for v in range(1, 10):
        if (cs >> v) & 0x1:
            unmask = ~(1 << v)
            mask = (1 << v)
            grid[x][y] = v
            xs[x] &= unmask
            ys[y] &= unmask
            rs[r] &= unmask
            dfs(idx + 1)
            grid[x][y] = 0
            xs[x] |= mask
            ys[y] |= mask
            rs[r] |= mask
            if ans and len(ans) == number:
                return

dfs(0)
return ans

程序运行时间和求解数量和空白位置数量相关。如果只是求解几个解,那么速度还是蛮快的。

def main():
    grid = [
        [7, 0, 0, 8, 3, 0, 0, 0, 5],
        [0, 2, 5, 0, 6, 0, 3, 0, 0],
        [0, 1, 0, 0, 7, 0, 9, 0, 2],
        [1, 0, 2, 5, 0, 3, 0, 7, 0],
        [5, 0, 8, 0, 0, 6, 4, 0, 0],
        [0, 3, 0, 9, 0, 0, 5, 0, 6],
        [9, 0, 6, 0, 1, 0, 0, 5, 0],
        [0, 0, 4, 0, 9, 0, 6, 1, 0],
        [3, 0, 0, 0, 5, 8, 0, 0, 4]
    ]
    start = time.time()
    ans = sudoku_solve(grid, number=0)
    stop = time.time()
    print('===== answer(%.2f, %s) =====' % (stop - start, len(ans)))
    for arr in ans:
        print('>' * 20)
        for i in range(len(arr)):
            print(arr[i])
        print('<' * 20)

输出如下,基本上是秒出

===== answer(0.00, 1) =====
>>>>>>>>>>>>>>>>>>>>
[7, 4, 9, 8, 3, 2, 1, 6, 5]
[8, 2, 5, 1, 6, 9, 3, 4, 7]
[6, 1, 3, 4, 7, 5, 9, 8, 2]
[1, 6, 2, 5, 4, 3, 8, 7, 9]
[5, 9, 8, 7, 2, 6, 4, 3, 1]
[4, 3, 7, 9, 8, 1, 5, 2, 6]
[9, 8, 6, 2, 1, 4, 7, 5, 3]
[2, 5, 4, 3, 9, 7, 6, 1, 8]
[3, 7, 1, 6, 5, 8, 2, 9, 4]
<<<<<<<<<<<<<<<<<<<<