实现多层感知机

news2024/9/22 23:37:44

目录

多层感知机:

介绍:

代码实现:

运行结果:

问题答疑:

线性变换与非线性变换

参数含义

为什么清除梯度?

反向传播的作用

为什么更新权重?


多层感知机:

介绍:

缩写:MLP,这是一种人工神经网络,由一个输入层、一个或多个隐藏层以及一个输出层组成,每一层都由多个节点(神经元)构成。在MLP中,节点之间只有前向连接,没有循环连接,这使得它属于前馈神经网络的一种。每个节点都应用一个激活函数,如sigmoid、ReLU等,以引入非线性,从而使网络能够拟合复杂的函数和数据分布。

代码实现:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Step 1: Define the MLP model
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(784, 128)  # Input layer to hidden layer
        self.fc2 = nn.Linear(128, 64)   # Hidden layer to another hidden layer
        self.fc3 = nn.Linear(64, 10)    # Hidden layer to output layer
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(-1, 784)             # Flatten the input from 28x28 to 784
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Step 2: Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Step 3: Define loss function and optimizer
model = SimpleMLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Step 4: Train the model
num_epochs = 5
for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

# Step 5: Evaluate the model on the test set (optional)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))

运行结果:

问题答疑:

线性变换与非线性变换

在神经网络中

线性变换通常指的是权重矩阵和输入数据的矩阵乘法,再加上偏置向量。数学上,对于一个输入向量𝑥x和权重矩阵𝑊W,加上偏置向量𝑏b,线性变换可以表示为: 𝑧=𝑊𝑥+𝑏z=Wx+b

非线性变换是指在神经网络的每一层之后应用的激活函数,如ReLU、sigmoid或tanh等。这些函数引入了非线性,使神经网络能够学习和表达复杂的函数关系。没有非线性变换,无论多少层的神经网络最终都将简化为一个线性模型。

参数含义

在上述模型中,参数如784, 128, 64, 10并不是字节,而是神经网络层的尺寸,具体来说是神经元的数量:

  • 784: 这是输入层的神经元数量,对应于MNIST数据集中每个图片的像素数量。MNIST的图片是28x28像素,因此总共有784个像素点。
  • 128 和 64: 这是两个隐藏层的神经元数量。它们代表了第一层和第二层的宽度,即这一层有多少个神经元。
  • 10: 这是输出层的神经元数量,对应于MNIST数据集中的10个数字类别(0到9)。

为什么清除梯度?

在每一次前向传播和反向传播过程中,梯度会被累积在张量的.grad属性中。如果不手动清零,这些梯度将会被累加,导致不正确的梯度值。因此,在每次迭代开始之前,都需要调用optimizer.zero_grad()来清空梯度。

反向传播的作用

反向传播(Backpropagation)是一种算法,用于计算损失函数相对于神经网络中所有权重的梯度。它的目的是为了让神经网络知道,当损失函数值较高时,哪些权重需要调整,以及调整的方向和幅度。这些梯度随后被用于权重更新,以最小化损失函数。

为什么更新权重?

权重更新是基于梯度下降算法进行的。在反向传播计算出梯度后,权重通过optimizer.step()函数更新,以朝着减小损失函数的方向移动。

这是训练神经网络的核心,即通过不断调整权重和偏置,使模型能够更好地拟合训练数据,从而提高预测准确性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1925048.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux: Mysql环境安装

Mysql环境安装(Centos) 前言一、卸载多余环境1.1 卸载mariadb1.2 查看并卸载系统mysql和mariadb安装包 二、换取mysql官方yum源三、安装并启动mysql服务3.1 yum源加载3.2 安装yum源3.3 安装mysql服务3.3.1 安装指令3.3.2 GPG密钥问题解决方法3.3.3 查看是…

LabVIEW液压数据采集测试系统

液压系统是装载机的重要组成部分,通过液压传动和控制实现各项作业功能,如提升、倾斜、转向等。液压系统的性能直接影响装载机的作业效率和稳定性。为了保证装载机液压系统的正常运行和优化设计,需要对其进行数据采集和测试。本文介绍了一套基…

Python酷库之旅-第三方库Pandas(022)

目录 一、用法精讲 55、pandas.lreshape函数 55-1、语法 55-2、参数 55-3、功能 55-4、返回值 55-5、说明 55-6、用法 55-6-1、数据准备 55-6-2、代码示例 55-6-3、结果输出 56、pandas.wide_to_long函数 56-1、语法 56-2、参数 56-3、功能 56-4、返回值 56-5…

Linux文件压缩与解压缩

在Linux中,tar实用程序是用于创建、管理和提取存档的常用命令。 tar实用程序的常用选项 执行tar操作需要以下tar命令操作之一: -c ,--create :创建存档文件(即压缩文件)。-t,--list&#xff1…

0708,LINUX目录相关操作 + LINUX全导图

主要是冷气太足感冒了,加上少吃药抗药性差,全天昏迷,学傻了学傻了 01:简介 02: VIM编辑器 04:目录 05:文件 03:常用命令 06:进程 07:进程间的通信 cat t_c…

