机器学习-10-基于paddle实现神经网络

news2024/12/25 10:09:05

文章目录

    • 总结
    • 参考
    • 本门课程的目标
    • 机器学习定义
    • 第一步:数据准备
    • 第二步:定义网络
    • 第三步:训练网络
    • 第四步:测试训练好的网络

总结

本系列是机器学习课程的系列课程,主要介绍基于paddle实现神经网络。

参考

MNIST 训练_副本

本门课程的目标

完成一个特定行业的算法应用全过程:

懂业务+会选择合适的算法+数据处理+算法训练+算法调优+算法融合
+算法评估+持续调优+工程化接口实现

机器学习定义

关于机器学习的定义,Tom Michael Mitchell的这段话被广泛引用:
对于某类任务T性能度量P,如果一个计算机程序在T上其性能P随着经验E而自我完善,那么我们称这个计算机程序从经验E中学习
在这里插入图片描述

使用MNIST数据集训练和测试模型。

第一步:数据准备

MNIST数据集

import paddle
from paddle.vision.datasets import MNIST
from paddle.vision.transforms import ToTensor

train_dataset = MNIST(mode='train', transform=ToTensor())
test_dataset = MNIST(mode='test', transform=ToTensor())

展示数据集图片

import matplotlib.pyplot as plt
import numpy as np

train_data0, train_label_0 = train_dataset[0][0], train_dataset[0][1]
train_data0 = train_data0.reshape([28, 28])
plt.figure(figsize=(2, 2))
plt.imshow(train_data0, cmap=plt.cm.binary)
print('train_data0 的标签为: ' + str(train_label_0))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_max = np.asscalar(a_max.astype(scaled_dtype))


train_data0 的标签为: [5]

第二步:定义网络

import paddle
import paddle.nn.functional as F
from paddle.nn import Conv2D, MaxPool2D, Linear

class MyModel(paddle.nn.Layer):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
        self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)
        self.conv2 = Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)
        self.linear1 = Linear(in_features=16*5*5, out_features=120)
        self.linear2 = Linear(in_features=120, out_features=84)
        self.linear3 = Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.max_pool1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.max_pool2(x)
        x = paddle.flatten(x, start_axis=1, stop_axis=-1)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = F.relu(x)
        x = self.linear3(x)
        return x

模型可视化

import paddle
mnist = MyModel()
paddle.summary(mnist, (1, 1, 28, 28))
---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
   Conv2D-1       [[1, 1, 28, 28]]      [1, 6, 28, 28]          156      
  MaxPool2D-1     [[1, 6, 28, 28]]      [1, 6, 14, 14]           0       
   Conv2D-2       [[1, 6, 14, 14]]     [1, 16, 10, 10]         2,416     
  MaxPool2D-2    [[1, 16, 10, 10]]      [1, 16, 5, 5]            0       
   Linear-1          [[1, 400]]            [1, 120]           48,120     
   Linear-2          [[1, 120]]            [1, 84]            10,164     
   Linear-3          [[1, 84]]             [1, 10]              850      
===========================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.06
Params size (MB): 0.24
Estimated Total Size (MB): 0.30
---------------------------------------------------------------------------






{'total_params': 61706, 'trainable_params': 61706}

第三步:训练网络

import paddle
from paddle.metric import Accuracy
from paddle.static import InputSpec

inputs = InputSpec([None, 784], 'float32', 'x')
labels = InputSpec([None, 10], 'float32', 'x')

# 用Model封装模型
model = paddle.Model(MyModel(), inputs, labels)

# 定义损失函数
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())

# 配置模型
model.prepare(optim, paddle.nn.CrossEntropyLoss(), Accuracy())
# 训练模型
model.fit(train_dataset, test_dataset, epochs=3, batch_size=64, save_dir='mnist_checkpoint', verbose=1)

The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/3
step 938/938 [==============================] - loss: 0.0208 - acc: 0.9456 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/0
Eval begin...
step 157/157 [==============================] - loss: 0.0041 - acc: 0.9777 - 19ms/step          
Eval samples: 10000
Epoch 2/3
step 938/938 [==============================] - loss: 0.0021 - acc: 0.9820 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/1
Eval begin...
step 157/157 [==============================] - loss: 2.1037e-04 - acc: 0.9858 - 19ms/step      
Eval samples: 10000
Epoch 3/3
step 938/938 [==============================] - loss: 0.0126 - acc: 0.9876 - 34ms/step          
save checkpoint at /home/aistudio/mnist_checkpoint/2
Eval begin...
step 157/157 [==============================] - loss: 4.7168e-04 - acc: 0.9884 - 19ms/step      
Eval samples: 10000
save checkpoint at /home/aistudio/mnist_checkpoint/final

