图像分类:Pytorch图像分类之--AlexNet模型

news2024/11/19 13:23:16

文章目录

        • 前言
        • 数据的处理
          • 数据集的下载
          • 数据集的划分
        • AlexNet介绍
        • 程序的实现
          • model.py
            • Dropout()函数
          • train.py
            • 数据预处理
            • 导入数据集

前言

搭建AlexNet来进行分类模型的训练,大致训练流程和图像分类:Pytorch图像分类之–LetNet模型差不多,两者最大的不同就是,读取训练数据的方式不同,前者读取是通过torchvision.datasets.CIFAR10 来导入数据;但是我们这篇博文中需要用到的数据在torchvision.datasets中没有包含,因此需要用到datasets.ImageFolder()来导入数据。

数据的处理

在本文中我们应用花分类数据集

数据集的下载
  • 数据集下载链接:
    http://download.tensorflow.org/example_images/flower_photos.tgz
    该数据集中包含 5 中类型的花,每种类型有600~900张图像不等。
  • 数据集下载完成后,解压后结构如下:
    在这里插入图片描述
数据集的划分

  下载的数据集需要我们自己通过程序进行划分训练集和测试集;划分数据集的脚本为:split_data.py其代码实现如下:

import os
from shutil import copy
import random

def mkfile(file):
    if not os.path.exists(file):
        os.makedirs(file)

#获取flower_photos文件夹下除.txt文件以外所有文件夹名(即5种花的类名)
file_path = '../data/flower_photos'
flower_class = [cla for cla in os.listdir(file_path) if ".txt" not in cla]

#创建训练集train文件夹,并由5种类名在其目录下创建5个子目录
mkfile('../data/train')
for cla in flower_class:
    mkfile('../data/train/' + cla)

#创建测试集test,并由5种类名在其目录下创建5个子目录
mkfile('../data/test')
for cla in flower_class:
    mkfile('../data/test/' + cla)

#划分训练集和测试集的比例 训练集:测试集=9:1
split_rate = 0.1

#遍历5种花的全部图像并按比例分成训练集和测试集
for cla in flower_class:
    cla_path = file_path + '/' + cla + '/'   #某一类别花的子目录
    images = os.listdir(cla_path)      #images 列表存储了该目录下所有图像的名称
    print(images)
    images_num = len(images)
    test_index = random.sample(images, k=int(images_num*split_rate))   #从image列表中随机抽取K个图像名称
    for index, image in enumerate(images):
        # test_index中保存测试集的图片名称
        if image in test_index:
            image_path = cla_path + image
            new_path = '../data/test/' + cla
            copy(image_path, new_path)
        #其余的图像保存到训练集
        else:
            image_path = cla_path + image
            new_path = '../data/train/' + cla
            copy(image_path, new_path)
        print("\r[{}] processing [{}/{}]".format(cla, index + 1, images_num), end="")  # processing bar
        print()

print("processing done!")

划分后如下图所示:
在这里插入图片描述

AlexNet介绍

  • AlexNet和LetNet没有明显结构上的不同,都是卷积和池化的堆叠,只不过,AlexNet深度要比LetNet深。AlexNet有五个卷积层、三个池化层和三个全连接层。

具体结构如下图所示:
在这里插入图片描述

  • AlexNet 是2012年 ISLVRC ( ImageNet Large Scale Visual Recognition Challenge)竞赛的冠军网络,分类准确率由传统的70%+提升到80%+。它是由Hinton和他的学生Alex Krizhevsky设计的。 也是在那年之后,深度学习开始迅速发展。
    在这里插入图片描述

程序的实现

model.py
import torch
import torch.nn as nn
import torch.functional as F
from torchinfo import summary



