[深度学习]卷积神经网络CNN

news2025/1/11 14:58:37

1 图像基础知识

import numpy as np
import matplotlib.pyplot as plt
# 图像数据
#img=np.zeros((200,200,3))
img=np.full((200,200,3),255)
# 可视化
plt.imshow(img)
plt.show()
# 图像读取
img=plt.imread('img.jpg')
plt.imshow(img)
plt.show()

2 CNN概述

  • 卷积层conv+relu
  • 池化层pool
  • 全连接层FC/Linear

3 卷积层

 

import matplotlib.pyplot as plt
import torch
from torch import nn
# 数据
img=plt.imread('img.jpg')
print(img.shape)
# conv
img=torch.tensor(img).permute(2,0,1).unsqueeze(0).to(torch.float32)
conv=nn.Conv2d(in_channels=3,out_channels=5,kernel_size=(3,5),stride=(1,2),padding=2)
# 处理
fm=conv(img)
print(fm.shape)

4 池化层

  • 下采样:样本减少
  • 上采样(深采样):样本增多
  • 最大池化相交平均池化使用更多
  • 通常kernel_size=(3,3),stride=(2,2),padding=(自定义)

import torch
from torch import nn
# 创建数据
torch.random.manual_seed(22)
data=torch.randint(0,10,[1,3,3],dtype=torch.float32)
print(data)

# 最大池化
pool=nn.MaxPool2d(kernel_size=(2,2),stride=(1,1),padding=0)
print(pool(data))

# 平均池化
pool=nn.AvgPool2d(kernel_size=(2,2),stride=(1,1),padding=0)
print(pool(data))

5 图像分类案例(LeNet)

import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torchvision.transforms import Compose
import matplotlib.pyplot as plt
from torchsummary import summary
from torch import optim
from torch.utils.data import DataLoader
# 获取数据
train_dataset=CIFAR10(root='cnn_net',train=True,transform=Compose([ToTensor()]),download=True)
test_dataset=CIFAR10(root='cnn_net',train=False,transform=Compose([ToTensor()]),download=True)
print(train_dataset.class_to_idx)
print(train_dataset.data.shape)
print(test_dataset.data.shape)

plt.imshow(test_dataset.data[100])
plt.show()
print(test_dataset.targets[100])

# 模型构建
class ImageClassification(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)
        self.conv2=nn.Conv2d(in_channels=6,out_channels=16,kernel_size=3,stride=1,padding=0)
        self.pool1=nn.MaxPool2d(kernel_size=2,stride=2)
        self.pool2=nn.MaxPool2d(kernel_size=2,stride=2)
        self.fc1=nn.Linear(in_features=576,out_features=120)
        self.fc2=nn.Linear(in_features=120,out_features=84)
        self.out=nn.Linear(in_features=84,out_features=10)
    def forward(self,x):
         x=self.pool1(torch.relu(self.conv1(x)))
         x=self.pool2(torch.relu(self.conv2(x)))
         x=x.reshape(x.size(0),-1)
         x=torch.relu(self.fc1(x))
         x=torch.relu(self.fc2(x))
         out=self.out(x)
         return out

model=ImageClassification()
summary(model,(3,32,32),batch_size=1)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1             [1, 6, 30, 30]             168
         MaxPool2d-2             [1, 6, 15, 15]               0
            Conv2d-3            [1, 16, 13, 13]             880
         MaxPool2d-4              [1, 16, 6, 6]               0
            Linear-5                   [1, 120]          69,240
            Linear-6                    [1, 84]          10,164
            Linear-7                    [1, 10]             850
