机器学习-11-卷积神经网络-基于paddle实现神经网络

news2024/12/23 17:26:25

文章目录

    • 总结
    • 参考
    • 本门课程的目标
    • 机器学习定义
    • 第一步:数据准备
    • 第二步:定义网络
    • 第三步:训练网络
    • 第四步:测试训练好的网络

总结

本系列是机器学习课程的系列课程,主要介绍基于paddle实现神经网络。

参考

MNIST 训练_副本

本门课程的目标

完成一个特定行业的算法应用全过程:

懂业务+会选择合适的算法+数据处理+算法训练+算法调优+算法融合
+算法评估+持续调优+工程化接口实现

机器学习定义

关于机器学习的定义,Tom Michael Mitchell的这段话被广泛引用:
对于某类任务T性能度量P,如果一个计算机程序在T上其性能P随着经验E而自我完善,那么我们称这个计算机程序从经验E中学习
在这里插入图片描述

使用MNIST数据集训练和测试模型。

第一步:数据准备

MNIST数据集

import paddle
from paddle.vision.datasets import MNIST
from paddle.vision.transforms import ToTensor

train_dataset = MNIST(mode='train', transform=ToTensor())
test_dataset = MNIST(mode='test', transform=ToTensor())

展示数据集图片

import matplotlib.pyplot as plt
import numpy as np

train_data0, train_label_0 = train_dataset[0][0], train_dataset[0][1]
train_data0 = train_data0.reshape([28, 28])
plt.figure(figsize=(2, 2))
plt.imshow(train_data0, cmap=plt.cm.binary)
print('train_data0 的标签为: ' + str(train_label_0))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_max = np.asscalar(a_max.astype(scaled_dtype))


train_data0 的标签为: [5]

第二步:定义网络

import paddle
import paddle.nn.functional as F
from paddle.nn import Conv2D, MaxPool2D, Linear

class MyModel(paddle.nn.Layer):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
        self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)
        self.conv2 = Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)
        self.linear1 = Linear(in_features=16*5*5, out_features=120)
        self.linear2 = Linear(in_features=120, out_features=84)
        self.linear3 = Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.max_pool1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.max_pool2(x)
        x = paddle.flatten(x, start_axis=1, stop_axis=-1)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = F.relu(x)
        x = self.linear3(x)
        return x

模型可视化

import paddle
mnist = MyModel()
paddle.summary(mnist, (1, 1, 28, 28))
---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
   Conv2D-1       [[1, 1, 28, 28]]      [1, 6, 28, 28]          156      
  MaxPool2D-1     [[1, 6, 28, 28]]      [1, 6, 14, 14]           0       
   Conv2D-2       [[1, 6, 14, 14]]     [1, 16, 10, 10]         2,416     
  MaxPool2D-2    [[1, 16, 10, 10]]      [1, 16, 5, 5]            0       
   Linear-1          [[1, 400]]            [1, 120]           48,120     
   Linear-2          [[1, 120]]            [1, 84]            10,164     
   Linear-3          [[1, 84]]             [1, 10]              850      
===========================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.06
Params size (MB): 0.24
Estimated Total Size (MB): 0.30
---------------------------------------------------------------------------






{'total_params': 61706, 'trainable_params': 61706}

第三步:训练网络

import paddle
from paddle.metric import Accuracy
from paddle.static import InputSpec

inputs = InputSpec([None, 784], 'float32', 'x')
labels = InputSpec([None, 10], 'float32', 'x')

# 用Model封装模型
model = paddle.Model(MyModel(), inputs, labels)

# 定义损失函数
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())

# 配置模型
model.prepare(optim, paddle.nn.CrossEntropyLoss(), Accuracy())
# 训练模型
model.fit(train_dataset, test_dataset, epochs=3, batch_size=64, save_dir='mnist_checkpoint', verbose=1)

The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/3
step 938/938 [==============================] - loss: 0.0208 - acc: 0.9456 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/0
Eval begin...
step 157/157 [==============================] - loss: 0.0041 - acc: 0.9777 - 19ms/step          
Eval samples: 10000
Epoch 2/3
step 938/938 [==============================] - loss: 0.0021 - acc: 0.9820 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/1
Eval begin...
step 157/157 [==============================] - loss: 2.1037e-04 - acc: 0.9858 - 19ms/step      
Eval samples: 10000
Epoch 3/3
step 938/938 [==============================] - loss: 0.0126 - acc: 0.9876 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/2
Eval begin...
step 157/157 [==============================] - loss: 4.7168e-04 - acc: 0.9884 - 19ms/step      
Eval samples: 10000
save checkpoint at /home/aistudio/mnist_checkpoint/final