class AlexNet(nn.Module):             #继承nn.Module这个父类
    def __init__(self):
        super(AlexNet, self).__init__()
        # Conv2d(in_channels, out_channels, kernel_size, stride, padding, ...)
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=48, kernel_size=11, stride=5, padding=2)         #input[3, 224, 224]  output[48, 55, 55]
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2)                                                  #output[48, 27, 27]
        self.conv2 = nn.Conv2d(in_channels=48, out_channels=128, kernel_size=5, padding=2)                  #output[128, 27, 27]
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2)                                                  #output[128, 13, 13]
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=192, kernel_size=3, padding=1)                 #output[192, 13, 13]
        self.conv4 = nn.Conv2d(in_channels=192, out_channels=192, kernel_size=3, padding=1)                 #output[192, 13, 13]
        self.conv5 = nn.Conv2d(in_channels=192, out_channels=128, kernel_size=3,padding= 1)                 #output[128, 13, 13]
        self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2)                                                  #output[128, 6, 6]

        self.fc1 = nn.Linear(128 * 6 * 6, 2048)
        self.fc2 = nn.Linear(2048, 2048)
        self.fc3 = nn.Linear(2048, 1000)

    def forward(self, x):
        x = F.relu(self.conv1(x))    #input[3, 224, 224]  output[48, 55, 55]
        x = self.pool1(x)            #input[48, 55, 55]  output[48, 27, 27]
        x = F.relu(self.conv2(x))    #input[48, 27, 27]  output[128, 27, 27]
        x = self.pool2(x)            #input[128, 27, 27]  output[128, 13, 13]
        x = F.relu(self.conv3(x))    #input[128, 13, 13]  output[192, 13, 13]
        x = F.relu(self.conv4(x))    #input[192, 13, 13]  output[192, 13, 13]
        x = F.relu(self.conv5(x))    #input[192, 13, 13]  output[128, 13, 13]
        x = self.pool3(x)            #input[128, 13, 13]  output[128, 6, 6]
        x = F.relu(self.fc1(x))      #input[128, 6, 6]  output(2048)
        x = nn.Dropout(0.5)
        x = F.relu(self.fc2(x))      #input(2048)  output(2048)
        x = nn.Dropout(0.5)
        x = self.fc3(x)              #input(2048)  output(1000)

        return x


alexnet = AlexNet()
summary(alexnet)

说明:原论文中用了两块GPU进行训练,每块GPU上对应参数第一层卷积输出通道为48,以上程序中参数是实现胡一块GPU的。

Dropout()函数
nn.Dropout(0.5)     #该函数的作用是参与训练的参数会随机失活50%

在这里插入图片描述

train.py
数据预处理
transform = {
    "train": transforms.Compose([transforms.Resize((224, 224)),   #随机裁剪,再缩放成224*224
                                 transforms.RandomHorizontalFlip(p=0.5), #水平方向随机翻转,概率为0.5
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]),
    "test": transforms.Compose([transforms.Resize((224, 224)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])}
导入数据集
#导入加载数据集
#获取图像数据集的路径
data_root = os.getcwd()  #获取当前路径
image_path = data_root + "/data/"

#导入训练集并进行处理
train_dataset = datasets.ImageFolder(root=image_path + "/train",
                                     transform=transform["train"])
#加载训练集
train_loader = torch.utils.data.DataLoader(train_dataset,  #导入的训练集
                                           batch_size=32,  #每批训练样本个数
                                           shuffle=True,   #打乱训练集
                                           num_workers=0)  #使用线程数

测试集数据导入同理

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
import os
from model import AlexNet
from train_tool import TrainTool


#使用GPU训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
#数据预处理
transform = {
    "train": transforms.Compose([transforms.Resize((224, 224)),   #随机裁剪,再缩放成224*224
                                 transforms.RandomHorizontalFlip(p=0.5), #水平方向随机翻转,概率为0.5
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]),
    "test": transforms.Compose([transforms.Resize((224, 224)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])}
#导入加载数据集
#获取图像数据集的路径
# data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
data_root = os.getcwd()
image_path = data_root + "/data/"

#导入训练集并进行处理
train_dataset = datasets.ImageFolder(root=image_path + "/train",
                                     transform=transform["train"])
train_num = len(train_dataset)
#加载训练集
train_loader = torch.utils.data.DataLoader(train_dataset,  #导入的训练集
                                           batch_size=32,  #每批训练样本个数
                                           shuffle=True,   #打乱训练集
                                           num_workers=0)  #使用线程数

#导入测试集并进行处理
test_dataset = datasets.ImageFolder(root=image_path + "/test",
                                     transform=transform["test"])
test_num = len(test_dataset)
#加载测试集
test_loader = torch.utils.data.DataLoader(test_dataset,  #导入的测试集
                                           batch_size=32,  #每批测试样本个数
                                           shuffle=True,   #打乱测试集
                                           num_workers=0)  #使用线程数

#定义超参数
alexnet = AlexNet(num_classes=5).to(device)   #定义网络模型
loss_function = nn.CrossEntropyLoss()  #定义损失函数为交叉熵
optimizer = optim.Adam(alexnet.parameters(), lr=0.0002)  #定义优化器定义参数学习率


#正式训练
train_acc = []
train_loss = []
test_acc = []
test_loss = []

epoch = 0

#for epoch in range(epochs):
while True:
    epoch = epoch + 1;
    alexnet.train()

    epoch_train_acc, epoch_train_loss = TrainTool.train(train_loader, alexnet, optimizer, loss_function, device)

    alexnet.eval()
    epoch_test_acc, epoch_test_loss = TrainTool.test(test_loader,alexnet, loss_function,device)

    # train_acc.append(epoch_train_loss)
    # train_loss.append(epoch_train_loss)
    # test_acc.append(epoch_test_acc)
    # test_loss.append(epoch_test_loss)
    if epoch_train_acc < 0.90:
       template = ('Epoch:{:2d}, train_acc:{:.1f}%, train_loss:{:.2f}, test_acc:{:.1f}%, test_loss:{:.2f}')
       print(template.format(epoch, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss))
       continue
    else:
       torch.save(alexnet.state_dict(),'./model/alexnet_params.pth')
       print('Done')
       break

train_tool.py 和predict.py和文章图像分类:Pytorch图像分类之–LetNet模型中一样,可以参考这篇文章中代码。

如有错误欢迎指正!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/110822.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

NewStarCTF公开赛week4密码学题目wp

目录前言一、LCG Revenge1.原题2.解题思路1) 考察知识2) 分析本质3.解题Python脚本二、代数关系1.原题2.解题思路3.解题Python脚本前言 哎呦喂&#xff0c;第三周勉强做了一道题&#xff0c;果然第四周就爆零了QAQ ———————————悲伤的分割线——————————— …

