3.AlexNet--CNN经典网络模型详解(pytorch实现)

news2024/12/27 12:37:38

        看博客AlexNet--CNN经典网络模型详解(pytorch实现)_alex的cnn-CSDN博客,该博客的作者写的很详细,是一个简单的目标分类的代码,可以通过该代码深入了解目标检测的简单框架。在这里不作详细的赘述,如果想更深入的了解,可以看另一个博客实现pytorch实现MobileNet-v2(CNN经典网络模型详解) - 知乎 (zhihu.com)。

在这里,直接写AlexNet--CNN的代码。

1.首先建立一个model.py文件,用来写神经网络,代码如下:

#model.py

import torch.nn as nn
import torch


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(  #打包
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55] 自动舍去小数点后
            nn.ReLU(inplace=True), #inplace 可以载入更大模型
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27] kernel_num为原论文一半
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            #全链接
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1) #展平   或者view()
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') #何教授方法
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)  #正态分布赋值
                nn.init.constant_(m.bias, 0)

2.下载数据集

DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'

3.下载完后写一个spile_data.py文件,将数据集进行分类

#spile_data.py

import os
from shutil import copy
import random


def mkfile(file):
    if not os.path.exists(file):
        os.makedirs(file)


file = 'flower_data/flower_photos'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
mkfile('flower_data/train')
for cla in flower_class:
    mkfile('flower_data/train/'+cla)

mkfile('flower_data/val')
for cla in flower_class:
    mkfile('flower_data/val/'+cla)

split_rate = 0.1
for cla in flower_class:
    cla_path = file + '/' + cla + '/'
    images = os.listdir(cla_path)
    num = len(images)
    eval_index = random.sample(images, k=int(num*split_rate))
    for index, image in enumerate(images):
        if image in eval_index:
            image_path = cla_path + image
            new_path = 'flower_data/val/' + cla
            copy(image_path, new_path)
        else:
            image_path = cla_path + image
            new_path = 'flower_data/train/' + cla
            copy(image_path, new_path)
        print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
    print()

print("processing done!")

之后应该是这样:
在这里插入图片描述

4.再写一个train.py文件,用来训练模型

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time


#device : GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


#数据转换

data_transform = {
    #具体是对图像进行各种转换操作,并用函数compose将这些转换操作组合起来
    #以下操作步骤:
        # 1.图片随机裁剪为224X224
        # 2.随机水平旋转,默认为概率0.5
        # 3.将给定图像转为Tensor
        # 4.归一化处理
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}


data_root = os.getcwd()
image_path = data_root + "/flower_data/"  # flower data set path
train_dataset = datasets.ImageFolder(root=image_path + "/train",
                                     transform=data_transform["train"])
train_num = len(train_dataset)

# print(train_dataset)
# print(train_dataset[1][0].size())
# print(train_dataset.imgs[0][0])


# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0)
i=1
# for step, data in enumerate(train_loader, start=0):
#     images, labels = data
#     print(i,"==>",images.shape,labels.shape)
#     i+=1

validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=True,
                                              num_workers=0)

test_data_iter = iter(validate_loader)
test_image, test_label = test_data_iter.__next__()


net = AlexNet(num_classes=5, init_weights=True)

net.to(device)
#损失函数:这里用交叉熵
loss_function = nn.CrossEntropyLoss()
#优化器 这里用Adam
optimizer = optim.Adam(net.parameters(), lr=0.0002)
#训练参数保存路径
save_path = './AlexNet.pth'
#训练过程中最高准确率
best_acc = 0.0

#开始进行训练和测试,训练一轮,测试一轮
for epoch in range(10):
    # train
    net.train()    #训练过程中,使用之前定义网络中的dropout
    running_loss = 0.0
    t1 = time.perf_counter()
    for step, data in enumerate(train_loader, start=0):
        images, labels = data
        #print("step:     ",images.shape,labels)
        optimizer.zero_grad()
        outputs = net(images.to(device))
        loss = loss_function(outputs, labels.to(device))
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        # print train process
        rate = (step + 1) / len(train_loader)
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
    print()
    print(time.perf_counter()-t1)

    # validate
    net.eval()    #测试过程中不需要dropout,使用所有的神经元
    acc = 0.0  # accumulate accurate number / epoch
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            predict_y = torch.max(outputs, dim=1)[1]
            acc += (predict_y == val_labels.to(device)).sum().item()
        val_accurate = acc / val_num
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / step, val_accurate))