================================================================
Total params: 81,302
Trainable params: 81,302
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.08
Params size (MB): 0.31
Estimated Total Size (MB): 0.40
----------------------------------------------------------------
# 模型训练
optimizer=optim.Adam(model.parameters(),lr=0.0001,betas=[0.9,0.99])
error=nn.CrossEntropyLoss()
epoches=10
for epoch in range(epoches):
    dataloader=DataLoader(train_dataset,batch_size=2,shuffle=True)
    loss_sum=0
    num=0.1
    for x,y in dataloader:
        y_=model(x)
        loss=error(y_,y)
        loss_sum+=loss.item()
        num+=1
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(loss_sum/num)
# 模型保存
torch.save(model.state_dict(),'model.pth')
# 模型预测
test_dataloader=DataLoader(test_dataset,batch_size=8,shuffle=False)
model.load_state_dict(torch.load('model.pth',weights_only=False))
corr=0
num=0
for x,y in test_dataloader:
    y_=model(x)
    out=torch.argmax(y_,dim=-1)
    corr+=(out==y).sum()
    num+=len(y)
    
print(corr/num)
    

优化方向

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2169559.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

实战OpenCV之色彩空间转换

基础入门 色彩空间是描述颜色的一种数学模型,它定义了颜色的三个或更多维度,比如:亮度、色相和饱和度等。最著名的色彩空间之一是RGB,它基于人眼对光的感知原理,通过红、绿、蓝三种基本颜色的不同强度组合来表示几乎所…

【HarmonyOS】鸿蒙仿iOS线性渐变实现

【HarmonyOS】仿照IOS中可以通过输入start(0,0),end(1,1)获取角度到.linearGradient,从而实现左上到右下渐变 class Point {x: number 0y: number 0 }Entry Component struct Page…

开源链动 2+1 模式 S2B2C 商城小程序:激活 KOC,开启商业新征程

摘要:本文深入探讨了 KOC 在立体连接中的重要性,以及如何通过开源链动 21 模式 S2B2C 商城小程序发现和找到更多的 KOC。强调了历史积累强关系和快速强化强关系的方法,并阐述了该商城小程序在推动商业发展中的关键作用。 一、引言 在当今竞争…

mysql 内存被打满记录

一:早上收到报警:提示:您的云数据库RDS的1个实例因存储空间满将被锁定,请关注实例的存储空间使用情况,可通过存储扩容或空间清理解除锁定。后续查看错误日志如下:磁盘没有空间了 没有多余的空间写binlog和…

2024年下安徽省事业编考试报名流程(电脑)

2024年下安徽省事业编考试报名流程(电脑)

极狐GitLab 17.4 升级指南

GitLab 是一个全球知名的一体化 DevOps 平台,很多人都通过私有化部署 GitLab 来进行源代码托管。极狐GitLab https://dl.gitlab.cn/6y2wxugm 是 GitLab 在中国的发行版,专门为中国程序员服务。可以一键式部署极狐GitLab。 本文分享极狐GitLab 17.4 升级…

【JVM】垃圾释放方式:标记-清除、复制算法、标记-整理、分代回收

文章目录 1. 标记-清除2. 复制算法4. 标记-整理4. 分代回收 把标记为垃圾的对象的内存空间进行释放。主要有三种释放方式 1. 标记-清除 把标记为垃圾的对象,直接释放掉(最朴素的做法) 此时就是把标记为垃圾的对象所对应的内存空间直接释放。…

【机器学习】探索LSTM:深度学习领域的强大时间序列处理能力

目录 🍔 LSTM介绍 🍔 LSTM的内部结构图 2.1 LSTM结构分析 2.2 Bi-LSTM介绍 2.3 使用Pytorch构建LSTM模型 2.4 LSTM优缺点 🍔 小结 学习目标 🍀 了解LSTM内部结构及计算公式. 🍀 掌握Pytorch中LSTM工具的使用. &…

反光柱定位算法-雷达强度数据包

反光柱定位算法-雷达强度数据包 反光柱定位算法-雷达强度数据包 作者: 苏凯 系统环境: 系统:ubuntu20.04 ros1版本: noetic 雷达: sick TM581 强度值标定文件: scanIntensities.txt 部署在环境中的反光柱数据…