Apache Flink 作业图 JobGraph 与执行图 ExecutionGraph

由 Flink 程序直接映射成的数据流图&#xff08;dataflow graph&#xff09;&#xff0c;也被称为逻辑流图&#xff08;logical StreamGraph&#xff09;。到具体执行环节时&#xff0c;Flink 需要进一步将逻辑流图进行解析&#xff0c;转换为物理执行图。 在这个转换过程中&am…

思派健康在港交所上市:九成收入靠“卖药”,持续大额亏损

12月23日&#xff0c;思派健康&#xff08;HK:00314&#xff09;在港交所上市。本次上市&#xff0c;思派健康的发行价格为18.60港元/股&#xff0c;全球发售991.94万股。据此前招股书介绍&#xff0c;思派健康将自全球发售收取所得款项净额约1.204亿港元。 招股书显示&#x…

每月明星计划(12 月),ECHO:我们的意见万岁!

我们很高兴 12 月的 MSP 比我们预期的要成功得多。提交项目的数量和质量甚至优于 11 月的 MSP。 在 11 月的 MSP 竞赛中&#xff0c;被选中的项目Owlando以其先进的 UGC Metaverse 概念及其与朋友创建、交流和玩耍的有趣方式引起了评委的注意&#xff0c;最终将结果构建到 NFT…

百度百科创建词条步骤是怎样的?

互联网时代&#xff0c;在百度上搜索人物、企业、品牌、作品之类的信息都会出现相关百科词条&#xff0c;一般在首页前几名的位置&#xff0c;权重非常高&#xff0c;获得的用户流量也是非常庞大的。 基于百度百科的宣传背书&#xff0c;大大提高了内容的可信度和知名度&#…

【详细学习SpringBoot源码之属性配置文件加载原理(application.properties|application.yaml)-7】

一.知识回顾 【0.SpringBoot专栏的相关文章都在这里哟&#xff0c;后续更多的文章内容可以点击查看】 【1.SpringBoot初识之Spring注解发展流程以及常用的Spring和SpringBoot注解】 【2.SpringBoot自动装配之SPI机制&SPI案例实操学习&SPI机制核心源码学习】 【3.详细学…

教室管理系统

开发工具(eclipse/idea/vscode等)&#xff1a;idea 数据库(sqlite/mysql/sqlserver等)&#xff1a;mysql 功能模块(请用文字描述&#xff0c;至少200字)&#xff1a; 关于这个系统的具体功能主要包括教师&#xff0c;学生&#xff0c;课程&#xff0c;教室还有班级这几个实体。…

acwing基础课——二分图

由数据范围反推算法复杂度以及算法内容 - AcWing 常用代码模板3——搜索与图论 - AcWing 基本思想&#xff1a; 二分图:在一张图中&#xff0c;如果能把全部点分到两个集合&#xff0c;且保证两个集合内部没有任何一条边&#xff0c;图中的边只存在于两个集合之间&#xff0c…

制造业ERP如何做好成本核算管理?

随着制造业的不断发展&#xff0c;制造业成本管理中存在的问题已成为制造业企业关注的焦点。在传统粗放的手工模式下&#xff0c;制造企业成本核算工作量会非常巨大&#xff0c;不能对成本信息进行实时监控&#xff0c;只能在成本费用发生后进行归集核算&#xff0c;数据有滞后…

