文章目录
- 前言
- 一、class CVRPTester:__init__(self,env_params,model_params, tester_params)
- 1.1函数解析
- 1.2函数分析
- 1.2.1加载预训练模型
- 1.2函数代码
- 二、class CVRPTester:run(self)
- 函数解析
- 函数代码
- 三、class CVRPTester:_test_one_batch(self, batch_size)
- 函数解析
- 函数代码
- 附录
- 代码(全)
前言
学习代码CVRPTester.py,对代码的分析如下。
/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPTester.py
一、class CVRPTester:init(self,env_params,model_params, tester_params)
1.1函数解析
执行流程图链接
1.2函数分析
1.2.1加载预训练模型
代码:
# Restore
model_load = tester_params['model_load']
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
checkpoint = torch.load(checkpoint_fullname, map_location=device)
self.model.load_state_dict(checkpoint['model_state_dict'])
model_load
: 这是一个字典,包含了从哪里加载预训练模型的路径信息以及具体的epoch
:
model_load = tester_params['model_load']
checkpoint_fullname
: 使用 Python 的字符串格式化功能,构造预训练模型的文件路径。
这会生成形如/path/to/model/checkpoint-8100.pt
的文件路径。即需要输入参数path
和epoch
。
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
- 加载模型:
torch.load(checkpoint_fullname, map_location=device)
:从磁盘加载模型检查点(即 .pt 文件),并将其存储在checkpoint
变量中。map_location=device
确保模型会被加载到正确的设备上(GPU 或 CPU)。self.model.load_state_dict(checkpoint['model_state_dict'])
:从加载的检查点中提取模型的状态字典,并将其加载到self.model
中。
checkpoint = torch.load(checkpoint_fullname, map_location=device)
self.model.load_state_dict(checkpoint['model_state_dict'])
示例
假设 tester_params_regret[‘model_load’] 如下所示:
tester_params_regret = {
'model_load': {
'path': '../../pretrained/vrp100',
'epoch': 8100,
},
# 其他参数...
}
然后 checkpoint_fullname
会被构造为/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/pretrained/models/checkpoint-8100.pt
,模型会从该路径加载。
1.2函数代码
def __init__(self,
env_params,
model_params,
tester_params):
# save arguments
self.env_params = env_params
self.model_params = model_params
self.tester_params = tester_params
# result folder, logger
self.logger = getLogger(name='trainer')
self.result_folder = get_result_folder()
# cuda
USE_CUDA = self.tester_params['use_cuda']
if USE_CUDA:
cuda_device_num = self.tester_params['cuda_device_num']
torch.cuda.set_device(cuda_device_num)
device = torch.device('cuda', cuda_device_num)
torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
device = torch.device('cpu')
torch.set_default_tensor_type('torch.FloatTensor')
self.device = device
# ENV and MODEL
self.env = Env(**self.env_params)
self.model = Model(**self.model_params)
# Restore
model_load = tester_params['model_load']
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
checkpoint = torch.load(checkpoint_fullname, map_location=device)
self.model.load_state_dict(checkpoint['model_state_dict'])
# utility
self.time_estimator = TimeEstimator()
二、class CVRPTester:run(self)
函数解析
函数执行流程图链接
函数代码
def run(self):
self.time_estimator.reset()
score_AM = AverageMeter()
aug_score_AM = AverageMeter()
if self.tester_params['test_data_load']['enable']:
self.env.use_saved_problems(self.tester_params['test_data_load']['filename'], self.device)
test_num_episode = self.tester_params['test_episodes']
episode = 0
while episode < test_num_episode:
remaining = test_num_episode - episode
batch_size = min(self.tester_params['test_batch_size'], remaining)
score, aug_score = self._test_one_batch(batch_size)
score_AM.update(score, batch_size)
aug_score_AM.update(aug_score, batch_size)
episode += batch_size
############################
# Logs
############################
elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode)
self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score:{:.3f}, aug_score:{:.3f}".format(
episode, test_num_episode, elapsed_time_str, remain_time_str, score, aug_score))
all_done = (episode == test_num_episode)
if all_done:
self.logger.info(" *** Test Done *** ")
self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg))
self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg))
三、class CVRPTester:_test_one_batch(self, batch_size)
函数解析
执行流程图链接
函数代码
def _test_one_batch(self, batch_size):
# Augmentation
###############################################
if self.tester_params['augmentation_enable']:
aug_factor = self.tester_params['aug_factor']
else:
aug_factor = 1
# Ready
###############################################
self.model.eval()
with torch.no_grad():
self.env.load_problems(batch_size, aug_factor)
reset_state, _, _ = self.env.reset()
self.model.pre_forward(reset_state)
# POMO Rollout
###############################################
state, reward, done = self.env.pre_step()
while not done:
selected, _ = self.model(state)
# shape: (batch, pomo)
state, reward, done = self.env.step(selected)
# Return
###############################################
aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size)
# shape: (augmentation, batch, pomo)
max_pomo_reward, _ = aug_reward.max(dim=2) # get best results from pomo
# shape: (augmentation, batch)
no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive value
max_aug_pomo_reward, _ = max_pomo_reward.max(dim=0) # get best results from augmentation
# shape: (batch,)
aug_score = -max_aug_pomo_reward.float().mean() # negative sign to make positive value
return no_aug_score.item(), aug_score.item()
附录
代码(全)
import torch
import os
from logging import getLogger
from CVRPEnv import CVRPEnv as Env
from CVRPModel import CVRPModel as Model
from utils.utils import *
class CVRPTester:
def __init__(self,
env_params,
model_params,
tester_params):
# save arguments
self.env_params = env_params
self.model_params = model_params
self.tester_params = tester_params
# result folder, logger
self.logger = getLogger(name='trainer')
self.result_folder = get_result_folder()
# cuda
USE_CUDA = self.tester_params['use_cuda']
if USE_CUDA:
cuda_device_num = self.tester_params['cuda_device_num']
torch.cuda.set_device(cuda_device_num)
device = torch.device('cuda', cuda_device_num)
torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
device = torch.device('cpu')
torch.set_default_tensor_type('torch.FloatTensor')
self.device = device
# ENV and MODEL
self.env = Env(**self.env_params)
self.model = Model(**self.model_params)
# Restore
model_load = tester_params['model_load']
checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
checkpoint = torch.load(checkpoint_fullname, map_location=device)
self.model.load_state_dict(checkpoint['model_state_dict'])
# utility
self.time_estimator = TimeEstimator()
def run(self):
self.time_estimator.reset()
score_AM = AverageMeter()
aug_score_AM = AverageMeter()
if self.tester_params['test_data_load']['enable']:
self.env.use_saved_problems(self.tester_params['test_data_load']['filename'], self.device)
test_num_episode = self.tester_params['test_episodes']
episode = 0
while episode < test_num_episode:
remaining = test_num_episode - episode
batch_size = min(self.tester_params['test_batch_size'], remaining)
score, aug_score = self._test_one_batch(batch_size)
score_AM.update(score, batch_size)
aug_score_AM.update(aug_score, batch_size)
episode += batch_size
############################
# Logs
############################
elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(episode, test_num_episode)
self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}], score:{:.3f}, aug_score:{:.3f}".format(
episode, test_num_episode, elapsed_time_str, remain_time_str, score, aug_score))
all_done = (episode == test_num_episode)
if all_done:
self.logger.info(" *** Test Done *** ")
self.logger.info(" NO-AUG SCORE: {:.4f} ".format(score_AM.avg))
self.logger.info(" AUGMENTATION SCORE: {:.4f} ".format(aug_score_AM.avg))
def _test_one_batch(self, batch_size):
# Augmentation
###############################################
if self.tester_params['augmentation_enable']:
aug_factor = self.tester_params['aug_factor']
else:
aug_factor = 1
# Ready
###############################################
self.model.eval()
with torch.no_grad():
self.env.load_problems(batch_size, aug_factor)
reset_state, _, _ = self.env.reset()
self.model.pre_forward(reset_state)
# POMO Rollout
###############################################
state, reward, done = self.env.pre_step()
while not done:
selected, _ = self.model(state)
# shape: (batch, pomo)
state, reward, done = self.env.step(selected)
# Return
###############################################
aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size)
# shape: (augmentation, batch, pomo)
max_pomo_reward, _ = aug_reward.max(dim=2) # get best results from pomo
# shape: (augmentation, batch)
no_aug_score = -max_pomo_reward[0, :].float().mean() # negative sign to make positive value
max_aug_pomo_reward, _ = max_pomo_reward.max(dim=0) # get best results from augmentation
# shape: (batch,)
aug_score = -max_aug_pomo_reward.float().mean() # negative sign to make positive value
return no_aug_score.item(), aug_score.item()