Alnet网络分析与demo实例

news2024/12/24 18:04:47

参考自 

  • up主的b站链接:霹雳吧啦Wz的个人空间-霹雳吧啦Wz个人主页-哔哩哔哩视频
  • 这位大佬的博客 Fun'_机器学习,pytorch图像分类,工具箱-CSDN博客

数据集下载

http://download.tensorflow.org/example_images/flower_photos.tgz

包含 5 中类型的花,每种类型有600~900张图像不等。

训练集与测试集划分

由于此数据集不像 CIFAR10 那样下载时就划分好了训练集和测试集,因此需要自己划分。具体操作可以看b站那个up 的视频,这里不再赘述

AlexNet详解

重点关注它和上一个模型不一样的地方

1.首次用GPU进行加速训练,上图上下两部分是完全一样的,这是因为用了两块GPU加速训练

2.使用了Relu函数

3.用了Dropout随即失活部分神经元 以减少过拟合

具体网络分析:

Conv1

输入:224*224*3

卷积:11*11*3    48个

  • padding = [1, 2] (左上围加半圈0,右下围加2倍的半圈0
  • stride = 4

输出:(224-11+3)/4+1 = 55   55*55*48

Maxpool1

输入:55*55*48

池化层:

  • kernel_size = 3
  • padding = 0
  • stride = 2

输出:(55-3)/2+1 = 27  27*27*48

Conv2

输入:27*27*48

卷积:5*5*48  128个

  • padding = [2, 2]
  • stride = 1

输出:(27-5+4)/1+1 = 27    27*27*128

Maxpool2

输入:27*27*128

  • 池化层:(只改变尺寸,不改变深度channel)
    • kernel_size = 3
    • padding = 0
    • stride = 2

输出:13*13*128

Conv3

输入:13*13*128

  • 卷积层:
    • 3*3 192个
    • padding = [1, 1]
    • stride = 1

输出:13*13*192

Conv4

输入:13*13*192

  • 卷积层:
    • 3*3 192个
    • padding = [1, 1]
    • stride = 1

输出: 13*13**192

Conv5

输入:13*13*192

  • 卷积层:
    • 3*3*128
    • padding = [1, 1]
    • stride = 1

输出:13*13*128

Maxpool3

输入:13*13*128

  • 池化层:
    • kernel_size = 3
    • padding = 0
    • stride = 2

输出:6*6*128

FC1、FC2、FC3

Maxpool3 → (6*6*256) → FC1 → 2048 → FC2 → 2048 → FC3 → 1000

总结

分析可以发现,除 Conv1 外,AlexNet 的其余卷积层都是在改变特征矩阵的深度,而池化层则只改变(减小)其尺寸。

构建模型;

import torch.nn as nn
import torch

class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()
        # 用nn.Sequential()将网络打包成一个模块,精简代码
        self.features = nn.Sequential(   # 卷积层提取图像特征
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True), 									# 直接修改覆盖原值,节省运算内存
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(   # 全连接层对图像分类
            nn.Dropout(p=0.5),			   # Dropout 随机失活神经元,默认比例为0.5
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self._initialize_weights()
            
	# 前向传播过程
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)	# 展平后再传入全连接层
        x = self.classifier(x)
        return x
        
	# 网络权重初始化,实际上 pytorch 在构建网络时会自动初始化权重
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):                            # 若是卷积层
                nn.init.kaiming_normal_(m.weight, mode='fan_out',   # 用(何)kaiming_normal_法初始化权重
                                        nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)                    # 初始化偏重为0
            elif isinstance(m, nn.Linear):            # 若是全连接层
                nn.init.normal_(m.weight, 0, 0.01)    # 正态分布初始化
                nn.init.constant_(m.bias, 0)          # 初始化偏重为0

Dropout  :  发现都有具体的api 想具体研究的可以去看看它的函数是怎么写的

数据预处理 - 图像增强

需要注意的是,对训练集的预处理,多了随机裁剪和水平翻转这两个步骤。可以起到扩充数据集的作用,增强模型泛化能力

data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),       # 随机裁剪,再缩放成 224×224
                                 transforms.RandomHorizontalFlip(p=0.5),  # 水平方向随机翻转,概率为 0.5, 即一半的概率翻转, 一半的概率不翻转
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),

    "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

训练:

import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json

