一、优化器
optimizer = optim.SGD(model.parameters(), lr=0.01(学习速率), momentum=0.9) optimizer = optim.Adam([var1, var2], lr=0.0001)
一般,学习率的设置,先从大的设置,逐渐变小。
神经网络可以参见上篇文章,接上篇文章的神经网络做模型优化:
sun = SUN() #调用神经网络 # 设置优化器 # 随机梯度下降 optim = torch.optim.SGD(sun.parameters(), lr=0.01) # 循环学习20次 for epoch in range(20): # 整体误差的总和 running_loss = 0.0 # 只循环了一次,学习了一次,在外面再次加一个循环 for data in dataloader: imgs, targets = data outputs = sun(imgs) result_loss = loss(outputs, targets) # 将模型中每一个可以调节参数对应的梯度调为零 optim.zero_grad() # 得到可以调节参数的梯度 result_loss.backward() optim.step() # 总体误差的总和 running_loss = running_loss + result_loss print(running_loss)
输出结果为:
由输出结果可见,模型优化器,使得误差总和在不断的变小。
二、模型的使用与修改
import torchvision
下载的数据集存放位置:
在下载ImageNet数据集时,会出现报错,数据集有143G,已经不支持下载。设置语句:
vgg16_false = torchvision.models.vgg16(pretrained=False)
当为False时,只是加载网络模型(也就是像之前的网络模型那样,只是加载模型,含有卷积,池化等,其中的参数都是默认的)
vgg16_true = torchvision.models.vgg16(pretrained=True)
当为True时,不仅加载模型,还要加载对应的参数。
VGG16将数据集分成1000个类。
print(vgg16_true)
输出结果:
之前使用的数据集CIFAR10数据集输出的为10个分类,于是,我们也可以根据现有的网络进行修改:
import torchvision # train_data = torchvision.datasets.ImageNet("../data_ImageNet", split='train', download= True, # transform=torchvision.transforms.ToTensor()) from torch import nn vgg16_false = torchvision.models.vgg16(pretrained=False) vgg16_true = torchvision.models.vgg16(pretrained=True) print(vgg16_true) # vgg16_true.add_module('add_linear', nn.Linear(1000,10)) # print(vgg16_true) 加在末尾 # 加在开头 vgg16_true.classifier.add_module('add_linear', nn.Linear(1000,10)) print(vgg16_true) print(vgg16_false) vgg16_false.classifier[6] = nn.Linear(4096, 10) print(vgg16_false)
二、网络模型的保存与读取
出现的问题:
AttributeError: Can't get attribute 'SUN' on <module '__main__' from 'D:/test pytorch/learningplan1/models_load.py'>
解决的办法是需要将网络模型放在读取的上面,不需要引用,但是需要将其放置在其上方。
当然,引用别人的模型时,就非常的不便,所以使用:
from models_save import *
就能够解决此问题。
三、模型的保存
import torch import torchvision from torch import nn vgg16 = torchvision.models.vgg16(pretrained = False) # 保存方式1(结构与参数均保存) torch.save(vgg16, 'vgg16_method1.pth') # 保存方式2(参数保存为字典形式) torch.save(vgg16.state_dict(), "vgg16_method2.pth") # 陷阱 class SUN(nn.Module): def __init__(self): super(SUN, self).__init__() self.convv = nn.Conv2d(3, 64, kernel_size=3) def forward(self, x): x = self.convv(x) return x sun = SUN() torch.save(sun, "sun_method1.pth")
四、模型的使用
import torch # 方式1 保存后加载 # model = torch.load("vgg16_method1.pth") # print(model) # 方式2 保存后加载 import torchvision from torch import nn from models_save import * vgg16 = torchvision.models.vgg16(pretrained = False) vgg16.load_state_dict(torch.load("vgg16_method2.pth")) # model = torch.load("vgg16_method2.pth") # print(vgg16) # 加载网络模型 # class SUN(nn.Module): # def __init__(self): # super(SUN, self).__init__() # self.convv = nn.Conv2d(3, 64, kernel_size=3) # # def forward(self, x): # x = self.convv(x) # return x model = torch.load("sun_method1.pth") print(model)