1、基本概念
LRTA* 算法是对RTA* 算法的改进,在RTA* 的原论文中,提到了:
Unfortunately, while RTA* as described above is ideally suited to single problem solving trials, it must be modified to accommodate multi-trial learning. The reason is that the algorithm records the second best estimate in the previous state, which represents an accurate estimate of that state looking back from the perspective of the next state.However, if the best estimate turns out to be correct, then storing the second best value can result in inflated values for some states. These inflated values will direct the next agents in the wrong direction on subsequent problem solving trials.
即:
RTA* 算法记录了前一状态下的第二个最佳估计,这代表了从下一状态的角度回顾该状态的准确估计。然而,如果最佳估计被证明是正确的,那么存储第二个最好的值可能会导致某些状态的值被夸大。这些夸大的价值观将引导下一代在随后的问题解决试验中走向错误的方向。
LRTA* 通过记录最优解而不是第二最优解解决了这个问题。LRTA* 算法的核心是对节点的启发式值的更新,启发式值的更新使用启发式函数:
H
(
s
)
=
g
(
s
,
s
′
)
+
H
(
s
′
)
H(s)=g(s,s^′)+H(s^′)
H(s)=g(s,s′)+H(s′)
其中H(s)表示前一节点的启发式值,g(s, s’)表示从节点s到s‘的代价,H(s’)表示当前节点的启发式值。LRTA*的完整算法如下:
h(s) 的设计,是为了防止之前A* 算法在搜索时陷入局部最小值,在LRTA* 搜索中,如果陷入了局部最小值,算法会根据访问附近节点的次数增加h(s),即增加总成本f(s),从而经过h(s)的多次叠加后,跳出局部最小值。
h(s) 的更新是这样的:
例如从 s 节点到 s’ 节点,假设路程g(s, s’) = 1,已知 h(s’) = 2(这个初始的时候就是已知的,代表了从该节点到目标节点的距离),那么从 s 节点访问 s’ 节点后,是对s节点的h(s)进行更新的,h(s) = g(s) + h(s’) = 1 + 2 = 3。
另外,如果想从 s’ 再回到 s节点的话,那么就要对 h(s’) 更新,h(s’) = g(s’, s) + h(s) = 1 + 3 = 4。
该算法的精髓就在 h(s) 的更新上,理解了 h(s) 的更新,LRTA* 算法就基本理解了。
具体例子可以看参考文档里面的第一篇,这里就不细述了。
2、代码示例:
import os
import sys
import math
import copy
import heapq
import matplotlib.pyplot as plt
class LRTAStar:
"""AStar set the cost + heuristics as the priority
"""
def __init__(self, s_start, s_goal, heuristic_type,xI, xG):
self.s_start = s_start
self.s_goal = s_goal
self.heuristic_type = heuristic_type
self.u_set = [(-1, 0), (-1, 1), (0, 1), (1, 1),
(1, 0), (1, -1), (0, -1), (-1, -1)] # feasible input set
self.obs = self.obs_map() # position of obstacles
self.OPEN = dict() # priority queue / OPEN set
self.CLOSED = [] # CLOSED set / VISITED order
self.PARENT = dict() # recorded parent
self.h = dict()
self.g = dict() # cost to come
self.x_range = 51 # size of background
self.y_range = 31
self.xI, self.xG = xI, xG
self.obs = self.obs_map()
def update_obs(self, obs):
self.obs = obs
def animation(self, path, visited, name):
self.plot_grid(name)
self.plot_visited(visited)
self.plot_path(path)
plt.show()
def plot_grid(self, name):
obs_x = [x[0] for x in self.obs]
obs_y = [x[1] for x in self.obs]
plt.plot(self.xI[0], self.xI[1], "bs")
plt.plot(self.xG[0], self.xG[1], "gs")
plt.plot(obs_x, obs_y, "sk")
plt.title(name)
plt.axis("equal")
def plot_visited(self, visited, cl='gray'):
if self.xI in visited:
visited.remove(self.xI)
if self.xG in visited:
visited.remove(self.xG)
count = 0
for x in visited:
count += 1
plt.plot(x[0], x[1], color=cl, marker='o')
plt.gcf().canvas.mpl_connect('key_release_event',
lambda event: [exit(0) if event.key == 'escape' else None])
if count < len(visited) / 3:
length = 20
elif count < len(visited) * 2 / 3:
length = 30
else:
length = 40
#
# length = 15
if count % length == 0:
plt.pause(0.001)
plt.pause(0.01)
def plot_path(self, path, cl='r', flag=False):
path_x = [path[i][0] for i in range(len(path))]
path_y = [path[i][1] for i in range(len(path))]
if not flag:
plt.plot(path_x, path_y, linewidth='3', color='r')
else:
plt.plot(path_x, path_y, linewidth='3', color=cl)
plt.plot(self.xI[0], self.xI[1], "bs")
plt.plot(self.xG[0], self.xG[1], "gs")
plt.pause(0.01)
def update_obs(self, obs):
self.obs = obs
def obs_map(self):
"""
Initialize obstacles' positions
:return: map of obstacles
"""
x = 51
y = 31
obs = set()
for i in range(x):
obs.add((i, 0))
for i in range(x):
obs.add((i, y - 1))
for i in range(y):
obs.add((0, i))
for i in range(y):
obs.add((x - 1, i))
for i in range(10, 21):
obs.add((i, 15))
for i in range(15):
obs.add((20, i))
for i in range(15, 30):
obs.add((30, i))
for i in range(16):
obs.add((40, i))
return obs
def searching(self):
"""
A_star Searching.
:return: path, visited order
"""
self.PARENT[self.s_start] = self.s_start
self.g[self.s_start] = 0
self.g[self.s_goal] = math.inf
self.OPEN[self.s_start] = self.f_value(self.s_start,self.s_start)
count = 1
while self.OPEN:
s = min(self.OPEN, key=self.OPEN.get)
print(count)
self.OPEN.pop(s)
self.CLOSED.append(s)
count += 1
if s == self.s_goal: # stop condition
break
new_h = math.inf
for s_n in self.get_neighbor(s):
new_cost = self.g[s] + self.cost(s, s_n)
new_s_n = self.heuristic2(s,s_n) + self.heuristic(s_n)
if new_s_n < new_h:
new_h = copy.deepcopy(new_s_n);
if s_n not in self.g or new_cost < self.g[s_n]:
self.g[s_n] = new_cost
self.PARENT[s_n] = s
self.OPEN[s_n] = self.f_value(s,s_n)
if new_s_n > self.h[s]:
self.h[s] = new_s_n
#print(self.CLOSED)
return self.extract_path(self.PARENT), self.CLOSED
def get_neighbor(self, s):
"""
find neighbors of state s that not in obstacles.
:param s: state
:return: neighbors
"""
return [(s[0] + u[0], s[1] + u[1]) for u in self.u_set]
def cost(self, s_start, s_goal):
"""
Calculate Cost for this motion
:param s_start: starting node
:param s_goal: end node
:return: Cost for this motion
:note: Cost function could be more complicate!
"""
if self.is_collision(s_start, s_goal):
return math.inf
return math.hypot(s_goal[0] - s_start[0], s_goal[1] - s_start[1])
def is_collision(self, s_start, s_end):
"""
check if the line segment (s_start, s_end) is collision.
:param s_start: start node
:param s_end: end node
:return: True: is collision / False: not collision
"""
if s_start in self.obs or s_end in self.obs:
return True
if s_start[0] != s_end[0] and s_start[1] != s_end[1]:
if s_end[0] - s_start[0] == s_start[1] - s_end[1]:
s1 = (min(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
s2 = (max(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
else:
s1 = (min(s_start[0], s_end[0]), max(s_start[1], s_end[1]))
s2 = (max(s_start[0], s_end[0]), min(s_start[1], s_end[1]))
if s1 in self.obs or s2 in self.obs:
return True
return False
def f_value(self,s, s_n):
"""
f = g + h. (g: Cost to come, h: heuristic value)
:param s: current state
:return: f
"""
if s_n in self.h:
return self.g[s_n] + self.h[s_n]
else:
self.h[s_n] = self.heuristic(s_n)
return self.g[s_n] + self.heuristic(s_n)
def extract_path(self, PARENT):
"""
Extract the path based on the PARENT set.
:return: The planning path
"""
path = [self.s_goal]
s = self.s_goal
while True:
s = PARENT[s]
path.append(s)
if s == self.s_start:
break
return list(path)
def heuristic2(self, s1,s2):
heuristic_type = self.heuristic_type # heuristic type
goal = self.s_goal # goal node
if heuristic_type == "manhattan":
return abs(s1[0] - s2[0]) + abs(s1[1] - s2[1])
else:#sqrt(x^2+y^2)
return math.hypot(s1[0] - s2[0], s1[1] - s2[1])
def heuristic(self, s):
"""
Calculate heuristic.
:param s: current node (state)
:return: heuristic function value
"""
heuristic_type = self.heuristic_type # heuristic type
goal = self.s_goal # goal node
if heuristic_type == "manhattan":
return abs(goal[0] - s[0]) + abs(goal[1] - s[1])
else:#sqrt(x^2+y^2)
return math.hypot(goal[0] - s[0], goal[1] - s[1])
def main():
s_start = (5, 5)
s_goal = (45, 25)
astar = LRTAStar(s_start, s_goal, "euclidean",s_start,s_goal)
path, visited = astar.searching()
astar.animation(path, visited, "LRTA*") # animation
if __name__ == '__main__':
main()
上述代码运行后执行的循环次数是和A* 的那个是一样的,这里有点奇怪,不知道是我理解的有问题还是确实它起到的作用比较有限,因为LRTA* 的作用主要是快速的跳出局部最优解的问题,但是这里可能没有出现这方面的问题所以其实也就没有起到优化的效果,总的来说,相当于了解一个算法思路,但是作用似乎很有限。
参考:
1、《[PR] LRTA* 搜索算法》
2、《[AI] LRTA*搜索算法及其扩展算法》