基于resnet网络架构训练图像分类模型

news2024/11/16 17:59:30

数据预处理部分:

  • 数据增强:torchvision中transforms模块自带功能,比较实用
  • 数据预处理:torchvision中transforms也帮我们实现好了,直接调用即可
  • DataLoader模块直接读取batch数据

网络模块设置:

  • 加载预训练模型,torchvision中有很多经典网络架构,调用起来十分方便,并且可以用人家训练好的权重参数来继续训练,也就是所谓的迁移学习
  • 需要注意的是别人训练好的任务跟咱们的可不是完全一样,需要把最后的head层改一改,一般也就是最后的全连接层,改成咱们自己的任务
  • 训练时可以全部重头训练,也可以只训练最后咱们任务的层,因为前几层都是做特征提取的,本质任务目标是一致的

网络模型保存与测试

  • 模型保存的时候可以带有选择性,例如在验证集中如果当前效果好则保存
  • 读取模型进行实际测试
  • import os
    import matplotlib.pyplot as plt
    %matplotlib inline
    import numpy as np
    import torch
    from torch import nn
    import torch.optim as optim
    import torchvision
    #pip install torchvision
    from torchvision import transforms, models, datasets
    #https://pytorch.org/docs/stable/torchvision/index.html
    import imageio
    import time
    import warnings
    warnings.filterwarnings("ignore")
    import random
    import sys
    import copy
    import json
    from PIL import Image

    数据读取与预处理操作

    data_dir = './flower_data/'
    train_dir = data_dir + '/train'
    valid_dir = data_dir + '/valid'

    制作好数据源:

  • data_transforms中指定了所有图像预处理操作
  • ImageFolder假设所有的文件按文件夹保存好,每个文件夹下面存贮同一类别的图片,文件夹的名字为分类的名字
    data_transforms = {
        'train': 
            transforms.Compose([
            transforms.Resize([96, 96]),
            transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
            transforms.CenterCrop(64),#从中心开始裁剪
            transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
            transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
            transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
            transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
        ]),
        'valid': 
            transforms.Compose([
            transforms.Resize([64, 64]),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    batch_size = 128
    
    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
    class_names = image_datasets['train'].classes
    image_datasets
    {'train': Dataset ImageFolder
         Number of datapoints: 6552
         Root location: ./flower_data/train
         StandardTransform
     Transform: Compose(
                    Resize(size=[96, 96], interpolation=bilinear, max_size=None, antialias=None)
                    RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)
                    CenterCrop(size=(64, 64))
                    RandomHorizontalFlip(p=0.5)
                    RandomVerticalFlip(p=0.5)
                    ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
                    RandomGrayscale(p=0.025)
                    ToTensor()
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ), 'valid': Dataset ImageFolder
         Number of datapoints: 818
         Root location: ./flower_data/valid
         StandardTransform
     Transform: Compose(
                    Resize(size=[64, 64], interpolation=bilinear, max_size=None, antialias=None)
                    ToTensor()
                    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                )}
    dataloaders
    {'train': <torch.utils.data.dataloader.DataLoader at 0x1e4c50b9400>,
     'valid': <torch.utils.data.dataloader.DataLoader at 0x1e4c51ad128>}
    dataset_sizes
    {'train': 6552, 'valid': 818}
  • 读取标签对应的实际名字

    with open('cat_to_name.json', 'r') as f:
        cat_to_name = json.load(f)
    cat_to_name
    {'1': 'pink primrose',
     '10': 'globe thistle',
     '100': 'blanket flower',
     '101': 'trumpet creeper',
     '102': 'blackberry lily',
     '11': 'snapdragon',
     '12': "colt's foot",
     '13': 'king protea',
     '14': 'spear thistle',
     '15': 'yellow iris',
     '16': 'globe-flower',
     '17': 'purple coneflower',
     '18': 'peruvian lily',
     '19': 'balloon flower',
     '2': 'hard-leaved pocket orchid',
     '20': 'giant white arum lily',
     '21': 'fire lily',
     '22': 'pincushion flower',
     '23': 'fritillary',
     '24': 'red ginger',
     '25': 'grape hyacinth',
     '26': 'corn poppy',
     '27': 'prince of wales feathers',
     '28': 'stemless gentian',
     '29': 'artichoke',
     '3': 'canterbury bells',
     '30': 'sweet william',
     '31': 'carnation',
     '32': 'garden phlox',
     '33': 'love in the mist',
     '34': 'mexican aster',
     '35': 'alpine sea holly',
     '36': 'ruby-lipped cattleya',
     '37': 'cape flower',
     '38': 'great masterwort',
     '39': 'siam tulip',
     '4': 'sweet pea',
     '40': 'lenten rose',
     '41': 'barbeton daisy',
     '42': 'daffodil',
     '43': 'sword lily',
     '44': 'poinsettia',
     '45': 'bolero deep blue',
     '46': 'wallflower',
     '47': 'marigold',
     '48': 'buttercup',
     '49': 'oxeye daisy',
     '5': 'english marigold',
     '50': 'common dandelion',
     '51': 'petunia',
     '52': 'wild pansy',
     '53': 'primula',
     '54': 'sunflower',
     '55': 'pelargonium',
     '56': 'bishop of llandaff',
     '57': 'gaura',
     '58': 'geranium',
     '59': 'orange dahlia',
     '6': 'tiger lily',
     '60': 'pink-yellow dahlia',
     '61': 'cautleya spicata',
     '62': 'japanese anemone',
     '63': 'black-eyed susan',
     '64': 'silverbush',
     '65': 'californian poppy',
     '66': 'osteospermum',
     '67': 'spring crocus',
     '68': 'bearded iris',
     '69': 'windflower',
     '7': 'moon orchid',
     '70': 'tree poppy',
     '71': 'gazania',
     '72': 'azalea',
     '73': 'water lily',
     '74': 'rose',
     '75': 'thorn apple',
     '76': 'morning glory',
     '77': 'passion flower',
     '78': 'lotus lotus',
     '79': 'toad lily',
     '8': 'bird of paradise',
     '80': 'anthurium',
     '81': 'frangipani',
     '82': 'clematis',
     '83': 'hibiscus',
     '84': 'columbine',
     '85': 'desert-rose',
     '86': 'tree mallow',
     '87': 'magnolia',
     '88': 'cyclamen',
     '89': 'watercress',
     '9': 'monkshood',
     '90': 'canna lily',
     '91': 'hippeastrum',
     '92': 'bee balm',
     '93': 'ball moss',
     '94': 'foxglove',
     '95': 'bougainvillea',
     '96': 'camellia',
     '97': 'mallow',
     '98': 'mexican petunia',
     '99': 'bromelia'}
  • 加载models中提供的模型,并且直接用训练的好权重当做初始化参数

  • 第一次执行需要下载,可能会比较慢,我会提供给大家一份下载好的,可以直接放到相应路径
    model_name = 'resnet'  #可选的比较多 ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception']
    #是否用人家训练好的特征来做
    feature_extract = True #都用人家特征,咱先不更新
    # 是否用GPU训练
    train_on_gpu = torch.cuda.is_available()
    
    if not train_on_gpu:
        print('CUDA is not available.  Training on CPU ...')
    else:
        print('CUDA is available!  Training on GPU ...')
        
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    CUDA is not available.  Training on CPU ...
  • 模型参数要不要更新

  • 有时候用人家模型,就一直用了,更不更新咱们可以自己定
    def set_parameter_requires_grad(model, feature_extracting):
        if feature_extracting:
            for param in model.parameters():
                param.requires_grad = False
    model_ft = models.resnet18()#18层的能快点,条件好点的也可以选152
    model_ft
    ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer2): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer3): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
      (fc): Linear(in_features=512, out_features=1000, bias=True)
  • 把模型输出层改成自己的

    def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
        
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, 102)#类别数自己根据自己任务来
                                
        input_size = 64#输入大小根据自己配置来
    
        return model_ft, input_size

    设置哪些层需要训练

    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)
    
    #GPU还是CPU计算
    model_ft = model_ft.to(device)
    
    # 模型保存,名字自己起
    filename='checkpoint.pth'
    
    # 是否训练所有层
    params_to_update = model_ft.parameters()
    print("Params to learn:")
    if feature_extract:
        params_to_update = []
        for name,param in model_ft.named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)
                print("\t",name)
    else:
        for name,param in model_ft.named_parameters():
            if param.requires_grad == True:
                print("\t",name)
    Params to learn:
    	 fc.weight
    	 fc.bias
  • model_ft
    ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer2): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer3): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
      (fc): Linear(in_features=512, out_features=102, bias=True)
    )
  • 优化器设置

    # 优化器设置
    optimizer_ft = optim.Adam(params_to_update, lr=1e-2)#要训练啥参数,你来定
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)#学习率每7个epoch衰减成原来的1/10
    criterion = nn.CrossEntropyLoss()

    训练模块

    def train_model(model, dataloaders, criterion, optimizer, num_epochs=25,filename='best.pt'):
        #咱们要算时间的
        since = time.time()
        #也要记录最好的那一次
        best_acc = 0
        #模型也得放到你的CPU或者GPU
        model.to(device)
        #训练过程中打印一堆损失和指标
        val_acc_history = []
        train_acc_history = []
        train_losses = []
        valid_losses = []
        #学习率
        LRs = [optimizer.param_groups[0]['lr']]
        #最好的那次模型,后续会变的,先初始化
        best_model_wts = copy.deepcopy(model.state_dict())
        #一个个epoch来遍历
        for epoch in range(num_epochs):
            print('Epoch {}/{}'.format(epoch, num_epochs - 1))
            print('-' * 10)
    
            # 训练和验证
            for phase in ['train', 'valid']:
                if phase == 'train':
                    model.train()  # 训练
                else:
                    model.eval()   # 验证
    
                running_loss = 0.0
                running_corrects = 0
    
                # 把数据都取个遍
                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)#放到你的CPU或GPU
                    labels = labels.to(device)
    
                    # 清零
                    optimizer.zero_grad()
                    # 只有训练的时候计算和更新梯度
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)
                    # 训练阶段更新权重
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
    
                    # 计算损失
                    running_loss += loss.item() * inputs.size(0)#0表示batch那个维度
                    running_corrects += torch.sum(preds == labels.data)#预测结果最大的和真实值是否一致
                    
                
                
                epoch_loss = running_loss / len(dataloaders[phase].dataset)#算平均
                epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
                
                time_elapsed = time.time() - since#一个epoch我浪费了多少时间
                print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
                
    
                # 得到最好那次的模型
                if phase == 'valid' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    state = {
                      'state_dict': model.state_dict(),#字典里key就是各层的名字,值就是训练好的权重
                      'best_acc': best_acc,
                      'optimizer' : optimizer.state_dict(),
                    }
                    torch.save(state, filename)
                if phase == 'valid':
                    val_acc_history.append(epoch_acc)
                    valid_losses.append(epoch_loss)
                    #scheduler.step(epoch_loss)#学习率衰减
                if phase == 'train':
                    train_acc_history.append(epoch_acc)
                    train_losses.append(epoch_loss)
            
            print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
            LRs.append(optimizer.param_groups[0]['lr'])
            print()
            scheduler.step()#学习率衰减
    
        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        print('Best val Acc: {:4f}'.format(best_acc))
    
        # 训练完后用最好的一次当做模型最终的结果,等着一会测试
        model.load_state_dict(best_model_wts)
        return model, val_acc_history, train_acc_history, valid_losses, train_losses, LRs 

    开始训练!

  • 我们现在只训练了输出层
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer_ft, num_epochs=20)
    Epoch 0/19
    ----------
    Time elapsed 0m 39s
    train Loss: 4.0874 Acc: 0.2355
    Time elapsed 0m 43s
    valid Loss: 3.5746 Acc: 0.2531
    Optimizer learning rate : 0.0100000
    
    Epoch 1/19
    ----------
    Time elapsed 1m 22s
    train Loss: 2.8185 Acc: 0.3953
    Time elapsed 1m 26s
    valid Loss: 3.5450 Acc: 0.3142
    Optimizer learning rate : 0.0100000
    
    Epoch 2/19
    ----------
    Time elapsed 2m 5s
    train Loss: 2.7673 Acc: 0.4174
    Time elapsed 2m 9s
    valid Loss: 3.9110 Acc: 0.2653
    Optimizer learning rate : 0.0100000
    
    Epoch 3/19
    ----------
    Time elapsed 2m 48s
    train Loss: 2.7962 Acc: 0.4255
    Time elapsed 2m 52s
    valid Loss: 3.6922 Acc: 0.3142
    Optimizer learning rate : 0.0100000
    
    Epoch 4/19
    ----------
    Time elapsed 3m 32s
    train Loss: 2.7453 Acc: 0.4428
    Time elapsed 3m 36s
    valid Loss: 3.9310 Acc: 0.3044
    Optimizer learning rate : 0.0100000
    
    Epoch 5/19
    ----------
    Time elapsed 4m 14s
    train Loss: 2.2935 Acc: 0.5043
    Time elapsed 4m 18s
    valid Loss: 3.3299 Acc: 0.3435
    Optimizer learning rate : 0.0010000
    
    Epoch 6/19
    ----------
    Time elapsed 4m 57s
    train Loss: 2.0654 Acc: 0.5258
    Time elapsed 5m 1s
    valid Loss: 3.2608 Acc: 0.3411
    Optimizer learning rate : 0.0010000
    
    Epoch 7/19
    ----------
    Time elapsed 5m 40s
    train Loss: 1.9603 Acc: 0.5369
    Time elapsed 5m 44s
    valid Loss: 3.2618 Acc: 0.3472
    Optimizer learning rate : 0.0010000
    
    Epoch 8/19
    ----------
    Time elapsed 6m 23s
    train Loss: 1.9216 Acc: 0.5401
    Time elapsed 6m 27s
    valid Loss: 3.1651 Acc: 0.3386
    Optimizer learning rate : 0.0010000
    
    Epoch 9/19
    ----------
    Time elapsed 7m 5s
    train Loss: 1.9203 Acc: 0.5458
    Time elapsed 7m 9s
    valid Loss: 3.0449 Acc: 0.3680
    Optimizer learning rate : 0.0010000
    
    Epoch 10/19
    ----------
    Time elapsed 7m 48s
    train Loss: 1.8366 Acc: 0.5553
    Time elapsed 7m 52s
    valid Loss: 3.0722 Acc: 0.3545
    Optimizer learning rate : 0.0001000
    
    Epoch 11/19
    ----------
    Time elapsed 8m 31s
    train Loss: 1.8324 Acc: 0.5546
    Time elapsed 8m 35s
    valid Loss: 3.0115 Acc: 0.3643
    Optimizer learning rate : 0.0001000
    
    Epoch 12/19
    ----------
    Time elapsed 9m 13s
    train Loss: 1.8054 Acc: 0.5553
    Time elapsed 9m 17s
    valid Loss: 3.0688 Acc: 0.3619
    Optimizer learning rate : 0.0001000
    
    Epoch 13/19
    ----------
    Time elapsed 9m 56s
    train Loss: 1.8436 Acc: 0.5534
    Time elapsed 10m 0s
    valid Loss: 3.0100 Acc: 0.3631
    Optimizer learning rate : 0.0001000
    
    Epoch 14/19
    ----------
    Time elapsed 10m 39s
    train Loss: 1.7417 Acc: 0.5614
    Time elapsed 10m 43s
    valid Loss: 3.0129 Acc: 0.3655
    Optimizer learning rate : 0.0001000
    
    Epoch 15/19
    ----------
    Time elapsed 11m 22s
    train Loss: 1.7610 Acc: 0.5672
    Time elapsed 11m 26s
    valid Loss: 3.0220 Acc: 0.3606
    Optimizer learning rate : 0.0000100
    
    Epoch 16/19
    ----------
    Time elapsed 12m 6s
    train Loss: 1.7788 Acc: 0.5676
    Time elapsed 12m 10s
    valid Loss: 3.0104 Acc: 0.3557
    Optimizer learning rate : 0.0000100
    
    Epoch 17/19
    ----------
    Time elapsed 12m 49s
    train Loss: 1.8033 Acc: 0.5638
    Time elapsed 12m 53s
    valid Loss: 3.0428 Acc: 0.3606
    Optimizer learning rate : 0.0000100
    
    Epoch 18/19
    ----------
    Time elapsed 13m 33s
    train Loss: 1.8294 Acc: 0.5568
    Time elapsed 13m 37s
    valid Loss: 3.0307 Acc: 0.3509
    Optimizer learning rate : 0.0000100
    
    Epoch 19/19
    ----------
    Time elapsed 14m 16s
    train Loss: 1.7949 Acc: 0.5612
    Time elapsed 14m 20s
    valid Loss: 3.0396 Acc: 0.3643
    Optimizer learning rate : 0.0000100
    
    Training complete in 14m 20s
    Best val Acc: 0.367971
  • 再继续训练所有层

    for param in model_ft.parameters():
        param.requires_grad = True
    
    # 再继续训练所有的参数,学习率调小一点
    optimizer = optim.Adam(model_ft.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    
    # 损失函数
    criterion = nn.CrossEntropyLoss()
    # 加载之前训练好的权重参数
    
    checkpoint = torch.load(filename)
    best_acc = checkpoint['best_acc']
    model_ft.load_state_dict(checkpoint['state_dict'])
    model_ft, val_acc_history, train_acc_history, valid_losses, train_losses, LRs  = train_model(model_ft, dataloaders, criterion, optimizer, num_epochs=10,)
    Epoch 0/9
    ----------
    Time elapsed 1m 32s
    train Loss: 2.2451 Acc: 0.4846
    Time elapsed 1m 36s
    valid Loss: 2.3190 Acc: 0.4633
    Optimizer learning rate : 0.0010000
    
    Epoch 1/9
    ----------
    Time elapsed 2m 54s
    train Loss: 1.2920 Acc: 0.6505
    Time elapsed 2m 58s
    valid Loss: 2.2263 Acc: 0.4670
    Optimizer learning rate : 0.0010000
    
    Epoch 2/9
    ----------
    Time elapsed 4m 15s
    train Loss: 1.1026 Acc: 0.6993
    Time elapsed 4m 19s
    valid Loss: 1.8115 Acc: 0.5452
    Optimizer learning rate : 0.0010000
    
    Epoch 3/9
    ----------
    Time elapsed 5m 35s
    train Loss: 0.9062 Acc: 0.7515
    Time elapsed 5m 39s
    valid Loss: 2.0045 Acc: 0.5403
    Optimizer learning rate : 0.0010000
    
    Epoch 4/9
    ----------
    Time elapsed 6m 56s
    train Loss: 0.8392 Acc: 0.7643
    Time elapsed 7m 0s
    valid Loss: 2.1381 Acc: 0.5171
    Optimizer learning rate : 0.0010000
    
    Epoch 5/9
    ----------
    Time elapsed 8m 17s
    train Loss: 0.7081 Acc: 0.7953
    Time elapsed 8m 21s
    valid Loss: 2.0461 Acc: 0.5599
    Optimizer learning rate : 0.0010000
    
    Epoch 6/9
    ----------
    Time elapsed 9m 38s
    train Loss: 0.6400 Acc: 0.8147
    Time elapsed 9m 42s
    valid Loss: 2.2603 Acc: 0.5452
    Optimizer learning rate : 0.0010000
    
    Epoch 7/9
    ----------
    Time elapsed 10m 59s
    train Loss: 0.6406 Acc: 0.8117
    Time elapsed 11m 3s
    valid Loss: 1.4649 Acc: 0.6406
    Optimizer learning rate : 0.0010000
    
    Epoch 8/9
    ----------
    Time elapsed 12m 20s
    train Loss: 0.5686 Acc: 0.8300
    Time elapsed 12m 24s
    valid Loss: 1.7538 Acc: 0.6100
    Optimizer learning rate : 0.0010000
    
    Epoch 9/9
    ----------
    Time elapsed 13m 41s
    train Loss: 0.5978 Acc: 0.8245
    Time elapsed 13m 45s
    valid Loss: 1.6953 Acc: 0.6161
    Optimizer learning rate : 0.0010000
    
    Training complete in 13m 45s
    Best val Acc: 0.640587
  • 加载训练好的模型

    model_ft, input_size = initialize_model(model_name, 102, feature_extract, use_pretrained=True)
    
    # GPU模式
    model_ft = model_ft.to(device)
    
    # 保存文件的名字
    filename='best.pt'
    
    # 加载模型
    checkpoint = torch.load(filename)
    best_acc = checkpoint['best_acc']
    model_ft.load_state_dict(checkpoint['state_dict'])

    测试数据预处理

  • 测试数据处理方法需要跟训练时一直才可以
  • crop操作的目的是保证输入的大小是一致的
  • 标准化操作也是必须的,用跟训练数据相同的mean和std,但是需要注意一点训练数据是在0-1上进行标准化,所以测试数据也需要先归一化
  • 最后一点,PyTorch中颜色通道是第一个维度,跟很多工具包都不一样,需要转换
  • # 得到一个batch的测试数据
    dataiter = iter(dataloaders['valid'])
    images, labels = dataiter.next()
    
    model_ft.eval()
    
    if train_on_gpu:
        output = model_ft(images.cuda())
    else:
        output = model_ft(images)

    output表示对一个batch中每一个数据得到其属于各个类别的可能性

    output.shape

    得到概率最大的那个

    _, preds_tensor = torch.max(output, 1)
    
    preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())
    preds
    array([ 34,  49,  43,  54,  20,  14,  49,  43,  50,  20,  19, 100,  78,
            96,  96,  62,  62,  63,  32,  38,  82,  43,  88,  73,   6,  51,
            43,  89,  55,  75,  55,  11,  46,  82,  48,  82,  20, 100,  48,
            20,  24,  49,  76,  93,  49,  46,  90,  75,  89,  75,  76,  99,
            56,  48,  77,  66,  60,  72,  89,  97,  76,  73,  17,  48,  39,
            31,  19,  74,  61,  46,  93,  80,  27,  11,  91,  18,  23,  47,
            29,  54,  18,  93,   1,  50,  79,  96,  39,  53,  63,  60,  49,
            23,  23,  52,  99,  89,   3,  50,  64,  15,  19,  60,  19,  75,
            50,  78,  82,  18,  75,  18,  82,  53,   3,  52,  60,  38,  62,
            47,  21,  59,  81,  48,  89,  64,  60,  55, 100,  60], dtype=int64)

    展示预测结果

    def im_convert(tensor):
        """ 展示数据"""
        
        image = tensor.to("cpu").clone().detach()
        image = image.numpy().squeeze()
        image = image.transpose(1,2,0)
        image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
        image = image.clip(0, 1)
    
        return image
    fig=plt.figure(figsize=(20, 20))
    columns =4
    rows = 2
    
    for idx in range (columns*rows):
        ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
        plt.imshow(im_convert(images[idx]))
        ax.set_title("{} ({})".format(cat_to_name[str(preds[idx])], cat_to_name[str(labels[idx].item())]),
                     color=("green" if cat_to_name[str(preds[idx])]==cat_to_name[str(labels[idx].item())] else "red"))
    plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/985218.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ARM DIY(九)陀螺仪调试

