【pytorch练习】使用pytorch神经网络架构拟合余弦曲线

news2025/1/7 15:57:45

在本篇博客中,我们将通过一个简单的例子,讲解如何使用 PyTorch 实现一个神经网络模型来拟合余弦函数。本文将详细分析每个步骤,从数据准备到模型的训练与评估,帮助大家更好地理解如何使用 PyTorch 进行模型构建和训练。

一、背景

在机器学习中,拟合曲线是一个常见的任务,尤其是在函数预测和回归问题中。今天,我们使用一个简单的神经网络模型来拟合余弦曲线,具体步骤包括:

准备训练数据;
构建神经网络模型;
训练模型;
可视化预测结果与真实数据。
本例通过 PyTorch 实现了整个流程,我们将逐步展开。

二、代码解析

  1. 导入必要的库
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False

首先,我们导入了PyTorch相关的库 torch、torch.nn,以及用于数据加载的 DataLoader 和 TensorDataset。为了可视化结果,我们还引入了 matplotlib。

此外,为了避免某些系统环境下的警告信息,我们设置了 os.environ[“KMP_DUPLICATE_LIB_OK”] = “TRUE”,这有助于避免在多线程计算中遇到一些潜在的错误。

  1. 准备拟合数据
# 准备拟合数据
x = np.linspace(-2 * np.pi, 2 * np.pi, 400)  # 生成从 -2π 到 2π 的 400 个点
y = np.cos(x)  # 计算对应的余弦值

# 绘制生成的数据的散点图
plt.figure(figsize=(7, 5), dpi=160)
plt.scatter(x, y, color='red', label='生成数据')
plt.title('x 和 cos(x) 数据散点图', fontsize=15)
plt.xlabel('x', fontsize=12)
plt.ylabel('cos(x)', fontsize=12)
plt.legend(fontsize=12)
plt.grid(True)
plt.show()

在这里插入图片描述
使用 numpy.linspace 生成一个包含 400 个点的 x 轴数据,范围从 -2π 到 2π,然后计算对应的 y 值,这里 y = cos(x)。

  1. 接下来,将数据整理成 PyTorch 能够接受的格式
# 将数据做成数据集的模样
X = np.expand_dims(x, axis=1)  # 使 X 变为二维数组
Y = y.reshape(400, -1)  # Y 为一列的数组
dataset = TensorDataset(torch.tensor(X, dtype=torch.float), torch.tensor(Y, dtype=torch.float))
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

通过 TensorDataset 将 x 和 y 数据捆绑成一个数据集,并使用 DataLoader 来批量加载数据,设置 batch_size=10,并启用数据打乱(shuffle=True)以增加模型训练的随机性。

  1. 构建神经网络
    接下来,我们将构建一个简单的神经网络来拟合这些数据。在这个例子中,我们使用了一个全连接的神经网络,并采用了 ReLU 激活函数。网络的结构如下:

输入层:1 个神经元(因为我们的输入是一个 1D 数值)。
隐藏层 1:10 个神经元,使用 ReLU 激活函数。
隐藏层 2:100 个神经元,使用 ReLU 激活函数。
隐藏层 3:10 个神经元,使用 ReLU 激活函数。
输出层:1 个神经元,输出拟合的结果。

import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features=1, out_features=10), nn.ReLU(),
            nn.Linear(10, 100), nn.ReLU(),
            nn.Linear(100, 10), nn.ReLU(),
            nn.Linear(10, 1)
        )

    def forward(self, input: torch.FloatTensor):
        return self.net(input)

# 创建模型实例
net = Net()
net 
Net(
  (net): Sequential(
    (0): Linear(in_features=1, out_features=10, bias=True)
    (1): ReLU()
    (2): Linear(in_features=10, out_features=100, bias=True)
    (3): ReLU()
    (4): Linear(in_features=100, out_features=10, bias=True)
    (5): ReLU()
    (6): Linear(in_features=10, out_features=1, bias=True)
  )
)

这段代码定义了一个简单的神经网络类 Net,它继承自 nn.Module。通过 nn.Sequential 来堆叠多个层,使得网络的结构更加简洁和易于理解。每一层都紧跟着一个 ReLU 激活函数,用于引入非线性特征。

  1. 训练模型
    接下来,我们开始训练模型。我们选择 Adam 优化器,并使用均方误差(MSE)作为损失函数。在每个 epoch 中,我们都会迭代一次所有的训练数据,通过反向传播更新模型参数。
# 设置优化器和损失函数
optim = torch.optim.Adam(net.parameters(), lr=0.001)
Loss = nn.MSELoss()

# 训练模型
for epoch in range(100):
    loss = None
    for batch_x, batch_y in dataloader:
        # 前向传播
        y_predict = net(batch_x)
        
        # 计算损失
        loss = Loss(y_predict, batch_y)
        
        # 清空梯度
        optim.zero_grad()
        
        # 反向传播
        loss.backward()
        
        # 更新参数
        optim.step()
    
    # 每10步打印一次训练日志
    if (epoch + 1) % 10 == 0:
        print(f"训练步骤: {epoch+1}, 模型损失: {loss.item()}")

