pytorch深度学习基础 6(简单的参数估计学习3)

news2024/11/15 8:06:52

上一篇博客说了如何拟合一条直线y=wx+b,今天我们现在使用PyTorch进行相同的曲线拟合
拟合y= x*x -2x +3 + 0.1(-1到1的随机值) 曲线
给定x范围(0,3)

生成数据

import numpy as np
import matplotlib.pyplot as plt
import torch as t
from torch.autograd import Variable as var

def get_data(x,w,b,d):  # w * x * x + b * x + d
    c,r = x.shape
    y = (w * x * x + b*x + d)+ (0.1*(2*np.random.rand(c,r)-1)) # 加入随机噪点
    return(y)

xs = np.arange(0,3,0.01).reshape(-1,1) # 创建的二维 NumPy 数组。这个数组包含了从 0 到 3(不包括3)的等差数列,步长为 0.01
ys = get_data(xs,1,-2,3)

xs = var(t.Tensor(xs))
ys = var(t.Tensor(ys))
plt.plot(xs,  ys)
plt.show()

生成数据后我们来看一下分布情况吧

这就按照等差数列,步长0.01生成的点集图了,接下来

搭建网络

class Fit_model(t.nn.Module):
    def __init__(self):
        super(Fit_model,self).__init__()  # 调用父类torch.nn.Module的初始化方法,这是创建PyTorch模型时的标准做法
        self.linear1 = t.nn.Linear(1,16) # 定义第一个线性层(也称为全连接层),它将输入数据的特征数量从1个扩展到16个
        self.relu = t.nn.ReLU() # 定义一个ReLU激活函数,用于增加模型的非线性
        self.linear2 = t.nn.Linear(16,1)

        self.criterion = t.nn.MSELoss() # 定义损失函数为均方误差损失(MSE),这是回归任务中常用的损失函数
        self.opt = t.optim.SGD(self.parameters(),lr=0.01) # 定义优化器为随机梯度下降(SGD),并设置学习率为0.01。
                                                          # self.parameters()会返回模型中所有可训练的参数
    def forward(self, input):
        y = self.linear1(input)
        y = self.relu(y)
        y = self.linear2(y)
        return y


model = Fit_model()
for e in range(20000):
    y_pre = model(xs)

    loss = model.criterion(y_pre, ys)
    if (e % 100 == 0):
        print(e, loss.data)

    # Zero gradients
    model.opt.zero_grad()
    # perform backward pass
    loss.backward()
    # update weights
    model.opt.step()
# 显示预测的结果
ys_pre = model(xs)

plt.title("curve")
plt.plot(xs.data.numpy(),ys.data.numpy())
# plt.plot(xs.data.numpy(),ys_pre.data.numpy())
plt.plot(xs.data.numpy(), ys_pre.data.numpy(), color='blue', label="ys_pre")
plt.legend("ys","ys_pre")

plt.show()

 这里的话小编训练了2w轮实际上从训练以后的loss反馈来看训练1w左右效果就很好了

上面的方法是固定轮数训练如果想要Loss达到某个标准,可以把训练部分改为如下代码即可

epoch = 1
# 使用while循环进行训练
while epoch:
    # 假设 model 的 forward 方法返回预测值
    y_pre = model(xs)

    # 计算损失
    loss = model.criterion(y_pre, ys)

    # 如果损失小于50,则跳出循环
    if loss.item() < 1:
        print(f"Epoch {epoch}: Loss is less than 50, stopping training.")
        break

        # 每100个epoch打印一次损失
    if (epoch % 100 == 0):
        print(f"Epoch {epoch}, Loss: {loss.item()}")

    # Zero gradients
    model.opt.zero_grad()
    # perform backward pass
    loss.backward()
    # update weights
    model.opt.step()
    epoch +=1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2071465.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

小程序学习day13-API Promise化、全局数据共享(状态管理)、分包

44、API Promise化 &#xff08;1&#xff09;基于回调函数的一部API的缺点&#xff1a;小程序官方提供的异步API都是基于回调函数实现的&#xff0c;容易造成回调地狱的问题&#xff0c;代码可读性、可维护性差 &#xff08;2&#xff09;API Promise化概念&#xff1a; 指…

Qt 环境搭建

sudo apt-get upadte sudo apt-get install qt4-dev-tools sudo apt-get install qtcreator sudo apt-get install qt4-doc sudo apt-get install qt4-qtconfig sudo apt-get install qt-demos编译指令 qmake -projectqmakemake实现Ubuntu20,04 与Windows之间的复制粘贴 安装o…

在C#中如何监控其它应用全屏

原文链接&#xff1a;https://www.cnblogs.com/zhaotianff/p/18338275 在C#中判断其它应用全屏可以有多种方案。我这里提供两种思路 使用定时器 在定时器中定时判断当前窗口的状态是否是最大化或者宽高是否等于桌面窗口的宽高。 这种方法我没有去尝试&#xff0c;凭个人经验…

复杂的编辑表格

需求描述 表格可以整体编辑&#xff1b;也可以单行弹框编辑&#xff1b;且整体编辑的时候&#xff0c;依然可以单行编辑 编辑只能给某一列&#xff08;这里是参数运行值&#xff09;修改&#xff0c;且根据数据内容的参数范围来判断展示不同的形式&#xff1a;input/数字输入/单…

小波卷积:为计算机视觉任务开辟新的参数效率之路

论文复述 这篇论文介绍了一种创新的卷积神经网络层——WTConv&#xff0c;它通过小波变换技术显著扩展了CNN的感受野&#xff0c;同时保持了参数效率。WTConv层能够实现对输入数据的多频率响应&#xff0c;增强了模型对形状而非纹理的特征识别能力&#xff0c;提高了在图像分类…