前言 今天调试六轴陀螺仪 MPU6050 硬件 硬件很简单&#xff0c;使用 I2C 接口&#xff0c;并且没有使用中断引脚。 焊接上 MPU6050 芯片和上拉电阻、滤波电容。 检测 MPU6050 是挂在 I2C-0 上的&#xff0c;I2C-0 控制器的驱动已 OK&#xff0c;所以直接使用 I2C-0 检测 …

2023 年前端编程 NodeJs 包管理工具 npm 安装和使用详细介绍

npm 基本概述 npm is the world’s largest software registry. Open source developers from every continent use npm to share and borrow packages, and many organizations use npm to manage private development as well. npm 官方网站&#xff1a;https://www.npmjs.…

程序员做哪些副业可以日赚一百?

日赚一百&#xff1f;对程序员来说简直不要太容易&#xff01;下面给程序员们推荐一些日赚100的副业&#xff1a; ①外包接单 程序员简单粗暴赚钱的副业之一。 外包接单的类型包括但不限于&#xff1a;软件开发、硬件开发、小程序功能开发、web开发……大到一个系统的开发、…

外包干了三年,我承认我确实废了……

没错&#xff0c;我也干过外包&#xff0c;一干就是三年&#xff0c;三年后&#xff0c;我废了…… 虽说废的不是很彻底&#xff0c;但那三年我几乎是出差了三年、玩了三年、荒废了三年&#xff0c;那三年&#xff0c;我的技术能力几乎是零成长的。 说起这段三年的外包经历&a…

