数据增强,迁移学习,Resnet分类实战

news2024/9/28 5:24:28

目录

1. 数据增强(Data Augmentation)

2. 迁移学习

3. 模型保存    

4. 102种类花分类实战

1. 数据集

2.导入包

3. 数据读取与预处理操作 

4. Datasets制作输入数据

5.将标签的名字读出 

6.展示原始数据 

7.加载models中提供的模型 

8.初始化 

9.优化器设置 

10.训练模块


1. 数据增强(Data Augmentation)

        数据不够怎么办?采用翻转,镜像,增加数据

        如何更加高效利用数据?多利用几次

        在pytorch中有数据预处理部分:

            数据增强:torchvision中transforms模块自带功能,比较实用

            数据预处理:torchvision中transforms也帮我们实现好了,直接调用即可

            DataLoader模块直接读取batch数据

        pyorch官网:https://pytorch.org/vision/stable

2. 迁移学习

        在训练自己的模型时出现一些问题:

        1. 自己的数据不够好

        2. 训练参数花费时间多

        3. 训练模型太难

        解决方法:

        有前人已经训练好了模型,其实就是将训练的参数保留下来,而且目标都差不多。那么把别人的模型参数当成初始化参数,所有的结构和前人模型一样。

        网络模块设置:

    加载预训练模型,torchvision中有很多经典网络架构,调用起来十分方便,并且也可以用人家训练好的权重参数来继续训练,也就是所谓的迁移学习

    需要注意的是别人训练好的任务根咱们的可不是完全一样的,需要把最后的head层改一改,一般也就是最后的全连接层,改成咱们自己的任务

    训练时可以完全重头训练,也可以只训练最后咱们任务层,因为前几层都是做特征提取的,本质任务目标一致的。

        总结:迁移学习策略

                1. 将卷积层当成初始化权重参数

                2.将卷积层权重参数冻住不变,全连接层重新训练(一般是,数据量少,冻住的层数多)

3. 模型保存    

        网络模型保存与测试

            模型保存的时候可以带有选择性,例如在验证集中如果当前效果好则保存

            读取模型进行实际测试

4. 102种类花分类实战

      1. 数据集

        有训练集,测试集。一共102种花,每种花有25~100个图像

2.导入包

import os
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

3. 数据读取与预处理操作 

data_dir = './flower_data'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'

制作好数据源:

    data_transforms中指定了所有图像预处理操作

    ImageFolder假设所有文件按文件夹保存好,每个文件夹下面存储同一类别的图片,文件夹的名字为分类的名字

data_transforms = {
    'train' : transforms.Compose([transforms.RandomRotation(45),#随即旋转,-45度到45度之间随机选
        transforms.CenterCrop(224), #从中心点开始裁剪
        transforms.RandomHorizontalFlip(p=0.5),#随即水平翻转,选择一个概率
        transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
        transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1), #参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
        transforms.ToTensor(),#转换成tensor格式
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])#均值,标准差
    ]), 
    'valid':transforms.Compose([transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])                   
    ])
}

 4. Datasets制作输入数据

        采用batch,将数据分组输入。

batch_size  = 8

image_datasets = {x : datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x]) for x in ['train','valid']}
dataloaders = {x : torch.utils.data.DataLoader(image_datasets[x],batch_size=batch_size,shuffle=True) for x in ['train','valid']}
dataset_szies = {x :len(image_datasets[x]) for x in ['train','valid']}
class_names = image_datasets['train'].classes
print(image_datasets)
print(dataloaders)

5.将标签的名字读出 

        用123....打标签好像不好,用花的名字作为标签

#读取标签对应的实际名字
with open('./flower_data/cat_to_name.json','r') as f:
    cat_to_name = json.load(f)
print(cat_to_name)

6.展示原始数据 

        展示下数据

            注意tensor的数据需要转换成numpy格式,而且还需要还原成标准化的结果

def im_convert(tensor):
    '''展示数据'''

    image = tensor.to('cpu').clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229,0.224,0.225)) + np.array((0.485,0.456,0.406))
    image = image.clip(0,1)

    return image

fig = plt.figure(figsize=(20,12))
colunms = 4
rows = 2

dataiter = iter(dataloaders['valid'])
inputs,classes =next(dataiter)

