VGG16_Weights.DEFAULT
或VGG16_Weights.IMAGENET1K_V1
:使用在 ImageNet 上训练的权重
import torchvision
from torch import nn
vgg16_false = torchvision.models.vgg16()
vgg16_true = torchvision.models.vgg16(weights='DEFAULT')
print(vgg16_true)
train_data = torchvision.datasets.CIFAR10('./data', train=True, transform=torchvision.transforms.ToTensor(),
download=True)
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)
修改 classifier 中
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)
修改指定
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)
模型保存方式
加载方式
思考data 赋值给imgs,targets
for data in dataloader:
PyTorch DataLoader 返回的数据格式 第一部分是图像张量信息
第二部分是
- tensor([1, 0, 1, 3, 1, 1, 1, 7, 3, 9, 6, 8, 4, ...])
- 每个数字代表一个类别的索引(从 0 到 9,因为 CIFAR-10 有 10 个类别)。