文章目录
- - 模型搭建
- 1. 搭建ResNetGenerator
- 2. 网络实例化
- 3.加载预训练模型权重文件
- 4. 神经网络设置为评估模式
- 预测处理
- 1. 定义图片的预处理方法
- 2. 导入图片
- 3. 预处理图片
- 4. 调用模型
- 5. 输出结果
- 模型搭建
1. 搭建ResNetGenerator
import torch
import torch.nn as nn
class ResNetBlock(nn.Module): # <1>
def __init__(self, dim):
super(ResNetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim)
def build_conv_block(self, dim):
conv_block = []
conv_block += [nn.ReflectionPad2d(1)]
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(dim),
nn.ReLU(True)]
conv_block += [nn.ReflectionPad2d(1)]
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x) # <2>
return out
class ResNetGenerator(nn.Module):
def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): # <3>
assert(n_blocks >= 0)
super(ResNetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
nn.InstanceNorm2d(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=True),
nn.InstanceNorm2d(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResNetBlock(ngf * mult)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=True),
nn.InstanceNorm2d(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input): # <3>
return self.model(input)
2. 网络实例化
netG = ResNetGenerator()
3.加载预训练模型权重文件
model_path = '../data/p1ch2/horse2zebra_0.4.0.pth'
model_data = torch.load(model_path)
netG.load_state_dict(model_data)
4. 神经网络设置为评估模式
netG.eval()
netG.eval()
是 PyTorch 中的一个方法,用于将神经网络模型设置为评估(evaluation)模式。
-
关闭 Dropout 和 Batch Normalization:
- 在训练过程中,Dropout 层会随机丢弃一些神经元,以防止过拟合。Batch Normalization 层会根据每个批次的数据计算均值和方差,以稳定训练过程。
- 在评估模式下,Dropout 层会关闭,所有神经元都会参与计算。Batch Normalization 层会使用训练过程中计算的均值和方差,而不是当前批次的数据。
-
确保一致性:
- 在评估模式下,模型的行为会更加一致和可预测,因为不会受到随机丢弃神经元或批次数据统计特性的影响。
-
推理和测试:
- 在进行模型推理或测试时,应该始终将模型设置为评估模式,以确保得到准确和稳定的结果。
预测处理
1. 定义图片的预处理方法
from PIL import Image
from torchvision import transforms
preprocess = transforms.Compose([
transforms.Resize((262, 461)), # 调整图像大小
transforms.ToTensor(), # 转换为张量
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化
])
2. 导入图片
img = Image.open("../data/p1ch2/horse.jpg")
3. 预处理图片
# 确保图像有3个通道
if img.mode != 'RGB':
img = img.convert('RGB')
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)
4. 调用模型
out_t = (batch_out.data.squeeze() + 1.0) /2
5. 输出结果
out_t = (batch_out.data.squeeze() + 1.0) /2
out_img = transforms.ToPILImage()(out_t)
# out_img.save('../data/p1ch2/zebra.jpg')
out_img
【注*:该模型的作用是将图片中的马,生成为斑马】
(完)