使用Tensorflow2和Pytorch实现线性回归

news2025/1/25 4:46:22

使用Tensorflow2和Pytorch实现线性回归

  • 步骤
  • Tensorflow2代码
    • 效果
  • Pytorch代码
    • 效果

步骤

准备步骤:
1. 创建数据集
2. 设置超参数
3. 创建模型(函数)
4. 选择损失函数
5. 选择优化器

训练步骤:
6. 通过模型(函数)前向传播
7. 计算损失
8. 对超参数求梯度
9. 使用优化器利用梯度调整超参数

测试步骤:
10. 创建测试集
11. 通过模型得到预测结果
12. 画出散点图和曲线图

Tensorflow2代码

import tensorflow as tf
import numpy as np
from tensorflow.keras import Model
from tensorflow.keras.losses import MeanSquaredLogarithmicError
from tensorflow.keras.optimizers import SGD
import matplotlib.pyplot as plt

#初始化参数
x=tf.reshape(tf.range(0,15,dtype=tf.float32),[15,1])
y=3*x+tf.constant(np.random.randn(15,1).astype(np.float32))+4
w=tf.Variable(np.random.rand(),dtype=tf.float32)
b=tf.Variable(np.random.rand(),dtype=tf.float32)

print('x=',np.reshape(x,[1,15]),'\ny=',np.reshape(y,[1,15]),'\nw=',w,'\nb=',b)

#创建模型
class My_model(Model):
    def __init__(self):
        super().__init__()

    # 构建一个线性层
    def linear(self,x):
        return w*x+b

    def call(self,x):
        x=self.linear(x)
        return x

#定义超参数
epoch=1500 #迭代次数
Ir=0.01 #学习率
model=My_model() #初始化模型
optimizer=SGD(learning_rate=Ir) #初始化优化器
losser=MeanSquaredLogarithmicError() #初始化损失函数
all_loss=[] #用于存储loss

print('--------------训练------------------------------------------------------------------')

for i in range(1,epoch+1):
    with tf.GradientTape() as tape:
        cy = model(x)  # 前向传播,获得预测值
        loss = losser(cy, y) #计算loss
        grad=tape.gradient(loss,[w,b]) #求出w,b的梯度
    optimizer.apply_gradients(zip(grad,[w,b]))
    if i%10==0:
        all_loss.append(loss) #添加loss
        print('epoch:',i,'loss:',loss) #打印loss值

print('w:',w,'\nb:',b)
print('--------------测试------------------------------------------------------------------')
#画图
plt.rcParams['font.sans-serif']=['SimHei'] #载入字体
px=tf.reshape(tf.range(0,15,0.1,dtype=tf.float32),[150,1])
py=model(px)
plt.subplot(121)
plt.title('结果图')
plt.scatter(x,y)
plt.plot(px.numpy(),py.numpy())
plt.subplot(122)
plt.title('loss图')
plt.plot(px.numpy(),all_loss)
plt.show()

效果

效果如图:
随着迭代次数的增加,loss逐渐减小。
在这里插入图片描述
在这里插入图片描述

Pytorch代码

import torch as th
from torch.nn import Module,Linear,MSELoss
from torch.optim import SGD
import matplotlib.pyplot as plt

#初始化参数
x=th.arange(0,15,1,dtype=th.float32).view(15,1)
y=3*x+th.randn(15,1,dtype=th.float32)+4
w=th.rand(1)
b=th.rand(1)
print('x=',x.view(1,15),'\ny=',y.view(1,15),'\nw=',w,'\nb=',b)
print('--------------训练------------------------------------------------------------------')
#创建模型
class My_model(Module):
    def __init__(self,input_shape,output_shape):
        super().__init__()
        self.linear=Linear(input_shape,output_shape)

    def forward(self,x):
        x=self.linear(x)
        return x

#定义超参数
epoch=1500 #迭代次数
Ir=0.01 #学习率
model=My_model(1,1) #初始化模型
optimizer=SGD(model.parameters(),Ir) #初始化优化器
losser=MSELoss() #初始化损失函数
all_loss=[] #用于存储loss

for i in range(1,epoch+1):
    optimizer.zero_grad() #将优化器的梯度清零,防止叠加
    cy=model(x)  #前向传播,获得预测值
    loss=losser(cy,y)
    loss.backward() #计算loss和反向传播
    optimizer.step() #更新权重
    if i%10==0:
        all_loss.append(loss)
        print('epoch:',i,'loss:',loss) #打印loss值

print(optimizer.state)
print('--------------测试------------------------------------------------------------------')
#画图
plt.rcParams['font.sans-serif']=['SimHei'] #载入字体
px=th.arange(0,15,0.1,dtype=th.float32,requires_grad=False).view(150,1)
py=model(px)
plt.subplot(121)
plt.title('结果图')
plt.scatter(x,y)
plt.plot(px.detach().numpy(),py.detach().numpy())
plt.subplot(122)
plt.title('loss图')
plt.plot(px.detach().numpy(),all_loss)
plt.show()

效果

效果如下:
随着迭代次数的增加,loss逐渐减小。
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/90779.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【人脸识别】人脸实时检测与跟踪【含GUI Matlab源码 673期】

⛄一、简介 如何在视频流中检测到人脸以及人脸追踪。对象检测和跟踪在许多计算机视觉应用中都很重要,包括活动识别,汽车安全和监视。所以这篇主要总结MATLAB的人脸检测和跟踪。 首先看一下流程。检测人脸——>面部特征提取——>脸部追踪。 ⛄二、…

springcloud3 EurekaClient集群的搭建2