数据结构(4.1)——串的存储结构

串的顺序存储 串(String)的顺序存储是指使用一段连续的存储单元来存储字符串中的字符。 计算串的长度 静态存储(定长顺序存储) #define MAXLEN 255//预定义最大串为255typedef struct {char ch[MAXLEN];//每个分量存储一个字符int length;//串的实际长…

接口安全配置

问题点: 有员工在工位在某个接口下链接一个集线器,从而扩展上网接口,这种行为在某些公司是被禁止的,那么网络管理员如何控制呢?可以配置接口安全来限制链接的数量,切被加入安全的mac地址不会老化&#xff…

开源模型应用落地-工具使用篇-Spring AI-Function Call(八)

​​​​​​​一、前言 通过“开源模型应用落地-工具使用篇-Spring AI(七)-CSDN博客”文章的学习,已经掌握了如何通过Spring AI集成OpenAI和Ollama系列的模型,现在将通过进一步的学习,让Spring AI集成大语言模型更高阶…

Linux的世界 -- 初次接触和一些常见的基本指令

一、Linux的介绍和准备 1、简单介绍下Linux的发展史 1991年10月5日,赫尔辛基大学的一名研究生Linus Benedict Torvalds在一个Usenet新闻组(comp.os.minix)中宣布他编制出了一种类似UNIX的小操作系统,叫Linux。新的操作系统是受到另一个UNIX的…

【Python】爬虫实战01:获取豆瓣Top250电影信息

本文中我们将通过一个小练习的方式利用urllib和bs4来实操获取豆瓣 Top250 的电影信息,但在实际动手之前,我们需要先了解一些关于Http 请求和响应以及请求头作用的一些知识。 1. Http 请求与响应 HTTP(超文本传输协议)是互联网上…

C#创建windows服务程序

步骤 1: 创建Windows服务项目 打开Visual Studio。选择“创建新项目”。在项目类型中搜索“Windows Service”并选择一个C#模板(如“Windows Service (.NET Framework)”),点击下一步。输入项目名称、位置和其他选项,然后点击“创…

C++ | Leetcode C++题解之第232题用栈实现队列

题目&#xff1a; 题解&#xff1a; class MyQueue { private:stack<int> inStack, outStack;void in2out() {while (!inStack.empty()) {outStack.push(inStack.top());inStack.pop();}}public:MyQueue() {}void push(int x) {inStack.push(x);}int pop() {if (outStac…

秋招突击——7/9——MySQL索引的使用

文章目录 引言正文B站网课索引基础创建索引如何在一个表中查看索引为字符串建立索引全文索引复合索引复合索引中的排序问题索引失效的情况使用索引进行排序覆盖索引维护索引 数据库基础——文档资料学习整理创建索引删除索引创建唯一索引索引提示复合索引聚集索引索引基数字符串…

网络安全——防御课实验二

在实验一的基础上&#xff0c;完成7-11题 拓扑图 7、办公区设备可以通过电信链路和移动链路上网(多对多的NAT&#xff0c;并且需要保留一个公网IP不能用来转换) 首先&#xff0c;按照之前的操作&#xff0c;创建新的安全区&#xff08;电信和移动&#xff09;分别表示两个外网…

基础小波降噪方法(Python)

主要内容包括&#xff1a; Stationary wavelet Transform (translation invariant) Haar wavelet Hard thresholding of detail coefficients Universal threshold High-pass filtering by zero-ing approximation coefficients from a 5-level decomposition of a 16Khz …

win10系统更新后无法休眠待机或者唤醒,解决方法如下

是否使用鼠标唤醒 是否使用鼠标唤醒 是否使用键盘唤醒

【Java开发实训】day03——方法的注意事项

目录 一、方法的基本概念 二、void和return关键字 三、单一返回点原则 四、static方法使用说明 &#x1f308;嗨&#xff01;我是Filotimo__&#x1f308;。很高兴与大家相识&#xff0c;希望我的博客能对你有所帮助。 &#x1f4a1;本文由Filotimo__✍️原创&#xff0c;首发于…

《Windows API每日一练》9.25 系统菜单

/*------------------------------------------------------------------------ 060 WIN32 API 每日一练 第60个例子POORMENU.C&#xff1a;使用系统菜单 GetSystemMenu函数 AppendMenu函数 (c) www.bcdaren.com 编程达人 -------------------------------------------…

Java02--基础概念

一、注释 注释是在程序指定位置添加的说明性信息 简单理解&#xff0c;就是对代码的一种解释 1.单行注释 格式: //注释信息 2.多行注释 格式: /*注释信息*/ 3.文档注释 格式: /**注释信息*/ 注释使用的细节: 注释内容不会参与编译和运…

九盾安防丨如何判断叉车是否超速?

在现代物流和生产流程中&#xff0c;叉车是提高效率和降低成本的关键工具。然而&#xff0c;叉车的高速行驶也带来了安全隐患&#xff0c;这就要求我们对其进行严格的安全管理。九盾安防&#xff0c;作为业界领先的安防专家&#xff0c;今天就为大家揭晓如何判断叉车是否超速&a…