for idx in range(colunms * rows):
    ax = fig.add_subplot(rows,colunms,idx+1,xticks = [],yticks = [])
    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    plt.imshow(im_convert(inputs[idx]))
plt.show()

7.加载models中提供的模型 

        加载models中提供的模型,并且直接用训练好的权重当作初始化参数

            第一次执行需要下载,可能会比较慢

model_name = 'resnet' #可选的会比较多['resnet','alexnet','vgg','squeezenet','densenet','inception']
# 是否用人家训练好的特征来做
feature_extract = True

#是否用GPU训练
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available .  Training on CPU...')
else:
    print('CUDA is  available .  Training on GPU...')

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def set_parameter_requires_grad(model,feature_extracting):
    if feature_extract:
        for param in model.parameters():
            param.requires_grad = False #冻不冻住

model_ft = models.resnet152()
print(model_ft)

8.初始化 

        迁移学习,用前人的参数,改变全连接层。

def initalize_model(model_name,num_classes,feature_extract,use_pretrained=True):
    #选择合适的模型,不同模型的初始化方法稍微有点区别
    model_ft = None
    input_size = 0

    if model_name == 'resnet':
        model_ft = models.resnet152(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft,feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Sequential(nn.Linear(num_ftrs,num_classes),
                                    nn.LogSoftmax(dim=1))
        input_size = 224

    return model_ft,input_size

feature_extract = True
model_ft,input_size = initalize_model(model_name,102,feature_extract,use_pretrained=True)

#GPU计算
model_ft = model_ft.to(device)

#模型保存
filename = 'checkpoint.pth'

#是否训练所有层

params_to_updata = model_ft.parameters()
print('Params to learn')
if feature_extract:
    params_to_updata = []
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_updata.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

print(model_ft)

 

9.优化器设置 

#优化器设置
optimizer_ft = optim.Adam(params_to_updata,lr=1e-2)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1) #学习率每7个epoch衰减成原来的1/10
#最后一层已经LogSoftmax()了,所以不能nn.CrossEntropyLoss()来计算了,nn.CrossEntropyLoss()相当于logSoftmax()和nn.NLLoss()整合
criterion = nn.NLLLoss()

10.训练模块

def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False,filename=filename):
    since = time.time()
    best_acc = 0
    '''
    checkpoint = torch.laod(filename)
    best_acc = checkpoint['best_acc]
    model.load_state_dict(checkpoint['optimizer'])
    model.class_to_idx = checkpoint['mapping']
    '''
    model.to(device)

    val_acc_history = []
    train_acc_history = []
    train_losses = []
    vaild_losses = []
    LRs = [optimizer.param_groups[0]['lr']]

    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch,num_epochs-1))
        print('-'*10)

        #训练和验证
        for phase in ['train','valid']:
            if phase == 'train':
                model.train() #训练
            else:
                model.eval()  #验证
        
            running_loss = 0.0
            running_corrects = 0

            #把数据都取个遍
            for inputs,labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                #清零
                optimizer.zero_grad()
                #只有训练的时候计算和更新梯度
                with torch.set_grad_enabled(phase=='train'):
                    if is_inception and phase == 'train':
                        outputs,aux_outputs = model(inputs)
                        loss1 = criterion(outputs,labels)
                        loss2 = criterion(aux_outputs,labels)
                        loss = loss1 + loss2

                    else: #resnet执行的是这里
                        outputs= model(inputs)
                        loss = criterion(outputs,labels)
                
                    _,preds = torch.max(outputs,1)

                    #训练阶段更新权重
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                #计算损失
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            time_elapsed = time.time() - since

            print('Time elapsed {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase,epoch_loss,epoch_acc))

            #得到最好的那次的模型
            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                state ={
                    'state_dict' : model.state_dict(),
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                }
                torch.save(state,filename)
            if phase == 'valid':
                val_acc_history.append(epoch_acc)
                vaild_losses.append(epoch_loss)
                scheduler.step(epoch_loss)
            if phase=='train':
                train_acc_history.append(epoch_acc)
                train_losses.append(epoch_loss)
                
        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
        LRs.append(optimizer.param_groups[0]['lr'])
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60,time_elapsed % 60))
    print('Best val Acc: {:.4f}'.format(best_acc))

    #训练完后用最好的一次当作模型的最终结果
    model.load_state_dict(best_model_wts)
    return model, val_acc_history,train_acc_history,vaild_losses,train_losses,LRs
