华容道程序求解

Table of Contents

code on github

在本科的时候学习完了BFS,就想到可以用它来求解华容道的最短路径求解。当时是用C语言编写的,为了简化状态表示,还使用整数来表示每个状态。 因为C语言当时没有标准库,所以还是手工实现的Queue, Hashmap等一些辅助数据结构。这几天忽然又想到这个问题,所以用Python重新实现了一遍。

1 状态表示

为了可以处理任意大小的矩阵,状态就使用矩阵来表示。考虑到Python的列表性能比较差,所以初始化的时候将python列表表示的矩阵变为numpy矩阵。

class State:
    def __init__(self, matrix, xy=None):
        if not isinstance(matrix, np.ndarray):
            matrix = np.array(matrix)
        self.matrix = matrix
        self.nm = matrix.shape
        self.xy = xy
        self.str_cache = None
        self.id_cache = None
        if xy is None:
            self.xy = self.find_zero()

    def find_zero(self):
        matrix = self.matrix
        n, m = self.nm
        for i in range(n):
            for j in range(m):
                if matrix[i][j] == 0:
                    return i, j

    def next_states(self):
        matrix = self.matrix
        x, y = self.xy
        n, m = self.nm
        states = []
        for dx, dy in ((0, 1), (0, -1), (1, 0), (-1, 0)):
            x2, y2 = x + dx, y + dy
            if 0 <= x2 < n and 0 <= y2 < m:
                matrix2 = np.copy(matrix)
                matrix2[x2][y2], matrix2[x][y] = matrix2[x][y], matrix2[x2][y2]
                state2 = State(matrix2, (x2, y2))
                states.append(state2)
        return states

    def __str__(self):
        return self.to_string()

    def is_equal(self, other):
        return self.xy == other.xy and self.identity() == other.identity()

    def identity(self):
        return self.matrix.tobytes()

    def to_string(self):
        if self.str_cache is not None:
            return self.str_cache
        self.str_cache = str(self.matrix)
        return self.str_cache


2 状态记录

每个状态会产生多个新状态,在BFS的时候需要判断这些状态之前是否看到过,所以我们需要有个对象进行状态记录。 另外每个状态最好可以用一个整数来对应,这样在BFS inqueue的时候就只需要放入整数就好。

class StateBK:
    def __init__(self):
        self.map = {}
        self.seq = []

    def get_index(self, st: State):
        if st.identity() in self.map:
            return self.map[st.identity()]
        index = len(self.seq)
        self.map[st.identity()] = index
        self.seq.append(st)
        return index

    def query_index(self, st: State):
        return st.identity() in self.map

    def get_state(self, index):
        return self.seq[index]

3 naive BFS

下面这个算法是通过BFS进行搜索。这个算法有个问题是,如果路径很长的话,整个树需要展开很多层, 会涉及到许多状态的探索,时间就会非常长。最坏的情况是,如果没有路径的话,那么需要遍历所有状态。

# NOTE(yan): naive BFS
def search_path_1(source: State, dest: State):
    bk = StateBK()
    parents = {}
    Q = deque()

    idx = bk.get_index(source)
    parents[idx] = -1
    Q.append(idx)

    paths = []
    found = False
    while len(Q):
        idx = Q.popleft()
        state = bk.get_state(idx)
        if state.is_equal(dest):
            found = True
            break
        next_states = state.next_states()
        for st in next_states:
            if bk.query_index(st):
                continue
            idx2 = bk.get_index(st)
            parents[idx2] = idx
            Q.append(idx2)

    if found:
        idx = bk.get_index(dest)
        while idx != -1:
            paths.append(bk.get_state(idx))
            idx = parents[idx]
        paths = paths[::-1]
    return paths

4 bidirectional BFS

这几天我想到的一个改进是,是否可以从source/destination同时进行检索。如果两个搜索方向上有一些状态是重合的话,那么说明这个便是最短路径。

如果最短路径的长度是20的话,因为每个状态会展开成为4个状态,那么最多会展开 4 ^ 20个状态(当然考虑到部分状态之前访问过,以及fanout没有这么大, 所以实际情况不会有这么多,但是大约是这个量级)。

但是如果是双向搜索的话,那么每个方向只需要搜索长度10的路径,那么最多会展开2 * (4 ^ 10)个状态,这个数量比之前的少很多。如果存在路径的话, 那么这种双向BFS会节省很多时间。

# NOTE(yan): bidirectional BFS
def search_path_2(source: State, dest: State):
    bk = [StateBK(), StateBK()]
    parents = [{}, {}]
    dists = [{}, {}]
    Q = [deque(), deque()]

    idx = bk[0].get_index(source)
    parents[0][idx] = -1
    dists[0][idx] = 0
    Q[0].append((idx, 0))

    idx = bk[1].get_index(dest)
    parents[1][idx] = -1
    dists[1][idx] = 0
    Q[1].append((idx, 0))

    depth = -1
    found = False

    # distance, pidx0, pidx1, direction
    opt = (1 << 30, None, None, 0)

    while True:
        depth += 1
        for i in range(2):
            while len(Q[i]):
                idx, d = Q[i].popleft()
                if d != depth:
                    Q[i].append((idx, d))
                    break

                state = bk[i].get_state(idx)
                if bk[1 - i].query_index(state):
                    pidx0 = idx
                    pidx1 = bk[1 - i].get_index(state)
                    dist = dists[i][pidx0] + dists[1 - i][pidx1]
                    if dist < opt[0]:
                        # print('min dist = {}, i = {}'.format(dist, i))
                        opt = (dist, pidx0, pidx1, i)
                        found = True
                        break

                next_states = state.next_states()
                for st in next_states:
                    if bk[i].query_index(st):
                        continue
                    idx2 = bk[i].get_index(st)
                    parents[i][idx2] = idx
                    dists[i][idx2] = d + 1
                    Q[i].append((idx2, d + 1))
            if found: break
        if found or not len(Q[0]) or not len(Q[1]):
            break

    if not found:
        return []

    dist, pidx0, pidx1, i = opt
    paths0 = []
    while pidx0 != -1:
        paths0.append(bk[i].get_state(pidx0))
        pidx0 = parents[i][pidx0]

    paths1 = []
    while pidx1 != -1:
        paths1.append(bk[1 - i].get_state(pidx1))
        pidx1 = parents[1 - i][pidx1]

    assert len(paths0) > 0
    assert len(paths1) > 0
    paths = paths0[::-1] + paths1[1:]
    if i: paths = paths[::-1]
    return paths

5 速度对比

def main():
    source_matrix = [
        [0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]
    ]
    dest_matrix = [
        [1, 2, 3],
        [4, 5, 6],
        [7, 8, 0]
    ]
    source = State(source_matrix)
    dest = State(dest_matrix)

    start = time.time()
    paths1 = search_path_1(source, dest)
    print('naive BFS ...')
    # print_paths(paths1)
    print('size = {}'.format(len(paths1)))
    end = time.time()
    print('timer = {}'.format(end - start))

    start = time.time()
    paths2 = search_path_2(source, dest)
    print('bidirectional BFS ...')
    # print_paths(paths2)
    print('size = {}'.format(len(paths2)))
    end = time.time()
    print('timer = {}'.format(end - start))


运行下来速度差别还是蛮大的,方法1是1.492s, 方法2是0.026s, 时间缩短了差不多98%.

➜  misc git:(master) ✗ python klotski.py
naive BFS ...
size = 23
timer = 1.4918239116668701
bidirectional BFS ...
size = 23
timer = 0.026241064071655273