GitHub - takuseno/d3rlpy: An offline deep reinforcement learning library
d3rlpy,离线强化学习算法库
我装在windows下用anaconda,按照官网教程
conda install -c conda-forge d3rlpy
第一次安装报错CondaSSLError: OpenSSL appears to be unavailable on this machine
[报错解决]CondaSSLError: OpenSSL appears to be unavailable on this machine. OpenSSL is required to downl_一件迷途小书童的博客-CSDN博客
参考这篇文章解决后正常安装没问题,值得注意的是d3rkpy安装时包含cudatoolkit11.几,我在想这个在不同电脑上可能之后会出错,不过后面运行算法时可以选择是否使用GPU
我是打算用离线强化学习算法,安装后测试,官网上也有测试代码
import d3rlpy
# prepare dataset
dataset, env = d3rlpy.datasets.get_d4rl('hopper-medium-v0')
# prepare algorithm
cql = d3rlpy.algos.CQL(use_gpu=True)
# train
cql.fit(
dataset,
eval_episodes=dataset,
n_epochs=100,
scorers={
'environment': d3rlpy.metrics.evaluate_on_environment(env),
'td_error': d3rlpy.metrics.td_error_scorer,
},
)
看得出来,这接口用起来非常方便啊
因为我没装d4rl所以肯定是失败了,d4rl数据集查了下资料可能无法装在windows环境下,有点难办。可以使用下面这个在测试,用的是d3rlpy自带用于测试的数据集,也是比较常用的两个环境,具体是在d3rlpy的文档上找到的
import d3rlpy
# prepare dataset
# dataset, env = d3rlpy.datasets.get_d4rl('CartPole-v0')
dataset, env = d3rlpy.datasets.get_pendulum("random")
# prepare algorithm
cql = d3rlpy.algos.CQL(use_gpu=True)
# train
cql.fit(
dataset,
eval_episodes=dataset,
n_epochs=100,
scorers={
'environment': d3rlpy.metrics.evaluate_on_environment(env),
'td_error': d3rlpy.metrics.td_error_scorer,
},
)
资料很充分,d3rlpy文档:d3rlpy.datasets.get_cartpole — d3rlpy documentation
成功运行:
如果失败的话可能是下载失败,
在这找到下载网址,自己下载到本地,改成规定的名字即可,放到对d3rlpy_data文件夹里,再运行时就不需要在线下载了,比如这样
之后回到d4rl,我打算把自己的数据集按照d4rl的格式来编写,但我不打算装d4rl
可以看到在d3rlpy中读取d4rl的数据集主要是用d4rl中的get_dataset函数,于是我索性把d4rl中这个函数搬到d3rlpy中,其实就是读取h5格式的函数,也挺好移植,主要也就这一段
data_dict = {}
with h5py.File(h5path, 'r') as dataset_file:
for k in tqdm(get_keys(dataset_file), desc="load datafile"):
try: # first try loading as an array
data_dict[k] = dataset_file[k][:]
except ValueError as e: # try loading as a scalar
data_dict[k] = dataset_file[k][()]
注意还需要
import h5py
from tqdm import tqdm
和
def get_keys(h5file):
keys = []
def visitor(name, item):
if isinstance(item, h5py.Dataset):
keys.append(name)
h5file.visititems(visitor)
return keys
至于原先是个类,我感觉好像也不需要,同时还是把在线改掉,直接变成一个绝对位置(这个在d4rl中也可以找到下载的网址)
h5path = "D:\xxx_project\pycharm\offline_RL\d3rlpy_data\hopper_random.hdf5"
运行成功
我考虑下一步制作自己的hdf5格式数据集,及做下自己的gym环境
甚至不能算是入门,希望没有问题,欢迎指正