类和对象(2)

文章目录 🎯引言👓类和对象(2)1.类的默认成员函数2.构造函数2.1构造函数概念 3.析构函数3.1. **析构函数的定义**3.2. **析构函数的特点** 4.拷贝构造函数4.1. **拷贝构造函数的定义** 5.赋值运算符重载5.1运算符重载5.2赋值运算符重载5.3日期类的实现 &…

smtp-server: 535 Error: authentication faile

问题描述: 在linux服务器上使用 mailx发送邮件时提示:smtp-server: 535 Error: authentication faile 原因:没有配置授权码或者授权码不正确 解决办法:配置授权码(以网易邮箱为例) 1. 进入网易邮箱网页版,打开 POP…

数据中心里全速运行的处理器正在浪费能源

数据中心是耗电大户,运营商一直在努力解决的一个关键问题是如何减少能源和资源消耗。人们已经找到了一些巧妙的解决方案,例如使用非饮用水来冷却设备,但一个显而易见的解决方案似乎被忽略了:启用处理器的各种省电功能。 随着需求的…

进程概念以及进程相关函数的使用

1.进程相关概念 1.1 程序和进程 程序,是指编译好的二进制文件,在磁盘上,不占用系统资源(cpu、内存、打开的文件、设备、锁....) 进程,是一个抽象的概念,与操作系统原理联系紧密。进程是活跃的程序,占用系…

Qt-QGroupBox容器类控件(39)

目录 容器类控件 描述 属性 使用 容器类控件 描述 这个是用来分组的,即把控件分组 使⽤ QGroupBox 实现⼀个带有标题的分组框.可以把其他的控件放到⾥⾯作为⼀组.这样看起来能更好看⼀点 属性 title分组框的标题alignment分组框内部内容的对⻬⽅式flat是否是…

微服务nacos解析部署使用全流程

1、什么是Spring Cloud Spring Cloud是一系列框架的集合。它利用Spring Boot的开发便利性巧妙地简化了分布式系统基础设施的开发,如服务发现注册、配置中心、消息总线、负载均衡、断路器、数据监控等,都可以用Spring Boot的开发风格做到一键启动和部署。…

stm32入门——GPIO输入输出(1)基础理解

最近比较想上进,又不知道要干什么,就来水几篇博客欺骗一下自己。 GPIO全称是:General Purpose Input / Output ,是stm32用于控制输入和输出信号的通用接口。我们用的MCU都有这玩意,比如STM32F103C8T6上有 GPIOA,GPIOB&…

算法葫芦书(笔试面试)

一、特征工程 1.特征归一化:所有特征统一到一个区间内 线性函数归一化(0到1区间)、零均值归一化(均值0,标准差1) 2.类比型特征->数值性特征 序号编码、独热编码、二进制编码(010&#xf…

prd文档编写(to b)

如何编写产品需求文档(PRD) | 人人都是产品经理 (woshipm.com) 一.prd文档编写得目的 PRD文档最为重要的目的就是:协调各个相关角色 PRD就是提高效率的,把各个角色的共识全部写出来,大家都已PRD为最终的工作指导文档…

2:数据结构:列表与元组

目录 2.1 列表的创建与操作 2.1.1 列表的创建 2.1.2 列表的常用操作 2.1.3 列表切片操作 2.2 元组的特点与用法 2.2.1 元组的创建 2.2.2 元组与列表的区别 2.2.3 元组的常用操作 2.3 示例代码与练习 2.3.1 示例代码:列表与元组的基本操作 2.3.2 练习题 文…

ICM20948 DMP代码详解(46)

接前一篇文章:ICM20948 DMP代码详解(45) 上一回讲到了inv_icm20948_setup_compass_akm函数中的以下代码片段: /* Set compass in power down through I2C SLV for compass */result inv_icm20948_execute_write_secondary(s, COM…