神经网络以及简单的神经网络模型实现

news2024/9/23 23:27:38

神经网络基本概念:

  1. 神经元(Neuron)

    神经网络的基本单元,接收输入,应用权重并通过激活函数生成输出
  2. 层(Layer)

    神经网络由多层神经元组成。常见的层包括输入层、隐藏层和输出层
  3. 权重(Weights)和偏置(Biases)

    权重用于调整输入的重要性,偏置用于调整模型的输出
  4. 激活函数(Activation Function)

    在神经元中引入非线性,如ReLU(Rectified Linear Unit)、Sigmoid、Tanh等。
  5. 损失函数(Loss Function)

    用于衡量模型预测与实际结果之间的差异,如均方误差(MSE)、交叉熵损失等。
  6. 优化器(Optimizer)

    用于调整模型权重以最小化损失函数,如随机梯度下降(SGD)、Adam等。

简单的神经网络示例:

下面是一个使用PyTorch构建简单线性回归的神经网络示例代码。这个示例展示了如何定义一个具有一个隐藏层的前馈神经网络,并训练它来逼近一些随机生成的数据点。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# 生成一些随机数据
np.random.seed(0)
X = np.linspace(0, 10, 100).reshape(-1, 1).astype(np.float32)
y = np.sin(X) + np.random.normal(0, 0.1, size=X.shape).astype(np.float32)

# 转换为PyTorch的张量
X_tensor = torch.tensor(X)
y_tensor = torch.tensor(y)

# 定义一个简单的神经网络模型
class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(1, 10)  # 输入层到隐藏层
        self.relu = nn.ReLU()        # 激活函数
        self.fc2 = nn.Linear(10, 1)  # 隐藏层到输出层

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# 实例化模型、损失函数和优化器
model = NeuralNet()
criterion = nn.MSELoss()  # 均方误差损失函数
optimizer = optim.Adam(model.parameters(), lr=0.01)  # Adam优化器

# 训练模型
epochs = 5000
losses = []
for epoch in range(epochs):
    optimizer.zero_grad()
    outputs = model(X_tensor)
    loss = criterion(outputs, y_tensor)
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    if (epoch+1) % 1000 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.6f}')

# 绘制损失函数变化图
plt.plot(losses, label='Training loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# 测试模型
model.eval()
with torch.no_grad():
    test_x = torch.tensor([[5.0]])  # 测试输入
    predicted = model(test_x)
    print(f'预测值: {predicted.item()}')

运行结果展示: 

代码理解:

下面便是详细分解这段代码进行理解: 

  • 生成数据

    • 使用 numpy 生成一些随机的带有噪声的正弦函数数据。
      import numpy as np
      
      # 生成带有正态分布噪声的正弦函数数据
      def generate_data(n_samples):
          np.random.seed(0)  # 设置随机种子以确保结果可复现
          X = np.random.uniform(low=0, high=10, size=n_samples)
          y = np.sin(X) + np.random.normal(scale=0.3, size=n_samples)
          return X, y
      
      # 生成数据
      X_train, y_train = generate_data(100)
      

  • 定义神经网络模型

    • NeuralNet 类继承自 nn.Module,定义了一个具有一个隐藏层的前馈神经网络。使用ReLU作为隐藏层的激活函数。
      import torch
      import torch.nn as nn
      import torch.optim as optim
      
      # 定义神经网络模型
      class NeuralNet(nn.Module):
          def __init__(self):
              super(NeuralNet, self).__init__()
              self.fc1 = nn.Linear(1, 10)  # 输入大小为1(X),输出大小为10
              self.fc2 = nn.Linear(10, 1)  # 输入大小为10,输出大小为1
              self.relu = nn.ReLU()
      
          def forward(self, x):
              x = self.relu(self.fc1(x))
              x = self.fc2(x)
              return x
      
      # 实例化模型
      model = NeuralNet()
      
      # 打印模型结构
      print(model)
      

  • 实例化模型、损失函数和优化器

    • model 是我们定义的神经网络模型。
    • criterion 是损失函数,这里使用均方误差损失。
    • optimizer 是优化器,这里使用Adam优化器来更新模型参数。
      # 定义损失函数(均方误差损失)
      criterion = nn.MSELoss()
      
      # 定义优化器(Adam优化器)
      optimizer = optim.Adam(model.parameters(), lr=0.01)
      

  • 训练模型

    • 使用 X_tensor 和 y_tensor 进行训练,优化模型使其逼近 y_tensor
      # 将numpy数组转换为PyTorch张量
      X_tensor = torch.tensor(X_train, dtype=torch.float32).view(-1, 1)
      y_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)
      
      # 训练模型
      def train_model(model, criterion, optimizer, X, y, epochs=1000):
          model.train()
          for epoch in range(epochs):
              optimizer.zero_grad()
              output = model(X)
              loss = criterion(output, y)
              loss.backward()
              optimizer.step()
              if (epoch+1) % 100 == 0:
                  print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
      
      train_model(model, criterion, optimizer, X_tensor, y_tensor)
      

  • 测试模型

    • 使用 model.eval() 将模型切换到评估模式,使用 torch.no_grad() 关闭梯度计算。
    • 测试输入为 5.0,打印预测结果。
      # 测试模型
      model.eval()
      with torch.no_grad():
          test_input = torch.tensor([[5.0]], dtype=torch.float32)
          predicted_output = model(test_input)
          print(f'预测输入为 5.0 时的输出: {predicted_output.item():.4f}')
      

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1924660.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【MySQL 进阶】MySQL 程序 -- 详解

