20250303-代码笔记-class CVRPTester

news2025/3/5 4:11:03

文章目录

  • 前言
  • 一、class CVRPTester:__init__(self,env_params,model_params, tester_params)
    • 1.1函数解析
    • 1.2函数分析
      • 1.2.1加载预训练模型
    • 1.2函数代码
  • 二、class CVRPTester:run(self)
    • 函数解析
    • 函数代码
  • 三、class CVRPTester:_test_one_batch(self, batch_size)
    • 函数解析
    • 函数代码
  • 附录
    • 代码(全)


前言

学习代码CVRPTester.py,对代码的分析如下。

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPTester.py


一、class CVRPTester:init(self,env_params,model_params, tester_params)

1.1函数解析

执行流程图链接

在这里插入图片描述

1.2函数分析

1.2.1加载预训练模型

代码:

# Restore
model_load = tester_params['model_load']
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
checkpoint = torch.load(checkpoint_fullname, map_location=device)
self.model.load_state_dict(checkpoint['model_state_dict'])
  • model_load: 这是一个字典,包含了从哪里加载预训练模型的路径信息以及具体的 epoch
model_load = tester_params['model_load']
  • checkpoint_fullname: 使用 Python 的字符串格式化功能,构造预训练模型的文件路径。
    这会生成形如 /path/to/model/checkpoint-8100.pt 的文件路径。即需要输入参数pathepoch
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
  • 加载模型:
    • torch.load(checkpoint_fullname, map_location=device):从磁盘加载模型检查点(即 .pt 文件),并将其存储在 checkpoint 变量中。map_location=device 确保模型会被加载到正确的设备上(GPU 或 CPU)。
    • self.model.load_state_dict(checkpoint['model_state_dict']):从加载的检查点中提取模型的状态字典,并将其加载到 self.model 中。
checkpoint = torch.load(checkpoint_fullname, map_location=device)
self.model.load_state_dict(checkpoint['model_state_dict'])

示例
假设 tester_params_regret[‘model_load’] 如下所示:

tester_params_regret = {
    'model_load': {
        'path': '../../pretrained/vrp100',
       
        'epoch': 8100,
    },
    # 其他参数...
}

然后 checkpoint_fullname 会被构造为/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/pretrained/models/checkpoint-8100.pt,模型会从该路径加载。
在这里插入图片描述


1.2函数代码

    def __init__(self,
                 env_params,
                 model_params,
                 tester_params):

        # save arguments
        self.env_params = env_params
        self.model_params = model_params
        self.tester_params = tester_params

        # result folder, logger
        self.logger = getLogger(name='trainer')
        self.result_folder = get_result_folder()


        # cuda
        USE_CUDA = self.tester_params['use_cuda']
        if USE_CUDA:
            cuda_device_num = self.tester_params['cuda_device_num']
            torch.cuda.set_device(cuda_device_num)
            device = torch.device('cuda', cuda_device_num)
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            device = torch.device('cpu')
            torch.set_default_tensor_type('torch.FloatTensor')
        self.device = device

        # ENV and MODEL
        self.env = Env(**self.env_params)
        self.model = Model(**self.model_params)

        # Restore
        model_load = tester_params['model_load']
        checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
        checkpoint = torch.load(checkpoint_fullname, map_location=device)
        self.model.load_state_dict(checkpoint['model_state_dict'])

        # utility
        self.time_estimator = TimeEstimator()


二、class CVRPTester:run(self)

函数解析

函数执行流程图链接
在这里插入图片描述

函数代码

    def run(self):
        self.time_estimator.reset()

        score_AM = AverageMeter()
        aug_score_AM = AverageMeter()

        if self.tester_params['test_data_load']['enable']:
            self.env.use_saved_problems(self.tester_params['test_data_load']['filename'], self.device)

        test_num_episode = self.tester_params['test_episodes']
        episode = 0

        while episode < test_num_episode:

            remaining = test_num_episode - episode
            batch_size = min(self.tester_params['test_batch_size'], remaining)

            score, aug_score = self._test_one_batch(batch_size)

            score_AM.update(score, batch_size)
            aug_score_AM.update(aug_score, batch_size)

            episode += batch_size

            ############################
            # Logs
            ############################
            elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode)
            self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score:{:.3f}, aug_score:{:.3f}".format(
                episode, test_num_episode, elapsed_time_str, remain_time_str, score, aug_score))

            all_done = (episode == test_num_episode)

            if all_done:
                self.logger.info(" *** Test Done *** ")
                self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg))
                self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg))


