文章目录
- 前言
- 一、get_random_problems 函数分析
- 二、augment_xy_data_by_8_fold 函数分析
- 代码
前言
该笔记分析代码的功能是生成随机VRP问题的数据,包含仓库坐标、节点坐标和节点需求。
对该代码进行改进
20250412-代码改进-拟蒙特卡洛
一、get_random_problems 函数分析
depot_xy = torch.rand(size=(batch_size, 1, 2))
- 生成仓库坐标:
- 生成形状为
(batch_size, 1, 2)
的随机张量,表示每个批次中仓库的二维坐标(范围[0,1)
)。
- 生成形状为
node_xy = torch.rand(size=(batch_size, problem_size, 2))
- 生成节点坐标:
- 生成形状为
(batch_size, problem_size, 2)
的随机张量,表示每个批次中所有节点的二维坐标。
- 生成形状为
if problem_size == 20:
demand_scaler = 30
elif problem_size == 50:
demand_scaler = 40
elif problem_size == 100:
demand_scaler = 50
node_demand = torch.randint(1, 10, size=(batch_size, problem_size)) / demand_scaler
- 生成节点需求:
- 根据
problem_size
选择缩放因子demand_scaler
。 - 生成 1~9 的整数需求,并缩放到
[1/50, 9/50]
等区间,确保需求值为浮点数。
- 根据
二、augment_xy_data_by_8_fold 函数分析
功能:通过8种几何变换对坐标数据进行增强,扩充数据集。
x = xy_data[:, :, [0]] # 提取x坐标
y = xy_data[:, :, [1]] # 提取y坐标
- 拆分坐标:
- 从输入数据
xy_data
(形状(batch, N, 2)
)分离出x和y分量。
- 从输入数据
dat1 = torch.cat((x, y), dim=2) # 原始坐标
dat2 = torch.cat((1 - x, y), dim=2) # x轴镜像
dat3 = torch.cat((x, 1 - y), dim=2) # y轴镜像
dat4 = torch.cat((1 - x, 1 - y), dim=2) # x+y轴镜像
dat5 = torch.cat((y, x), dim=2) # 转置坐标
dat6 = torch.cat((1 - y, x), dim=2) # 转置后x轴镜像
dat7 = torch.cat((y, 1 - x), dim=2) # 转置后y轴镜像
dat8 = torch.cat((1 - y, 1 - x), dim=2) # 转置后x+y轴镜像
- 生成8种变换:
- 对坐标进行镜像翻转和转置操作,生成8种变体。
aug_xy_data = torch.cat((dat1, dat2, ..., dat8), dim=0)
- 合并增强数据:
- 将8种变换后的数据沿批次维度拼接,最终形状为
(8*batch, N, 2)
。
代码
import torch
import numpy as np
def get_random_problems(batch_size, problem_size):
depot_xy = torch.rand(size=(batch_size, 1, 2))
# shape: (batch, 1, 2)
node_xy = torch.rand(size=(batch_size, problem_size, 2))
# shape: (batch, problem, 2)
if problem_size == 20:
demand_scaler = 30
elif problem_size == 50:
demand_scaler = 40
elif problem_size == 100:
demand_scaler = 50
else:
raise NotImplementedError
node_demand = torch.randint(1, 10, size=(batch_size, problem_size)) / float(demand_scaler)
# shape: (batch, problem)
return depot_xy, node_xy, node_demand
def augment_xy_data_by_8_fold(xy_data):
# xy_data.shape: (batch, N, 2)
x = xy_data[:, :, [0]]
y = xy_data[:, :, [1]]
# x,y shape: (batch, N, 1)
dat1 = torch.cat((x, y), dim=2)
dat2 = torch.cat((1 - x, y), dim=2)
dat3 = torch.cat((x, 1 - y), dim=2)
dat4 = torch.cat((1 - x, 1 - y), dim=2)
dat5 = torch.cat((y, x), dim=2)
dat6 = torch.cat((1 - y, x), dim=2)
dat7 = torch.cat((y, 1 - x), dim=2)
dat8 = torch.cat((1 - y, 1 - x), dim=2)
aug_xy_data = torch.cat((dat1, dat2, dat3, dat4, dat5, dat6, dat7, dat8), dim=0)
# shape: (8*batch, N, 2)
return aug_xy_data