对象临时中间状态的条件竞争覆盖

Portswigger练兵场之条件竞争 &#x1f984;条件竞争之对象临时中间状态的条件竞争 Lab: Partial construction race conditions&#x1f680;实验前置必要知识点 某些框架尝试通过使用某种形式的请求锁定来防止意外的数据损坏。例如&#xff0c;PHP 的本机会话处理程序模块…

任务管理系统所需功能概述

"任务管理需要有哪些功能&#xff1f;清晰的任务创建与编辑、智能分类和标签系统、提醒与通知功能、进度跟踪与报告、协作与共享功能、集成与兼容性。" 一款优秀的任务管理工具可以帮助我们有效地规划、执行和监控各项任务&#xff0c;提高工作效率。本文将探讨一款理…

深度学习(十一)---zed 调用yolov5 进行识别目标并实时测距

1. 前言 zed 相机测距有2种方式&#xff1a;一种是根据点云数据进行测试&#xff0c;二是根据zed获取深度值进行测距。上篇文章 调用yolov5模型进行实时图像推理及网页端部署 我们讲述了zed调用yolov5进行目标识别&#xff0c;我们在此基础上进一步实现目标测距功能。 2.深度…

Apache httpd漏洞复现

文章目录 未知后缀名解析漏洞多后缀名解析漏洞启动环境漏洞复现 换行解析漏洞启动环境漏洞复现 未知后缀名解析漏洞 该漏洞与Apache、php版本无关&#xff0c;属于用户配置不当造成的解析漏洞。在有多个后缀的情况下&#xff0c;只要一个文件含有.php后缀的文件即将被识别成PHP…