print('Finished Training')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1614958.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[Meachines][Easy]Devvortex

Main $ nmap -p- 10.10.11.242 --min-rate 1000 # echo 10.10.11.242 devvortex.htb>>/etc/hosts 子域名爆破 $ apt install seclists $ wfuzz -c -w /usr/share/seclists/Discovery/DNS/subdomains-top1million-5000.txt -u "http://devvortex.htb/" -H &…

【Transformer】detr梳理

every blog every motto: You can do more than you think. https://blog.csdn.net/weixin_39190382?typeblog 0. 前言 detr detr 1. 引言 论文: https://arxiv.org/pdf/2005.12872v3.pdf 时间: 2020.5.26 作者: Nicolas Carion?, Fra…

陈奂仁联手 The Sandbox 推出“Hamsterz Doodles”人物化身系列

全新人物化身系列结合艺术与实用性 开创元宇宙新篇章 著名亚洲唱作歌手兼香港电影金像奖得主陈奂仁携手 The Sandbox,兴奋地宣布推出新的元宇宙人物化身系列 —— Hamsterz Doodles 仓鼠涂鸦。 陈奂仁在 The Sandbox 推出 Hamsterz Doodles 系列,将艺术与…

智能家居—ESP32开发环境搭建

相关文章 毕业设计——基于ESP32的智能家居系统(语音识别、APP控制) 智能家居—ESP32开发环境搭建 一、下载安装二、验证三、资料获取 一、下载安装 下载安装 vscode 安装插件 创建工程 二、验证 写一个简单的函数来验证一下功能 void setup() {// put your setup c…

类和对象(2)——封装(封装的概念、包、staic)

前言 面向对象程序三大特性:封装、继承、多态。而类和对象阶段,主要研究的就是封装特性。何为封装呢?简单来说就是套壳屏蔽细节。 一、什么是封装 1.1 概念 将数据和操作数据的方法进行有机结合,隐藏对象的属性和实现细节&…

【码农圈子】想加免费的程序员微信群的看过来

群名:码农圈子 很多人后台反应,最近有没有免费的微信技术交流社群 。今天特意写一篇文章来创建一些只有程序猿的微信群。(广告党慎入!) 这些微信技术群都是完全免费,后续也不会收取任何费用 。 群规则 …

Paragon NTFS如何手动更新? Paragon NTFS格式化硬盘会损失数据吗?

Paragon NTFS for Mac常被用于实现在Mac上读写NTFS格式硬盘,然而,有时用户可能会遇到软件无法自动更新的情况,需要进行手动更新操作。下面我们来看看Paragon NTFS如何手动更新,Paragon NTFS格式化硬盘会损失数据吗的相关内容。 一…

Python 使用 pip 安装 matplotlib 模块(精华版)

pip 安装 matplotlib 模块 1.使用pip安装matplotlib(五步实现):2.使用下载的matplotlib画图: 1.使用pip安装matplotlib(五步实现): 长话短说:本人下载 matplotlib 花了大概三个半小时屡屡碰壁,险些暴走。为了不让新来的小伙伴走我的弯路,特意…

【matlab 代码的python复现】 Matlab实现的滤波器设计实现与Python 的库函数相同实现Scipy

实现一个IIR滤波器的设计 背景 Matlab 设计的滤波器通常封装过于完整,虽然在DSP中能够实现更多功能的滤波器设计但是很难实现Python端口的实现。 我们以一段原始的生物电信号EEG信号进行处理。 EEG信号 1.信号获取 EEG信号通常通过头皮电极,经过多通道采样芯片采样,将获…

mysql面试题八(SQL语句)

目录 1.SQL 基本组成部分 常用操作示例 创建表 插入数据 查询数据 更新数据 删除数据 创建索引 授予用户权限 2.常见的聚合查询 1. 计数(COUNT) 2. 求和(SUM) 3. 平均值(AVG) 4. 最大值&…

使用FPGA实现超前进位加法器

介绍 前面已经向大家介绍过8位逐位进位加法器了,今天向大家介绍4位超前进位加法器。 对于逐位进位加法器来说,计算任意一位的加法运算时,必须等到低位的加法运算结束送来进位才能运行。这种加法器结构简单,但是运算慢。 对于超…

WSL安装-问题解决

WslRegisterDistribution failed with error: 0x8004032d WslRegisterDistribution failed with error: 0x80080005 Error: 0x80080005 ??????? 解决: 1、 winr输入:optionalfeatures.exe 2、打开这两项

钉钉报警的优势在哪里?如何配置钉钉机器人进行报警信息推送?

一、常见的报警方式 1、短信或者电话报警 这样的报警方式更适合高级别的报警提醒,用于处理紧急情况。出现级别不高而又频繁地发送短信会让人产生排斥感,而且电话或者短信的报警方式也存在一定的成本。 2、邮件报警 邮件报警更适用于工作时的提醒&…

支付方式模块代码示例

支付方式模块代码示例 效果展示 <view class"card"><uni-title type"h3" title"支付方式"></uni-title><radio-group change"radioChange"><label class"radio"><view class"zf-t…

ThingsBoard通过规则链使用邮件发送报警信息

1、描述 2、通过规则链路配置发送邮件只需 两步 3、案例 1、基础链路 2、选择变换节点里面的To Email 3、 编辑节点to email 4、 将创建告警与to email链接 5、选择外部节点中的send email 6、配置邮箱相关信息&#xff0c;如过不知道密钥如何获取的&#xff0c;请查看下…

yolo-驾驶行为监测:驾驶分心检测-抽烟打电话检测

在现代交通环境中&#xff0c;随着汽车技术的不断进步和智能驾驶辅助系统的普及&#xff0c;驾驶安全成为了公众关注的焦点之一 。 分心驾驶&#xff0c;尤其是抽烟、打电话等行为&#xff0c;是导致交通事故频发的重要因素。为了解决这一问题&#xff0c;研究人员和工程师们…

JRT质控数据录入

之前有时间做了质控物维护界面&#xff0c;有了维护之后就应该提供可以录入业务数据的功能了&#xff0c;当时给质控物预留了一个“项目批次业务数据”的功能说是业务数据会给每天拷贝维护数据。这次一起补上&#xff0c;展示JRT怎么写质控数据录入的界面。 界面如下&#xff…

【Linux基础】Linux基础概念

目录 前言 浅谈什么是文件&#xff1f; Linux下目录结构的认识及路径 目录结构 路径 家目录 什么是递归式的删除 重定向 输出重定向&#xff1a; 追加重定向&#xff1a; 输入重定向&#xff1a; 命令行管道 shell外壳 为什么需要shell外壳&#xff1f; shell外壳…

智能算法 | Matlab基于CBES融合自适应惯性权重和柯西变异的秃鹰搜索算法

智能算法 | Matlab基于CBES融合自适应惯性权重和柯西变异的秃鹰搜索算法 目录 智能算法 | Matlab基于CBES融合自适应惯性权重和柯西变异的秃鹰搜索算法效果一览基本介绍程序设计参考资料效果一览 基本介绍 Matlab基于CBES融合自适应惯性权重和柯西变异的秃鹰搜索算法 融合自适应…

Linux下SPI设备驱动实验:使用内核提供的读写SPI设备中的数据的函数

一. 简介 前面文章的学习&#xff0c;已经实现了 读写SPI设备中数据的功能。文章如下&#xff1a; Linux下SPI设备驱动实验&#xff1a;验证读写SPI设备中数据的函数功能-CSDN博客 本文来使用内核提供的读写SPI设备中的数据的API函数&#xff0c;来实现读写SPI设备中数据。 …