第四步:测试训练好的网络

import paddle
import numpy as np
import matplotlib.pyplot as plt
from paddle.metric import Accuracy
from paddle.static import InputSpec

inputs = InputSpec([None, 784], 'float32', 'x')
labels = InputSpec([None, 10], 'float32', 'x')
model = paddle.Model(MyModel(), inputs, labels)
model.load('./mnist_checkpoint/final')
model.prepare(optim, paddle.nn.CrossEntropyLoss(), Accuracy())

# results = model.evaluate(test_dataset, batch_size=64, verbose=1)
# print(results)

results = model.predict(test_dataset, batch_size=64)

test_data0, test_label_0 = test_dataset[0][0], test_dataset[0][1]
test_data0 = test_data0.reshape([28, 28])
plt.figure(figsize=(2,2))
plt.imshow(test_data0, cmap=plt.cm.binary)

print('test_data0 的标签为: ' + str(test_label_0))
print('test_data0 预测的数值为:%d' % np.argsort(results[0][0])[0][-1])

Predict begin...
step 157/157 [==============================] - 27ms/step          
Predict samples: 10000
test_data0 的标签为: [7]
test_data0 预测的数值为:7


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_max = np.asscalar(a_max.astype(scaled_dtype))

在这里插入图片描述


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1616011.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深入剖析机器学习领域的璀璨明珠——支持向量机算法

在机器学习的广袤星空中,支持向量机(Support Vector Machine,简称SVM)无疑是一颗璀璨的明珠。它以其独特的分类能力和强大的泛化性能,在数据分类、模式识别、回归分析等领域大放异彩。本文将详细剖析SVM算法的原理、特…

MLLM | InternLM-XComposer2-4KHD: 支持336 像素到 4K 高清的分辨率的大视觉语言模型

上海AI Lab,香港中文大学等 论文标题:InternLM-XComposer2-4KHD: A Pioneering Large Vision-Language Model Handling Resolutions from 336 Pixels to 4K HD 论文地址:https://arxiv.org/abs/2404.06512 Code and models are publicly available at https://gi…

互联网扭蛋机小程序:打破传统扭蛋机的局限,提高销量

扭蛋机作为一种适合全年龄层的娱乐消费方式,深受人们的喜欢,通过一个具有神秘性的商品给大家带来欢乐。近几年,扭蛋机在我国的发展非常迅速,市场规模在不断上升。 经过市场的发展,淘宝线上扭蛋机小程序开始流行起来。…

一文讲透彻Redis 持久化

文章目录 ⛄1.RDB持久化🪂🪂1.1.执行时机🪂🪂1.2.RDB原理🪂🪂1.3.小结 ⛄2.AOF持久化🪂🪂2.1.AOF原理🪂🪂2.2.AOF配置🪂🪂2.3.AOF文件…

40+ Node.js 常见面试问题 [2024]

今天就开始你的Node.js生涯。在这里,我们探讨了最佳Node.js面试问题和答案,以帮助应届生和经验丰富的候选人获得理想的工作。 Node.js 是许多大公司技术堆栈的重要组成部分,例如 PayPal、Trello、沃尔玛和 NASA。 根据 ZipRecruiter 的数据&…

了解边缘计算,在制造行业使用边缘计算。

边缘计算是一种工业元宇宙技术,可以帮助组织实现其数据的全部潜力。 处理公司的所有数据可能具有挑战性,而边缘计算可以帮助公司更快地处理数据。在制造业中,边缘计算可以帮助进行预测性维护和自动驾驶汽车操作等工作。 什么是边缘计算? …

ruoyi-cloud-plus添加一个不要认证的公开新页面

文章目录 一、前端1. 组件创建2. src/router/index.ts3. src/permission.ts 二、后端1. 设计思想2. ruoyi-gateway.yml3. 开发Controller 版本RuoYiCloudPlusv2.1.2plus-uiVue3 ts 以新增一个公开的课程搜索页面为例。 一、前端 1. 组件创建 在view目录下创建一个页面的vue…

python--使用pika库操作rabbitmq实现需求

