Resnet与Pytorch花图像分类

news2024/12/26 23:05:26

1、介绍

1.1数据集介绍

flower_data
    ├── train
    │   └── 1-102(102个文件夹)
    │   	└── XXX.jpg(每个文件夹含若干张图像)
    ├── valid
    │   └── 1-102(102个文件夹)
    └── ───	└── XXX.jpg(每个文件夹含若干张图像)  
     
cat_to_name.json:每一类花朵的"名称-编号"对应关系

1.2 任务介绍

实现102种花朵的分类任务,即通过训练train数据集后,从valid数据集中选取某一花朵图像,能准确判别其属于哪一类花朵

1.3Resnet介绍

在ResNet网络中有如下两个亮点:

  1. 提出residual结构(残差结构),并搭建超深的网络结构(突破1000层)
  2. 使用Batch Normalization加速训练(丢弃dropout)

在ResNet网络提出之前,传统的卷积神经网络都是通过将一系列卷积层与下采样层进行堆叠得到的。但是当堆叠到一定网络深度时,就会出现两个问题:

  1. 梯度消失或梯度爆炸
  2. 退化问题(degradation problem)

2、数据预处理

2.1引入头文件

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

2.2数据读取

#数据读取与预处理操作
data_dir = './flower_data/'
# 训练集
train_dir = data_dir + '/train'
#验证集
valid_ir = data_dir + '/valid'

2.3制作数据源

#制作数据源
data_transfroms = {
    'train':transforms.Compose([transforms.RandomRotation(45), #随机旋转(-45~45)
    transforms.CenterCrop(224), #从中心开始裁剪
    transforms.RandomHorizontalFlip(p = 0.5), #随机水平翻转
    transforms.RandomVerticalFlip(p = 0.5), #随机垂直翻转
    transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue = 0.1),
    transforms.RandomGrayscale(p = 0.025), #概率转换成灰度率,3通道就是R=G=B
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
    'valid':transforms.Compose([transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
}

2.4batch数据制作

#batch数据制作
batch_size = 8
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x),data_transfroms[x]) for x in ['train','valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],batch_size = batch_size,shuffle = True) for x in ['train','valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train','valid']}
class_names = image_datasets['train'].classes

2.5读取数据标签

#读取标签对应的实际名字
with open('cat_to_name.json','r') as f:
    cat_to_name = json.load(f)

查看cat_to_name.json文件:

{'21': 'fire lily',
 '3': 'canterbury bells',
 '45': 'bolero deep blue',
 '1': 'pink primrose',
 '34': 'mexican aster',
 '27': 'prince of wales feathers',
 '7': 'moon orchid',
 '16': 'globe-flower',
 '25': 'grape hyacinth',
 '26': 'corn poppy',
 '79': 'toad lily',
 '39': 'siam tulip',
 '24': 'red ginger',
 '67': 'spring crocus',
 '35': 'alpine sea holly',
 '32': 'garden phlox',
 '10': 'globe thistle',
 '6': 'tiger lily',
 '93': 'ball moss',
 '33': 'love in the mist',
 '9': 'monkshood',
 '102': 'blackberry lily',
 '14': 'spear thistle',
 '19': 'balloon flower',
 '100': 'blanket flower',
 '13': 'king protea',
 '49': 'oxeye daisy',
 '15': 'yellow iris',
 '61': 'cautleya spicata',
 '31': 'carnation',
 '64': 'silverbush',
 '68': 'bearded iris',
 '63': 'black-eyed susan',
 '69': 'windflower',
 '62': 'japanese anemone',
 '20': 'giant white arum lily',
 '38': 'great masterwort',
 '4': 'sweet pea',
 '86': 'tree mallow',
 '101': 'trumpet creeper',
 '42': 'daffodil',
 '22': 'pincushion flower',
 '2': 'hard-leaved pocket orchid',
 '54': 'sunflower',
 '66': 'osteospermum',
 '70': 'tree poppy',
 '85': 'desert-rose',
 '99': 'bromelia',
 '87': 'magnolia',
 '5': 'english marigold',
 '92': 'bee balm',
 '28': 'stemless gentian',
 '97': 'mallow',
 '57': 'gaura',
 '40': 'lenten rose',
 '47': 'marigold',
 '59': 'orange dahlia',
 '48': 'buttercup',
 '55': 'pelargonium',
 '36': 'ruby-lipped cattleya',
 '91': 'hippeastrum',
 '29': 'artichoke',
 '71': 'gazania',
 '90': 'canna lily',
 '18': 'peruvian lily',
 '98': 'mexican petunia',
 '8': 'bird of paradise',
 '30': 'sweet william',
 '17': 'purple coneflower',
 '52': 'wild pansy',
 '84': 'columbine',
 '12': "colt's foot",
 '11': 'snapdragon',
 '96': 'camellia',
 '23': 'fritillary',
 '50': 'common dandelion',
 '44': 'poinsettia',
 '53': 'primula',
 '72': 'azalea',
 '65': 'californian poppy',
 '80': 'anthurium',
 '76': 'morning glory',
 '37': 'cape flower',
 '56': 'bishop of llandaff',
 '60': 'pink-yellow dahlia',
 '82': 'clematis',
 '58': 'geranium',
 '75': 'thorn apple',
 '41': 'barbeton daisy',
 '95': 'bougainvillea',
 '43': 'sword lily',
 '83': 'hibiscus',
 '78': 'lotus lotus',
 '88': 'cyclamen',
 '94': 'foxglove',
 '81': 'frangipani',
 '74': 'rose',
 '89': 'watercress',
 '73': 'water lily',
 '46': 'wallflower',
 '77': 'passion flower',
 '51': 'petunia'}

3、数据展示

3.1图像处理函数

#展示数据
def im_convert(tensor):
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229,0.224,0.225)) + np.array((0.485,0.456,0.406))
    image = image.clip(0.1)

    return image

