LeNet-5

news2024/11/19 5:33:55

目录

一、知识点

二、代码

三、查看卷积层的feature map

1. 查看每层信息

​2. show_featureMap.py


背景:LeNet-5是一个经典的CNN,由Yann LeCun在1998年提出,旨在解决手写数字识别问题。

一、知识点

1. iter()+next()

iter():返回迭代器

next():使用next()来获取下一条数据

data = [1, 2, 3]
data_iter = iter(data)
print(next(data_iter))  # 1
print(next(data_iter))  # 2
print(next(data_iter))  # 3

2. enumerate

enumerate(sequence,[start=0]) 函数用于将一个可遍历的数据对象组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。

start--下标起始位置的值。 

data = ['zs', 'ls', 'ww']
print(list(enumerate(data)))
# [(0, 'zs'), (1, 'ls'), (2, 'ww')]

3. torch.no_grad()

在该模块下,所有计算得出的tensor的requires_grad都自动设置为False。

当requires_grad设置为False时,在反向传播时就不会自动求导了,可以节约存储空间。

4. torch.max(input,dim)

input -- tensor类型

dim=0 -- 行比较

dim=1 -- 列比较

import torch

data = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
x = torch.max(data, dim=0)
print(x)
# values=tensor([7., 8., 9.]),
# indices=tensor([2, 2, 2])
x = torch.max(data, dim=1)
print(x)
# values=tensor([3., 6., 9.]),
# indices=tensor([2, 2, 2])

5. torch.eq:对两个张量Tensor进行逐个元素的比较,若相同位置的两个元素相同,则返回True;若不同,返回False。

注意:item返回一个数。

import torch

data1 = torch.tensor([1, 2, 3, 4, 5])
data2 = torch.tensor([2, 3, 3, 9, 5])
x = torch.eq(data1, data2)
print(x)  # tensor([False, False,  True, False,  True])
sum = torch.eq(data1, data2).sum()
print(sum)  # tensor(2)
sum_item = torch.eq(data1, data2).sum().item()
print(sum_item)  # 2

6. squeeze(input,dim)函数

squeeze(0):若第一维度值为1,则去除第一维度

squeeze(1):若第二维度值为2,则去除第二维度

squeeze(-1):去除最后维度值为1的维度

7. unsqueeze(input,dim)

增加大小为1的维度,即返回一个新的张量,对输入的指定位置插入维度 1且必须指明维度。

二、代码

model.py

import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 5)  # output(16,28,28)
        self.pool1 = nn.MaxPool2d(2, 2)  # output(16,14,14)
        self.conv2 = nn.Conv2d(16, 32, 5)  # output(32,10,10)
        self.pool2 = nn.MaxPool2d(2, 2)  # output(32,5,5)
        self.fc1 = nn.Linear(32 * 5 * 5, 120)  # output:120
        self.fc2 = nn.Linear(120, 84)  # output:84
        self.fc3 = nn.Linear(84, 10)  # output:10

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1, 32 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

train.py

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

from model import LeNet

def main():
    # preprocess data
    transform = transforms.Compose([
        # Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
        transforms.ToTensor(),
        # (mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    # 训练集 如果数据集已经下载了,则download=False
    train_data = torchvision.datasets.CIFAR10('./data', train=True, transform=transform, download=False)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=36, shuffle=True, num_workers=0)
    # 验证集
    val_data = torchvision.datasets.CIFAR10('./data', train=False, download=False, transform=transform)
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=10000, shuffle=False, num_workers=0)

    # 返回迭代器
    val_data_iter = iter(val_loader)
    val_image, val_label = next(val_data_iter)

    net = LeNet()
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.001)

    # loop over the dataset multiple times
    for epoch in range(5):
        epoch_loss = 0
        for step, data in enumerate(train_loader, start=0):
            # get the inputs from train_loader;data is a list of[inputs,labels]
            inputs, labels = data
            # 在处理每一个batch时并不需要与其他batch的梯度混合起来累积计算,因此需要对每个batch调用一遍zero_grad()将参数梯度设置为0
            optimizer.zero_grad()
            # 1.forward
            outputs = net(inputs)
            # 2.loss
            loss = loss_function(outputs, labels)
            # 3.backpropagation
            loss.backward()
            # 4.update x by optimizer
            optimizer.step()

            # print statistics
            # 使用item()取出的元素值的精度更高
            epoch_loss += loss.item()
            # print every 500 mini-batches
            if step % 500 == 499:
                with torch.no_grad():
                    outputs = net(val_image)
                    predict_y = torch.max(outputs, dim=1)[1]  # [0]取每行最大值,[1]取每行最大值的索引
                    val_accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)
                    print('[epoch:%d step:%5d] train_loss:%.3f test_accuracy:%.3f' % (
                        epoch + 1, step + 1, epoch_loss / 500, val_accuracy))
                    epoch_loss = 0
    print('Train finished!')

    sava_path = './model/LeNet.pth'
    torch.save(net.state_dict(), sava_path)


if __name__ == '__main__':
    main()

predict.py

import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet

def main():
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),  # CHW格式
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    net = LeNet()
    net.load_state_dict(torch.load('./model/LeNet.pth'))

    image = Image.open('./predict/2.png')  # HWC格式
    image = transform(image)
    image = torch.unsqueeze(image, dim=0)  # 在第0维加一个维度 #[N,C,H,W] N:Batch批处理大小

    with torch.no_grad():
        outputs = net(image)
        predict = torch.max(outputs, dim=1)[1]
    print(classes[predict])


if __name__ == '__main__':
    main()

2.png

 

三、查看卷积层的feature map

1. 查看每层信息

    for i in net.children():
        print(i)

2. show_featureMap.py

import torch
import torch.nn as nn
from model import LeNet
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt

def main():
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),  # CHW格式
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    image = Image.open('./predict/2.png')  # HWC格式
    image = transform(image)
    image = torch.unsqueeze(image, dim=0)  # 在第0维加一个维度 #[N,C,H,W] N:Batch批处理大小

    net = LeNet()
    net.load_state_dict(torch.load('./model/LeNet.pth'))
    conv_weights = []  # 模型权重
    conv_layers = []  # 模型卷积层
    counter = 0  # 模型里有多少个卷积层

    # 1.将卷积层以及对应权重放入列表中
    model_children = list(net.children())
    for i in range(len(model_children)):
        if type(model_children[i]) == nn.Conv2d:
            counter += 1
            conv_weights.append(model_children[i].weight)
            conv_layers.append(model_children[i])

    outputs = []
    names = []
    for layer in conv_layers[0:]:
        # 2.每个卷积层对image进行计算
        image = layer(image)
        outputs.append(image)
        names.append(str(layer))
    # 3.进行维度转换
    print(outputs[0].shape)  # torch.Size([1, 16, 28, 28]) 1-batch 16-channel 28-H 28-W
    print(outputs[0].squeeze(0).shape)  # torch.Size([16, 28, 28]) 去除第0维
    # 将16颜色通道的feature map加起来,变为一张28×28的feature map,sum将所有灰度图映射到一张
    print(torch.sum(outputs[0].squeeze(0), 0).shape)  # torch.Size([28, 28])

    processed_data = []
    for feature_map in outputs:
        feature_map = feature_map.squeeze(0)  # torch.Size([16, 28, 28])
        gray_scale = torch.sum(feature_map, 0)  # torch.Size([28, 28])
        # 取所有灰度图的平均值
        gray_scale = gray_scale / feature_map.shape[0]
        processed_data.append(gray_scale.data.numpy())

    # 4.可视化特征图
    figure = plt.figure()
    for i in range(len(processed_data)):
        x = figure.add_subplot(1, 2, i + 1)
        x.imshow(processed_data[i])
        x.set_title(names[i].split('(')[0])
    plt.show()


if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1017971.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【操作系统】聊聊C10K

什么是C10K问题 C10K 就是 Client 10000 问题,即“在同时连接到服务器的客户端数量超过 10000 个的环境中,即便硬件性能足够, 依然无法正常提供服务。 其实说白了就是并发请求1W个请求 同时进行连接服务端,服务端可以支撑服务。…

Linux系统之安装uptime-kuma服务器监控面板

Linux系统之安装uptime-kuma服务器监控面板 一、uptime-kuma介绍1.1 uptime-kuma简介1.2 uptime-kuma特点 二、本次实践环境介绍2.1 环境规划2.2 本次实践介绍2.3 环境要求 三、检查本地环境3.1 检查本地操作系统版本3.2 检查系统内核版本3.3 检查系统是否安装Node.js 四、部署…

post更新,put相当于删除重新增一条

索引数据 //删除后新增 PUT my_dynamic_temp/_doc/1 { “name”:“test”, “class”:“1204” } //覆盖更新 POST my_dynamic_temp/_update/1 { “doc”: { “name”:“test”, “class”:“1203”, “pernum”:“998” } }

springboot 集成mybatis-plus的使用

一、在spring boot中配置mybatis-plus 1、创建一个spring boot项目&#xff0c;注意勾选mysql 2、在pom.xml文件中添加mybatis-plus的依赖包 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0&qu…

瑞芯微RK3568:烧录系统

烧录系统 文章目录 烧录系统windowsLinuxupgrade_tool 工具烧写烧写 update.img擦除操作使用 rkflash.sh 脚本烧写 编译Linux_SDK后得到多个镜像文件 windows Windows 下通过瑞芯微开发工具&#xff08;RKDevTool&#xff09; 来烧写镜像。 Loader parameter uboot …

狂神docker

狂神说 docker 参考文章 -----docker 概述 docker 为什么会出现&#xff1f;–环境部署麻烦&#xff0c;两套环境&#xff08;开发-运维&#xff09; 我的电脑可以运行&#xff0c;到你那就不可用。 开发即运维–开发打包部署上线一条龙 环境配置十分麻烦&#xff0c;机器部署…

Spring Social微信登录

微信登录的appId获得可在微信开放平台申请&#xff0c;以下用测试号 1、完成WeixinProperties 用测试账号登录 public class WeixinProperties {private String appId "wxd99431bbff8305a0";private String appSecret "60f78681d063590a469f1b297feff3c4&q…