Author: wencoo Blog:https://wencoo.blog.csdn.net/ Date: 22/04/2024 Email: jianwen056aliyun.com Wechat:wencoo824 QQ:1419440391 Details:文章目录 目录正文 或 背景pika链接mqpika指定消费数量pika自动消费实现pika获取队列任务数量pi…

去哪儿网开源的一个对应用透明,无侵入的Java应用诊断工具

今天 V 哥给大家带来一款开源工具Bistoury,Bistoury 是去哪儿网开源的一个对应用透明,无侵入的java应用诊断工具,用于提升开发人员的诊断效率和能力。 Bistoury 的目标是一站式java应用诊断解决方案,让开发人员无需登录机器或修改…

使用大卫的k8s监控面板(k8s+prometheus+grafana)

问题 书接上回,对EKS(AWS云k8s)启用AMP(AWS云Prometheus)监控AMG(AWS云 grafana),上次我们只是配通了EKSAMPAMG的监控路径。这次使用一位大卫老师的grafana的面板,具体地址如下: ht…

Google Ads广告为Demand Gen推出生成式AI工具,可自动生成广告图片

谷歌今天宣布在Google Ads广告中为Demand Gen活动推出新的生成人工智能功能。 这些工具由谷歌人工智能提供支持,广告商只需几个步骤即可使用文本提示创建高质量的图片。 这些由人工智能驱动的创意功能旨在增强视觉叙事能力,帮助品牌在YouTube、YouTube…

【Hadoop】-Apache Hive概述 Hive架构[11]

目录 Apache Hive概述 一、分布式SQL计算-Hive 二、为什么使用Hive Hive架构 一、Hive组件 Apache Hive概述 Apache Hive是一个在Hadoop上构建的数据仓库基础设施,它提供了一个SQL-Like查询语言来分析和查询大规模的数据集。Hive将结构化查询语言(…

第十二届蓝桥杯C/C++ B组 杨辉三角形(二分查找+思维)

3418. 杨辉三角形 - AcWing题库 题目描述: 思路: 从上图片中,我们可以看出来这是一个对称图形,所以我们只看左半部分就可以了,我们一行一列去做数据量是1e9这样会很麻烦,所以我们这里做一个思想转换,斜着…

单片机 VS 嵌入式LInux (学习方法)

linux 嵌入式开发岗位需要掌握Linux的主要原因之一是,许多嵌入式系统正在向更复杂、更功能丰富的方向发展,需要更强大的操作系统支持。而Linux作为开源、稳定且灵活的操作系统,已经成为许多嵌入式系统的首选。以下是为什么嵌入式开发岗位通常…

申请IP地址SSL证书的七大步骤

申请IP地址SSL证书的目的是为了在使用IP地址作为访问地址而非域名的情况下,为您的服务提供HTTPS加密,确保数据传输的安全性。以下是申请IP地址SSL证书的一般步骤和注意事项: 一、选择合适的SSL证书类型: IP SSL证书:…

java锁常识

AQS框架 AQS(AbstractQueuedSynchronizer)是 Java 中用于构建锁和同步器的基础框架。它提供了一种实现同步器的方式,使得开发者可以基于 AQS 构建各种类型的同步工具,如独占锁、共享锁、信号量等。 AQS 主要基于 FIFO 队列&…

RoadBEV:鸟瞰视图下的路面重建

作者:Tong Zhao,Lei Yang,Yichen Xie等 编译:董亚微一点人工一点智能 RoadBEV:鸟瞰视图下的路面重建https://mp.weixin.qq.com/s/hDNHwvpFe39doiXlVc-d7Q 摘要:道路的路面状况,特别是几何轮廓…

线程池多线程在项目中的实际应用

一.发短信 发短信的场景有很多,比如手机号验证码登录注册,电影票买完之后会发送取票码,发货之后会有物流信息,支付之后银行发的付款信息,电力系统的电费预警信息等等 在这些业务场景中,有一个特征&#x…

Linux 网络编程项目--简易ftp

主要代码 config.h #define LS 0 #define GET 1 #define PWD 2#define IFGO 3#define LCD 4 #define LLS 5 #define CD 6 #define PUT 7#define QUIT 8 #define DOFILE 9struct Msg {int type;char data[1024];char secondBuf[128]; }; 服务器: #i…

231 基于matlab的北斗信号数据解析

基于matlab的北斗信号数据解析,多通道和单通道接收到的北斗信号数据,利用接收到的北斗数据(.dat .txt文件),进行解析,得到初始伪距,平滑伪距,载波相位,并计算其标准差&am…