Sarsa算法讲解及实现
1. Q表格
我们使用表格来存储每一个状态 state, 和在这个 state 每个行为 action 所拥有的 Q 值。
Q即为Q(s,a)就是在某一时刻的 s 状态下(s∈S),采取动作a (a∈A)动作能够获得收益的期望,环境会根据agent的动作反馈相应的回报reward r,所以算法的主要思想就是将State与Action构建成一张Q-table来存储Q值,然后根据Q值来选取能够获得最大的收益的动作。
例子:
Q-Table | a1 | a2 |
---|---|---|
s1 | q(s1,a1) | q(s1,a2) |
s2 | q(s2,a1) | q(s2,a2) |
s3 | q(s3,a1) | q(s3,a2) |
2. Sarsa算法讲解
在强化学习中,Sarsa是一种对Q表格进行更新的算法,由于在强化学习环境最开始的时候,也可以认为是游戏刚开始的时候,Q表格是随机初始化的,所以需要在智能体不断与环境进行交互的时候不断地更新Q表格。
Sarsa表示的是State-Action-Reward-State-Action,是一个学习马尔可夫决策过程策略的算法,通常应用于机器学习和强化学习学习领域中。
State-Action-Reward-State-Action:这个名称清楚地反应了其学习更新函数依赖的5个值:分别是当前状态S1,当前状态选中的动作A1,获得的奖励Reward,S1状态下执行A1后取得的状态S2及S2状态下将会执行的动作A2。我们取这5个值的首字母串起来可以得出一个词SARSA。
Sarsa算法的更新公式:
Q ( s t , a t ) ← Q ( s t , a t ) + α [ r t + γ Q ( s t + 1 , a t + 1 ) − Q ( s t , a t ) ] Q(s_{\boldsymbol{t}},a_{t})\leftarrow Q(s_{\boldsymbol{t}},a_{t})+\alpha [r_{\boldsymbol{t}}+\gamma Q(s_{t+1},a_{\boldsymbol{t}+1})-Q(s_{\boldsymbol{t}},a_{t})] Q(st,at)←Q(st,at)+α[rt+γQ(st+1,at+1)−Q(st,at)]
Sarsa算法伪代码:
算法中各个参数的意义:
-
- alpha是学习率, 来决定这次的误差有多少是要被学习的, alpha是一个小于1 的数.
-
- gamma 是对未来 reward 的衰减值. 我们可以这样想象.
-
- Q表示的是Q表格.
-
- Epsilon greedy 是用在决策上的一种策略, 比如 epsilon = 0.9 时, 就说明有90% 的情况我会按照 Q 表的最优值选择行为, 10% 的时间使用随机选行为. 【这也是结合了强化学习中探索和利用的概念】
3. 代码
# agent.py
import numpy as np
class SarsaAgent(object):
def __init__(self,
obs_n,
act_n,
learning_rate=0.01,
gamma=0.9,
e_greedy=0.1):
self.obs_n = obs_n # 状态维度
self.act_n = act_n # 动作维度
self.learning_rate = learning_rate # 学习率
self.gamma = gamma # 奖励衰减率
self.e_greedy = e_greedy # 按一定概率随机选动作
self.Q = np.zeros((obs_n, act_n)) # Q表格 todo:嵌套一层有什么作用?
def sample(self, obs):
if np.random.sample() < (1 - self.e_greedy): # 强化概念 #根据table的Q值选动作
return self.predict(obs)
else:
# 随机选择一个
return np.random.choice(self.act_n)
def predict(self, obs):
# 进行预测,直接选择Q值最高的那个动作
# 拉出该状态的那一行动作
Q_list = self.Q[obs]
maxQ = np.max(Q_list)
action_list = np.where(Q_list == maxQ)[0]
# 防止有多个最大值,所以随机选择一个
return np.random.choice(action_list)
def learn(self, obs, act, reward, obs_next, act_next, done):
"""
obs: 交互前的obs, s_t
action: 本次交互选择的action, a_t
reward: 本次动作获得的奖励r
next_obs: 本次交互后的obs, s_t+1
next_action: 根据当前Q表格, 针对next_obs会选择的动作, a_t+1
done: episode是否结束
"""
# s a r s a
if done:
target = reward
else:
target = reward + self.gamma * self.Q[obs_next][act_next]
self.Q[obs][act] += self.learning_rate * (target - self.Q[obs][act])
def save(self):
npy_file = './q_table.npy'
np.save(npy_file, self.Q)
print(npy_file + ' saved.')
def restore(self, npy_file='./q_table.npy'):
self.Q = np.load(npy_file)
print(npy_file + ' loaded.')
# train.py
import gym
from agent import SarsaAgent
import time
"""
相当于是跑一个回合呗
"""
def run_episode(env, agent, render=False):
total_steps = 0
total_reward = 0
obs = env.reset()
action = agent.sample(obs)
while True:
# 采取动作,获取环境的反馈值
next_obs, reward, done, _ = env.step(action)
next_action = agent.sample(next_obs)
# 训练sarsa算法
agent.learn(obs, action, reward, next_obs, next_action, done)
action = next_action
obs = next_obs
total_reward += reward
total_steps += 1
if render:
env.render() # 渲染新的一帧图形
if done:
break
return total_reward, total_steps
"""
用于在训练好Q表格后对其进行测试
"""
def test_episode(env, agent):
total_reward = 0
obs = env.reset()
while True:
action = agent.predict(obs)
next_obs, reward, done, _ = env.step(action)
total_reward += reward
obs = next_obs
time.sleep(0.5)
env.render()
if done:
print('test reward = %.1f' % total_reward)
break
def main():
env = gym.make("CliffWalking-v0") # 0 up, 1 right, 2 down, 3 left
agent = SarsaAgent(obs_n=env.observation_space.n,
act_n=env.action_space.n,
learning_rate=0.1,
gamma=0.9,
e_greedy=0.1)
is_render = False
for episode in range(500):
ep_reward, ep_steps = run_episode(env, agent, is_render)
# 每隔20个episode渲染一下看看效果
if episode % 20 == 0:
print('it is in ' + str(episode) + 'round')
is_render = True
else:
is_render = False
# 训练结束,查看算法效果
test_episode(env, agent)
if __name__ == '__main__':
main()