[paddle] 非线性拟合问题的训练

news2025/1/8 17:08:02

利用paddlepaddle建立神经网络,模拟有限个数据的非线性拟合

本文仍然考虑 f ( x ) = sin ⁡ ( x ) x f(x)=\frac{\sin(x)}{x} f(x)=xsin(x) 函数在区间 [-10,10] 上固定数据的拟合。

import paddle
import paddle.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# 设置随机种子以确保结果的可重复性
paddle.seed(1)

# 生成数据集
x_data = (np.random.rand(500) * 20 - 10).astype('float32')  # 生成500个随机x值,范围在-10到10之间
y_data = np.sin(x_data) / x_data  # 生成y值
y_data = y_data.reshape(-1, 1)  # 将y_data转换为二维数组

# 定义模型,一个具有2个隐藏层的多层感知器
class MyModel(nn.Layer):
    def __init__(self):
        super(MyModel, self).__init__()
        self.hidden1 = nn.Linear(in_features=1, out_features=50)
        self.bn = nn.BatchNorm1D(num_features=50)
        self.hidden2 = nn.Linear(in_features=50, out_features=1)

    def forward(self, x):
        x = paddle.tanh(self.hidden1(x))
        x = self.bn(x)
        x = self.hidden2(x)
        return x

model = MyModel()

# 定义损失函数
loss_fn = nn.MSELoss()

# 设置优化器
optimizer = paddle.optimizer.Adam(learning_rate=0.01, parameters=model.parameters())

# 训练数据
train_data = paddle.to_tensor(x_data).unsqueeze(-1), paddle.to_tensor(y_data)

# 训练模型
epochs = 1000
for epoch in range(1, epochs + 1):
    loss = loss_fn(model(train_data[0]), train_data[1])
    loss.backward()
    optimizer.step()
    optimizer.clear_grad()
    if epoch % 100 == 0:
        print(f'Epoch {epoch}: Loss = {loss.numpy()}')

# 使用训练好的模型进行预测
y_pred = model(train_data[0]).numpy()

# 可视化结果
plt.scatter(x_data, y_data, label='True')
plt.scatter(x_data, y_pred, label='Predicted')
plt.legend()
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2272709.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GWAS数据和软件下载

这部分主要是数据获取,以及软件配置方法。 一、配套数据和代码 数据和代码目前在不断的更新,最新的教程可以私信,我通过后手动发送最新版的pdf和数据代码。发送的压缩包,有电子版的pdf和数据下载链接,里面是最新的百度网盘的地址,下载到本地即可。然后根据pdf教程,结合配套的…

win32汇编环境,在对话框中画五边形与六边形

;运行效果 ;win32汇编环境,在对话框中画五边形与六边形 ;展示五边形与六边形的画法 ;将代码复制进radasm软件里,直接编译可运行.重要部分加备注。 ;下面为asm文件 ;>>>>>>>>>>>>>>>>>>>>>>>>>&g…

springcloud 介绍

Spring Cloud是一个基于Spring Boot的微服务架构解决方案集合,它提供了一套完整的工具集,用于快速构建分布式系统。在Spring Cloud的架构中,服务被拆分为一系列小型、自治的微服务,每个服务运行在其独立的进程中,并通过…

如何进行千万级别数据跑批优化

目录 背景问题分析解决方案 数据库问题分片广播分批获取事务控制充分利用服务器资源MQ消费任务并行动态调整并发度失败任务如何继续下游接口时间线程安全异常 & 监控 总结 背景 定义:跑批是指在特定日期对大量数据进行定时处理的过程。在金融领域,…

电脑提示wlanapi.dll丢失怎么办?wlanapi.dll丢失的多种解决方法

电脑提示wlanapi.dll丢失?别担心,这里有多种解决方法! 作为软件开发领域的从业者,我深知电脑在运行过程中可能会遇到的各种问题,其中“wlanapi.dll丢失”这一报错信息就常常让用户感到困惑和不安。今天,我…

刷服务器固件

猫眼淘票票 大麦 一 H3C通用IP 注:算力服务器不需要存储 二 刷服务器固件 1 登录固定IP地址 2 升级BMC版本 注 虽然IP不一致但是步骤是一致的 3 此时服务器会出现断网现象,若不断网等上三分钟ping一下 4 重新登录 5 断电拔电源线重新登录查看是否登录成功

深入Android架构(从线程到AIDL)_13 线程安全的化解之例