第四步:测试训练好的网络

import paddle
import numpy as np
import matplotlib.pyplot as plt
from paddle.metric import Accuracy
from paddle.static import InputSpec

inputs = InputSpec([None, 784], 'float32', 'x')
labels = InputSpec([None, 10], 'float32', 'x')
model = paddle.Model(MyModel(), inputs, labels)
model.load('./mnist_checkpoint/final')
model.prepare(optim, paddle.nn.CrossEntropyLoss(), Accuracy())

# results = model.evaluate(test_dataset, batch_size=64, verbose=1)
# print(results)

results = model.predict(test_dataset, batch_size=64)

test_data0, test_label_0 = test_dataset[0][0], test_dataset[0][1]
test_data0 = test_data0.reshape([28, 28])
plt.figure(figsize=(2,2))
plt.imshow(test_data0, cmap=plt.cm.binary)

print('test_data0 的标签为: ' + str(test_label_0))
print('test_data0 预测的数值为:%d' % np.argsort(results[0][0])[0][-1])

Predict begin...
step 157/157 [==============================] - 27ms/step          
Predict samples: 10000
test_data0 的标签为: [7]
test_data0 预测的数值为:7


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_max = np.asscalar(a_max.astype(scaled_dtype))

在这里插入图片描述


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1633266.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C# Form1.cs 控件全部丢失的问题解决

在应用C#开发程序时,代码写了一堆,等调试时,点开 Form1.cs窗体时,出现如下提示。点击忽略并继续是,整个窗体控件全部丢失。 初次遇到这个问题,很容易进入到误区,以为窗体控件真的全部丢失了&am…

算法入门ABC

前言 初学算法时真的觉得这东西晦涩难懂,貌似毫无用处!后来的后来,终于渐渐明白搞懂算法背后的核心思想,能让你写出更加优雅的代码。就像一首歌唱的那样:后来,我总算学会了如何去爱,可惜你早已远…

Linux磁盘分区与管理

标题日期版本说明署名 Linux磁盘分区与管理 2024.4.29 v0.0.0*******LGB 目标:能够熟练掌握Linux系统磁盘分区管 操作系统:Centos Stream 9 实验过程: 1.首先我们先新建一块磁盘。 2.我们先对新建磁盘进行分区。 3.输入n 创建分区&#xf…

Llama3 在线试用与本地部署

美国当地时间4月18日,Meta 开源了 Llama3 大模型,目前开源版本为 8B 和 70B 。Llama 3 模型相比 Llama 2 具有重大飞跃,并在 8B 和 70B 参数尺度上建立了 LLM 模型的新技术。由于预训练和后训练的改进,Llama3 模型是目前在 8B 和 …

React、React Router 和 Redux 常用Hooks 总结,提升您的开发效率!

Hooks 是 React 16.8 中引入的一种新特性,它使得函数组件可以使用 state 和其他 React 特性,从而大大提高了函数组件的灵活性和功能性。下面分别总结React、React Router 、Redux中常用的Hooks。 常用Hooks速记 React Hooks useState:用于…

vue为遍历生成的表单设置ref属性

最近在写表单重置的时候出现了问题&#xff0c;在this.$refs[formName].resetFields();的时候卡了很久。 经过网上的搜索终于解决的问题&#xff01; 对于不需要遍历的表单 这是vue代码&#xff1a; <el-dialog title"段落描述" :visible.sync"dialogFormV…

流水线工作流程

java编译命令&#xff1a; java -jar xxx.jar (其它参数已忽略) docker镜像构建命令&#xff1a; docker build -t [镜像名称:latest] -f 指定[Dockerfile] [指定工作目录] 推送镜像 jenkinsfile: 主要流程登录镜像仓库&#xff0c;打包镜像&#xff0c;推送到镜像仓库

MySql 主从同步-在原来同步基础上增加历史数据库

在MySql已经主从同步的后&#xff0c;由于有新的需求再增加1个历史数据库&#xff0c;要改原来的1个变成现在的2个数据库。在官网并没有找到类似的场景&#xff08;官方同步多个数据是从一开始就设置&#xff0c;不是后续增加的&#xff09;&#xff0c;只能结合以往的经验自己…

第三方软件测试机构-科技成果评价测试