三、class CVRPTester:_test_one_batch(self, batch_size)

函数解析

执行流程图链接
在这里插入图片描述

函数代码

    def _test_one_batch(self, batch_size):

        # Augmentation
        ###############################################
        if self.tester_params['augmentation_enable']:
            aug_factor = self.tester_params['aug_factor']
        else:
            aug_factor = 1

        # Ready
        ###############################################
        self.model.eval()
        with torch.no_grad():
            self.env.load_problems(batch_size, aug_factor)
            reset_state, _, _ = self.env.reset()
            self.model.pre_forward(reset_state)

        # POMO Rollout
        ###############################################
        state, reward, done = self.env.pre_step()
        while not done:
            selected, _ = self.model(state)
            # shape: (batch, pomo)
            state, reward, done = self.env.step(selected)

        # Return
        ###############################################
        aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size)
        # shape: (augmentation, batch, pomo)

        max_pomo_reward, _ = aug_reward.max(dim=2)  # get best results from pomo
        # shape: (augmentation, batch)
        no_aug_score = -max_pomo_reward[0, :].float().mean()  # negative sign to make positive value

        max_aug_pomo_reward, _ = max_pomo_reward.max(dim=0)  # get best results from augmentation
        # shape: (batch,)
        aug_score = -max_aug_pomo_reward.float().mean()  # negative sign to make positive value

        return no_aug_score.item(), aug_score.item()


附录

代码(全)


import torch

import os
from logging import getLogger

from CVRPEnv import CVRPEnv as Env
from CVRPModel import CVRPModel as Model

from utils.utils import *


class CVRPTester:
    def __init__(self,
                 env_params,
                 model_params,
                 tester_params):

        # save arguments
        self.env_params = env_params
        self.model_params = model_params
        self.tester_params = tester_params

        # result folder, logger
        self.logger = getLogger(name='trainer')
        self.result_folder = get_result_folder()


        # cuda
        USE_CUDA = self.tester_params['use_cuda']
        if USE_CUDA:
            cuda_device_num = self.tester_params['cuda_device_num']
            torch.cuda.set_device(cuda_device_num)
            device = torch.device('cuda', cuda_device_num)
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            device = torch.device('cpu')
            torch.set_default_tensor_type('torch.FloatTensor')
        self.device = device

        # ENV and MODEL
        self.env = Env(**self.env_params)
        self.model = Model(**self.model_params)

        # Restore
        model_load = tester_params['model_load']
        checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
        checkpoint = torch.load(checkpoint_fullname, map_location=device)
        self.model.load_state_dict(checkpoint['model_state_dict'])

        # utility
        self.time_estimator = TimeEstimator()

    def run(self):
        self.time_estimator.reset()

        score_AM = AverageMeter()
        aug_score_AM = AverageMeter()

        if self.tester_params['test_data_load']['enable']:
            self.env.use_saved_problems(self.tester_params['test_data_load']['filename'], self.device)

        test_num_episode = self.tester_params['test_episodes']
        episode = 0

        while episode < test_num_episode:

            remaining = test_num_episode - episode
            batch_size = min(self.tester_params['test_batch_size'], remaining)

            score, aug_score = self._test_one_batch(batch_size)

            score_AM.update(score, batch_size)
            aug_score_AM.update(aug_score, batch_size)

            episode += batch_size

            ############################
            # Logs
            ############################
            elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode)
            self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score:{:.3f}, aug_score:{:.3f}".format(
                episode, test_num_episode, elapsed_time_str, remain_time_str, score, aug_score))

            all_done = (episode == test_num_episode)

            if all_done:
                self.logger.info(" *** Test Done *** ")
                self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg))
                self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg))

    def _test_one_batch(self, batch_size):

        # Augmentation
        ###############################################
        if self.tester_params['augmentation_enable']:
            aug_factor = self.tester_params['aug_factor']
        else:
            aug_factor = 1

        # Ready
        ###############################################
        self.model.eval()
        with torch.no_grad():
            self.env.load_problems(batch_size, aug_factor)
            reset_state, _, _ = self.env.reset()
            self.model.pre_forward(reset_state)

        # POMO Rollout
        ###############################################
        state, reward, done = self.env.pre_step()
        while not done:
            selected, _ = self.model(state)
            # shape: (batch, pomo)
            state, reward, done = self.env.step(selected)

        # Return
        ###############################################
        aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size)
        # shape: (augmentation, batch, pomo)

        max_pomo_reward, _ = aug_reward.max(dim=2)  # get best results from pomo
        # shape: (augmentation, batch)
        no_aug_score = -max_pomo_reward[0, :].float().mean()  # negative sign to make positive value

        max_aug_pomo_reward, _ = max_pomo_reward.max(dim=0)  # get best results from augmentation
        # shape: (batch,)
        aug_score = -max_aug_pomo_reward.float().mean()  # negative sign to make positive value

        return no_aug_score.item(), aug_score.item()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2309801.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++学习之C++初识、C++对C语言增强、对C语言扩展