一、MySQL 程序简介 MySQL 安装完成通常会包含如下程序: 1、Linux 系统 程序⼀般在 /usr/bin 目录下,可以通过命令查看: 2、Windows系统 目录:你的安装路径\MySQL Server 8.0\bin,可以通过命令查看: 可…

Vue el-input 限制输入内容

&#x1f914;日常项目中经常遇到既要el-input的样式&#xff0c;又要el-input-number限制&#xff0c;所以需要绑定input事件进行约束输入限制。 以下使用自定义指令进行约束el-input输入的值&#xff0c;便于后期统一管理和拓展。 预览 代码 <!DOCTYPE html> <ht…

STM32入门开发操作记录(二)——LED与蜂鸣器

目录 一、工程模板二、点亮主板1. 配置寄存器2. 调用库函数 三、LED1. 闪烁2. 流水灯 四、蜂鸣器 一、工程模板 参照第一篇&#xff0c;新建工程目录ProjectMould&#xff0c;将先前打包好的Start&#xff0c;Library和User文件^C^V过来&#xff0c;并在Keil5内完成器件支持包的…

[MySQL][表操作]详细讲解

目录 1.创建表1.基本语法2.创建表案例 2.查看表结构3.修改表1.语法2.示例3.modify和change区别 4.删除表 1.创建表 1.基本语法 语法&#xff1a; CREATE TABLE table_name (field1 datatype,field2 datatype,field3 datatype ) character set 字符集 collate 校验规则 engin…

阿里云产品流转

本文主要记述如何使用阿里云对数据进行流转&#xff0c;这里只是以topic流转&#xff08;再发布&#xff09;为例进行说明&#xff0c;可能还会有其他类型的流转&#xff0c;不同服务器的流转也可能会不一样&#xff0c;但应该大致相同。 1 创建设备 具体细节可看&#xff1a;…

STM32F103定时器中断详解

目录 目录 目录 前言 一.什么是定时器 1.1 STM32F103定时器概述 1.2基本定时器 1.2通用定时器 1.3高级定时器 1.4 三种定时器区别 基本定时器&#xff08;Basic Timer&#xff09; 通用定时器&#xff08;General-Purpose Timer&#xff09; 高级定时器&#xff08;Advanced Ti…

企业网三层架构

企业网三层架构&#xff1a;是一种层次化模型设计&#xff0c;旨在将复杂的网络设计分成三个层次&#xff0c;每个层次都着重于某些特定的功能&#xff0c;以提高效率和稳定性。 企业网三层架构层次&#xff1a; 接入层&#xff1a;使终端设备接入到网络中来&#xff0c;提供…

提高使用安全,智慧校园在线用户功能概述

智慧校园系统融入了一个查看当前在线用户的功能&#xff0c;这一设计旨在为管理人员提供一个实时的窗口&#xff0c;洞悉校园平台的即时活跃情况&#xff0c;确保系统的高效运作与环境安全。通过这一功能&#xff0c;管理员能够一目了然地看到所有正活跃在平台上的用户群体&…

『 Linux 』匿名管道应用 - 简易进程池

文章目录 池化技术进程池框架及基本思路进程的描述组织管道通信建立的潜在问题 任务的描述与组织子进程读取管道信息控制子进程进程退出及资源回收 池化技术 池化技术是一种编程技巧,一般用于优化资源的分配与复用; 当一种资源需要被使用时这意味着这个资源可能会被进行多次使…

