DQN 玩 2048 实战|第三期!优化网络,使用GPU、Env奖励优化

news2025/3/18 7:22:03

视频讲解:

DQN 玩 2048 实战|第三期!优化网络,使用GPU、Env奖励优化

1. 仅考虑局部合并奖励:目前的奖励只设置为合并方块时获得的分数,只关注了每一步的即时合并收益,而没有对最终达成 2048 这个目标给予额外的激励,如果没有对达成 2048 给予足够的奖励信号,Agent 可能不会将其作为一个重要的目标

2. 训练硬件资源利用不高,没有使用GPU进行加速,默认为CPU,较慢

代码修改如下:

step函数里面,输入维度增加max_tile最大的数是多少

if 2048 in self.board:
    reward += 10000
    done = True
state = self.board.flatten()
max_tile = np.max(self.board)
state = np.append(state, max_tile)
return state, reward, done
input_size = 17

检查系统中是否有可用的 GPU,如果有则使用 GPU 进行计算,否则使用 CPU。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

在 train ,创建模型实例后,使用 .to(device) 将模型移动到指定的设备(GPU 或 CPU)

model = DQN(input_size, output_size).to(device)
target_model = DQN(input_size, output_size).to(device)

在训练和推理过程中,将输入数据(状态、动作、奖励等)也移动到指定的设备上。

state = torch.FloatTensor(state).unsqueeze(0).to(device)

next_state = torch.FloatTensor(next_state).unsqueeze(0).to(device)

states = torch.FloatTensor(states).to(device)
actions = torch.LongTensor(actions).to(device)
rewards = torch.FloatTensor(rewards).to(device)
next_states = torch.FloatTensor(next_states).to(device)
dones = torch.FloatTensor(dones).to(device)

将 state 和 next_state 先使用 .cpu() 方法移动到 CPU 上,再使用 .numpy() 方法转换为 NumPy 数组

replay_buffer.add(state.cpu().squeeze(0).numpy(), action, reward, next_state.cpu().squeeze(0).numpy(), done)

这个不改的话,会出现 TypeError: can't convert cuda:0 device type tensor to numpy 错误

完整代码如下:

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.table import Table

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2048 游戏环境类
class Game2048:
    def __init__(self):
        self.board = np.zeros((4, 4), dtype=int)
        self.add_random_tile()
        self.add_random_tile()

    def add_random_tile(self):
        empty_cells = np.argwhere(self.board == 0)
        if len(empty_cells) > 0:
            index = random.choice(empty_cells)
            self.board[index[0], index[1]] = 2 if random.random() < 0.9 else 4

    def move_left(self):
        reward = 0
        new_board = np.copy(self.board)
        for row in range(4):
            line = new_board[row]
            non_zero = line[line != 0]
            merged = []
            i = 0
            while i < len(non_zero):
                if i + 1 < len(non_zero) and non_zero[i] == non_zero[i + 1]:
                    merged.append(2 * non_zero[i])
                    reward += 2 * non_zero[i]
                    i += 2
                else:
                    merged.append(non_zero[i])
                    i += 1
            new_board[row] = np.pad(merged, (0, 4 - len(merged)), 'constant')
        if not np.array_equal(new_board, self.board):
            self.board = new_board
            self.add_random_tile()
        return reward

    def move_right(self):
        self.board = np.fliplr(self.board)
        reward = self.move_left()
        self.board = np.fliplr(self.board)
        return reward

    def move_up(self):
        self.board = self.board.T
        reward = self.move_left()
        self.board = self.board.T
        return reward

    def move_down(self):
        self.board = self.board.T
        reward = self.move_right()
        self.board = self.board.T
        return reward

    def step(self, action):
        if action == 0:
            reward = self.move_left()
        elif action == 1:
            reward = self.move_right()
        elif action == 2:
            reward = self.move_up()
        elif action == 3:
            reward = self.move_down()
        done = not np.any(self.board == 0) and all([
            np.all(self.board[:, i] != self.board[:, i + 1]) for i in range(3)
        ]) and all([
            np.all(self.board[i, :] != self.board[i + 1, :]) for i in range(3)
        ])
        if 2048 in self.board:
            reward += 10000
            done = True
        state = self.board.flatten()
        max_tile = np.max(self.board)
        state = np.append(state, max_tile)
        return state, reward, done

    def reset(self):
        self.board = np.zeros((4, 4), dtype=int)
        self.add_random_tile()
        self.add_random_tile()
        state = self.board.flatten()
        max_tile = np.max(self.board)
        state = np.append(state, max_tile)
        return state