# 预处理
data_transform = transforms.Compose(
    [transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# load image
img = Image.open("蒲公英.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)

# read class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

# create model
model = AlexNet(num_classes=5)
# load model weights
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))

# 关闭 Dropout
model.eval()
with torch.no_grad():
    # predict class
    output = torch.squeeze(model(img))     # 将输出压缩,即压缩掉 batch 这个维度
    predict = torch.softmax(output, dim=0)
    predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

预测:

import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json

# 预处理
data_transform = transforms.Compose(
    [transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# load image
img = Image.open("蒲公英.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)

# read class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

# create model
model = AlexNet(num_classes=5)
# load model weights
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))

# 关闭 Dropout
model.eval()
with torch.no_grad():
    # predict class
    output = torch.squeeze(model(img))     # 将输出压缩,即压缩掉 batch 这个维度
    predict = torch.softmax(output, dim=0)
    predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

打印出预测的标签以及概率值:

dandelion 0.7221569418907166

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1333410.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

嵌入式开发——PWM高级定时器

学习目标 加强掌握PWM开发流程理解定时器与通道的关系掌握多通道配置策略掌握互补PWM配置策略掌握定时器查询方式掌握代码抽取优化策略掌握PWM调试方式学习内容 需求 点亮8个灯,采用pwm的方式。 定时器 通道 <

Netty-4-网络编程模式

我们经常听到各种各样的概念——阻塞、非阻塞、同步、异步&#xff0c;这些概念都与我们采用的网络编程模式有关。 例如&#xff0c;如果采用BIO网络编程模式&#xff0c;那么程序就具有阻塞、同步等特质。 诸如此类&#xff0c;不同的网络编程模式具有不同的特点&#xff0c…

【大数据】NiFi 的基本使用

NiFi 的基本使用 1.NiFi 的安装与使用1.1 NiFi 的安装1.2 各目录及主要文件 2.NiFi 的页面使用2.1 主页面介绍2.2 面板介绍 3.NiFi 的工作方式3.1 基本方式3.2 选择处理器3.3 组件状态3.4 组件的配置3.4.1 SETTINGS&#xff08;通用配置&#xff09;3.4.2 SCHEDULING&#xff0…

博弈论:理解决策背后的复杂动态

1.基本概念 博弈论是一门研究具有冲突和合作元素决策制定的数学理论。它不仅适用于经济学&#xff0c;还广泛应用于政治学、心理学、生物学等领域。博弈论的核心在于分析参与者&#xff08;称为“玩家”&#xff09;在特定情境下的策略选择&#xff0c;以及这些选择如何影响最…

工资发放 C语言xdoj92

题目描述&#xff1a; 公司财务要发工资现金&#xff0c;需要提前换取100元、50元、20元、10元、5元和1元的人民币&#xff0c; 请输入工资数&#xff0c;计算张数最少情况下&#xff0c;各自需要多少张。 输入格式&#xff1a;共一行&#xff0c;输入一个正整数。 输出格式&am…

游戏软件提示d3dcompiler_43.dll的五个解决方法,亲测靠谱

在使用电脑进行工作&#xff0c;玩游戏的时候&#xff0c;我们常常会遇到一些错误提示&#xff0c;其中之一就是“D3DCompiler_43.dll丢失”的提示。D3DCompiler_43.dll是一个非常重要的动态链接库文件。它是由DirectX SDK提供的&#xff0c;用于编译和优化DirectX着色器代码的…

50 个具有挑战性的概率问题 [04/50]:尝试直至首次成功

一、说明 你好&#xff0c;我最近对与概率相关的问题产生了兴趣。我偶然发现了 Frederick Mosteller 所著的《五十个具有挑战性的概率问题及其解决方案》这本书。我认为创建一个系列来讨论这些可能作为面试问题出现的迷人问题会很有趣。每篇文章仅包含 1 个问题&#xff0c;使其…

【Qt之Quick模块】6. QML语法详解_1 基础语法与三种导入语句

前言 通过以上1-5文档的介绍&#xff0c;Quick与QML的概念及QML语法、类型、文件作用等已叙述个大概&#xff0c;接下来是对QML语法进行展开来说。 其实&#xff0c;学习任何一门语言或者做任何一件事情&#xff0c;并不用一开始就要求尽善尽美&#xff0c;做个无懈可击&…

【Python从入门到进阶】45、Scrapy框架核心组件介绍

接上篇《44、Scrapy的基本介绍和安装》 上一篇我们学习了Scrapy框架的基础介绍以及环境的搭建&#xff0c;本篇我们来学习一下Scrapy框架的核心组件的使用。 下面的核心组件的介绍&#xff0c;仍是基于这幅图的机制&#xff0c;大家可以再回顾一下&#xff1a; 注&#xff1a;…

数学的雨伞下:理解世界的乐趣

这本书没有一个公式&#xff0c;却讲透了数学的本质&#xff01; 《数学的雨伞下&#xff1a;理解世界的乐趣》。一本足以刷新观念的好书&#xff0c;从超市到对数再到相对论&#xff0c;娓娓道来。对于思维空间也给出了一个更容易理解的角度。 作者&#xff1a;米卡埃尔•洛奈…

Ubuntu20.04纯命令配置PCL(点云库)

Ubuntu20.04纯命令配置PCL&#xff08;点云库&#xff09; 最近在学习点云库&#xff08;PCL&#xff09;的使用&#xff0c;第一步就是在自己的电脑安装配置PCL。 首先&#xff0c;对于ubuntu 16.04以上版本&#xff0c;可以直接使用命令进行安装&#xff0c;新建好一个文件夹…

分析冒泡排序

#include <stdio.h> int main() { int arr[10] { 2,5,1,3,6,4,7,8,9,0 }; int i 0; int j 0; for( i 0 ;i < sizeof(arr)/sizeof(arr[0]) - 1 ; i) 红色的代表数组一共有n个元素&#xff0c;则需要n-1次 { for( j 0 // 这里可以让数组从哪一…

人工智能轨道交通行业周刊-第69期(2023.12.11-12.24)

本期关键词&#xff1a;集装箱智能管理、智慧工地、智能应急机器人、车辆构造、大模型推理 1 整理涉及公众号名单 1.1 行业类 RT轨道交通人民铁道世界轨道交通资讯网铁路信号技术交流北京铁路轨道交通网上榜铁路视点ITS World轨道交通联盟VSTR铁路与城市轨道交通RailMetro轨…

UG在实体上刻字

当我们想在实体上显示文字的时候&#xff0c;需要用到文本命令&#xff0c;菜单-插入-曲线-文本 文本命令中的具体用法 当在曲线和平面上显示文字的时候&#xff0c;只需要输入文字&#xff0c;并选中相应的曲线或者平面即可 当在曲面上显示文字的时候&#xff0c;设置如下 当文…

Vue3中的混入(mixins)

本文主要介绍Vue3中的混入&#xff08;mixins&#xff09;。 目录 一、在普通写法中使用混入&#xff1a;二、在setup写法中使用混入&#xff1a; 混入是Vue中一种用于在组件中共享可复用功能的特性。在Vue 3中&#xff0c;混入的使用方式有所改变。 一、在普通写法中使用混入…

c++11--类型自动推导

1.自动类型推断 1.1.auto a.auto声明变量的类型必须由编译器在编译时期推导而得。 int main(){double foo();auto x 1;//x类型为intauto y foo();// y类型为doubleauto z;// errreturn 0; }b.auto声明得变量必须被初始化。 c.针对指针和引用 推导类型是指针类型时&#xff0…

抠图、换背景、正装图证件照制作方法

本篇灵感是最近又要使用别的底色的正装照的图片。上学的时候&#xff0c;要求证件照的底色是蓝底、党员档案里要求图片的底色是红底、 将来上班的证件照要求是白底&#xff0c;并且无论是考研还是找工作都是制作简历的时候&#xff0c;根据简历的样板不同需要更换不同的底色。 …

HBase 集群搭建

文章目录 安装前准备兼容性官方网址 集群搭建搭建 Hadoop 集群搭建 Zookeeper 集群解压缩安装配置文件高可用配置分发 HBase 文件 服务的启停启动顺序停止顺序 验证进程查看 Web 端页面 安装前准备 兼容性 1&#xff09;与 Zookeeper 的兼容性问题&#xff0c;越新越好&#…

jQuery实现响应式瀑布流 - 实现灯箱效果

在这之前&#xff0c;有写过一篇关于实现瀑布流的文章&#xff0c;后期有人留言提出需要添加灯箱效果的功能&#xff0c;所以这次则讲述下如何实现此功能。由于该篇接上篇写的&#xff1a;jQuery实现响应式瀑布流效果&#xff08;jQueryflex&#xff09;_jquery瀑布流插件-CSDN…

驾驶未来:百度Apollo自动驾驶技术的探索与实践(文末赠送apollo周边)

&#x1f3ac; 鸽芷咕&#xff1a;个人主页 &#x1f525; 个人专栏:《linux深造日志》《粉丝福利》 ⛺️生活的理想&#xff0c;就是为了理想的生活! ⛳️ 粉丝福利活动 ✅参与方式&#xff1a;通过连接报名观看课程&#xff0c;即可免费获取精美周边 ⛳️活动链接&#xf…