机器学习——卷积神经网络

news2025/1/9 1:14:47

卷积神经网络CNN

多层感知机MLP的层数足够,理论上可以用其提取出二位特征,但是毕竟复杂,卷积神经网络就可以更合适的来提取高维的特征。
而卷积其实是一种运算
在这里插入图片描述
二维离散卷积的公式
在这里插入图片描述
可以看成g是一个图像的像素点,f是每个像素点对应的权重,权重越大,重要程度越大,这里的权重f可以根据梯度反向传播的方式训练
在CNN中进行卷积运算的层称为卷积层,层中的权重f被称为卷积核
如果将f进行翻转,得到的参数在位置上是翻转的,对参数数值没有影响。这样的运算称为互相关。

卷积的运算例子

在这里插入图片描述

用卷积神经网络完成图像分类任务

class CNN(nn.Module):

    def __init__(self, num_classes=10):
        super().__init__()
        # 类别数目
        self.num_classes = num_classes
        # Conv2D为二维卷积层,参数依次为
        # in_channels:输入通道
        # out_channels:输出通道,即卷积核个数
        # kernel_size:卷积核大小,默认为正方形
        # padding:填充层数,padding=1表示对输入四周各填充一层,默认填充0
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, 
            kernel_size=3, padding=1)
        # 第二层卷积,输入通道与上一层的输出通道保持一致
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
        # 最大池化,kernel_size表示窗口大小,默认为正方形
        self.pooling1 = nn.MaxPool2d(kernel_size=2)
        # 丢弃层,p表示每个位置被置为0的概率
        # 随机丢弃只在训练时开启,在测试时应当关闭
        self.dropout1 = nn.Dropout(p=0.25)
        
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
        self.pooling2 = nn.MaxPool2d(2)
        self.dropout2 = nn.Dropout(0.25)

        # 全连接层,输入维度4096=64*8*8,与上一层的输出一致
        self.fc1 = nn.Linear(4096, 512)
        self.dropout3 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, num_classes)

    # 前向传播,将输入按顺序依次通过设置好的层
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pooling1(x)
        x = self.dropout1(x)

        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pooling2(x)
        x = self.dropout2(x)

        # 全连接层之前,将x的形状转为 (batch_size, n)
        x = x.view(len(x), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout3(x)
        x = self.fc2(x)
        return x
#%%
batch_size = 64 # 批量大小
learning_rate = 1e-3 # 学习率
epochs = 5 # 训练轮数
np.random.seed(0)
torch.manual_seed(0)

# 批量生成器
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

model = CNN()
# 使用Adam优化器
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# 使用交叉熵损失
criterion = F.cross_entropy

# 开始训练
for epoch in range(epochs):
    losses = 0
    accs = 0
    num = 0
    model.train() # 将模型设置为训练模式,开启dropout
    with tqdm(trainloader) as pbar:
        for data in pbar:
            images, labels = data
            outputs = model(images) # 获取输出
            loss = criterion(outputs, labels) # 计算损失
            # 优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # 累积损失
            num += len(labels)
            losses += loss.detach().numpy() * len(labels)
            # 精确度
            accs += (torch.argmax(outputs, dim=-1) \
                == labels).sum().detach().numpy()
            pbar.set_postfix({
                'Epoch': epoch, 
                'Train loss': f'{losses / num:.3f}', 
                'Train acc': f'{accs / num:.3f}'
            })
    
    # 计算模型在测试集上的表现
    losses = 0
    accs = 0
    num = 0
    model.eval() # 将模型设置为评估模式,关闭dropout
    with tqdm(testloader) as pbar:
        for data in pbar:
            images, labels = data
            outputs = model(images)
            loss = criterion(outputs, labels)
            num += len(labels)
            losses += loss.detach().numpy() * len(labels)
            accs += (torch.argmax(outputs, dim=-1) \
                == labels).sum().detach().numpy()
            pbar.set_postfix({
                'Epoch': epoch, 
                'Test loss': f'{losses / num:.3f}', 
                'Test acc': f'{accs / num:.3f}'
            })
# 该工具包中有AlexNet、VGG等多种训练好的CNN网络
from torchvision import models 
import copy

# 定义图像处理方法
transform = transforms.Resize([512, 512]) # 规整图像形状

def loadimg(path):  
    # 加载路径为path的图像,形状为H*W*C
    img = plt.imread(path)
    # 处理图像,注意重排维度使通道维在最前
    img = transform(torch.tensor(img).permute(2, 0, 1))
    # 展示图像
    plt.imshow(img.permute(1, 2, 0).numpy())
    plt.show()
    # 添加batch size维度
    img = img.unsqueeze(0).to(dtype=torch.float32)
    img /= 255 # 将其值从0-255的整数转换为0-1的浮点数
    return img

content_image_path = os.path.join('style_transfer', 'content', '04.jpg')
style_image_path = os.path.join('style_transfer', 'style.jpg')

# 加载内容图像
print('内容图像')
content_img = loadimg(content_image_path)
# 加载风格图像
print('风格图像') 
style_img = loadimg(style_image_path)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1806880.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从反向传播过程看激活函数与权重初始化的选择对深度神经网络稳定性的影响

之前使用深度学习时一直对各种激活函数和权重初始化策略信手拈用,然而不能只知其表不知其里。若想深入理解为何选择某种激活函数和权重初始化方法卓有成效还是得回归本源,本文就从反向传播的计算过程来按图索骥。 为了更好地演示深度学习中的前向传播和…

Modbus主站和从站的区别

Modbus主站,从站 在工业自动化领域,Modbus是一种常用的通信协议,用于设备之间的数据交换。在Modbus通信中,主站和从站是两个关键的角色。了解主站和从站之间的区别对正确配置和管理Modbus网络至关重要。 Modbus主站的特点和功能 1.通信请求发…

文献阅读:Solving olympiad geometry without human demonstrations

文献阅读:Solving olympiad geometry without human demonstrations 1. 文章简介2. 方法介绍 1. Overview2. Symbolic deduce3. Language Model4. 联合使用 3. 实验考察 & 结论 1. 基础实验考察2. 结果分析3. 样例展示 4. 总结 & 思考 文献链接&#xff1a…

《web应用技术》第十次作业

将自己的项目改造为基于vue-cli脚手架的项目&#xff0c;页面有导航&#xff0c;学会使用router。 <el-aside width"200px" style"background-color: aliceblue;"> <el-menu :default-openeds"[1]" style"background-color:rgb(1…

关于Redis中哨兵(Sentinel)

Redis Sentinel 相关名词解释 名词 逻辑结构 物理结构 主节点 Redis 主服务 一个独立的 redis-server 进程 从节点 Redis 从服务 一个独立的 redis-server 进程 Redis 数据节点 主从节点 主节点和从节点的进程 哨兵节点 监控 Redis 数据节点的节点 一个独立的 re…

Cyber Weekly #10

赛博新闻 1、最强开源大模型面世&#xff1a;阿里发布Qwen2 6月7日凌晨&#xff0c;阿里巴巴通义千问团队发布了Qwen2系列开源模型。该系列模型包括5个尺寸的预训练和指令微调模型&#xff1a;Qwen2-0.5B、Qwen2-1.5B、Qwen2-7B、Qwen2-57B-A14B以及Qwen2-72B。据Qwen官方博客…

开发没有尽头,尽力既是完美

最近遇到了一些难题&#xff0c;开发系统总有一些地方没有考虑周全&#xff0c;偏偏用户使用的时候“完美复现”了这个隐藏的Bug...... 讲道理创业一年之久为了生存&#xff0c;我一直都有在做复盘&#xff0c;复盘的核心就是&#xff1a;如何提升营收、把控开发质量&#xff0…

嵌入式仪器模块:示波器模块和自动化测试软件

示波器模块 • 32 位分辨率 • 125 MSPS 采样率 • 支持单通道/双通道模块选择 • 低速模式可实现实时功率分布和整机功率检测 • 高速模式可实现信号分析和上电时序测量 应用场景 • 抓取并分析波形的周期、幅值、异常信号等指标 • 电源纹波与噪声分析 • 信号模板比…

vue28:组件化开发和根组件

简单写个点击事件 <template> <div class"app"><div class"box" click"fn"></div></div> </template><script> export default {//导出当前组件的配置项//里面可以提供 data methods computed wat…

SpringBoot: 启动流程和类装载

前面我们学过Spring定制了自己的可执行jar&#xff0c;将真正执行时需要的类和依赖放到BOOT-INF/classes、BOOT-INF/lib来&#xff0c;为了能够识别这些为止的源文件&#xff0c;Spring定制了自己类加载器&#xff0c;本节我们来讲解这个类加载器。本节涉及的内容主要包括: Sp…

web端中使用vue3 实现 移动端的上拉滚动加载功能

需要再web端实现上拉加载 纯属web端的东西 类似这样的功能效果 能够在web端实现滚动分页 overflow-y: scroll;首先给这个大盒子 一个 css 样式 支持滚动 再给固定高度 这个盒子里的内容就能立马滚动起来 给这个盒子一个ref 的属性 以及 有原生滚动事件 scroll const handle…

Wireshark TS | 应用传输丢包问题

问题背景 仍然是来自于朋友分享的一个案例&#xff0c;实际案例不难&#xff0c;原因也就是互联网线路丢包产生的重传问题。但从一开始只看到数据包截图的判断结果&#xff0c;和最后拿到实际数据包的分析结果&#xff0c;却不是一个结论&#xff0c;方向有点跑偏&#xff0c;…

gdb 【Linux】

程序发布方式&#xff1a;  1、debug版本&#xff1a;程序会被加入调试信息&#xff0c;以便于进行调试。  2、release版本&#xff1a;不添加任何调试信息&#xff0c;是不可调试   确定一个可执行程序是debug&#xff0c;还是release [cxqiZ7xviiy0goapxtblgih6oZ test_g…

张大哥笔记:经济下行,这5大行业反而越来越好

现在人们由于生活压力大&#xff0c;于是就干脆降低自己的欲望&#xff0c;只要不是必需品就不买了&#xff0c;自然而然消费也就降低了&#xff0c;消费降级未必是不好的现象&#xff01; 人的生物本能是趋利避害&#xff0c;追求更好的生存和发展空间&#xff0c;回避对自己有…

Vue3+Vite报错:vite忽略.vue扩展名 Failed to resolve import ..... Does the file exist?

Vue3Vite报错&#xff1a;vite忽略.vue扩展名 Failed to resolve import … Does the file exist? 先看报错&#xff1a; 分析原因 原因是我们没有写后缀名 建议你在你的vite.config.js中加上如下配置 import { defineConfig } from "vite"; import vue from &qu…

华为坤灵路由器配置SSH

配置SSH服务器的管理网口IP地址。 <HUAWEI> system-view [HUAWEI] sysname SSH Server [SSH Server] interface meth 0/0/0 [SSH Server-MEth0/0/0] ip address 10.248.103.194 255.255.255.0 [SSH Server-MEth0/0/0] quit 在SSH服务器端生成本地密钥对。 [SSH Server…

SpringAI(二)

大模型:具有大规模参数和复杂计算结构的机器学习模型.通常由深度神经网络构建而成,拥有数十亿甚至数千亿个参数.其设计目的在于提高模型的表达能力和预测性能,应对复杂的任务和数据. SpringAI是一个AI工程领域的应用程序框架 大概推出时间是2023年7月份(不确定) 目的是将S…

单片机数码管时钟电路的设计

5 调试 数码管的引脚1&#xff5e;4&#xff0c;a&#xff5e;g以及小数点的排列都不是连续的&#xff0c;这就意味着难免需要飞线。数码管是分共阴和共阳的&#xff0c;起初我错把原理图中的共阳数码管当成了共阴数码管&#xff0c;焊上去了之后才发现&#xff0c;为了避免拆卸…

如何下载BarTender软件及详细安装步骤

BarTender是美国海鸥科技推出的一款优秀的条码打印软件&#xff0c;应用于 WINDOWS95 、 98 、 NT 、 XP 、 2000 、 2003 和 3.1 版本&#xff0c; 产品支持广泛的条形码码制和条形码打印机&#xff0c; 不但支持条形码打印机而且支持激光打印机&#xff0c;还为世界知名品牌条…

sqli-labs 靶场 less-11~14 第十一关、第十二关、第十三关、第十四关详解:联合注入、错误注入

SQLi-Labs是一个用于学习和练习SQL注入漏洞的开源应用程序。通过它&#xff0c;我们可以学习如何识别和利用不同类型的SQL注入漏洞&#xff0c;并了解如何修复和防范这些漏洞。Less 11 SQLI DUMB SERIES-11判断注入点 尝试在用户名这个字段实施注入,且试出SQL语句闭合方式为单…