微信多开bat代码

目录标题 创建txt文件右键 点击新建文本文档 复制如下代码进去&#xff0c;将里面的微信地址改成自己的微信地址&#xff08;查看地址方法&#xff1a;右击微信图标->属性->目标&#xff09;复制如下代码 创建txt文件 右键 点击新建文本文档 复制如下代码进去&#xff0…

Linux系统——MySQL安装(CentOS7 超详细演示)

Linux系统安装MySQL MySQL8.0.26-Linux版安装1. 准备一台Linux服务器2. 下载Linux版MySQL安装包3. 上传MySQL安装包4. 创建目录,并解压5. 安装mysql的安装包6. 启动MySQL服务7. 查询自动生成的root用户密码8. 修改root用户密码9. 创建用户10. 并给root用户分配权限11. 通过Data…

2023高教社杯 国赛数学建模E题思路 - 黄河水沙监测数据分析

1 赛题 E 题 黄河水沙监测数据分析 黄河是中华民族的母亲河。研究黄河水沙通量的变化规律对沿黄流域的环境治理、气候变 化和人民生活的影响&#xff0c; 以及对优化黄河流域水资源分配、协调人地关系、调水调沙、防洪减灾 等方面都具有重要的理论指导意义。 附件 1 给出了位…

willchange 优化性能的原理是什么