#开始训练!!!
model_ft,val_acc_history,train_acc_history,vaild_losses,train_losses,LRs = train_model(model_ft,dataloaders,criterion,optimizer_ft,num_epochs=5)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1663288.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

轻松操作!ae导出mp4格式,一篇文章学会

在视频制作的过程中,Adobe After Effects作为一款强大而专业的后期处理工具,为我们提供了丰富的特效和编辑功能。然而,在完成创作后,将项目导出为通用的MP4格式是分享和展示作品的关键一步。在本文中,我们将探讨ae导出…

营销的本质是“利他”,资深运营高手分享9套消费返利玩转市场!

营销的本质是“利他”,资深运营高手分享9套消费返利玩转市场! 文丨微三云营销总监胡佳东,点击上方“关注”,为你分享市场商业模式电商干货。 - 引言:2024年移动互联网基本已经占据了核心不可篡改的地位,而…

Graphormer:Transformer用于图预测任务

文章信息 文章题为“Do Transformers Really Perform Bad for Graph Representation?”,该文章发表于2021年NeurIPS会议上。文章提出Graphormer图预测任务。 摘要 Transformer架构已经成为许多领域的主导选择,例如自然语言处理和计算机视觉。此外…

1015: 堆排序算法

解法&#xff1a; 20240510_193050 最后一个非叶子节点就是最后一个节点的父节点 进行一次最小堆调整&#xff08;如视频&#xff09; #include<iostream> #include<vector> using namespace std; void min_heapfy(vector<int>& a,int sta,int end) {i…

游戏行业被攻击的原因、攻击种类及合适的服务器

很多游戏刚上线没多久就频繁遭到同行恶意攻击。在相关数据报告中&#xff0c;2023年上半年遭受DDoS攻击的行业中&#xff0c;游戏行业占到40%&#xff0c;而且攻击方式、攻击频率、攻击峰值呈明显上升趋势。很多充满创意的游戏开发公司刚才开发上线一个很有特色的产品&#xff…

在k8s中安装Grafana并对接Prometheus,实现k8s集群监控数据的展示

&#x1f407;明明跟你说过&#xff1a;个人主页 &#x1f3c5;个人专栏&#xff1a;《Grafana&#xff1a;让数据说话的魔术师》 &#x1f3c5; &#x1f516;行路有良友&#xff0c;便是天堂&#x1f516; 目录 一、引言 1、Grafana简介 2、Grafana的重要性与影响力 …

需求规格说明书设计规范(编制实际项目案例-word)

二、 项目概述 2.1 项目背景 2.2 现状分析 2.2.1 业务现状 2.2.2 系统现状 三、 总体需求 3.1 系统范围 3.2 系统功能 3.3 用户分析 3.4 假设与依赖关系 四、 功能需求 五、 非功能性需求 5.1 用户界面需求 5.2 软硬件环境需求 5.3 产品质量需求 5.4 接口需求 5.5 其他需求 六、…

C语言(指针)2

Hi~&#xff01;这里是奋斗的小羊&#xff0c;很荣幸各位能阅读我的文章&#xff0c;诚请评论指点&#xff0c;关注收藏&#xff0c;欢迎欢迎~~ &#x1f4a5;个人主页&#xff1a;小羊在奋斗 &#x1f4a5;所属专栏&#xff1a;C语言 本系列文章为个人学习笔记&#x…

在拥有多个同名称密码的ap环境中,如何连接到指定信道或mac的ap路由器?

在给客户做ESP32-C3入墙开关项目时&#xff0c;客户问&#xff1a;在拥有多个同名称密码的ap环境中&#xff0c;如何连接到指定信道或mac的ap路由器&#xff1f;针对这个问题&#xff0c;启明云端工程师给出下面解决方法。 1、将wifi_sta_config_t配置中的channel配置为该信道…

又一位互联网大佬转行当网红,能写进简历么?

