使用Alexnet实现CIFAR10数据集的训练

news2024/11/27 17:57:30

   如果对你有用的话,希望能够点赞支持一下,这样我就能有更多的动力更新更多的学习笔记了。😄😄     

        使用Alexnet进行CIFAR-10数据集进行测试,这里使用的是将CIFAR-10数据集的分辨率扩大到224X224,因为在测试训练的时候,发现将CIFAR10数据集的分辨率拉大可以让模型更快地进行收敛,并且识别的效果也是比低分辨率的更加好。

首先来介绍一下,Alexnet:

1.论文下载地址:http://www.cs.toronto.edu/~fritz/absps/imagenet.pdf

2.Alexnet的历史地位:

①在ILSVRC-2010和ILSVRC-2012比赛中,使用ImageNet数据集的一个子集,训练了一个最大的卷积神经网络,并且在该数据集取得相对于现在来说很好的结果。

②完成高度优化的GPU实现,用于2D卷积和训练神经网络的操作,并将其公开。

③使用了一些技巧(ReLu、多块GPU并行训练、局部响应归一化、Overlapping池化、Dropout等),能够改善性能、减少训练时间。

3.Alexnet的网络结构图(使用了两个GPU,所以网络的结构是分开进行画出来的):

4.代码实现:

数据集的处理:
        调用torchvision里面封装好的数据集(CIFAR10)进行数据的训练,并且利用官方已经做好的数据集分类(using 50000 images for training, 10000 images for validation)是数据集的划分大小。进行了一些简单的数据增强,分别是随机的随机剪切和随机的水平拉伸操作。

模型的代码结构目录:

 data:进行模型训练的时候会自动开始下载数据集的信息到这个文件夹里面。

res:该文件夹会保存模型的权重和记录模型在训练过程当中计算出来的train_loss, train_acc和val_acc做成的xml文件夹。

train.py代码如下:

import torchvision

from model import AlexNet
import os
import parameters
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from tqdm import tqdm
from fuction import writer_into_excel_onlyval