写在前面 今天说一下性能优化部分的其中一个点&#xff0c;这个点叫做 willchange&#xff0c;说他的原因主要有以下几个&#xff1a;第一很多人知道用这个可以提高性能但是不知道原因是什么&#xff0c;第二&#xff0c;我们用的时候他虽然可以提高性能&#xff0c;但是不代表…

Revit SDK 介绍:GenericModelCreation常规模型的创建

前言 这个例子介绍了如何创建拉伸、放样、扫掠、融合、放样融合&#xff0c;涵盖了一个建模软件需要的基本建模方法。 内容 CreateExtrusion 生成的放样融合接口&#xff1a; m_creationFamily.NewExtrusion(true, curve, sketchPlane, bottomProfile, topProfile)核心逻辑&…

msvcr100.dll丢失应该怎么解决,比较靠谱的五种解决方法分享

首先&#xff0c;我们需要了解什么是msvcr100.dll。msvcr100.dll是Microsoft Visual C 2010 Redistributable Package的一部分&#xff0c;它包含了一些运行时库&#xff0c;用于支持某些程序的运行。简单来说&#xff0c;msvcr100.dll就是一个“桥梁”&#xff0c;连接了我们的…

组织架构图怎么做?手把手教你绘制完美的组织架构图

组织架构图是企业和组织中不可或缺的工具&#xff0c;它以图形化的方式展示了各级层次的关系&#xff0c;帮助人们更好地理解和管理整体结构。本文将为你详细介绍绘制组织架构图的要素、步骤和技巧&#xff0c;并推荐四款强大的绘图软件&#xff0c;让你轻松绘制完美的组织架构…

