CIFAR-10 是一个经典的计算机视觉数据集,广泛用于图像分类任务。它包含 10 个类别的 60,000 张彩色图像,每张图像的大小是 32x32 像素。数据集被分为 50,000 张训练图像和 10,000 张测试图像。每个类别包含 6,000 张图像,具体类别包括:
- 飞机 (airplane)
- 汽车 (automobile)
- 鸟 (bird)
- 猫 (cat)
- 鹿 (deer)
- 狗 (dog)
- 青蛙 (frog)
- 马 (horse)
- 船 (ship)
- 卡车 (truck)
CIFAR-10 是一个多类分类问题,目标是根据图像内容(例如,物体的形状、颜色等特征)预测图像所属的类别。图像分类模型(如卷积神经网络 CNN)常用于这个任务,通过学习图像的空间特征来做出预测。
来看看实现过程:
import torch
import torchvision.datasets
from torch.utils.data import DataLoader
from torch import nn
train_data = torchvision.datasets.CIFAR10(root="../input/cifar10-python",train=True,transform=torchvision.transforms.ToTensor(),
download=True)
test_data = torchvision.datasets.CIFAR10(root="../input/cifar10-python",train=False,transform=torchvision.transforms.ToTensor(),
download=True)
print(f"train length: {len(train_data)}")
print(f"test length: {len(test_data)}")
Files already downloaded and verified Files already downloaded and verified train length: 50000 test length: 10000
找到了CIFAR10数据集并且导入进来,用了三个卷积层的网络模型来训练,进行了10轮训练。
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3,32,5,1,2),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, 5, 1, 2),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, 1, 2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64*4*4,64),
nn.Linear(64,10)
)
def forward(self,x):
x = self.model(x)
return x
mynet = CNN()
mynet = mynet.cuda()
loss_func = nn.CrossEntropyLoss().cuda()
learning_rate = 0.0001
optimizer = torch.optim.Adam(mynet.parameters(),lr=learning_rate)
total_train = 0
total_test = 0
epoch = 10
for i in range(epoch):
print(f"----No.{i+1} training...-----")
mynet.train()
for data in train_dataloader:
imgs, targets = data
imgs = imgs.cuda()
targets = targets.cuda()
outputs = mynet(imgs)
loss = loss_func(outputs,targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_train = total_train + 1
if total_train % 100 == 0:
print(f"训练次数:{total_train},loss:{loss.item()}")
#测试
mynet.eval()
total_test_loss = 0
total_accuracy = 0
with torch.no_grad():
for data in test_dataloader:
imgs, targets = data
imgs = imgs.cuda()
targets = targets.cuda()
outputs = mynet(imgs)
loss = loss_func(outputs, targets)
total_test_loss = total_test_loss + loss.item()
accuracy = (outputs.argmax(1) == targets).sum()
total_accuracy = total_accuracy + accuracy
print(f"测试集的loss:{total_test_loss},准确率:{total_accuracy/len(test_data)}")
torch.save(mynet, f'myCNN_{i+1}p.pth')
print("模型保存成功")
----No.1 training...----- 训练次数:100,loss:2.0156445503234863 训练次数:200,loss:1.999146580696106 训练次数:300,loss:1.860052466392517 训练次数:400,loss:1.7510318756103516 训练次数:500,loss:1.7712416648864746 训练次数:600,loss:1.6994789838790894 训练次数:700,loss:1.7278780937194824 测试集的loss:257.74497163295746,准确率:0.41990000009536743 模型保存成功 ----No.2 training...----- 训练次数:800,loss:1.515326976776123 训练次数:900,loss:1.485555648803711 训练次数:1000,loss:1.6138449907302856 训练次数:1100,loss:1.7650551795959473 训练次数:1200,loss:1.4380264282226562 训练次数:1300,loss:1.3843588829040527 训练次数:1400,loss:1.5849156379699707 训练次数:1500,loss:1.5038520097732544 测试集的loss:236.6359145641327,准确率:0.47110000252723694 模型保存成功 ----No.3 training...----- 训练次数:1600,loss:1.4474828243255615 训练次数:1700,loss:1.4474865198135376 训练次数:1800,loss:1.7310973405838013 训练次数:1900,loss:1.5719612836837769 训练次数:2000,loss:1.6212022304534912 训练次数:2100,loss:1.2924069166183472 训练次数:2200,loss:1.256321907043457 训练次数:2300,loss:1.560215711593628 测试集的loss:221.27214550971985,准确率:0.5011000037193298 模型保存成功 ----No.4 training...----- 训练次数:2400,loss:1.4557472467422485 训练次数:2500,loss:1.2620049715042114 训练次数:2600,loss:1.4703019857406616 训练次数:2700,loss:1.4131494760513306 训练次数:2800,loss:1.303225040435791 训练次数:2900,loss:1.4961038827896118 训练次数:3000,loss:1.2810102701187134 训练次数:3100,loss:1.337519645690918 测试集的loss:210.63251876831055,准确率:0.5252999663352966 模型保存成功 ----No.5 training...----- 训练次数:3200,loss:1.1311390399932861 训练次数:3300,loss:1.2354803085327148 训练次数:3400,loss:1.2415772676467896 训练次数:3500,loss:1.4213279485702515 训练次数:3600,loss:1.4151396751403809 训练次数:3700,loss:1.2579320669174194 训练次数:3800,loss:1.201486349105835 训练次数:3900,loss:1.287066102027893 测试集的loss:202.65885722637177,准确率:0.5475999712944031 模型保存成功 ----No.6 training...----- 训练次数:4000,loss:1.2759090662002563 训练次数:4100,loss:1.3534283638000488 训练次数:4200,loss:1.4388338327407837 训练次数:4300,loss:1.1126259565353394 训练次数:4400,loss:1.072700023651123 训练次数:4500,loss:1.2942607402801514 训练次数:4600,loss:1.3078550100326538 测试集的loss:195.93554836511612,准确率:0.5615000128746033 模型保存成功 ----No.7 training...----- 训练次数:4700,loss:1.3510404825210571 训练次数:4800,loss:1.3887534141540527 训练次数:4900,loss:1.2628172636032104 训练次数:5000,loss:1.3063734769821167 训练次数:5100,loss:0.9366315007209778 训练次数:5200,loss:1.208983063697815 训练次数:5300,loss:1.0933520793914795 训练次数:5400,loss:1.2654058933258057 测试集的loss:190.015959918499,准确率:0.5735999941825867 模型保存成功 ----No.8 training...----- 训练次数:5500,loss:1.1543941497802734 训练次数:5600,loss:1.0732381343841553 训练次数:5700,loss:1.179479718208313 训练次数:5800,loss:1.0669857263565063 训练次数:5900,loss:1.3145105838775635 训练次数:6000,loss:1.4563915729522705 训练次数:6100,loss:1.0026252269744873 训练次数:6200,loss:0.9769096374511719 测试集的loss:184.76930475234985,准确率:0.5831999778747559 模型保存成功 ----No.9 training...----- 训练次数:6300,loss:1.2531676292419434 训练次数:6400,loss:1.0582406520843506 训练次数:6500,loss:1.467718482017517 训练次数:6600,loss:0.9885475635528564 训练次数:6700,loss:0.9887412190437317 训练次数:6800,loss:1.1251451969146729 训练次数:6900,loss:1.0831143856048584 训练次数:7000,loss:0.8735517263412476 测试集的loss:180.18007707595825,准确率:0.5949000120162964 模型保存成功 ----No.10 training...----- 训练次数:7100,loss:1.1680148839950562 训练次数:7200,loss:0.9758849740028381 训练次数:7300,loss:1.1076891422271729 训练次数:7400,loss:0.8192071914672852 训练次数:7500,loss:1.2766807079315186 训练次数:7600,loss:1.2046217918395996 训练次数:7700,loss:0.8206453323364258 训练次数:7800,loss:1.1484739780426025 测试集的loss:176.2480058670044,准确率:0.6036999821662903 模型保存成功
拿网上下载的几张图片测试一下,注意路径
import torch
import torchvision
from PIL import Image
from torch import nn
# 10分类,分别为airplane'= 0 'automobile'= 1 'bird'= 2'cat'= 3 'deer'= 4 'dog'= 5 'frog'= 6 'horse'= 7 'ship'= 8 'truck'= 9
image_path = "/kaggle/input/testdata/bird.jpg"
image = Image.open(image_path)
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),
torchvision.transforms.ToTensor()])
image = transform(image)
image = torch.reshape(image,(1,3,32,32))
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3,32,5,1,2),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, 5, 1, 2),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, 1, 2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64*4*4,64),
nn.Linear(64,10)
)
def forward(self,x):
x = self.model(x)
return x
model = torch.load("/kaggle/working/myCNN_10p.pth",map_location=torch.device('cpu'))
model.eval()
with torch.no_grad():
output = model(image)
print(output.argmax(1))
tensor([2])