def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    epochs = parameters.epoch
    save_model = parameters.alexnet_save_model
    save_path = parameters.alexnet_save_path_CIFAR10

    data_transform = {
        # 进行数据增强的处理
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),

        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    }

    train_dataset = torchvision.datasets.CIFAR10(root='./data/CIFAR10', train=True,
                                                 download=True, transform=data_transform["train"])

    val_dataset = torchvision.datasets.CIFAR10(root='./data/CIFAR10', train=False,
                                               download=False, transform=data_transform["val"])

    train_num = len(train_dataset)
    val_num = len(val_dataset)
    print("using {} images for training, {} images for validation.".format(train_num, val_num))

    batch_size = parameters.batch_size

    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=nw,
                                               )

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=nw,
                                             )

    model = AlexNet(num_classes=parameters.CIFAR10_class)
    model.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=parameters.alexnet_lr)
    best_acc = 0.0

    # 记录训练产生的数据
    train_acc_list = []
    train_loss_list = []
    val_acc_list = []

    for epoch in range(epochs):
        # train
        model.train()
        running_loss_train = 0.0
        train_accurate = 0.0
        train_bar = tqdm(train_loader)
        for images, labels in train_bar:
            optimizer.zero_grad()

            outputs = model(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            predict = torch.max(outputs, dim=1)[1]
            train_accurate += torch.eq(predict, labels.to(device)).sum().item()
            running_loss_train += loss.item()

        train_accurate = train_accurate / train_num
        running_loss_train = running_loss_train / train_num
        train_acc_list.append(train_accurate)
        train_loss_list.append(running_loss_train)

        print('[epoch %d] train_loss: %.7f  train_accuracy: %.3f' %
              (epoch + 1, running_loss_train, train_accurate))

        # validate
        model.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_loader = tqdm(val_loader)
            for val_data in val_loader:
                val_images, val_labels = val_data
                outputs = model(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        val_acc_list.append(val_accurate)
        print('[epoch %d] val_accuracy: %.3f' %
              (epoch + 1, val_accurate))
        writer_into_excel_onlyval(save_path, train_loss_list, train_acc_list, val_acc_list, "CIFAR10")

        # 选择最好的模型进行保存,此处的评价指标是acc
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(model.state_dict(), save_model)


if __name__ == '__main__':
    main()

parameters.py代码如下:
# -*- coding:utf-8 -*-
# @Time : 2023-01-10 19:12
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm


# 训练的次数
epoch = 2

# 训练的批次大小
batch_size = 4

# 数据集的分类类别数量
CIFAR10_class = 10

# 模型训练时候的学习率大小
alexnet_lr = 0.002

# 保存模型权重的路径 保存xml文件的路径
alexnet_save_path_CIFAR10 = './res/'
alexnet_save_model = './res/best_model.pth'

model.py代码如下:

# -*- coding:utf-8 -*-
# @Time : 2023-01-10 19:08
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm



import torch.nn as nn
import torch


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(48, 128, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(128, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            # output[128, 6, 6]
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

fuction代码如下:

# -*- coding:utf-8 -*-
# @Time : 2023-01-10 19:09
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm


import xlwt



def writer_into_excel_onlyval(excel_path,loss_train_list, acc_train_list, val_acc_list,dataset_name:str=""):
    workbook = xlwt.Workbook(encoding='utf-8')  # 设置一个workbook,其编码是utf-8
    worksheet = workbook.add_sheet("sheet1", cell_overwrite_ok=True)  # 新增一个sheet
    worksheet.write(0, 0, label='Train_loss')
    worksheet.write(0, 1, label='Train_acc')
    worksheet.write(0, 2, label='Val_acc')


    for i in range(len(loss_train_list)):  # 循环将a和b列表的数据插入至excel
        worksheet.write(i + 1, 0, label=loss_train_list[i])  # 切片的原来是传进来的Imgs是一个路径的信息
        worksheet.write(i + 1, 1, label=acc_train_list[i])
        worksheet.write(i + 1, 2, label=val_acc_list[i])


    workbook.save(excel_path + str(dataset_name) +".xls")  # 这里save需要特别注意,文件格式只能是xls,不能是xlsx,不然会报错
    print('save success!   .')



最后的实验结果保存:

其中部分参数,例如是学习率的大小,训练的批次大小,数据增强的一些小参数,可以根据自己的经验和算力的现实情况进行调整。

如果对你有用的话,希望能够点赞支持一下,这样我就能有更多的动力更新更多的学习笔记了。😄😄

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/153398.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第03讲:Docker 容器的数据卷

一、什么是数据卷 数据卷是宿主机中的一个目录或文件,当容器目录或者文件和数据卷目录或者文件绑定后,对方的修改会立即同步,一个数据卷可以被多个容器同时挂载,一个容器也可以被挂载多个数据卷,数据卷的作用:容器数据…

基于遥感卫星影像水体提取方法综述

水体提取分类依据及基础 水体提取分类依据 水体提取的方法很多,很多学者也进行了分类,大体上有一个分类框架,主要是基于光学影像的分类,比如王航等[7]将水体提取分成3类,分别是基于阈值法、分类器法和自动化法; 李丹等[8]更深一步进行总结,引入近些年发展火热的基于雷达影像数…

Redisson自定义序列化

配置RedissonClientBean public RedissonClient redissonClient() {Config config new Config();// 单节点模式SingleServerConfig singleServerConfig config.useSingleServer();singleServerConfig.setAddress("redis://127.0.0.1:6379");singleServerConfig.set…

LeetCode二叉树经典题目(六):二叉搜索树

目录 28. LeetCode617. 合并二叉树 29. LeetCode700. 二叉搜索树中的搜索 30. LeetCode98. 验证二叉搜索树 31. LeetCode530. 二叉搜索树的最小绝对差 32. LeetCode501. 二叉搜索树中的众数 33. LeetCode236. 二叉树的最近公共祖先​ 28. LeetCode617. 合并二叉树 递归&…

Hi3861鸿蒙物联网项目实战:智能安防报警

华清远见FS-Hi3861开发套件,支持HarmonyOS 3.0系统。开发板主控Hi3861芯片内置WiFi功能,开发板板载资源丰富,包括传感器、执行器、NFC、显示屏等,同时还配套丰富的拓展模块。开发板配套丰富的学习资料,包括全套开发教程…

Windows11 系统打开IE浏览器的方式(完整版)

前言 大家好,好久不见! 1、最近疯狂加班,旧电脑不太给力,换了新电脑,嘎嘎开心;开心之余发现新电脑是Win11系统的,但是IE浏览器找不到了,由于我的某些工作需要用到IE浏览器&#xf…

Vue2前端路由(vue-router的使用)、动态路由、路由和视图的命名以及声明式和编程式导航

目录 一、vue2的前端路由(vue-router) 1、路由:页面地址与组件之间的对应关系 2、路由方式:服务器端路由、前端路由 3、前端路由:在前端维护一组路由规则(地址和组件之间的对应关系)&#xf…

【UE4 第一人称射击游戏】34-制作一个简易计时器

上一篇:【UE4 第一人称射击游戏】33-创建一个迷你地图本篇效果:可以看到左上角有个简易的关卡计时器在倒计时步骤:打开“FPSHUD”,拖入一个图像控件图像选择“Timer_Backing”,尺寸改为4719拖入3个文本控件大小为1210字…

学习ffmpeg-录屏实现记录

项目需要一个录屏的功能,之前看到了一个使用Qt计时器截图avilib生成AVIffmpeg合并视频音频的方式:Qt C 录屏录音功能实现(avilibffmpeg)以及动态库生成https://blog.csdn.net/qq_35769071/article/details/125323624使用后&#x…

【.dll 没有被指定在windows上运行】

修复(重新注册DLL)的具体步骤如下: 方法一: 1、快捷键winr打开“运行”输入cmd,点击确定打开命令提示符窗口。 2、复制:for %1 in (%windir%\system32*.dll) do regsvr32.exe /s %1 命令,在打开的管理员…

ubuntu安装vue

首先建议使用ubuntu18.04以上的系统,不然会有类似fcntlGLIBC_2.28‘未定义的引用的报错 VUE官网:http://caibaojian.com/vue/guide/installation.html 其中安装说明只写到:npm install vue 我们还需要安装node.js、npm 1、安装 NVM&#xf…

Windows安装TensorRT

文章目录前言TensorRT下载TensorRT安装参考资料前言 本文将介绍Windows如何安装TensorRT。本文的基础是:Windows安装PytorchCUDA环境 TensorRT下载 进入官方网站:https://developer.nvidia.com/nvidia-tensorrt-8x-download 寻找自己对应的版本&#…

RabbitMQ之Work Queue(工作队列)

前言:大家好,我是小威,24届毕业生,曾经在某央企公司实习,目前在某税务公司。本篇文章将记录和分享RabbitMQ工作队列相关的知识点。 本篇文章记录的基础知识,适合在学Java的小白,也适合复习中&am…

【自学Python】Python string转bytes

Python string转bytes Python string转bytes教程 在 Python 中,bytes 类型和 字符串 的所有操作、使用和内置方法也都基本一致。因此,我们也可以实现将字符串类型转换成 bytes 类型。 Python string转bytes方法 如果字符串内容都是 ASCII 字符&#…

从0到1完成一个Vue后台管理项目(十六、后端分页方法以及分页组件的封装以及复用)

往期 从0到1完成一个Vue后台管理项目(一、创建项目) 从0到1完成一个Vue后台管理项目(二、使用element-ui) 从0到1完成一个Vue后台管理项目(三、使用SCSS/LESS,安装图标库) 从0到1完成一个Vu…

C++STL——list类与模拟实现

Listlistlist的常用接口模拟实现完整代码list与vector的区别list list是一个带头双向循环链表。 list文档介绍:https://legacy.cplusplus.com/reference/list/list/ list因为是链表结构,所以没有 [] 去访问数据的方式,只有用迭代器&#xff…

第十六届中国大数据技术大会五大分论坛顺利举办!

1月8日下午,由苏州市人民政府指导、中国计算机学会主办、苏州市吴江区人民政府支持,CCF大数据专家委员会、苏州市吴江区工信局、吴江区东太湖度假区管委会、苏州市吴江区科技局、苏州大学未来科学与工程学院及DataFounain数联众创联合承办的第十六届中国…

基于java springboot+mybatis学生学科竞赛管理管理系统设计和实现

基于java springbootmybatis学生学科竞赛管理管理系统设计和实现 博主介绍:5年java开发经验,专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 超级帅帅吴 Java毕设项目精品实战案例《500套》 欢迎点赞 收藏 ⭐留言…

10.Isaac教程--在Docker中通过模拟训练目标检测

在Docker中通过模拟训练目标检测 文章目录在Docker中通过模拟训练目标检测怎么运行的主机设置硬件要求软件要求NGC Docker 注册表设置第一次运行数据集生成配置您的工作区Jupyter 变量设置开始训练添加您自己的 3D 模型故障排除接下来人工智能中的一个常见问题是训练样本的数据…

02【Http、Request】

文章目录02【Http、Request】一、HTTP协议1.1 HTTP协议概述1.1.1 HTTP协议的概念1.1.2 HTTP协议的特点:2.1 HTTP请求的组成2.1.1 请求行2.1.2 请求头2.1.3 请求体二、HttpServletRequest对象2.1 HttpServletRequest对象简介2.2 HttpServletRequest的使用2.2.1 请求行…