无涯教程-JavaScript - DELTA函数

描述 DELTA函数测试两个值是否相等。如果number1 number2,则返回1&#xff1b;否则返回1。否则返回0。 您可以使用此功能来过滤一组值。如,通过合计几个DELTA函数,您可以计算相等对的计数。此功能也称为Kronecker Delta功能。 语法 DELTA (number1, [number2])争论 Argum…

RK3588算法盒子maskrom模式下系统烧录

先在firefly官网下载bulidroot文件&#xff0c;然后进行烧录相关文件。 点击upgrade&#xff0c;进行烧录升级&#xff1b; 等待烧录完成后&#xff0c; 重新开机进入loader模式下&#xff0c; 进行烧录Untunt20.04系统的操作就行。

企业电脑文件加密系统 / 防泄密软件——「天锐绿盾」

「天锐绿盾」是一种公司文件加密系统&#xff0c;旨在保护公司内网数据安全&#xff0c;防止信息泄露。该系统由硬件和软件组成&#xff0c;其中包括服务端程序、控制台程序和终端程序。 PC访问地址&#xff1a; isite.baidu.com/site/wjz012xr/2eae091d-1b97-4276-90bc-6757c…

python趣味编程-太空入侵者游戏

在 Python 中使用 Turtle 的简单太空入侵者游戏免费源代码 使用 Turtle 的简单太空入侵者游戏是一个用Python编程语言编码的桌面游戏应用程序。该项目包含克隆实际太空入侵者游戏的多种功能。该项目可以使正在学习计算机相关课程的学生受益。该应用程序易于学习,可以帮助您发现…

软件设计师(八)算法设计与分析

算法被公认为是计算机科学的基石&#xff0c;算法理论研究的是算法的设计技术和分析技术。 一、算法设计和分析的基本概念 1、算法 &#xff08;Algorithm&#xff09; 算法&#xff1a;是对特定问题求解步骤的一种描述&#xff0c;它是指令的有限序列&#xff0c;其中每一条…