八皇后问题,是一个古老而著名的问题,是回溯算法的典型案例。该问题是国际西洋棋棋手马克斯·贝瑟尔于1848年提出:在8×8格的国际象棋上摆放八个皇后,使其不能互相攻击,即任意两个皇后都不能处于同一行、同一列或同一斜线上,问有多少种摆法。
回溯法的原理用伪代码表示如下:
// 回溯法原理
QueenBacktrace{
while(i < 8){
检查当前行i,从当前列j开始向后找到一个可以放置的安全列号
if 找到安全的列号{
放置皇后,将列号入栈,进入下一行i++
j = 0; // 下一行将从第0列开始搜索
}
else 找不到安全的列号 回溯{
弹栈回到上一行i--
移去放置的皇后,弹栈出来的j,j++继续向后搜索
}
}
}
然而当八皇后变成n皇后时,回溯法将爆炸性增长。这里将使用Las Vegas的概率算法与回溯法组合求解。
首先是贪心的Las Vegas算法原理:
-
遍历n行,每一行尝试随机放置一个可以放置的位置,如可放置皇后的位置有{1,3,5,7},从其中随机挑一个
-
每一次判断是否可以放置使用三种不冲突判断:
- 无列冲突,任意两个行的列号不同
- 无45°对角线冲突,即任意两个皇后(x1,y1) (x2,y2),x1-y1 != x2-y2
- 无135°对角线冲突,即任意两个皇后(x1,y1) (x2,y2),x1+y1 != x2+y2
-
在第stepVegas行开始,进入回溯法求解
如此是在前stepVegas行的皇后是概率的选取皇后位置,而第stepVegas行后是有回溯法确定的选择。
伪代码如下:
LVQueens(){
col[8] // 存每行的皇后放置的列位置
diag45 // 存已放置皇后的x-y值
diag135 // 存以防止的皇后x+y值
k = 0 // 行号
repeat{
nb = 0 // 计数器
for i <- 0 to 8{ // 遍历8列
if(i 不属于 col and (i-k) 不属于 diag45 and (i+k) 不属于diag135){
nb += 1
if random(1,nb)随机数从1~nb中选择出的值==1{
j = i // 注意这里第一次可行的列号random一定选中了j = i
}
}
}
if(nb >0){ // 说明找到过可以放置的解
# 此时第k行的皇后将放置在j位置上
col.append(j)
diag45.append(j-k)
diag135.append(j+k)
k += 1
}
if(nb == 0) return false // 说明没找到解
if(k == stepVegas): // 在第stepVegas行开始使用回溯法求解
return QueenBacktrace(n, k, col, diag45, diag135, ans, count)
}
}
回溯法和LV算法对比:
- 回溯法是一定有精确解的
- LV算法不一定有结果,可能是成功,也有概率失败
- LV组合回溯算法也是有可能成功,有可能失败
为了衡量不同stepVegas从哪一行开始进行回溯效果最好,引入一个遍历节点消耗的概念:即放置一次皇后作为一次搜索的节点。显然LV算法中每一层只会放置一次皇后,而回溯法因为失败回溯会放置导致进入过错误的节点而节点数更多。
测试时候重复如repeat = 1000次,然后查看这1000次成功的概率为p,成功时候遍历节点的平均数为s,失败遍历的节点数平均为e,那么总的平均次数t为: t = s + ( 1 − p ) e p t = s+ \frac{(1-p)e}p t=s+p(1−p)e
由以上的算法去测试结果如下
n = 8, stepVegas = 0 : p = 1.0 s = 114.0 e = 0 t = 114.0
n = 8, stepVegas = 1 : p = 1.0 s = 40.042 e = 0 t = 40.042
n = 8, stepVegas = 2 : p = 0.858 s = 22.884615384615383 e = 39.556338028169016 t = 29.43123543123543
n = 8, stepVegas = 3 : p = 0.518 s = 13.432432432432432 e = 15.236514522821576 t = 27.610038610038607
n = 8, stepVegas = 4 : p = 0.247 s = 10.48582995951417 e = 8.731739707835326 t = 37.10526315789474
n = 8, stepVegas = 5 : p = 0.157 s = 9.178343949044587 e = 7.290628706998814 t = 48.32484076433121
n = 8, stepVegas = 6 : p = 0.145 s = 9.027586206896551 e = 6.990643274853801 t = 50.248275862068965
n = 8, stepVegas = 7 : p = 0.141 s = 9.0 e = 6.933643771827707 t = 51.241134751773046
n = 8, stepVegas = 8 : p = 0.127 s = 9.0 e = 6.953035509736541 t = 56.79527559055118
n = 8, bestStepVegas = 3
所以看到当bestStepVegas =3时候效果最好,遍历的节点数最小,比单纯回溯法114次遍历节点数降低到27.6。
这里给出对于n皇后的组合算法求解n=8~20的对比,如下代码:
import random
random.seed(0)
def QueenBacktrace(n, k, col, diag45, diag135, ans, count):
"""
从第k行开始使用回溯法求n皇后
"""
i = k # 行
j = 0 # 列
while(i < n and i >= k):
# 找到当前行i的一个安全列,除非找不到了
if(j < n):
if (j not in col) and ((j-i) not in diag45) and ((j+i) not in diag135):
count += 1
# 找到了安全列,放置皇后,入栈,进入下一行
ans[i] = j # 入栈结果
col.append(j)
diag45.append(j-i)
diag135.append(j+i)
i += 1 # 下一行
j = 0
else:
j += 1
# 找不到安全的列号,弹栈回溯
else:
i -= 1
j = ans[i] + 1
col.pop()
diag135.pop()
diag45.pop()
if(i == n):
return True,ans,count
else:
return False,ans,count
# for j in range(n):
# # 找到起始第k行的安全的行号进入回溯法
# if (j not in col) and ((j-k+1) not in diag45) and ((j+k+1) not in diag135):
# for i in range(k+1, n):
def QueensLV(n, stepVegas):
""" LV 混合 回溯的Queen算法 """
count = 1 # 计数通过的节点数
col = [] # 已使用列集合
diag45 = [] # 45度对角线冲突集合
diag135 = [] # 135度对角线冲突集合
ans = [0 for i in range(n)] # 皇后放在第index行的第ans[index]列
k = 0 # 行号
j = 0 # 选择的安全行号
while(True):
nb = 0 # 计数器
if(k == stepVegas):
break
for i in range(n):
# 循环随机找一个开放的位置
if (i not in col) and ((i-k) not in diag45) and ((i+k) not in diag135):
nb += 1
if(random.randint(1,nb) == 1):
j = i
if(nb > 0):
# 此时第k+1的皇后将放置在j位置上
count+=1
ans[k] = j
col.append(j)
diag45.append(j-k)
diag135.append(j+k)
k += 1
if(nb == 0):
return (False,ans,count)
# 回溯法继续执行
return QueenBacktrace(n, k, col, diag45, diag135, ans, count)
repeat = 1000
for i in range(8,21):
minT = 100000000000 # 找到最小的node次数
bestStepVegas = 0
for stepVegas in range(0,i+1):
s = 0
s_count = 0
e = 0
e_count = 0
for j in range(repeat):
flag,ans,count = QueensLV(i,stepVegas)
if(flag):
s += count
s_count += 1
else:
e += count
e_count += 1
if(s_count == 0):
s = 0
else:
s = 1.0*s/s_count
if(e_count == 0):
e = 0
else:
e = 1.0*e/e_count
p = s_count/repeat
t = s + e*e_count/s_count
print("n = "+str(i)+", stepVegas = "+str(stepVegas)+ " : p = "+str(p) +" s = "+str(s)+ " e = "+str(e)+" t = "+str(t))
if(minT > t):
minT = t
bestStepVegas = stepVegas
print("n = "+str(i)+", bestStepVegas = "+str(bestStepVegas) + "\n")