训练步骤: 10, 模型损失: 0.12506699562072754
训练步骤: 20, 模型损失: 0.024437546730041504
训练步骤: 30, 模型损失: 0.08189699053764343
训练步骤: 40, 模型损失: 0.03138166293501854
训练步骤: 50, 模型损失: 0.00651053711771965
训练步骤: 60, 模型损失: 0.0032562180422246456
训练步骤: 70, 模型损失: 0.00018047125195153058
训练步骤: 80, 模型损失: 0.005476313643157482
训练步骤: 90, 模型损失: 0.0014593529049307108
训练步骤: 100, 模型损失: 0.0008746677194721997
  1. 可视化
    训练完成后,我们可以使用训练好的模型来进行预测,并将预测结果与真实数据进行比较。
# 绘制真实数据与预测数据的对比
plt.figure(figsize=(12, 7), dpi=160)
plt.plot(x, y, label="实际值", marker="X")
plt.plot(x, predict.detach().numpy(), label="预测值", marker='o')
plt.xlabel("x", size=15)
plt.ylabel("cos(x)", size=15)
plt.xticks(size=15)
plt.yticks(size=15)
plt.legend(fontsize=15)
plt.show()

在这里插入图片描述
通过绘制图表,我们可以清楚地看到,训练好的神经网络已经很好地拟合了余弦函数,并且与真实数据非常接近。

** 通过本篇教程,我们了解了如何使用 PyTorch 从零开始构建神经网络,并使用该网络拟合一个简单的余弦曲线。我们逐步演示了数据准备、网络构建、模型训练以及预测可视化的过程。希望通过这篇文章,你能够掌握神经网络的基本操作,并能够将其应用于其他任务中。**

如果你有任何问题或建议,欢迎在评论区留言交流!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2271954.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

电脑steam api dll缺失了怎么办?

电脑故障解析与自救指南:Steam API DLL缺失问题的全面解析 在软件开发与电脑维护的广阔天地里,我们时常会遇到各种各样的系统报错与文件问题,其中“Steam API DLL缺失”便是让不少游戏爱好者和游戏开发者头疼的难题之一。作为一名深耕软件开…

Conda 安装 Jupyter Notebook

文章目录 1. 安装 Conda下载与安装步骤: 2. 创建虚拟环境3. 安装 Jupyter Notebook4. 启动 Jupyter Notebook5. 安装扩展功能(可选)6. 更新与维护7. 总结 Jupyter Notebook 是一款非常流行的交互式开发工具,尤其适合数据科学、机器…

组合的能力

在《德鲁克最后的忠告》一书中,有这样一段话: 企业将由各种积木组建而成:人员、产品、理念和建筑。积木的设计组合至少和其供给一样重要。……对于一切程序、应用软件以及附件来说,重要的是掌握将已有的软件模块组合的能力&…

去掉el-table中自带的边框线

1.问题:el-table中自带的边框线 2.解决后的效果: 3.分析:明明在el-table中没有添加border,但是会出现边框线. 可能的原因: 由 Element UI 的默认样式或者表格的某些内置样式引起的。比如,<el-table> 会通过 border-collapse 或 border-spacing 等属性影响边框的显示。 4…

大模型与EDA工具

EDA工具&#xff0c;目标是硬件设计&#xff0c;而硬件设计&#xff0c;您也可以看成是一个编程过程。 大模型可以辅助软件编程&#xff0c;相信很多人都体验过了。但大都是针对高级语言的软件编程&#xff0c;比如&#xff1a;C&#xff0c;Java&#xff0c;Python&#xff0c…

【HarmonyOS之旅】基于ArkTS开发(一) -> Ability开发一

目录 1 -> FA模型综述 1.1 -> 整体架构 1.2 -> 应用包结构 1.3 -> 生命周期 1.4 -> 进程线程模型 2 -> PageAbility开发 2.1 -> 概述 2.1.1 ->功能简介 2.1.2 -> PageAbility的生命周期 2.1.3 -> 启动模式 2.2 -> featureAbility接…

BART:用于自然语言生成、翻译和理解的去噪序列到序列预训练

摘要&#xff1a; 我们提出了BART&#xff0c;一种用于预训练序列到序列模型的去噪自编码器。BART通过以下方式训练&#xff1a;(1) 使用任意的噪声函数对文本进行破坏&#xff0c;(2) 学习一个模型来重建原始文本。它采用了一种标准的基于Transformer的神经机器翻译架构&#…

Promise编码小挑战

题目 我们将实现一个 createImage 函数&#xff0c;该函数返回一个 Promise&#xff0c;用于处理图片加载的异步操作。此外&#xff0c;还会实现暂停执行的 wait 函数。 Part 1: createImage 函数 该函数会&#xff1a; 创建一个新的图片元素。将图片的 src 设置为提供的路径…