3.2展示图像

fig=plt.figure(figsize=(20, 12))
columns = 4
rows = 2

dataiter = iter(dataloaders['valid'])
inputs, classes = dataiter.next()

for idx in range (columns*rows):
    ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    plt.imshow(im_convert(inputs[idx]))
plt.show()

4、进行迁移学习

迁移学习的关键点:

  • 研究可以用哪些知识在不同的领域或者任务中进行迁移学习,即不同领域之间有哪些共有知识可以迁移
  • 研究在找到了迁移对象之后,针对具体问题所采用哪种迁移学习的特定算法,即如何设计出合适的算法来提取和迁移共有知识
  • 研究什么情况下适合迁移,迁移技巧是否适合具体应用,其中涉及到负迁移的问题。

4.1训练全连接层

加载models中提供的模型,并且直接用训练好的权重当做初始化参数 

下载链接:https://download.pytorch.org/models/resnet152-394f9c45.pth

 选择resnet网络

model_name = 'resnet'  #可选的有: ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception']

#是否用官方训练好的特征来做
feature_extract = True 

设置用GPU训练

#是否用GPU来训练
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('cuda is not available. Training on CPU')
else:
    print('cuda is available. Training on GPU')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

屏蔽预训练模型的权重,只训练全连接层的权重: 

def set_parameter_requires_grad(model,feature_extracting):
    if feature_extracting:
        for param in model.parameter():
            param.requires_grad = False

选择resnet152网络

model_ft = models.resnet152()

设置优化器:

#优化器设置
optimizer_ft = optim.Adam(params_to_update,lr = 1e-2)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1) #学习率每7个epoch衰减成原来的1/10
criterion = nn.NLLLoss()

定义训练模块:

# 训练模块
def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False,filename = filename):
    since = time.time()
    best_acc = 0

    model.to(device)
    val_acc_history = []
    train_acc_history = []
    train_losses = []
    valid_losses = []
    LRs = [optimizer.param_groups[0]['lr']]

    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(num_epochs):
        print('Epoch {} / {}'.format(epoch,num_epochs - 1))
        print('-' * 10)

        #训练与验证
        for phase in ['train','valid']:
            if phase == 'train':
                model.train()  #训练
            else:
                model.eval()  #验证

            running_loss = 0.0
            running_corrects = 0

            #把数据取个遍
            for inputs,labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                #清零
                optimizer.zero_grad()

                #只有训练的时候计算与更新梯度
                with torch.set_grad_enabled(phase == 'train'):
                    if is_inception and phase == 'train':
                        outputs,aux_outputs = model(inputs)
                        loss1 = criterion(outputs,labels)
                        loss2 = criterion(aux_outputs,labels)
                        loss = loss1 + 0.4 * loss2
                    else: #resnet执行的是这里
                        outputs = model(inputs)
                        loss = criterion(outputs,labels)
                        _, preds = torch.max(outputs,1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                #计算损失
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            time_elapsed = time.time() - since
            print('Time elapsed {:.0f}m {:.0f}f'.format(time_elapsed // 60,time_elapsed % 60))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase,epoch_loss,epoch_acc))

            #得到最好的模型
            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                state = {
                    'state_dict': model.state_dict(),
                    'best_acc': best_acc,
                    'optimizer':optimizer.state_dict(),
                }
                torch.save(state,filename)
                if phase == 'valid':
                    val_acc_history.append(epoch_acc)
                    valid_losses.append(epoch_loss)
                    scheduler.step(epoch_loss)
                if phase == 'train':
                    train_acc_history.append(epoch_acc)
                    train_losses.append(epoch_loss)

        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
        LRs.append(optimizer.param_groups[0]['lr'])
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed //60,time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    #训练完后用最好的一次当做模型最终的结果
    model.load_state_dict(best_model_wts)
    return model,val_acc_history,train_acc_history.valid_losses,train_losses,LRs

开始训练:

# 开始训练
model_ft,val_acc_history,train_acc_history,valid_lossea,train_losses,LRs = train_model(model_ft,dataloaders,criterion,optimizer_ft,num_epochs=20,is_inception=(model_name == 'inception'))

4.2训练所有层

我们从上次训练好最优的那个全连接层的参数开始,以此为基础训练所有层,设置param.requires_grad = True表明接下来训练全部网络,之后把学习率调小一点,衰减函数为每7次衰减为原来的1/10,损失函数不变

再继续训练所有层
for param in model_ft.parameters():
    param.requires_grad = True

#再继续训练所有的参数,学习率调小一点(lr)
optimizer = optim.Adam(params_to_update,lr = 1e-4)
#衰减函数(每七次衰减为原来的七分之一)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1)

#损失函数
criterion = nn.NLLLoss()

 导入之前的最优结果并开始训练:

#在之前训练得到最好的模型的基础上继续训练
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

model_ft,val_acc_history,train_acc_history,valid_lossea,train_losses,LRs = train_model(model_ft,dataloaders,criterion,optimizer_ft,num_epochs=10,is_inception=(model_name == 'inception'))

5、测试网络效果

5.1测试数据预处理

首先将新训练好的checkpoint.pth重命名为serious.pth,之后加载训练好的模型:

#加载训练好的模型
model_ft,input_size = initialize_model(model_name,102,feature_extract,use_pretrained=True)

#GPU模型
model_ft = model_ft.to(device)
#保存文件的名字
filename = 'serious.pth'
#加载模型
checkpoint = torch.load(filename)
best_acc = checkpoint['beat_acc']
model_ft.load_state_dict(checkpoint['state_dict'])

定义图像处理函数:

def process_image(image_path):
    img = Image.open(image_path)

    #Resize,thumbnail方法只能进行缩小,所以进行判断
    if img.size[0] > img.size[1]:
        img.thumbnail((10000,256))
    else:
        img.thumbnail((256,10000))

    #Crop操作
    left_margin = (img.width-224)/2
    bottom_margin = (img.height-224)/2
    right_margin = (left_margin) + 224
    top_margin = bottom_margin + 224
    img  = img.crop(left_margin,bottom_margin,right_margin,top_margin)

    #相同的预处理方法
    img = np.array(img)/255
    mean = np.array([0.485,0.456,0.406])
    std = np.array([0.229,0.224,0.225])
    img = (img - mean)/std

    #注意颜色通道应该放在第一个位置
    img = img.transpose((2,0,1))

    return img

定义图像展示函数:

#展示数据
def imshow(image,ax = None,title = None):
    if ax is None:
        fig,ax = plt.subplots()

    #颜色通道还原
    image = np.array(image).transpose((1,2,0))

    #预处理还原
    mean = np.array([0.485,0.456,0.406])
    std = np.array([0.229,0.224,0.225])
    image = std * image + mean
    image = np.clip(image,0.1)

    ax.imshow(image)
    ax.set_title(title)

    return ax

展示一个数据:

image_path = 'image_06621.jpg'
img = process_image(image_path)
imshow(img)

 

得到一个batch测试数据:

#测试一个batch数据
dataiter = iter(dataloaders['valid'])
images,labels = dataiter.next()

model_ft.eval()

if train_on_gpu:
    output = model_ft(images.cuda())
else:
    output = model_ft(images)

利用torch.max()函数计算标签值:

#得到属于类别的八个编号
_,preds_tensor = torch.ax(output,1)
preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())

5.2结果可视化

#展示预测结果
fig = plt.figure(figsize=(20,20))
columns = 4
rows = 2

for idx in range(columns * rows):
    ax = fig.add_subplot(rows,columns,idx+1,xticks=[],yticks=[])
    plt.imshow(im_convert(images[idx]))
    ax.set_title("{} {}".format(cat_to_name[str(preds[idx])],cat_to_name[str(labels[idx].item())]),
                 color = ("green" if cat_to_name[str(preds[idx])] == cat_to_name[str(labels[idx].item())] else "red"))
plt.show()

结果如下(绿色标题代表识别成功,红色标题代表识别失败,括号里面为真实值,括号外为预测值)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/817228.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何使用免费敏捷工具Leangoo领歌管理Sprint Backlog

什么是Sprint Backlog? Sprint Backlog是Scrum的主要工件之一。在Scrum中,团队按照迭代的方式工作,每个迭代称为一个Sprint。在Sprint开始之前,PO会准备好产品Backlog,准备好的产品Backlog应该是经过梳理、估算和优先…

ffmpeg安装

简介 FFmpeg是一个开源的音视频处理库,它提供了一系列的工具和API,可以用于处理音视频文件。你可以使用FFmpeg的命令行工具来执行各种音视频处理操作,比如转码、剪辑、合并等。FFmpeg的命令格式通常是:ffmpeg [全局选项] {[输入文…

章节5:SQL注入之WAF绕过

章节5:SQL注入之WAF绕过 5.1 SQL注入之WAF绕过上 WAF拦截原理:WAF从规则库中匹配敏感字符进行拦截。 5.2 SQL注入之WAF绕过下 (原理简单了解) 关键词大小写绕过 有的WAF因为规则设计的问题,只匹配纯大写或纯小写的…

B. Binary Cafe(二进制的妙用)

题目:Problem - B - Codeforces 总结: 对于该题最简单的方法为使用二进制的数表示状态 例如: 对于一个数7的二进制:111 它的每一位都可表示两种状态我们可以理解为取或者不取 对于7这个数字它可以表示一种状态即在三个位置都…

使用Roles模块搭建LNMP架构

使用Roles模块搭建LNMP架构 1.Ansible-playbook中部署Nginx角色2.Ansible-playbook中部署PHP角色3.Ansible-playbook中部署MySQL角色4.启动安装分布式LNMP 1.Ansible-playbook中部署Nginx角色 创建nginx角色所需要的工作目录; mkdir -p /etc/ansible/playbook/rol…

剖析 Kubernetes 控制器:Deployment、ReplicaSet 和 StatefulSet 的功能与应用场景

🌷🍁 博主猫头虎 带您 Go to New World.✨🍁 🦄 博客首页——猫头虎的博客🎐 🐳《面试题大全专栏》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺 &a…

【kubernetes】k8s单master集群环境搭建及kuboard部署

k8s入门学习环境搭建 学习于许大仙: https://www.yuque.com/fairy-era k8s官网 https://kubernetes.io/ kuboard官网 https://kuboard.cn/ 基于k8s 1.21.10版本 前置环境准备 一主两从,三台虚拟机 CPU内存硬盘角色主机名IPhostname操作系统4C16G50Gmasterk8s-mast…

JSON动态生成表格

<!DOCTYPE html> <html><head><meta charset"utf-8"><title></title></head><body><script>var fromjava"{\"total\":3,\"students\":[{\"name\":\"张三\",\&q…

哔哩哔哩缓存转码|FFmpeg将m4s文件转为mp4|PHP自动批量转码B站视频

window下载安装FFmpeg 打开ffMpeg官网选择window>Windows builds from gyan.dev 打开https://www.gyan.dev/ffmpeg/builds/ 这里是上面提取的下载链接如果过期不能用自己去官网下 配置FFmpeg环境变量 上面下载的FFmpeg是绿色软件&#xff0c;下载解压到你的常用软件安装目…

配置IPv6 over IPv4 GRE隧道示例

组网需求 如图1&#xff0c;两个IPv6网络分别通过SwitchA和SwitchC与IPv4公网中的SwitchB连接&#xff0c;客户希望两个IPv6网络中的PC1和PC2实现互通。 其中PC1和PC2上分别指定SwitchA和SwitchC为自己的缺省网关。 图1 配置IPv6 over IPv4 GRE隧道组网图 配置思路 要实现I…

【LeetCode每日一题合集】2023.7.24-2023.7.30(TODO Lazy 线段树)

文章目录 771. 宝石与石头代码1——暴力代码2——位运算集合⭐&#xff08;英文字母的long集合表示&#xff09; 2208. 将数组和减半的最少操作次数&#xff08;贪心 优先队列&#xff09;2569. 更新数组后处理求和查询⭐⭐⭐⭐⭐&#xff08;线段树&#xff09;2500. 删除每行…

这所985很保护一志愿,每年招150+!非常稳定!

一、学校及专业介绍 中国海洋大学&#xff08;Ocean University of China&#xff0c;OUC&#xff09;&#xff0c;位于山东省青岛市&#xff0c;是中华人民共和国教育部直属的综合性全国重点大学&#xff0c;位列国家“双一流”、“985工程”、“211工程”重点建设高校。 1.1…

CHI中的error处理

Error Handling Error types 包含两种sub-packet级别的error, 和两种packe级别的error; Packet level error Data Error, DERR □ 访问的地址是正确的&#xff0c;但是访问的数据有错误&#xff1b;通常是在数据崩溃的时候使用&#xff0c;例如ECC&#xf…

三分钟白话RocketMQ系列—— 核心概念

目录 关键字摘要 Q1&#xff1a;RocketMQ是什么&#xff1f; Q2: 作为消息中间件&#xff0c;RocketMQ和kafka有什么区别&#xff1f; Q3: RocketMQ的基本架构是怎样的&#xff1f; Q4&#xff1a;RocketMQ有哪些核心概念&#xff1f; 总结 RocketMQ是一个开源的分布式消…

iOS--Runloop

Runloop概述 一般来说&#xff0c;一个线程一次只能执行一个任务&#xff0c;执行完成后线程就会退出。就比如之前学OC时使用的命令行程序&#xff0c;执行完程序就结束了。 而runloop目的就是使线程在执行完一次代码之后不会结束程序&#xff0c;而是使该线程处于一种休眠的状…

初步了解c#编程语言--(1)

初识c#编程语言 一、见识c#语言编写的各类应用程序 关于用c#语言编写的各类应用程序有以下几种&#xff1a; 1.Console 在编写Console程序时&#xff0c;要注意创建项目时&#xff0c;是选择控制台应用程序&#xff08;Console Application&#xff09;&#xff0c;在这里…

【计算机视觉】BLIP:源代码示例demo(含源代码)

文章目录 一、Image Captioning二、VQA三、Feature Extraction四、Image-Text Matching 一、Image Captioning 首先配置代码&#xff1a; import sys if google.colab in sys.modules:print(Running in Colab.)!pip3 install transformers4.15.0 timm0.4.12 fairscale0.4.4!g…

linux(进程)[6]

管理概念 先描述&#xff0c;再组织 进程 启动一个软件就相当于启动了一个进程 Linux下执行一条命令就在系统层面创建了一个进程&#xff01;&#xff01; 如何管理 进程对应的代码和数据 进程对应的PCB结构体 PCB&#xff08;process control block&#xff09; 在Linu…

Java反射机制的详细讲解

目录 1.反射机制是什么&#xff1f; 2.反射机制能干什么&#xff1f; 3.反射相关的类 ​编辑 4.Class类(反射机制的起源 ) 5.反射机制相关的API 1.(重要)常用获得类相关的方法 2.常用获得类中属性相关的方法(以下方法返回值为Field相关 3.(了解)获得类中注解相关的方法…

iOS开发-实现3DTouch按压App快捷选项shortcutItems及跳转功能

iOS开发-实现3DTouch按压App快捷选项shortcutItems及跳转功能 App的应用图标通过3D Touch按压App图标&#xff0c;会显示快捷选项&#xff0c;点击选项可快速进入到App的特定页面。 这里用到了UIApplicationShortcutItem与UIMutableApplicationShortcutItem 一、效果图 这里…