使用Alexnet实现CIFAR100数据集的训练

news2025/1/15 22:48:57

 如果对你有用的话,希望能够点赞支持一下,这样我就能有更多的动力更新更多的学习笔记了。😄😄     

        使用Alexnet进行CIFAR-10数据集进行测试,这里使用的是将CIFAR-10数据集的分辨率扩大到224X224,因为在测试训练的时候,发现将CIFAR10数据集的分辨率拉大可以让模型更快地进行收敛,并且识别的效果也是比低分辨率的更加好。

首先来介绍一下,Alexnet:

1.论文下载地址:http://www.cs.toronto.edu/~fritz/absps/imagenet.pdf

2.Alexnet的历史地位:

①在ILSVRC-2010和ILSVRC-2012比赛中,使用ImageNet数据集的一个子集,训练了一个最大的卷积神经网络,并且在该数据集取得相对于现在来说很好的结果。

②完成高度优化的GPU实现,用于2D卷积和训练神经网络的操作,并将其公开。

③使用了一些技巧(ReLu、多块GPU并行训练、局部响应归一化、Overlapping池化、Dropout等),能够改善性能、减少训练时间。

3.Alexnet的网络结构图(使用了两个GPU,所以网络的结构是分开进行画出来的):

4.代码实现:

数据集的处理:
        调用torchvision里面封装好的数据集(CIFAR100)进行数据的训练,并且利用官方已经做好的数据集分类(using 50000 images for training, 10000 images for validation)是数据集的划分大小。进行了一些简单的数据增强,分别是随机的随机剪切和随机的水平拉伸操作。

模型的代码结构目录:

data:进行模型训练的时候会自动开始下载数据集的信息到这个文件夹里面。

res:该文件夹会保存模型的权重和记录模型在训练过程当中计算出来的train_loss, train_acc和val_acc做成的xml文件夹。

train.py代码如下:

# -*- coding:utf-8 -*-
# @Time : 2023-01-10 19:08
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm



import torchvision

from model import AlexNet
import os
import parameters
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from tqdm import tqdm
from fuction import writer_into_excel_onlyval