# 深度 Q 网络类
class DQN(nn.Module):
    def __init__(self, input_size, output_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# 经验回放缓冲区类
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(dones)

    def __len__(self):
        return len(self.buffer)

# 可视化函数
def visualize_board(board, ax):
    ax.clear()
    table = Table(ax, bbox=[0, 0, 1, 1])
    nrows, ncols = board.shape
    width, height = 1.0 / ncols, 1.0 / nrows

    # 定义颜色映射
    cmap = mcolors.LinearSegmentedColormap.from_list("", ["white", "yellow", "orange", "red"])

    for (i, j), val in np.ndenumerate(board):
        color = cmap(np.log2(val + 1) / np.log2(2048 + 1)) if val > 0 else "white"
        table.add_cell(i, j, width, height, text=val if val > 0 else "",
                       loc='center', facecolor=color)

    ax.add_table(table)
    ax.set_axis_off()
    plt.draw()
    plt.pause(0.1)

# 训练函数
def train():
    env = Game2048()
    input_size = 17
    output_size = 4
    model = DQN(input_size, output_size).to(device)
    target_model = DQN(input_size, output_size).to(device)
    target_model.load_state_dict(model.state_dict())
    target_model.eval()

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()
    replay_buffer = ReplayBuffer(capacity=10000)
    batch_size = 32
    gamma = 0.99
    epsilon = 1.0
    epsilon_decay = 0.995
    epsilon_min = 0.01
    update_target_freq = 10

    num_episodes = 1000
    fig, ax = plt.subplots()
    for episode in range(num_episodes):
        state = env.reset()
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        done = False
        total_reward = 0
        while not done:
            visualize_board(env.board, ax)
            if random.random() < epsilon:
                action = random.randint(0, output_size - 1)
            else:
                q_values = model(state)
                action = torch.argmax(q_values, dim=1).item()

            next_state, reward, done = env.step(action)
            next_state = torch.FloatTensor(next_state).unsqueeze(0).to(device)
            replay_buffer.add(state.cpu().squeeze(0).numpy(), action, reward, next_state.cpu().squeeze(0).numpy(), done)

            if len(replay_buffer) >= batch_size:
                states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
                states = torch.FloatTensor(states).to(device)
                actions = torch.LongTensor(actions).to(device)
                rewards = torch.FloatTensor(rewards).to(device)
                next_states = torch.FloatTensor(next_states).to(device)
                dones = torch.FloatTensor(dones).to(device)

                q_values = model(states)
                q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

                next_q_values = target_model(next_states)
                next_q_values = next_q_values.max(1)[0]
                target_q_values = rewards + gamma * (1 - dones) * next_q_values

                loss = criterion(q_values, target_q_values)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            state = next_state
            total_reward += reward

        if episode % update_target_freq == 0:
            target_model.load_state_dict(model.state_dict())

        epsilon = max(epsilon * epsilon_decay, epsilon_min)
        print(f"Episode {episode}: Total Reward = {total_reward}, Epsilon = {epsilon}")

    plt.close()

if __name__ == "__main__":
    train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2317079.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【python】http post 在body中传递json数据 以发送

http post 在body中传递json数据 以发送&#xff0c;json的格式非常重要这里要传递json对象&#xff0c;而不是一个json字符串 传递post一个 JSON 字符串 是ok的 是的&#xff0c; {"rsource_rhythm_action_list": {"name": "AI_\\u6708\\u4eae\\u…

[贪心算法]-最大数(lambda 表达式的补充)

1.解析 我们一般使用的排序比较大小都是 a>b 那么a在b的前面 ab 无所谓 a<b a在b的后面 本题的排序则是 ab>ba 那么a在b的前面 abba 无所谓 ab<ba a在b的后面 2.代码 class Solution { public:string largestNumber(vector<int>& nums) {//1.先把所有…

C语言 —— 此去经年梦浪荡魂音 - 深入理解指针(卷二)

目录 1. 数组名与地址 2. 指针访问数组 3.一维数组传参本质 4.二级指针 5. 指针数组 6. 指针数组模拟二维数组 1. 数组名与地址 我们先看下面这个代码&#xff1a; int arr[10] { 1,2,3,4,5,6,7,8,9,10 };int* p &arr[0]; 这里我们使用 &arr[0] 的方式拿到了数…

python实现简单的图片去水印工具

python实现简单的图片去水印工具 使用说明&#xff1a; 点击"打开图片"选择需要处理的图片 在图片上拖拽鼠标选择水印区域&#xff08;红色矩形框&#xff09; 点击"去除水印"执行处理 点击"保存结果"保存处理后的图片 运行效果 先简要说明…

使用dify+deepseek部署本地知识库

使用difydeepseek部署本地知识库 一、概述二、安装windows docker desktop1、确认系统的Hyper-v功能正常启用2、docker官网下载安装windows客户端3、安装完成后的界面如下所示 三、下载安装ollama四、部署本地deepseek五、本地下载部署dify5.1 下载dify的安装包5.2 将dify解压到…

【算法day13】最长公共前缀

最长公共前缀 https://leetcode.cn/problems/longest-common-prefix/submissions/612055945/ 编写一个函数来查找字符串数组中的最长公共前缀。 如果不存在公共前缀&#xff0c;返回空字符串 “”。 class Solution { public:string longestCommonPrefix(vector<string&g…

Java高频面试之集合-13

hello啊&#xff0c;各位观众姥爷们&#xff01;&#xff01;&#xff01;本baby今天来报道了&#xff01;哈哈哈哈哈嗝&#x1f436; 面试官&#xff1a;为什么 hash 函数能降哈希碰撞&#xff1f; 哈希函数通过以下核心机制有效降低碰撞概率&#xff0c;确保不同输入尽可能映…

RGV调度算法(三)--遗传算法

1、基于时间窗 https://wenku.baidu.com/view/470e9fd8b4360b4c2e3f5727a5e9856a57122693.html?_wkts_1741880736197&bdQuery%E7%8E%AF%E7%A9%BF%E8%B0%83%E5%BA%A6%E7%AE%97%E6%B3%95 2.2019年MathorCup高校数学建模挑战赛B题 2019-mathorcupB题-环形穿梭机调度模型&a…

YOLOv8轻量化改进——Coordinate Attention注意力机制

现在针对YOLOv8的架构改进越来越多&#xff0c;今天尝试引入了Coordinate Attention注意力机制以改进对小目标物体的检测效率。 yolov8的下载和安装参考我这篇博客&#xff1a; 基于SeaShips数据集的yolov8训练教程_seaships处理成yolov8-CSDN博客 首先我们可以去官网找到CA注…

基于SpringBoot+Vue的驾校预约管理系统+LW示例参考

1.项目介绍 系统角色&#xff1a;管理员、普通用户、教练功能模块&#xff1a;用户管理、管理员管理、教练管理、教练预约管理、车辆管理、车辆预约管理、论坛管理、基础数据管理等技术选型&#xff1a;SpringBoot&#xff0c;Vue等测试环境&#xff1a;idea2024&#xff0c;j…

ONNX:统一深度学习工作流的关键枢纽

引言 在深度学习领域&#xff0c;模型创建与部署的割裂曾是核心挑战。不同框架训练的模型难以在多样环境部署&#xff0c;而 ONNX&#xff08;Open Neural Network Exchange&#xff09;作为开放式神经网络交换格式&#xff0c;搭建起从模型创建到部署的统一桥梁&#xff0c;完…

蓝桥杯————23年省赛 ——————平方差

3.平方差 - 蓝桥云课 一开始看题我还没有意识到问题的严重性 我丢&#xff0c;我想 的是用两层循环来做&#xff0c;后来我试了一下最坏情况&#xff0c;也就是l1 r 1000000000 结果运行半天没运行出来&#xff0c;我就知道坏了&#xff0c;孩子们&#xff0c;要出事&#…

一、串行通信基础知识

一、串行通信基础知识 1.处理器与外部设备通信有两种方式 并行通信&#xff1a;数据的各个位用多条数据线同时传输。&#xff08;传输速度快&#xff0c;但占用引脚资源多。&#xff09; 串行通信&#xff1a;将数据分成一位一位的形式在一条数据线上逐个传输。&#xff08;线路…

自带多个接口,完全免费使用!

做自媒体的小伙伴们&#xff0c;是不是经常为语音转文字的事儿头疼&#xff1f; 今天给大家推荐一款超实用的语音转文字软件——AsrTools&#xff0c;它绝对是你的得力助手&#xff01; AsrTools 免费的语音转文字软件 这款软件特别贴心&#xff0c;完全免费&#xff0c;而且操…

Qt QML解决SVG图片显示模糊的问题

前言 在QML中直接使用SVG图片&#xff0c;使用Image控件加载资源&#xff0c;显示出来图片是模糊的&#xff0c;很影响使用体验。本文介绍重新绘制SVG图片&#xff0c;然后注册到QML中使用。 效果图&#xff1a; 左边是直接使用Image加载资源显示的效果 右边是重绘后的效果 …

【Linux我做主】基础命令完全指南上篇

Linux基础命令完全指南【上篇】 Linux基础命令完全指南github地址前言命令行操作的引入Linux文件系统树形结构的根文件系统绝对路径和相对路径适用场景Linux目录下的隐藏文件 基本指令目录和文件相关1. ls2. cd和pwdcdpwd 3. touch4. mkdir5. cp6. mv移动目录时覆盖写入的两种特…

Designing Dashboards with SAP Analytics Cloud

Designing Dashboards with SAP Analytics Cloud

项目实战系列:基于瑞萨RA6M5构建多节点OTA升级-系统设计<一>

项目背景 原嵌入式控制系统采用分布式模块化架构&#xff0c;由12个功能板卡&#xff08;通信控制、信号采集、驱动执行等&#xff09;组成。系统维护阶段存在以下痛点&#xff1a; 低效的本地烧录机制&#xff1a;各板卡固件升级需通过JTAG接口逐一手动连接JLINK仿真器&#x…

《AI大模型趣味实战》 No3:快速搭建一个漂亮的AI家庭网站-相册/时间线/日历/多用户/个性化配色/博客/聊天室/AI管家(下)

《AI大模型趣味实战》 No3&#xff1a;快速搭建一个漂亮的AI家庭网站-相册/时间线/日历/多用户/个性化配色/博客/聊天室/AI管家(下) 摘要 本文介绍了家庭网站V1.3版本的更新内容&#xff0c;主要聚焦于AI管家功能的优化与完善。V1.3版本对AI管家模块进行了全面升级&#xff0…