黑神话悟空不只是玩游戏 有人用它3天赚了85W

这几天你是不是在想办法升级电脑配置&#xff0c;买PS5玩黑神话悟空游戏&#xff0c;每一个男人看到那么好的游戏画面&#xff0c;都控制不住想玩&#xff0c;今天分享给大家一些资料&#xff0c;让你快速玩游戏的同时&#xff0c;还能挣点外快&#xff0c;黑神话悟空不只是玩游…

MATLAB 计算两点沿某个方向的间距(81)

MATLAB 计算两点沿某个方向的间距(81) 一、算法介绍二、算法实现1.代码2.效果一、算法介绍 上一章介绍了如何计算点到空间直线的距离,这里进一步的,我们也可以计算两个点,沿着某个方向的距离,这在很多处理中都会使用到,实际上就是将两点投影到该方向的直线,再计算间距…

线性表复习之初始化顺序表操作

线性表的顺序表示-初始化顺序表 代码 #include <stdio.h> #define MaxSize 10 // 定义最大长度typedef struct{int data[MaxSize]; // 申请空间&#xff08;静态&#xff09;int length; // 当前长度 }SqList;void InitList(SqList &L){for (int i 0; i < MaxS…

java-队列--黑马

队列 别看这个&#xff0c;没用&#xff0c;还是多刷力扣队列题 定义 队列是以顺序的方式维护一组数据的集合&#xff0c;在一端添加数据&#xff0c;从另一端移除数据。一般来讲&#xff0c;添加的一端称之尾&#xff0c;而移除一端称为头 。 队列接口定义 // 队列的接口定…

河南萌新联赛2024第(六)场:郑州大学

目录 A-装备二选一&#xff08;一&#xff09;_河南萌新联赛2024第&#xff08;六&#xff09;场&#xff1a;郑州大学 (nowcoder.com) 思路&#xff1a; 代码&#xff1a; B-百变吗喽_河南萌新联赛2024第&#xff08;六&#xff09;场&#xff1a;郑州大学 (nowcoder.com) …

3DsMax将两个模型的UV展到一个UV上面

3DsMax将两个模型的UV展到一个UV上面 3Dmax中的准备工作 创建一个方块&#xff0c;一个球体&#xff0c;模拟两个模型 添加修改器 打开UV编辑器&#xff0c;快速剥 使用缩放工具&#xff0c;缩放UV&#xff0c;放到一个位置 选择正方形&#xff1a;添加修改器&#xff0…

8.3 数据库基础技术-关系代数

并、交、差 笛卡尔积、投影、选择 自然连接 真题

宝塔面板配置node/npm/yarn/pm2....相关全局变量 npm/node/XXX: command not found

1.打开终端 , cd 到根目录 cd / 2.跳转至node目录下,我的node版本是v16.14.2 cd /www/server/nodejs/v16.14.2/bin 2.1 如果不知道自己node版本多少就跳转到 cd /www/server/nodejs 然后查找当前目录下的文件 ls 确定自己的node版本 cd /node版本/bin 3.继续查看bin…

天润融通助力呷哺呷哺:AI技术赋能3000万会员精细化运营

呷哺集团于1998年11月在北京成立&#xff0c;以“一人一锅”台式小火锅的用餐模式&#xff0c;以及其推出的多样化套餐与良好的用餐服务赢得了众多消费者的青睐&#xff0c;并迅速在市场上占据了一席之地。经过20多年的发展&#xff0c;呷哺呷哺已成为一个多品牌经营、全产业链…

基于Android的安全知识学习APP的设计与实现(论文+源码)_kaic

基于Android的安全知识学习APP的设计与实现 摘 要 随着科技的进步&#xff0c;智能手机已经成为人们工作、学习和生活的必需品。基于Android系统的强大功能&#xff0c;使用Java语言、Linux操作系统&#xff0c;搭配Android Studio&#xff0c;并配备Android开发插件&#…

Unet改进3:在不同位置添加NAMAttention注意力机制

本文内容:在不同位置添加NAMAttention注意力机制 目录 论文简介 1.步骤一 2.步骤二 3.步骤三 4.步骤四 论文简介 识别不太显著的特征是模型压缩的关键。然而,它在革命性的注意机制中尚未得到研究。在这项工作中,我们提出了一种新的基于归一化的注意力模块(NAM),它抑制…

广州自闭症学校哪家好?

在广州&#xff0c;选择一家适合自闭症儿童的康复学校是一个需要慎重考虑的决定。在众多机构中&#xff0c;星启帆自闭症儿童康复机构以其专业的师资团队、全面的康复服务以及温馨的学习环境脱颖而出&#xff0c;成为众多家长信赖的选择。 星启帆自闭症康复中心&#xff0c;作…

敦煌智旅:Serverless 初探,运维提效 60%

作者&#xff1a; 百潼 行业新趋势 在后疫情时代&#xff0c;文旅行业开始复苏&#xff0c;在行业的发展趋势中&#xff0c;我们看到了一个充满机遇和挑战的未来。通过不断创新和适应市场需求&#xff0c;文旅行业继续不断发展壮大&#xff0c;为消费者提供更加丰富多样的旅游…

UnQLite:多语言支持的嵌入式NoSQL数据库深入解析

文章目录 1. 引言2. Key/Value 存储接口2.1 关键函数2.2 使用示例2.3 高级操作&#xff1a;批量文件存储 3. 游标的使用4. UnQLite-Python使用示例4. UnQLite数据库引擎架构5.1 Key/Value存储层5.2 文档存储层5.3 可插拔的存储引擎5.4 事务管理器与分页模块5.5 虚拟文件系统 6.…

右值引用与左值引用

目录 1. 左值与右值2. 左值引用与右值引用 1. 左值与右值 2. 左值引用与右值引用