LeetCode37 使用回溯算法实现解数独,详解剪枝优化

1,142 阅读12分钟

本文始发于个人公众号:TechFlow,原创不易,求个关注


数独是一个老少咸宜的益智游戏,一直有很多拥趸。但是有没有想过,数独游戏是怎么创造出来的呢?当然我们可以每一关都人工设置,但是显然这工作量非常大,满足不了数独爱好者的需求。

所以常见的一种形式是,我们只会选择难度,不同的难度对应不同的留空的数量。最后由程序根据我们选择的难度替我们生成一个数独问题。但是熟悉数独的朋友都知道,并不是所有的数独都是可解的,如果设置的不好可能会出现数独无法解开的情况。所以程序在生成完数独之后,往往还需要进行可行性检验。

所以今天文章的内容就关于如何解开一个数独

题意

LeetCode当中关于数独的是36和37两题,其中36要求判断一个给出的数独问题是否合法。37题则是给出一个必定拥有一个解法的数独的解。36题只需要判断在同行同列以及同区域当中是否有重复的数组出现,没太多的意思,所以我们跳过,直接进入37题紧张刺激的解数独问题。

题意没什么好说的,就是解这样一个数独:

要求数独当中的每行每列以及每个黑边框加粗标记的3*3的区域当中1-9都只出现一次。这个就是日常数独的规则,我想大家应该都能看明白。

解完成之后,是这样的:

题目就介绍完了,下面进入正题,即试着去解开这个数独。

解法

之前在写题解的时候,经常写的一句话就是从最简单的暴力解法开始。然而之前也说过了,并不是所有的问题都有简单粗暴的暴力解法的。比如这题就不行。不行的原因也很简单,因为我们并不知道数独当中留了多少空,我们很难简单地用循环去遍历所有填空的方法。

并且对于所有需要填数字的空格而言,前面的选择的数字会影响后面的决策,所以从原理上我们也不能直接遍历,需要用一个模式将这些待决策的区域串联起来。前面选过的数字后面自动不选,如果之后选错了数字还可以撤销,回到之前的选择。如果看过之前关于回溯算法文章的同学从我的描述当中应该能get到,这描述的其实是回溯算法的使用场景。

如果对回溯算法有所遗忘或者是新关注的同学可以点击下方链接,回顾一下关于搜索和回溯算法的讲解。

LeetCode 31:递归、回溯、八皇后、全排列一篇文章全讲清楚

和八皇后的对比

如果你还记得八皇后问题,再来看这道题可能会有一些感觉。也许你会觉得这两题好像有一些共通的部分,如果你再仔细思考一下,你也许会发现其实这看似迥异的两个问题实则是在说同一件事情。

你看,在八皇后当中,我们需要考虑的是皇后的位置,经过我们的优化之后,我们把问题转化成了每一行放置一个皇后。我们需要选择,在当前行皇后应该放在那里。而在本题当中,空白的位置是固定的,我们要选择的不再是位置,而是空白当中需要填什么数字。你看,一个是选择放置的位置,一个是选择放置的数字,表面上来看不太相同,但实际上都是在做同样一件事情,就是选择。再仔细分析一下,又会发现皇后可以选择的位置是固定的,这题数独上可供选择的数字其实也是固定的。

这难道不是同一个问题吗?

既然是同一个问题,那当然可以使用同一种方法。在八皇后当中我们通过回溯法枚举了皇后放置的位置,通过回溯修改之前的选择来找答案。这题本质上是一样的,我们枚举空白位置放置的数字,如果之后遍历不成功,找不到解,说明之前的放置错了,我们需要回溯回去修改之前的选择。

我们再来看下回溯问题的代码模板:

def dfs(depth):
if depth >= 8:
return
for choice in all_choices():
record(choice)
dfs(depth+1)
rollback(choice)

对照模板,八皇后当中递归深度是皇后的数量,这题当中就是空白位置的数量。八皇后选择的是皇后放置的位置,这题当中就是选择空白点放置的数字。八皇后当中回溯是将皇后移除,这题当中是将之前放的数字挪走。对照一下,想必你们肯定可以非常顺利地写出代码:

def dfs(board, n, ret):
if n == 81:
# 判断棋盘是否合法
if validateBoard(board):
ret = board.copy()
return

x, y = n / 9, n % 9
if board[x][y] != '.':
dfs(board, n+1, ret)

for i in range(9):
c = str(i+1)
board[x][y] = c
dfs(board, n+1, ret)
board[x][y] = '.'

这段代码非常简单,没什么难的,只不过要在最后递归结束的时候判断一下棋盘是否合法,要额外写一个方法而已。但是如果你真的这么做了,妥妥的超时。原因也很简单,这么做虽然看起来用到了回溯算法,但是回溯算法本质上只是解决了遍历一个问题所有可能性的问题。我们可以算一下这道题所有摆放的可能性,一个空最多有9种放法,随着空白位置的增多,这个复杂度是一个指数级的增长,显然是一定会超时的。

到这里给大家传递一个结论,纯搜索或者是回溯算法本质就是暴力枚举,只不过是高级一点的枚举。

优化

既然这样做不行,那么就要想想怎么办才可以。这道题并没有给我们多少操作的空间,无论如何我们总是要试着去摆放的,我们也不可能设计出一个算法来能够开天眼,不用枚举就算得出来每一个位置应该填什么。所以回溯法是一定要用的,只是我们用的太简单粗暴了,所以不行。

于是,我们进入了一个很大的问题——搜索优化

这真的是一个很大的问题,在搜索问题上有各种各样千奇百怪的优化方法,包括不仅限于各种各样的剪枝技巧、A*, IDA*等启发式搜索、蚁群算法、遗传算法等智能算法……不过好在这些方法当中的许多并不是普适的,需要我们结合问题的实际去寻找适合的优化方法,有时候还需要一点运气。

比如我曾经听学长讲过一个故事,之前他在比赛的时候有一次他被一道搜索题卡住了。他把所有想到的优化方法都用尽了,还是超时,最后逼不得已构思了一个计算概率的方法,在每次搜索的时候只选择概率最大的分支,其余的分支全部剪掉。这显然不太合理,他抱着侥幸的想法提交了一下,没想到通过了。他赛后查看题解才发现这就是正解,只是这一切原本背后是有一套数学证明和分析的,但他是靠着直觉猜测出的结论,以至于觉得不可思议。

剪枝

扯远了,我们回到正题。面临搜索问题的优化,最常用的方法还是剪枝。剪枝这个词很形象,因为我们搜索的时候背后逻辑上其实是一棵树形的搜索树。而剪枝就是在做决策的时候,提前判断一些不可能存在解的分支给剪掉。

从上图我们可以看出来,剪枝发生的位置越接近上层,剪掉的搜索子树就越大,节省的资源也就越多,效果也就越好。

但是实际问题当中,往往越上层的信息越少,剪枝条件也就越难触发

剪枝只有核心思想,就是减少当下做出的决策,但是没有固定的套路,需要我们自己构思。同样的问题,不同的剪枝方案得到的结果可能大相径庭。好的剪枝方案一般都基于对问题的深入理解和思考。

我们稍微想一下,就可以想到一个很简单的思路,即把检查是否合法的方法从递归结束之后挪到放置之前。

def dfs(board, n, ret):
if n == 81:
ret = board.copy()
return

x, y = n / 9, n % 9
if board[x][y] != '.':
dfs(board, n+1, ret)

for i in range(9):
c = str(i+1)
# 判断棋盘是否合法
if validateBoard(board):
board[x][y] = c
dfs(board, n+1, ret)
board[x][y] = '.'

这也是常用的做法,对于当下已经出现重复的数字,我们没必要再放一下试试看了,因为已经不可能构成合法解了。

如果你能想到这点,说明你对剪枝的理解已经入门了。但是很遗憾,如果你真这么干了,还是会超时。

原因也很简单,因为我们判断棋盘是否合法需要遍历整个棋盘,会带来大量的开销。因为for循环当中的每一个决策,我们都需要判断一次合法情况。所以这个剪枝判断带来的代价是随着搜索的次数一直增加的。

这也是剪枝的另一个问题,即剪枝的判断条件很多时候都是有代价的。随着剪枝条件复杂性的增加,带来的开销也会增加。甚至可能出现剪枝了还不如不剪的情况发生。