目录 7、 线程安全的化解之例 复习:Android单线程环境 非单线程环境的线程安全议题 范例-1 范例-2​编辑 同步(Synchronization)化解线程安全的问题 7、 线程安全的化解之例 复习:Android单线程环境 View是一个单线程的类;其意味着&…

每日AIGC最新进展(80): 重庆大学提出多角色视频生成方法、Adobe提出大视角变化下的人类视频生成、字节跳动提出快速虚拟头像生成方法

Diffusion Models专栏文章汇总:入门与实战 Follow-Your-MultiPose: Tuning-Free Multi-Character Text-to-Video Generation via Pose Guidance 在多角色视频生成的研究中,如何实现文本可编辑和姿态可控的角色生成一直是一个具有挑战性的课题。现有的方法往往只关注单一对象的…

【多线程初阶篇¹】线程理解| 线程和进程的区别

目录 一、认识线程Thread 1.为啥引入线程 2.线程理解 🔥 3.面试题:线程和进程的区别 一、认识线程Thread 1.为啥引入线程 为了解决进程太重量的问题 解释(为什么说线程比进程更轻量?/为什么说线程创建/销毁开销比进程小&#…

Cursor 实战技巧:好用的提示词插件Cursor Rules

你好啊,见字如面。感谢阅读,期待我们下一次的相遇。 最近在小红书发现了有人分享这款Cursor提示词的插件,下面给各位分享下使用教程。简单来说Cursor Rules就是可以为每一个我们自己的项目去配置一个系统级别的提示词,这样在我们…

Tomcat解析

架构图 核心功能 Tomcat是Apache开源的轻量级Java Servlet容器,其中一个Server(Tomcat实例)可以管理多个Service(服务),一个Service包含多个Connector和一个Engine,负责管理请求到应用的整个流…

List-顺序表--2

目录 1、ArrayList 2、ArrayList构造方法 3、ArrayList常见方法 4、ArrayList的遍历 5、ArrayList的扩容机制 6、ArrayList的具体使用 6.1、杨辉三角 6.2、简单的洗牌算法 1、ArrayList 在集合框架中,ArrayList 是一个普通的类,实现了 List 接口…

【C++】字符数|组的输出详解与拓展

博客主页: [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: C 文章目录 💯前言💯字符数组的输出:三种方法解析方法1:直接输出字符串代码示例解析与特点 方法2:使用while循环逐字符输出代码示例解析与特点 方法3&#x…

解决iNodeClient客户端出现查询SSL VPN网关参数失败的问题

一、问题: 使用iNodeClient连接VPN报错,校验网关、用户名、密码都没问题,仍然抱错查询SSL VPN网关参数失败,请检查网络配置或联系管理员。 二、解决方案: 2.1 方案一 重启iNodeAuthService服务 sudo /Library/Star…

数树数(中等难度)

题目: 解题代码: n,qmap(int,input().split())#分别输入层数和路径数量 for i in range(q):sinput()#输入“L”或“R”x1for j in s:if j "L":xx*2-1 #!!!规律else:xx*2print(x)

CAN201 Introduction to Networking(计算机网络)Pt.5 网络安全

文章目录 6. Network Security(网络安全)6.1 What is network security(什么是网络安全)6.2 Principles of cryptography(密码学的原则)6.2.1 Breaking an encryption scheme(破解加密方案&…

【ArcGIS Pro二次开发实例教程】(2):BSM字段赋值

一、简介 一般的数据库要素或表格都有一个BSM字段,用来标识唯一值。 此工具要实现的功能是:按一定的规律(前缀中间的填充数字OBJECT码)来给BSM赋值。 主要技术要点包括: 1、ProWindow的创建,Label,Comb…

ros2笔记-2.5.3 多线程与回调函数

本节体验下多线程。 python示例 在src/demo_python_pkg/demo_python_pkg/下新建文件,learn_thread.py import threading import requestsclass Download:def download(self,url,callback):print(f线程:{threading.get_ident()} 开始下载:{…

C语言练习:求数组的最大值与最小值

文章目录 1. 提出任务2. 完成任务2.1 方法一:通过返回结构体指针来间接返回结果2.1.1 编写程序,实现功能2.1.2 运行程序,查看结果 2.2 方法二:通过参数传递数组,并在函数中修改传入的参数2.2.1 编写程序,实…

conda安装及demo:SadTalker实现图片+音频生成高质量视频

1.安装conda 下载各个版本地址:https://repo.anaconda.com/archive/ win10版本: Anaconda3-2023.03-1-Windows-x86_64 linux版本: Anaconda3-2023.03-1-Linux-x86_64 Windows安装 环境变量 conda -V2.配置conda镜像源 安装pip conda…