1.导入必要的库环境,代码如下所示。
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
2.本悬崖漫步环境中无需提供奖励函数以及状态转移函数,而需提供一个与智能体进行交互的step()函数,该函数输入为智能体当前状态下的动作,输出为当前状态下的奖励以及智能体的下一状态,代码如下所示。
class CliffWalkingEnv:
def __init__(self,ncol,nrow,step_reward,cliff_reward):
self.ncol=ncol
self.nrow=nrow
self. x=0#记录当前智能体位置的横坐标
self.y=self.nrow-1#记录当前智能体位置的纵坐标
self.step_reward=step_reward
self.cliff_reward=cliff_reward
def step(self,action):#外部调用此函数改变当前位置
change=[[0,-1],[0,1],[-1,0],[1,0]]#定义四个动作,change[0]:上;change[1]:下;change[2]:左;change[3]:右;坐标系原点(0,0)定义在左上角
self.x=np.clip(self.x+change[action][0],0,self.ncol-1)#也可采用self.x=min(self.ncol-1,max(0,self.x+change[action][0]))
self.y=np.clip(self.y+change[action][1],0,self.nrow-1)#也可采用self.y=min(self.nrow-1,max(0,self.y+change[action][1]))
next_state=self.ncol*self.y+self.x#计算下一个状态
reward=self.step_reward
done=False
if self.y==self.nrow-1 and self.x>0:#如果当前位置在悬崖或者终点
done=True
if self.x!=self.ncol-1:#如果在悬崖
reward=self.cliff_reward
return next_state,reward,done
def reset(self):#环境重置
self.x=0
self.y=self.nrow-1
return self.y*self.ncol+self.x
3.实现Sarsa算法,维护一个Q_table()表格,本表格主要存储当前策略下所有状态动作对的价值,即所有状态下各个动作的动作价值函数,Sarsa算法与环境进行交互,采用-贪婪策略进行采样,使用时序差分算法进行Sarsa算法更新。本程序默认终止状态时所有动作的价值为0,即终止状态的动作价值函数均为0,其在初始化为0后不会在与环境交互过程中进行更新。具体代码如下所示。
class Sarsa:
""" Sarsa算法 """
def __init__(self,ncol,nrow,epsilon,alpha,gamma,n_action=4):
self.ncol=ncol
self.nrow=nrow
self.epsilon=epsilon
self.alpha=alpha
self.gamma=gamma
self.n_action=n_action
self.Q_table=np.zeros([self.ncol*self.nrow,self.n_action])#初始化Q(s,a)表格
def take_action(self,state):
if np.random.random()<self.epsilon:#如果小于epsilon,则随机选择动作
action=np.random.randint(self.n_action)
else:
action=np.argmax(self.Q_table[state])
return action
def best_action(self,state):
Q_max=np.max(self.Q_table[state])
a=[0]*self.n_action#or a=[0 for _ in range(self.n_action)]
for i in range(self.n_action):#若两动作价值一样为最大,则会记录下来
if self.Q_table[state,i]==Q_max:
a[i]=1
return a
def update(self,s0,a0,r,s1,a1):#Sarsa算法核心部分,采用时序差分算法估计动作价值函数Q,Q(s0,a0)~Q(s0,a0)+alpha(r0+gamma*Q(s1,a1)-Q(s0,a0))
TD_error=r+self.gamma*self.Q_table[s1,a1]-self.Q_table[s0,a0]
self.Q_table[s0,a0]+=self.alpha*TD_error
4.本案例悬崖漫步环境中相关参数设置如下所示,可以自行修改。
ncol=12#悬崖漫步环境中的网格环境列数
nrow=4#悬崖漫步环境中的网格环境行数
step_reward=-1#每步的即时奖励
cliff_reward=-100#悬崖的即时奖励
epsilon=0.1#epsilon-贪婪算法的探索因子
alpha=0.1 #价值估计更新的步长
gamma=0.9 #回报计算的折扣衰减因子
n_action=4 #动作个数
5.主程序实现部分如下所示。
env=CliffWalkingEnv(ncol,nrow,step_reward,cliff_reward)
agent=Sarsa(ncol,nrow,epsilon,alpha,gamma,n_action)
num_episode=500#智能体在环境中运行的序列的数量
episode_Gt_list=[]#记录每个序列的回报
pbar_num=10#进度条的数量
for i in range(pbar_num):#显示每个进度条
with tqdm(total=num_episode/pbar_num,desc='Episode %d'%i) as pbar:
for episode in range(int(num_episode/pbar_num)):#每个进度条的序列数量
number=1#记录每个序列的步数
episode_Gt=0
state=env.reset()
action=agent.take_action(state)
done=False
while not done:
next_state,reward,done=env.step(action)
next_action=agent.take_action(next_state)
episode_Gt+=reward #回报的计算不进行折扣因子衰减,考虑远期最优。
agent.update(state,action,reward,next_state,next_action)
state=next_state
action=next_action
number+=1
episode_Gt_list.append(episode_Gt)
if (episode+1)%10==0:#每10条序列打印一下这十条序列的平均回报
pbar.set_postfix({'episode':'%d' %((num_episode/pbar_num)*i+episode+1),'return':'%.3f'%np.mean(episode_Gt_list[-10:],axis=None)})
pbar.update(1)
print('\n')
episodes_list=list(range(len(episode_Gt_list)))
plt.plot(episodes_list,episode_Gt_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Sarsa on {}'.format('Cliff Walking'))#or "plt.title('Sarsa on %s'%'Cliff Walking') "
plt.show()
6.结果如图所示。
Episode 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 276.35it/s, episode=50, return=-119.100]
Episode 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 345.86it/s, episode=100, return=-73.600]
Episode 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 421.77it/s, episode=150, return=-33.200]
Episode 3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 586.43it/s, episode=200, return=-38.700]
Episode 4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 668.64it/s, episode=250, return=-27.600]
Episode 5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 787.75it/s, episode=300, return=-19.800]
Episode 6: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 891.04it/s, episode=350, return=-19.200]
Episode 7: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 875.73it/s, episode=400, return=-21.100]
Episode 8: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 844.25it/s, episode=450, return=-29.200]
Episode 9: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 904.60it/s, episode=500, return=-17.800]
我们可以发现随着训练的进行,Sarsa算法获得的回报越来越高,在进行500条序列的学习后,可以获得-20左右的回报,此时非常接近最优策略了。
7.如下所示的程序是实现查看Sarsa算法得到的策略在各状态下会使智能体采取何种动作的功能。
def print_agent(agent,env,action_meaning,disaster=[],end=[]):
for i in range(env.nrow):
for j in range(env.ncol):
if (i*env.ncol+j) in disaster:
print('****',end=' ')
elif (i*env.ncol+j) in end:
print('EEEE',end=' ')
else:
a=agent.best_action(i*env.ncol+j)
pi_str=''
for k in range(len(action_meaning)):
pi_str+=action_meaning[k] if a[k]>0 else 'o'
print(pi_str,end=' ')
print()
action_meaning=['^','v','<','>']
print('Sarsa算法最终收敛得到的策略为:')
print_agent(agent,env,action_meaning,disaster=[range(37,47)],end=[47])
8.结果如下所示:
ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo
^ooo ^ooo ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo ooo> ovoo
^ooo oo<o oo<o ^ooo ^ooo oo<o ^ooo ooo> oo<o ooo> ooo> ovoo
^ooo **** **** **** **** **** **** **** **** **** **** EEEE
可以发现Sarsa算法会采取比较远离悬崖的策略来抵达目标。