一.C初识 1.C简介 2.第一个C程序 //#include <iostream> //iostream 相当于 C语言下的 stdio.h i - input 输入 o -output 输出 //using namespace std; //using 使用 namespace 命名空间 std 标准 &#xff0c;理解为打开一个房间&#xff0c;房间里有我们所需…

关于虚拟环境中遇到的bug

conda和cmd介绍 介绍 Conda 概述&#xff1a; Conda是一个开源包管理系统和环境管理系统&#xff0c;尤其适用于Python和R语言的开发环境。它允许用户创建独立的虚拟环境&#xff0c;方便地管理依赖包和软件版本。 特点&#xff1a; 环境管理&#xff1a;可以创建、导入、导…

【网络安全 | 渗透测试】GraphQL精讲一:基础知识

未经许可,不得转载, 文章目录 GraphQL 定义GraphQL 工作原理GraphQL 模式GraphQL 查询GraphQL 变更(Mutations)查询(Queries)和变更(Mutations)的组成部分字段(Fields)参数(Arguments)变量别名(Aliases)片段(Fragments)订阅(Subscriptions)自省(Introspecti…

什么是JTAG、SWD?

一、什么是JTAG&#xff1f; JTAG&#xff08;Joint Test Action Group&#xff0c;联合测试行动小组&#xff09;是一种国际标准测试协议&#xff0c;常用于芯片内部测试及对系统进行调试、编程等操作。以下从其起源、工作原理、接口标准、应用场景等方面详细介绍&#xff1a…

如何在Apple不再支持的MacOS上安装Homebrew

手头有一台2012年产的Macbook Pro&#xff0c;系统版本停留在了10.15.7&#xff08;2020年9月24日发布的&#xff09;。MacOS 11及后续的版本都无法安装到这台老旧的电脑上。想通过pkg安装Homebrew&#xff0c;发现Homebrew releases里最新的pkg安装包不支持MacOS 10.15.7&…

在笔记本电脑上用DeepSeek搭建个人知识库

最近DeepSeek爆火&#xff0c;试用DeepSeek的企业和个人越来越多。最常见的应用场景就是知识库和知识问答。所以本人也试用了一下&#xff0c;在笔记本电脑上部署DeepSeek并使用开源工具搭建一套知识库&#xff0c;实现完全在本地环境下使用本地文档搭建个人知识库。操作过程共…

Java面试第七山!《MySQL索引》

一、索引的本质与作用 索引是帮助MySQL高效获取数据的数据结构&#xff0c;类似于书籍的目录。它通过减少磁盘I/O次数&#xff08;即减少数据扫描量&#xff09;来加速查询&#xff0c;尤其在百万级数据场景下&#xff0c;索引可将查询效率提升数十倍。 核心作用&#xff1a;…

DeepSeek 助力 Vue3 开发:打造丝滑的弹性布局(Flexbox)

前言&#xff1a;哈喽&#xff0c;大家好&#xff0c;今天给大家分享一篇文章&#xff01;并提供具体代码帮助大家深入理解&#xff0c;彻底掌握&#xff01;创作不易&#xff0c;如果能帮助到大家或者给大家一些灵感和启发&#xff0c;欢迎收藏关注哦 &#x1f495; 目录 Deep…

动态表头报表的绘制与导出

目录 一、效果图 二、整体思路 三、代码区 一、效果图 根据选择的日期范围动态生成表头&#xff08;eg&#xff1a;2025.2.24--2025.03.03&#xff09;每个日期又分为白班、夜班&#xff1b;数据列表中对产线合并单元格。支持按原格式导出对应的报表excel。 点击空白区可新…

DeepSeek 助力 Vue3 开发:打造丝滑的密码输入框(Password Input)

前言&#xff1a;哈喽&#xff0c;大家好&#xff0c;今天给大家分享一篇文章&#xff01;并提供具体代码帮助大家深入理解&#xff0c;彻底掌握&#xff01;创作不易&#xff0c;如果能帮助到大家或者给大家一些灵感和启发&#xff0c;欢迎收藏关注哦 &#x1f495; 目录 Deep…

Linux 基础---文件权限

概念 文件权限是针对文件所有者、文件所属组、其他人这三类人而言的&#xff0c;对应的操作是chmod。设置方式&#xff1a;文字设定法、数字设定法。 文字设定法&#xff1a;r,w,x,- 来描述用户对文件的操作权限数字设定法&#xff1a;0,1,2,3,4,5,6,7 来描述用户对文件的操作…

SpringBoot五:JSR303校验

精心整理了最新的面试资料和简历模板&#xff0c;有需要的可以自行获取 点击前往百度网盘获取 点击前往夸克网盘获取 松散绑定 意思是比如在yaml中写的是last-name&#xff0c;这个和lastName意思是一样的&#xff0c;-后的字母默认是大写的 JSR303校验 就是可以在字段增加…

【计算机网络】考研复试高频知识点总结

文章目录 一、基础概念1、计算机⽹络的定义2、计算机⽹络的目标3、计算机⽹络的组成4、计算机⽹络的分类5、计算机⽹络的拓扑结构6、计算机⽹络的协议7、计算机⽹络的分层结构8、OSI 参考模型9、TCP/IP 参考模型10、五层协议体系结构 二、物理层1、物理层的功能2、传输媒体3、 …

Error Density-dependent Empirical Risk Minimization

经验误差密度依赖的风险最小化 v.s. 经验风险最小化 论文&#xff1a; 《 Error Density-dependent Empirical Risk Minimization》 发表在&#xff1a; ESWA’24 相关代码&#xff1a; github.com/zxlml/EDERM 研究背景 传统的经验风险最小化&#xff08;ERM&#xff09;方…

02_NLP文本预处理之文本张量表示法

文本张量表示法 概念 将文本使用张量进行表示,一般将词汇表示为向量,称为词向量,再由各个词向量按顺序组成矩阵形成文本表示 例如: ["人生", "该", "如何", "起头"]># 每个词对应矩阵中的一个向量 [[1.32, 4,32, 0,32, 5.2],[3…

Spring Boot全局异常处理:“危机公关”团队

目录 一、全局异常处理的作用二、Spring Boot 实现全局异常处理&#xff08;附上代码实例&#xff09;三、总结&#xff1a; &#x1f31f;我的其他文章也讲解的比较有趣&#x1f601;&#xff0c;如果喜欢博主的讲解方式&#xff0c;可以多多支持一下&#xff0c;感谢&#x1…

C# OnnxRuntime部署DAMO-YOLO香烟检测

目录 说明 效果 模型信息 项目 代码 下载 参考 说明 效果 模型信息 Model Properties ------------------------- --------------------------------------------------------------- Inputs ------------------------- name&#xff1a;input tensor&#xff1a;Floa…

[密码学实战]Java生成SM2根证书及用户证书

前言 在国密算法体系中,SM2是基于椭圆曲线密码(ECC)的非对称加密算法,广泛应用于数字证书、签名验签等场景。本文将结合代码实现,详细讲解如何通过Java生成SM2根证书及用户证书,并深入分析其核心原理。 一、证书验证 1.代码运行结果 2.根证书验证 3.用户证书验证 二、…

安装 cnpm 出现 Unsupported URL Type “npm:“: npm:string-width@^4.2.0

Unsupported URL Type "npm:": npm:string-width^4.2.0 可能是 node 版本太低了&#xff0c;需要安装低版本的 cnpm 试试 npm cache clean --force npm config set strict-ssl false npm install -g cnpm --registryhttps://registry.npmmirror.com 改为 npm insta…

探秘基带算法:从原理到5G时代的通信变革【九】QPSK调制/解调

文章目录 2.8 QPSK 调制 / 解调简介QPSK 发射机的实现与原理QPSK 接收机的实现与原理QPSK 性能仿真QPSK 变体分析 本博客为系列博客&#xff0c;主要讲解各基带算法的原理与应用&#xff0c;包括&#xff1a;viterbi解码、Turbo编解码、Polar编解码、CORDIC算法、CRC校验、FFT/…