一 概述 1.1 概述 本文主要是搭建集成eurekaserver的几个客户端,即服务提供者,消费者。架构图如下所示 1.2 使用eureka整合的优点 使用Eureka管理注册的好处:消费者直接调用服务名称而不用在关系地址和端口,且该服务还有负载均…

[附源码]Nodejs计算机毕业设计基于的仓库管理系统Express(程序+LW)

该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程。欢迎交流 项目运行 环境配置: Node.js Vscode Mysql5.7 HBuilderXNavicat11VueExpress。 项目技术: Express框架 Node.js Vue 等等组成,B/S模式 Vscode管理前后端分…

DBCO-PEG-Mesylate,Mesylate-PEG-DBCO,甲磺酸酯聚乙二醇环辛炔

一、试剂基团反应特点(Reagent group reaction characteristics): DBCO-PEG-Mesylate属于高分子PEG,甲磺酸酯是甲磺酸与醇酯化而成的酯类化合物。“点击化学"一般由叠氮化物(azide)和炔烃(…

React - 组件样式模块化

React - 组件样式模块化一. 存在的问题二. 解决样式冲突,组件样式模块化当多个组件使用相同类名时,设置的css样式会存在冲突渲染。 一. 存在的问题 例如有Page1、Page2两个组件,在 Page1 组件引入了css样式,Page2 组件未引入。 组…

用Excel写个摸球模拟器玩玩

用Excel写个摸球模拟器玩玩背景代码实现相关资料背景 最近对象有个需求,想要帮忙写个程序,实现功能:模拟两种颜色的球,随机摸球N次后,摸到不同颜色的次数。 考虑到非程序员的环境配置问题,直接用Excel中的…

【配电网规划】SOCPR和基于线性离散最优潮流(OPF)模型的配电网规划( DNP )(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

深兰科技接连斩获工业设计奖!出众产品设计,助AI产品一路领先

十余年来,第三代AI浪潮奔腾汹涌,中国AI产业从全面追赶到部分实现超越。两年前,AI更是正式成为国家七大新基建之一。从国家战略到基础设施,AI正全面地从文件走向现实,国内人工智能的市场规模也迅速扩大。这背后&#xf…

简易聊天室代码分享 js+socket.io

先言 这我以前写的,这里就是单纯分享下代码,不算正经文章。效果如下,前端用一个单html文件。然后后端用node.js和socket.io,也是只用一个单js文件就好。这里可以看下代码的实现逻辑就好,因为来连数据库才能运行的。有…

HTML网页设计制作大作业 基于HTML+CSS+JavaScript实现炫丽口红网化妆品网站(10页)

🎉精彩专栏推荐 💭文末获取联系 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 💂 作者主页: 【主页——🚀获取更多优质源码】 🎓 web前端期末大作业: 【📚毕设项目精品实战案例 (10…

《天天学敏捷:Scrum团队转型记》读书笔记

读书给人以快乐、给人以光彩、给人以才干。 —— 培根 基本信息 作者:杨蕾,郑江著推荐值:76.7%微信读书:天天学敏捷:Scrum团队转型记 收获 & 思考 阅读目标:提前明确目标,有助于提升阅读效…

营销新赛道:虚拟数字人

2021年10月Facebook改名Meta,引爆全球范围的元宇宙热,和Web 3.0相比较,元宇宙是一个完整的生态,而Web 3.0特指一种交互方式和实现方法,两者之间的关系类似于移动互联网与HTML 5。在元宇宙生态下,营销的3要素…

[附源码]Python计算机毕业设计-高校人事管理系统Django(程序+LW)

该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程 项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等…

Vue--》路由vue-router的使用讲解

目录 路由简述 vue-router vue-router的安装配置与使用 路由重定向 嵌套路由 嵌套路由重定向 命名路由 动态路由 路由简述 路由(英文:router)就是对应关系。单页面应用程序(SPA)指的是一个web网站只有唯一一个…

故障分析 | MySQL死锁案例分析

作者:杨奇龙 网名“北在南方”,资深 DBA,主要负责数据库架构设计和运维平台开发工作,擅长数据库性能调优、故障诊断。 本文来源:原创投稿 *爱可生开源社区出品,原创内容未经授权不得随意使用,转…

[附源码]Python计算机毕业设计高校教材网上征订系统Django(程序+LW)

该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程 项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等…

密码改造技术路径大比拼--“免”改造太理想,“重”改造太复杂,“易”改造是王道

随着《密码法》的颁布施行,密码产业进入爆发式增长期。市场用户侧、供给侧、监管侧对于“密评密改”的标准路径和部署方式共识度低,有唱专业的,有唱商业的,有唱便捷的,有唱可持续发展的,有唱单品的&#xf…

ANSYS Mechanical 2020 R1 版本新特性-CABLE 280单元分析索结构

导读:3D 缆索单元,可用的产品:Pro | Premium | Enterprise | PrepPost | Solver | AS add-on 一、CABLE 280 单元概述 CABLE280适用于分析中等至极细的缆索结构(如海底电缆)。该单元是三维三节点二次线单元。每个节点有x , y , z三个平动自…

【Anime.js】——JavaScript动画库:Anime.js

官方文档 官网定义: anime.js 是一个简便的JS动画库,用法简单而且适用范围广,涵盖CSS,DOM,SVG还有JS的对象,各种带数值属性的东西都可以动起来。 一、搭建开发环境 1、新建一个文件夹 ,用vs c…

CpG ODN丨艾美捷ODN 1982 (synthetic)参数说明

艾美捷CpG ODN系列——ODN 1982 (synthetic):具有硫代磷酸酯骨架的GpC寡脱氧核苷酸。 艾美捷CpG ODN丨ODN 1982 (synthetic)化学性质: 序列:5-tccatgagcttcctgagct-3(小写字母表示硫代磷酸酯键)。 MW:638…