def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    epochs = parameters.epoch
    save_model = parameters.alexnet_save_model
    save_path = parameters.alexnet_save_path_CIFAR10

    data_transform = {
        # 进行数据增强的处理
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),

        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    }

    train_dataset = torchvision.datasets.CIFAR100(root='./data/CIFAR100', train=True,
                                                 download=True, transform=data_transform["train"])

    val_dataset = torchvision.datasets.CIFAR100(root='./data/CIFAR100', train=False,
                                               download=False, transform=data_transform["val"])

    train_num = len(train_dataset)
    val_num = len(val_dataset)
    print("using {} images for training, {} images for validation.".format(train_num, val_num))

    batch_size = parameters.batch_size

    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=nw,
                                               )

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=nw,
                                             )

    model = AlexNet(num_classes=parameters.CIFAR100_class)
    model.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=parameters.alexnet_lr)
    best_acc = 0.0

    # 记录训练产生的数据
    train_acc_list = []
    train_loss_list = []
    val_acc_list = []

    for epoch in range(epochs):
        # train
        model.train()
        running_loss_train = 0.0
        train_accurate = 0.0
        train_bar = tqdm(train_loader)
        for images, labels in train_bar:
            optimizer.zero_grad()

            outputs = model(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            predict = torch.max(outputs, dim=1)[1]
            train_accurate += torch.eq(predict, labels.to(device)).sum().item()
            running_loss_train += loss.item()

        train_accurate = train_accurate / train_num
        running_loss_train = running_loss_train / train_num
        train_acc_list.append(train_accurate)
        train_loss_list.append(running_loss_train)

        print('[epoch %d] train_loss: %.7f  train_accuracy: %.3f' %
              (epoch + 1, running_loss_train, train_accurate))

        # validate
        model.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_loader = tqdm(val_loader)
            for val_data in val_loader:
                val_images, val_labels = val_data
                outputs = model(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        val_acc_list.append(val_accurate)
        print('[epoch %d] val_accuracy: %.3f' %
              (epoch + 1, val_accurate))
        writer_into_excel_onlyval(save_path, train_loss_list, train_acc_list, val_acc_list, "CIFAR100")

        # 选择最好的模型进行保存,此处的评价指标是acc
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(model.state_dict(), save_model)


if __name__ == '__main__':
    main()

parameters.py代码如下:

# -*- coding:utf-8 -*-
# @Time : 2023-01-10 19:12
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm


# 训练的次数
epoch = 2

# 训练的批次大小
batch_size = 4

# 数据集的分类类别数量
CIFAR100_class = 100

# 模型训练时候的学习率大小
alexnet_lr = 0.002

# 保存模型权重的路径 保存xml文件的路径
alexnet_save_path_CIFAR10 = './res/'
alexnet_save_model = './res/best_model.pth'

model.py代码如下:

# -*- coding:utf-8 -*-
# @Time : 2023-01-10 19:08
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm



import torch.nn as nn
import torch


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(48, 128, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(128, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            # output[128, 6, 6]
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

fuction代码如下:

# -*- coding:utf-8 -*-
# @Time : 2023-01-10 19:09
# @Author : DaFuChen
# @File : CSDN写作代码笔记
# @software: PyCharm


import xlwt



def writer_into_excel_onlyval(excel_path,loss_train_list, acc_train_list, val_acc_list,dataset_name:str=""):
    workbook = xlwt.Workbook(encoding='utf-8')  # 设置一个workbook,其编码是utf-8
    worksheet = workbook.add_sheet("sheet1", cell_overwrite_ok=True)  # 新增一个sheet
    worksheet.write(0, 0, label='Train_loss')
    worksheet.write(0, 1, label='Train_acc')
    worksheet.write(0, 2, label='Val_acc')


    for i in range(len(loss_train_list)):  # 循环将a和b列表的数据插入至excel
        worksheet.write(i + 1, 0, label=loss_train_list[i])  # 切片的原来是传进来的Imgs是一个路径的信息
        worksheet.write(i + 1, 1, label=acc_train_list[i])
        worksheet.write(i + 1, 2, label=val_acc_list[i])


    workbook.save(excel_path + str(dataset_name) +".xls")  # 这里save需要特别注意,文件格式只能是xls,不能是xlsx,不然会报错
    print('save success!   .')



        其中部分参数,例如是学习率的大小,训练的批次大小,数据增强的一些小参数,可以根据自己的经验和算力的现实情况进行调整。

如果对你有用的话,希望能够点赞支持一下,这样我就能有更多的动力更新更多的学习笔记了。😄😄

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/153248.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【ROS】—— ROS常用组件_TF坐标变换_多态坐标变换与TF坐标变换实操(十一)

文章目录前言1. 多态坐标变换1.1 发布方1.2 订阅方(C)1.3 订阅方(python)2. 坐标系关系查看3. TF坐标变换实操(C)3.1准备3.2 生成新的乌龟3.3 增加键盘控制3.4 发布方(发布两只乌龟的坐标信息)3.5 订阅方(解析坐标信息并生成速度信息)前言 📢本系列将依托赵虚左老师…

将git仓库瘦身

一个小工具的仓库居然有7个g了,每次clone都要等好久,在网上找的方法,实际了几个小时才成功瘦身,做一次记录 一、排查是哪些历史文件占用了内存,下面是查询最大的5个文件 git rev-list --objects --all | grep "$(…

使用nvm管理node版本,实现高版本与低版本node之间的转换

第一步:先清空本地安装的node.js版本 1.按健winR弹出窗口,键盘输入cmd,然后敲回车(或者鼠标直接点击电脑桌面最左下角的win窗口图标弹出,输入cmd再点击回车键) 然后进入命令控制行窗口,并输入where node查…

Linux环境配置

一、安装ubuntu16的步骤 1、新建Ubuntu 右击文件---------新建虚拟机 典型----------------下一步 稍后安装操作系统-------------下一步 Linux(L)------------------Ubuntu64位-----------------下一步 虚拟机名称(随意起)----…

OSS简单介绍

OSS 阿里云对象存储OSS(Object Storage Service)是一款海量、安全、低成本、高可靠的云存储服务,可提供99.9999999999%(12个9)的数据持久性,99.995%的数据可用性。多种存储类型供选择,全面优化…

197: vue+openlayers 预加载preload瓦片地图,减少过渡期间的空白区域

第197个 点击查看专栏目录 本示例的目的是介绍演示如何在vue+openlayers项目中演示瓦片预加载和没有预加载的不同状态。 没有采用预加载 当我们平移和缩放地图时,经常会遇到过渡期间的空白区域(因为内容正在加载),过一会儿,切片图像才出现了。 采用预加载,将预载的值设置为…

深入ReentrantLock锁

1. 前言 今天我们来探讨下另一个核心锁ReentrantLock. 从具体的实现到JVM层面是如何实现的。 我们都会一一进行讨论的,好了,废话不多说了,我们就开始吧 2. ReentrantLock 以及synchronized 核心区别: ReentrantLock 是一个抽象的…

MVC框架知识详解

✅作者简介:热爱国学的Java后端开发者,修心和技术同步精进。 🍎个人主页:Java Fans的博客 🍊个人信条:不迁怒,不贰过。小知识,大智慧。 💞当前专栏:Java案例分…

机器学习课程学习随笔

文章目录本文来源机器学习简介机器学习流程机器学习可以完成如下功能:机器学习应用场景金融领域零售领域机器学习分类机器学习实现基于python等代码自己实现本文来源 本博客 是通过学习亚马逊的官方机器学习课程的随笔。 课程参考链接https://edu.csdn.net/course/…

爬虫与反爬虫 - 道高一尺魔高一丈 - 2013最新 - JS逆向 - Python Scrapy实现 - 爬取某天气网站历史数据

目录 背景介绍 网站分析 第1步:找到网页源代码 第2步:分析网页源代码 Python 实现 成果展示 后续 Todo 背景介绍 今天这篇文章,3个目的,1个是自己记录,1个是给大家分享,还有1个是向这个被爬网站的前…

synchronized锁膨胀(附代码分析)

synchronized锁膨胀 1. 基本概念 Java对象头 Java对象的对象头信息中的 Mark Word 主要是用来记录对象的锁信息的。 现在看一下 Mark Word 的对象头信息。如下: 其实可以根据 mark word 的后3位就可以判断出当前的锁是属于哪一种锁。注意:表格中的…

shell脚本练习2023年下岗版

shell脚本练习 1.判断指定进程的运行情况 #!/bin/bash NAMEhttpd #这里输入进程的名称 NUM$(ps -ef |grep $NAME |grep -vc grep) if [ $NUM -eq 1 ]; thenecho "$NAME running." elseecho "$NAME is not running!" fi2.判断用户是否存在 #!/bin/bash r…

【RabbitMQ】安装、启动、配置、测试一条龙

一、基本环境安装配置 1.英文RabbitMQ是基于erlang开发的所以需要erlang环境,点击以下链接下载安装 Downloads - Erlang/OTP 2.官网下载RabbitMQ安装包并安装 Installing on Windows — RabbitMQ 3.配置erlang本地环境变量(和JAVAHOME类似) 4.cmd查看erlang版本 5.点击以下…

自己看的操作系统

计算机网络冯诺依曼体系进程线程内核和虚拟内存os管理线程冯诺依曼体系 计算机五大组成:输入设备、输出设备、控制器、运算器、存储器 进程线程 这些应用都是进程 进程相当于一个菜谱,读取到内存中去使用。 电脑一时间能运行很多进程。 进程中为什么要…

excel函数技巧:MAX在数字查找中的应用妙招

大家都知道VLOOKUP可以按给定的内容去匹配到我们所需的数据,正因为如此,它在函数界有了很大的名气。但是今天要分享的这三个示例,如果使用VLOOKUP去匹配数据的话,就有些麻烦了。就在VLOOKUP头疼不已的时候,MAX函数二话…

2022 年度总结

1、CSDN 年度总结 2022年的粉丝涨幅比较明显竟然超过了之前几年的总和,这是比较意外的。应该是因为今年研究了一些云原生、元宇宙的原因,方向比努力真的重要的多。 1500的阅读确实没想到~~~说明低头一族还是没白当 涨粉稍微明细,不过还需…

English Learning - L1-11 时态 + 情态动词 2023.1.9 周一

English Learning - L1-11 时态 情态动词 2023.1.9 周一8 时态8.4 完成进行时(一)现在完成进行时核心思维:动作开始于现在之前,并有限地持续下去,动作到目前为止尚未完成1. 动作从过去某时开始一直持续到现在并可能继…

【Python】如何使用python将一个py文件变成一个软件?

系列文章目录 这个系列文章将写一些python中好玩的小技巧。 第一章 使用Python 做一个软件 目录 系列文章目录 前言 一、第一步:写好文件 二、第二步:生成程序 1.安装库 2.使用安装的库进行转化 总结 前言 本文重点说如何将py文件转化为exe文件…

回溯法--符号三角形(杂记)

回溯法说来简单,写起来难,真的是要愁死。回溯法有两种模板--子集树、排列树5.4符号三角形--dfs计算多少个满足条件的符号三角形,同号下面为“”,异号下面为“-”。根据异或的规则我们令“”0,“-”1,(异或的…

postgresql 启用ssl安全连接方式

SSL的验证流程 利用openssl环境自制证书 CA 证书签发 创建私钥ca.key,使用des3算法,有效期2048天 openssl genrsa -des3 -out ca.key 2048生成根CA证书请求(.csr) openssl req -new -key ca.key -out ca.csr -subj "/CCN/STGuangDong/LGuangZhou…