猫狗识别数据集https://download.csdn.net/download/Victor_Li_/88483483?spm=1001.2014.3001.5501
训练集图片路径
测试集图片路径
训练代码如下
import torch
import torchvision
import matplotlib.pyplot as plt
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import time
from torch.optim.lr_scheduler import StepLR
if __name__ == '__main__':
torch.autograd.set_detect_anomaly(True)
mp.freeze_support()
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
print('CUDA is not available. Training on CPU...')
else:
print('CUDA is available! Training on GPU...')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32
# 设置数据预处理的转换
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)), # 调整图像大小为 224x224
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.RandomRotation(45),
torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
torchvision.transforms.ToTensor(), # 转换为张量
torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化
])