PS CS6视频剪辑基本技巧(四)字幕居中和滚动字幕

在第三讲中介绍了添加字幕的方法&#xff0c;但有的读者可能会发现&#xff0c;字幕模板设定的字幕起始是固定不变的&#xff0c;假如设定的起始位置是最左边&#xff0c;那么无论一行字多有多少个&#xff0c;都是从最左边开始排。那么有没有办法可以让字幕可以批量居中呢&…

大数据技术之SparkCore

文章开篇先简单介绍一下SparkCore&#xff1a; Spark Core是spark的核心与基础&#xff0c;实现了Spark的基本功能&#xff0c;包含任务调度&#xff0c;内存管理&#xff0c;错误恢复与存储系统交互等模块 Spark Core中包含了对Spark核心API——RDD API(弹性分布式数据集)的定…

你以为传切片就是传引用了吗?

xdm &#xff0c;我们在写 golang 的时候&#xff0c;引用和传值傻傻分不清&#xff0c;就例如我们传 切片 的时候&#xff0c;你能分清楚你传的切片是传值还是传引用呢&#xff1f; 引用是什么&#xff1f; 引用就是给对象起另一个名字&#xff0c;引用类型引用另一种类型 引…

【自省】线程池里的定时任务跑的可欢了,可咋停掉特定的任务?

客户端抢到分布式锁之后开始执行任务&#xff0c;执行完毕后再释放分布式锁。持锁后因客户端异常未能把锁释放&#xff0c;会导致锁成为永恒锁。为了避免这种情况&#xff0c;在创建锁的时候给锁指定一个过期时间。到期之后锁会被自动删除掉&#xff0c;这个角度看是对锁资源的…

Going Home(二分图最大权匹配KM算法)

C-Going Home_2022图论班第一章图匹配例题与习题 (nowcoder.com) 在网格地图上有n个小人和n座房子。在每个单位时间内&#xff0c;每个小人都可以水平或垂直地移动一个单位步到相邻点。对于每个小矮人&#xff0c;你需要为他每走一步支付1美元的旅费&#xff0c;直到他进入一所…

Git命令笔记,下载、提交代码、解决冲突、分支处理

下载代码&#xff0c;复制https地址到本地文件夹&#xff0c;鼠标右键选择git bash后输入命令 git clone https://gitee.com/View12138/ViewFaceCore.git 下载后初始化&#xff1a;git init 下载代码后不运行报错&#xff08;如下&#xff09;&#xff0c;需要执行初始化命令…

Google ProtoBuf的使用

Google的protobuf太好用了&#xff0c;又小&#xff0c;读写又快 跑步快慢受鞋的影响太大了&#xff0c;但是造鞋的工具研究起来还是很有难度的&#xff0c;百度真是充斥的大量的转载文件&#xff0c;不管能不能用、能不能看懂&#xff0c;反正是各种转载&#xff0c;有的连错…

2023年企业固定资产管理怎么破局?

2022年已经在风雨中过去&#xff0c;转眼我们迎来了2023年。过去的一年&#xff0c;固定资产管理的痛依旧历历在目&#xff0c;如何让新的一年中&#xff0c;固定资产管理工作有所突破&#xff0c;不再承受固定资产资产管理的痛处&#xff0c;是每个企业管理者和企业固定资产管…

snap打包初步了解

前言 和snap比较类似的有三种打包方式&#xff1a; Snap Flatpak appimage Appimage是将所有的资源打包在一起&#xff0c;以一个类似与独立exe的方式执行&#xff0c;虽然简单使用&#xff0c;但是解压资源和本地缓存数据都比较麻烦。 Flatpak和snap十分类似&#xff0c;但…

XXE无回显攻击详解

今天继续给大家介绍渗透测试相关知识&#xff0c;本文主要内容是XXE无回显攻击详解。 免责声明&#xff1a; 本文所介绍的内容仅做学习交流使用&#xff0c;严禁利用文中技术进行非法行为&#xff0c;否则造成一切严重后果自负&#xff01; 再次强调&#xff1a;严禁对未授权设…

怎么把element的tootip设置为点击后出现提示框,且在提示框里面放其他元素,vue2动态给对象添加属性并实现响应式应答,样式穿透

怎么把element的tootip设置为点击后出现提示框 我目前有一个需求&#xff0c;就是要点击文字才会出现提示框&#xff0c;而不是hover上去就以后&#xff0c;找资料看文档&#xff0c;看了半天让我终于实现了&#xff0c;其实也不难&#xff0c;可能是最开始我没有理解value&am…