最近半个月&#xff0c;有两个中年男人仿佛住进了热搜。 一个是刚刚辟谣自己“卡里没有冰冷的 40 亿”的雷军&#xff0c;另一个则是在今年年初就高呼“如果有可能&#xff0c;企业家都要去当网红”的 360 创始人周鸿祎。 他也确实做到了。 先是作为当年 3Q 大战的当事人&…

企业破产重整:从“至暗时刻”到“涅槃重生”

今天我们不谈星辰大海&#xff0c;而是要潜入商业世界的深海区&#xff0c;探索那些濒临绝境的企业是如何借助“破产重整”的神秘力量&#xff0c;实现惊天大逆转的&#xff01; 一、破产重整&#xff0c;到底是个啥&#xff1f; 想象一下&#xff0c;企业像是一位远航的船长…

Map按value降序并统计

package com.ldj.cloud.user.demo;import java.util.*;/*** User: ldj* Date: 2024/5/11* Time: 10:03* Description: map按value降序*/ public class Tr {public static void main(String[] args) {ArrayList<String> list new ArrayList<>();list.add("a&q…

在家就可以轻松赚零花钱的副业

互联网的兴起让很多人实现了在家办公的梦想&#xff0c;同时也为人们提供了更多的挣钱方式。以下是4种可以在家中兼职副业赚钱的方法&#xff1a; 1. 写作工作 如果你善于写作&#xff0c;并且有一定的文学素养&#xff0c;那么可以通过自己的博客或其他媒体平台来写作&#…

4. 初探MPI——集体通信

系列文章目录 初探MPI——MPI简介初探MPI——&#xff08;阻塞&#xff09;点对点通信初探MPI——&#xff08;非阻塞&#xff09;点对点通信初探MPI——集体通信 文章目录 系列文章目录前言一、集体通信以及同步点二、MPI_Bcast 广播2.1 使用MPI_Send 和 MPI_Recv 来做广播2.…

【一站式学会Kotlin】第四节默认参数和具名参数、unit返回值类型

作者介绍&#xff1a; 百度资深Android工程师T6&#xff0c;在百度任职7年半。 目前&#xff1a;成立赵小灰代码工作室&#xff0c;欢迎大家找我交流Android、微信小程序、鸿蒙项目。文章底部&#xff0c;csdn有为我插入微信的联络方式&#xff0c;欢迎大家联络我。 一&#x…

按键的短按、长按和连续的划分

在实际生活中&#xff0c;我们使用到的按键在短按、长按和按键松开时都会触发不同的功能。按键短按后松开和长按后松开的应用比短按和长按的应用较少&#xff0c;我了解的按键短按后松开和长按后松开的应用是在点动控制和长动控制中。这里主要讨论按键的短按、长按和连续这三种…

用Xinstall实现智能信息的无缝传递

在这个信息化的时代&#xff0c;智能信息的传递显得尤为重要。无论是对于个人还是企业&#xff0c;高效、准确的信息传递都是成功的关键。然而&#xff0c;随着科技的飞速发展&#xff0c;传统的信息传递方式已经无法满足我们的需求。这时&#xff0c;Xinstall应运而生&#xf…

Linux的命令(第二篇)

昨天学习到了第17个命令到 rm 命令&#xff08;作用删除目录和文件&#xff09;&#xff0c;今天继续往下里面了解其他命令以及格式、选项&#xff1a; &#xff08;17&#xff09;wc命令&#xff08;此wc非wc&#xff09; 作用&#xff1a;统计行数、单词数、字符分数。 格…

UEC++ FString做为参数取值时报错error:C4840

问题描述 用来取FString类型的变量时报错&#xff1a; 问题解决 点击错误位置&#xff0c;跳转到代码&#xff1a; void AMyDelegateActor::TwoParamDelegateFunc(int32 param1, FString param2) {UE_LOG(LogTemp, Warning, TEXT("Two Param1:%d Param2:%s"), param…

AutoCAD中密集的填充打散后消失的问题

有时候在AutoCAD中&#xff0c;图案填充的填充面积过大或填充太过密集时&#xff0c;将该填充打散&#xff0c;也就是执行Explode时&#xff0c;会发现填充图案消失了。 原因是打散后线条太大&#xff0c;系统就不显示了。可以通过设置&#xff1a;HPMAXLINES 值&#xff0c;来…