科技成果评价测试是对科研成果的工作质量、学术水平、实际应用和成熟程度等方面进行的客观、具体、恰当的评价过程。这一评价过程有助于了解科技成果的质量和水平&#xff0c;以及其在学术和应用方面的价值和潜力。 科技成果评价测试主要包括以下几个方面&#xff1a; 工作质量…

OpenVoice: Versatile Instant Voice Cloning

OpenVoice&#xff1a;多功能即时语音克隆 摘要 OpenVoice是一种多功能的即时声音克隆方法&#xff0c;它只需要参考说话者的一小段音频就可以复制他们的声音并以多种语言生成语音。OpenVoice 在解决以下领域中的开放性挑战方面代表了重大进展&#xff1a;1) 灵活的声音风格控…

【1762】java校园单车投放系统Myeclipse开发mysql数据库web结构jsp编程servlet计算机网页项目

一、源码特点 java校园单车投放管理系统是一套完善的java web信息管理系统 采用serlvetdaobean&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S 模式开发。开发环境为TOMCAT7.0,Myeclipse8.5开发&#…

面试题:两阶段提交与三阶段提交的区别?

主要区别有以下几点&#xff1a; 增加了一个询问阶段&#xff0c;问了下&#xff0c;你能不不能行&#xff1f;加入了超时机制 2PC&#xff08;二阶段提交协议&#xff09; 2PC&#xff0c;两阶段提交&#xff0c;将事务的提交过程分为资源准备和资源提交两个阶段&#xff0c;…

Linux配置双网卡,1NAT 2桥接,ARM板上网

1、简介 版本型号&#xff1a;ubuntu18.04 ARM板型号&#xff1a;6ull本文主要记录配置第一次ubuntu与arm板连接的nfs配置和ARM板上网的配置&#xff0c;按照配置网络、配置nfs系统、给板子连网 顺序进行。该配置的前提是创建ubuntu系统的网络配置选择的是NAT模式&…

算法设计优化——起泡排序

文章目录 0.概述1 起泡排序&#xff08;基础版&#xff09;1.1 算法分析1.2 算法实现1.3 重复元素与稳定性1.4 复杂度分析 2 起泡排序&#xff08;改进版&#xff09;2.1 目标2.2 改进思路2.3 实现2.4 复杂度分析 3 起泡排序&#xff08;改进版2&#xff09;3.1 目标3.1 改进思…

鸿蒙内核源码分析(汇编基础篇) | CPU在哪里打卡上班

本篇通过拆解一段很简单的汇编代码来快速认识汇编&#xff0c;为读懂鸿蒙汇编打基础.系列篇后续将逐个剖析鸿蒙的汇编文件. 汇编很简单 第一&#xff1a; 要认定汇编语言一定是简单的&#xff0c;没有高深的东西&#xff0c;无非就是数据的搬来搬去&#xff0c;运行时数据主要…

基于Spring Boot的体质测试数据分析及可视化系统设计与实现

基于Spring Boot的体质测试数据分析及可视化系统的设计与实现 开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/idea 系统部分展示 前台首页界面图&#xff0c;体质测试…

day17-day20_项目实战项目部署

万信金融 项目部署 目标&#xff1a; 理解DevOps概念 能够使用Docker Compose部署项目 理解持续集成的作用 会使用Jenkins进行持续集成 1 DevOps介绍 1.1 什么是DevOps DevOps是Development和Operations两个词的缩写&#xff0c;引用百度百科的定义&#xff1a; DevOps…

68.网络游戏逆向分析与漏洞攻防-利用数据包构建角色信息-自动生成CPP函数解决数据更新的问题

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 如果看不懂、不知道现在做的什么&#xff0c;那就跟着做完看效果&#xff0c;代码看不懂是正常的&#xff0c;只要会抄就行&#xff0c;抄着抄着就能懂了 内容…

Seata-server配置

首先先兼容一下版本看看所用的版本是否都兼容 版本兼容查看 建立seata-server数据库 数据库DDL 给每个业务库建立undo.log表 undo.log 然后在虚拟机安装seata-server 创建文件路径&#xff0c;并创建docker-compose.yml文件 创建完成后先启动一下seata docker run -rm seata…

linux远程访问及控制

一、SSH远程管理 1.SSH的简介 SSH远程管理是一种通过 SSH 协议安全地管理远程计算机的方法。允许管理员通过加密的连接从本地计算机或其他远程位置连接到远程计算机&#xff0c;并执行管理任务、配置设置、故障排除等操作。 远程链接的两种方法&#xff1a;SSH 、Telnet S…