降低剪枝的开销

解决的方法也很简单,既然我们剪枝的使用过程中带来的开销很大,我们第一想法就是降低这个开销。

在这个问题当中,我们基于常规的思路去判断整体是否合法,而判断整体合法显然需要遍历整个board。但问题是我们做了许多无用功,因为board上可能会引起非法的数字只有当前放置的这个,之前的摆放的位置都经过校验,显然都是合法的。我们没必要判断那么多,只需要判断当前的数字是否会引起新的非法就可以了。

也就是说我们把判断的标准从整体细化到了局部,这么做能成立的条件有两个,第一个是题目当中保证了数独一定有解,也就是在我们搜索开始之前的起始状态一定是合法的。第二点是,我们每一个合法的状态可以累加,而不会出现意外。也就是说,有可能前面的选择不合理导致后面没有数字可以选的情况出现,但是不可能出现前面的摆放都合法,突然到后面变得非法了。

如果能想通了以上两点,那么我们自然能做出这个结论:即我们不需要判断board,只需要判断当前待摆放的数字,这个做法是合理并且可行的。

剩下的问题就是我们怎么快速地判断当前选择的数字放在此处是否合法呢?

到这里,相信大家应该不难想到,原理也很简单,因为题目当中说了我们需要保证每行、每列每个方块当中的1-9只出现一次。所以我们用三种容器分别存储每行、每列每个方块当中1-9出现的次数即可。

具体来看代码:

class Solution:

# 全局变量,存储每行、每列和每个block当中放置的数字的数量
# 用数组会比dict更快
rowDict = [[0 for _ in range(10)] for _ in range(10)]
colDict = [[0 for _ in range(10)] for _ in range(10)]
blockDict = [[0 for _ in range(10)] for _ in range(10)]

def dfs(self, cur, bd, board):
if cur == 81:
# 拼装答案
for i in range(9):
for j in range(9):
board[i][j] = chr(ord('0') + bd[i][j])
return

x, y = cur // 9, cur % 9
# 如果原本就有数字,直接跳过
if bd[x][y] != 0:
self.dfs(cur+1, bd, board)
return

for i in range(1, 10):
# 如果在行或者列或者block中出现过,那么当下不能放入
blockId = (x // 3) * 3 + y // 3
if Solution.rowDict[x][i] > 0 or Solution.colDict[y][i] > 0 or Solution.blockDict[blockId][i] > 0:
continue

# 更新容器
bd[x][y] = i
Solution.rowDict[x][i] += 1
Solution.colDict[y][i] += 1
Solution.blockDict[blockId][i] += 1
# 往下递归
self.dfs(cur+1, bd, board)
# 回溯之后还原
bd[x][y] = 0
Solution.rowDict[x][i] -= 1
Solution.colDict[y][i] -= 1
Solution.blockDict[blockId][i] -= 1

def solveSudoku(self, board: List[List[str]]) -> None:
"""
Do not return anything, modify board in-place instead.
"""


bd = [[0 for _ in range(9)] for _ in range(9)]

for i in range(9):
for j in range(9):
if board[i][j] != '.':
# 将字符串转成数字
bd[i][j] = ord(board[i][j]) - ord('0')
# 将已经填好的数字插入我们的容器当中
Solution.rowDict[i][bd[i][j]] += 1
Solution.colDict[j][bd[i][j]] += 1
# 计算一下在哪个block当中
blockId = (i // 3) * 3 + j // 3
Solution.blockDict[blockId][bd[i][j]] += 1

self.dfs(0, bd, board)

这段代码不算短,除了回溯之外还涉及到了基础的剪枝的分析,比无脑的回溯搜索复杂了一些。对这道题深入思考,可以加深对搜索问题的理解。而搜索算法是非常重要的算法之一,许多问题的本质都可以蜕化成搜索问题,因此对搜索算法能力的提升是非常必要的。

今天的文章就是这些,希望大家都能把这题吃透。我们下周LeetCode专题再见。

如果觉得有所收获,请顺手点个关注吧,你们的举手之劳对我来说很重要。