GuLi商城-商品服务-API-品牌管理-JSR303分组校验

注解:@Validated 实体类: package com.nanjing.gulimall.product.entity;import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; import com.nanjing.common.valid.ListValue; import com.nanjing.common.valid.Updat…

代码随想录——不同路径Ⅱ(Leetcode 63)

题目链接 动态规划 class Solution {public int uniquePathsWithObstacles(int[][] obstacleGrid) {int m obstacleGrid.length;int n obstacleGrid[0].length;int[][] dp new int[m][n];// 遇到障碍则从(0,0)到达for(int i 0; i < m && obstacleGrid[i][0] …

【C++】初始化列表”存在的意义“和“与构造函数体内定义的区别“

构造函数是为了方便类的初始化而存在&#xff0c;而初始化时会遇到const成员变量、引用成员变量等&#xff0c;这些变量不允许函数内赋值&#xff0c;必须要在初始化时进行赋值&#xff0c;所以就有了初始化列表&#xff0c;初始化列表只能存在于类的构造函数中&#xff0c;用于…

百日筑基第二十天-一头扎进消息队列3-RabbitMQ

百日筑基第二十天-一头扎进消息队列3-RabbitMQ 如上图所示&#xff0c;RabbitMQ 由 Producer、Broker、Consumer 三个大模块组成。生产者将数据发送到 Broker&#xff0c;Broker 接收到数据后&#xff0c;将数据存储到对应的 Queue 里面&#xff0c;消费者从不同的 Queue 消费数…

PySide(PyQt),csv文件的显示

1、正常显示csv文件 import sys import csv from PySide6.QtWidgets import QApplication, QMainWindow, QTableWidget, QTableWidgetItem, QWidgetclass CSVTableWidgetDemo(QMainWindow):def __init__(self):super().__init__()# 创建显示控件self.widget QWidget(self)sel…

Hadoop-28 ZooKeeper集群 ZNode简介概念和测试 数据结构与监听机制 持久性节点 持久顺序节点 事务ID Watcher机制

章节内容 上节我们完成了&#xff1a; ZooKeeper 集群配置ZooKeeper 集群启动ZooKeeper 集群状况查看Follower 和 Leader 节点 背景介绍 这里是三台公网云服务器&#xff0c;每台 2C4G&#xff0c;搭建一个Hadoop的学习环境&#xff0c;供我学习。 之前已经在 VM 虚拟机上搭…

html表格账号密码备忘录:表格内容将通过JavaScript动态生成。点击查看密码10秒关闭

<!DOCTYPE html> <html lang"zh-CN"><head><meta charset"UTF-8"><title>账号密码备忘录</title><style>body {background: #2c3e50;text-shadow: 1px 1px 1px #100000;}/* 首页样式开始 */.home_page {color: …

Qt:22.鼠标相关事件(实例演示——鼠标进入/离开某控件的事件、鼠标按下事件、鼠标释放事件、鼠标双击事件)

目录 1.实例演示——鼠标进入/离开某控件的事件&#xff1a; 2.鼠标按下事件&#xff1a; 3.鼠标释放事件&#xff1a; 4.鼠标双击事件&#xff1a; 1.实例演示——鼠标进入/离开某控件的事件&#xff1a; 首先创建一个C类文件 Label&#xff0c;填写好要继承的父类 QLabe…

【ARM】使用JasperGold和Cadence IFV科普

#工作记录# 原本希望使用CCI自带的验证脚本来验证修改过后的address map decoder&#xff0c;但是发现需要使用JasperGold或者Cadence家的IFV的工具&#xff0c;我们公司没有&#xff0c;只能搜搜资料做一下科普了解&#xff0c;希望以后能用到吧。这个虽然跟ARM没啥关系不过在…

HarmonyOS NEXT:一次开发,多端部署

寄语 这几年特别火的uni-app实现了“一次开发&#xff0c;多端使用”&#xff0c;它这个端指的是ios、安卓、各种小程序这些&#xff0c;而HarmonyOS NEXT也提出了“一次开发&#xff0c;多端部署”&#xff0c;而它这个端指的是终端设备&#xff0c;也就是我们的手机、平板、电…

基于大语言模型(LLM)的合成数据生成、策展和评估的综述

节前&#xff0c;我们星球组织了一场算法岗技术&面试讨论会&#xff0c;邀请了一些互联网大厂朋友、参加社招和校招面试的同学。 针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。 合集&#x…