Dubbo扩展点加载机制

加载机制中已经存在的一些关键注解&#xff0c;如SPI、©Adaptive> ©Activateo然后介绍整个加载机制中最核心的ExtensionLoader的工作流程及实现原理。最后介绍扩展中使用的类动态编译的实 现原理。 Java SPI Java 5 中的服务提供商https://docs.oracle.com/jav…

【Web】软件系统安全赛CachedVisitor——记一次二开工具的经历

明天开始考试周&#xff0c;百无聊赖开了一把CTF&#xff0c;还顺带体验了下二开工具&#xff0c;让无聊的Z3很开心&#x1f642; CachedVisitor这题 大概描述一下&#xff1a;从main.lua加载一段visit.script中被##LUA_START##(.-)##LUA_END##包裹的lua代码 main.lua loca…

在不到 5 分钟的时间内将威胁情报 PDF 添加为 AI 助手的自定义知识

作者&#xff1a;来自 Elastic jamesspi 安全运营团队通常会维护威胁情报报告的存储库&#xff0c;这些报告包含由报告提供商生成的大量知识。然而&#xff0c;挑战在于&#xff0c;这些报告的内容通常以 PDF 格式存在&#xff0c;使得在处理安全事件或调查时难以检索和引用相关…

vscode代码AI插件Continue 安装与使用

“Continue” 是一款强大的插件&#xff0c;它主要用于在开发过程中提供智能的代码延续功能。例如&#xff0c;当你在编写代码并且需要进行下一步操作或者完成一个代码块时&#xff0c;它能够根据代码的上下文、语法规则以及相关的库和框架知识&#xff0c;为你提供可能的代码续…

leetcode(hot100)4

解题思路&#xff1a;双指针思想 利用两个for循环&#xff0c;第一个for循环把所有非0的全部移到前面&#xff0c;第二个for循环将指针放在非0的末尾全部加上0。 还有一种解法就是利用while循环双指针条件&#xff0c;当不为0就两个指针一起移动 &#xff0c;为0就只移动右指针…

vulnhub——Earth靶机

使用命令在kali查看靶机ip arp-scan -l 第一 信息收集 使用 nmap 进行 dns 解析 把这两条解析添加到hosts文件中去&#xff0c;这样我们才可以访问页面 这样网站就可以正常打开 扫描ip时候我们发现443是打开的&#xff0c;扫描第二个dns解析的443端口能扫描出来一个 txt 文件…

k8s基础(1)—Kubernetes-Pod

一、Pod简介 Pod是Kubernetes&#xff08;k8s&#xff09;系统中可以创建和管理的最小单元&#xff0c;是资源对象模型中由用户创建或部署的最小资源对象模型‌。Pod是由一个或多个容器组成的&#xff0c;这些容器共享存储和网络资源&#xff0c;可以看作是一个逻辑的主机‌。…

【FlutterDart】页面切换 PageView PageController(9 /100)

上效果&#xff1a; 有些不能理解官方例子里的动画为什么没有效果&#xff0c;有可能是我写法不对 后续如果有动画效果修复了&#xff0c;再更新这篇&#xff0c;没有动画效果&#xff0c;总觉得感受的丝滑效果差了很多 上代码&#xff1a; import package:flutter/material.…

使用 NestJS 构建高效且模块化的 Node.js 应用程序,从安装到第一个 API 端点:一步一步指南

一、安装 NestJS 要开始构建一个基于 NestJS 的应用&#xff0c;首先需要安装一系列依赖包。以下是必要的安装命令&#xff1a; npm i --save nestjs/core nestjs/common rxjs reflect-metadata nestjs/platform-express npm install -g ts-node包名介绍nestjs/coreNestJS 框…

第07章 存储管理(一)

一、磁盘简介 1.1 名称称呼 磁盘/硬盘/disk是同一个东西&#xff0c;不同于内存的是容量比较大。 1.2 类型 机械&#xff1a;机械硬盘即是传统普通硬盘&#xff0c;主要由&#xff1a;盘片&#xff0c;磁头&#xff0c;盘片转轴及控制电机&#xff0c;磁头控制器&#xff0…

Appium(一)--- 环境搭建

一、Android自动化环境搭建 1、JDK 必须1.8及以上(1) 安装&#xff1a;默认安装(2) 环境变量配置新建JAVA_HOME:安装路径新建CLASSPath%JAVA_HOME%\lib\dt.jar;%JAVA_HOME%\lib\tools.jar在path中增加&#xff1a;%JAVA_HOME%\bin;%JAVA_HOME%\jre\bin&#xff1b;(3) 验证…

Framebuffer 驱动

实验环境: 正点原子alpha 开发板 调试自己编写的framebuffer 驱动,加载到内核之后,显示出小企鹅 1. Framebufer 总体框架 fbmem.c 作为Framebuffer的核心层,向上提供app使用的接口,向下屏蔽了底层各种硬件的差异; 准确来说fbmem.c 就是一个字符设备驱动框架的程序,对…