基于SSM+Vue的医学生在线学习交流平台

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用Vue技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…

【数字人】1、SadTalker | 使用语音驱动单张图片合成视频(CVPR2023)

Sad Talker&#xff1a;使用一张图片和一段语音来生成口型和头、面部视频 论文&#xff1a;SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation 代码&#xff1a;https://github.com/Winfredy/SadTalker …

Linux命令200例:dip用于用户与远程主机建立通信连接

&#x1f3c6;作者简介&#xff0c;黑夜开发者&#xff0c;CSDN领军人物&#xff0c;全栈领域优质创作者✌。CSDN专家博主&#xff0c;阿里云社区专家博主&#xff0c;2023年6月csdn上海赛道top4。 &#x1f3c6;数年电商行业从业经验&#xff0c;历任核心研发工程师&#xff0…

Stable DIffusion 炫酷应用 | AI嵌入艺术字+光影光效

目录 1 生成AI艺术字基本流程 1.1 生成黑白图 1.2 启用ControlNet 参数设置 1.3 选择大模型 写提示词 2 不同效果组合 2.1 更改提示词 2.2 更改ControlNet 2.2.1 更改模型或者预处理器 2.2.2 更改参数 3. 其他应用 3.1 AI光影字 本节需要用到ControlNet&#xff0c;可…

6.前端·新建子模块与开发(常规开发)

文章目录 学习资料常规开发创建组件与脚本菜单创建-新增自定义图标菜单创建-栏目创建 学习资料 https://www.bilibili.com/video/BV13g411Y7GS?p12&vd_sourceed09a620bf87401694f763818a31c91e 常规开发 创建组件与脚本 首先新建前端的目录结构&#xff0c;属于自己业…

CTF 全讲解:[SWPUCTF 2022 新生赛]webdog1__start

文章目录 参考环境题目learning.php信息收集isset()GET 请求查询字符串全局变量 $_GET MD5 绕过MD5韧性脆弱性 md5()弱比较隐式类型转换字符串连接数学运算布尔判断 相等运算符 MD5 绕过科学计数法前缀 0E 与 0e绕过 start.php信息收集头部检索 f14g.php信息收集 探秘 F1l1l1l1…

Springboot 实践(18)Nacos配置中心参数自动刷新测试

前文讲解了Nacos 2.2.3配置中心的服务端的下载安装&#xff0c;和springboot整合nacos的客户端。Springboot整合nacos关键在于使用的jar版本要匹配&#xff0c;文中使用版本如下&#xff1a; ☆ springboot版本: 2.1.5.RELEASE ☆ spring cloud版本 Greenwich.RELEASE ☆ sp…

Python 算数运算符

视频版教程 Python3零基础7天入门实战视频教程 Python支持所有的基本算术运算符&#xff0c;这些算术运算符用于执行基本的数学运算&#xff0c;如加、减、乘、除和求余等。下面是7个基本的算术运算符。 以下&#xff0c;假设变量a为10&#xff0c;变量b为21&#xff1a; 实…

OpenCV之YOLOv3目标检测

&#x1f482; 个人主页:风间琉璃&#x1f91f; 版权: 本文由【风间琉璃】原创、在CSDN首发、需要转载请联系博主&#x1f4ac; 如果文章对你有帮助、欢迎关注、点赞、收藏(一键三连)和订阅专栏哦 目录 前言 一、预处理 1.获取分类名 2.获取输出层名称 3.图像尺度变换 二…

【JavaSE笔记】初识Java

一、前言 Java是一种非常优秀的程序设计语言&#xff0c;它具有令人赏心悦目的语法和易于理解的语义。 本文将通过一个简单的Java程序&#xff0c;介绍Java的一些基础内容。 二、Java基本结构 1、简单的Java程序 从最简单的一个Java程序开始逐渐了解Java语言。 以下是一段…

数学建模——微分方程介绍

一、基础知识 1、一阶微分方程 称为一阶微分方程。y(x0)y0为定解条件。 其常规求解方法&#xff1a; &#xff08;1&#xff09;变量分离 再两边积分就可以求出通解。 &#xff08;2&#xff09;一阶线性求解公式 通解公式&#xff1a; 有些一阶微分方程需要通过整体代换…

echarts的折线图,在点击图例后,提示出现变化,不报错。tooltip的formatter怎么写

在点击图例的年后&#xff0c;提示框会相应的变化&#xff0c;多选和单选都会响应变化。tooptip的重度在formatter tooltip:{show:true,trigger:"axis",alwaysShowContent:true,triggerOn:"mousemove",textStyle:{color:"#fff"},backgroundColor…

结构体-时间的计算

任务描述 本关任务需要你编写函数计算一个时间之前“xx小时xx分xx秒”的时间是多少。 以24小时制的格式记录当前时间&#xff0c;譬如“09:19:52”&#xff0c;表示上午9点19分52秒&#xff0c;则“1小时20分30秒”前的时间应该是“同一天”的“07:59:22”。 提示&#xff1a;…