Pytorch实现线性回归Linear Regression

news2025/1/14 18:14:32

借助 PyTorch 实现深度神经网络 - 线性回归 - 第 2 周 | Coursera

线性回归预测

用PyTorch实现线性回归模块

f070f2ec76df43abaf729d5793230134.png

创建自定义模块(内含一个线性回归)

6b99a23f99be443c99f8501dc8f9943f.png

24c6b66fe591473f92115b0246077091.png

训练线性回归模型

对于线性回归,特定类型的噪声是高斯噪声

4dd6e6e710d041118e8c6916ee5ab399.png

平均损失均方误差函数:

ca40bebf938f4d819d5169765b05fe2c.png

loss求解(导数=0):

823863a0a51d49c494927630586adcc1.png

梯度下降

783413bc85ec40649e6c44e137d022e4.png

eq?%5Ceta表示学习率

66a21da7c5ad46f1806059a03f38f82f.png

学习率过高,可能错过参数的最佳值

学习率过低,需要大量的迭代才能获得最小值

Batch Gradient Descent:使用整个训练集来更新模型的参数 

用Pytorch实现线性回归--梯度

493855116b9949559e0ff6ca31455924.png

45f47906ebf14783ac0a1fa75cbc857d.png

0a401308297e4e32bb8b0d351a6617fd.png

9dece871069a4a2891ddbd7478e07e0f.png

每个epoch就是一个iteration:

9802fbe8327d435ba6fb028827a0fcbb.png

画图版:

import torch
w=torch.tensor(-10.0,requires_grad=True)
X=torch.arange(-3,3,0.1).view(-1,1)
f=-3*X
# The class for plotting

class plot_diagram():
    
    # Constructor
    def __init__(self, X, Y, w, stop, go = False):
        start = w.data
        self.error = []
        self.parameter = []
        print(type(X.numpy()))
        self.X = X.numpy()
       
        self.Y = Y.numpy()
        self.parameter_values = torch.arange(start, stop)
        self.Loss_function = [criterion(forward(X), Y) for w.data in self.parameter_values] 
        w.data = start
        
    # Executor
    def __call__(self, Yhat, w, error, n):
        self.error.append(error)
        self.parameter.append(w.data)
        plt.subplot(212)
        plt.plot(self.X, Yhat.detach().numpy())
        plt.plot(self.X, self.Y,'ro')
        plt.xlabel("A")
        plt.ylim(-20, 20)
        plt.subplot(211)
        plt.title("Data Space (top) Estimated Line (bottom) Iteration " + str(n))
        # Convert lists to PyTorch tensors
        parameter_values_tensor = torch.tensor(self.parameter_values)
        loss_function_tensor = torch.tensor(self.Loss_function)

        # Plot using the tensors
        plt.plot(parameter_values_tensor.numpy(), loss_function_tensor.numpy())
  
        plt.plot(self.parameter, self.error, 'ro')
        plt.xlabel("B")
        plt.figure()
    
    # Destructor
    def __del__(self):
        plt.close('all')
gradient_plot = plot_diagram(X, Y, w, stop = 5)
# Define a function for train the model

def train_model(iter):
    LOSS=[]
    for epoch in range (iter):
        
        # make the prediction as we learned in the last lab
        Yhat = forward(X)
        
        # calculate the iteration
        loss = criterion(Yhat,Y)
        
        # plot the diagram for us to have a better idea
        gradient_plot(Yhat, w, loss.item(), epoch)
        
        # store the loss into list
        LOSS.append(loss.item())
        
        # backward pass: compute gradient of the loss with respect to all the learnable parameters
        loss.backward()
        
        # updata parameters
        w.data = w.data - lr * w.grad.data
        
        # zero the gradients before running the backward pass
        w.grad.data.zero_()
train_model(4)

15c5d775a43646ab97e2989e6726fede.png

cc0cff3af6a34921afbc9b5072d50532.png

9ed27917184b4175858fb2d17e559f1b.png

0e72fbd061894b10bbb3e5e9c1fba34b.png

用Pytorch实现线性回归--训练

与上文类似,只是多加了个b

梯度

c99b3440b11249348eefeba808ee2008.png

0aa14cf6f1714b678471e616405a2022.png

edfd17106a104718a2179867111b28b9.png

画函数图:

# The class for plot the diagram

class plot_error_surfaces(object):
    
    # Constructor
    def __init__(self, w_range, b_range, X, Y, n_samples = 30, go = True):
        W = np.linspace(-w_range, w_range, n_samples)
        B = np.linspace(-b_range, b_range, n_samples)
        w, b = np.meshgrid(W, B)    
        Z = np.zeros((30,30))
        count1 = 0
        self.y = Y.numpy()
        self.x = X.numpy()
        for w1, b1 in zip(w, b):
            count2 = 0
            for w2, b2 in zip(w1, b1):
                Z[count1, count2] = np.mean((self.y - w2 * self.x + b2) ** 2)
                count2 += 1
            count1 += 1
        self.Z = Z
        self.w = w
        self.b = b
        self.W = []
        self.B = []
        self.LOSS = []
        self.n = 0
        if go == True:
            plt.figure()
            plt.figure(figsize = (7.5, 5))
            plt.axes(projection='3d').plot_surface(self.w, self.b, self.Z, rstride = 1, cstride = 1,cmap = 'viridis', edgecolor = 'none')
            plt.title('Cost/Total Loss Surface')
            plt.xlabel('w')
            plt.ylabel('b')
            plt.show()
            plt.figure()
            plt.title('Cost/Total Loss Surface Contour')
            plt.xlabel('w')
            plt.ylabel('b')
            plt.contour(self.w, self.b, self.Z)
            plt.show()
    
    # Setter
    def set_para_loss(self, W, B, loss):
        self.n = self.n + 1
        self.W.append(W)
        self.B.append(B)
        self.LOSS.append(loss)
    
    # Plot diagram
    def final_plot(self): 
        ax = plt.axes(projection = '3d')
        ax.plot_wireframe(self.w, self.b, self.Z)
        ax.scatter(self.W,self.B, self.LOSS, c = 'r', marker = 'x', s = 200, alpha = 1)
        plt.figure()
        plt.contour(self.w,self.b, self.Z)
        plt.scatter(self.W, self.B, c = 'r', marker = 'x')
        plt.xlabel('w')
        plt.ylabel('b')
        plt.show()
    
    # Plot diagram
    def plot_ps(self):
        plt.subplot(121)
        plt.ylim
        plt.plot(self.x, self.y, 'ro', label="training points")
        plt.plot(self.x, self.W[-1] * self.x + self.B[-1], label = "estimated line")
        plt.xlabel('x')
        plt.ylabel('y')
        plt.ylim((-10, 15))
        plt.title('Data Space Iteration: ' + str(self.n))

        plt.subplot(122)
        plt.contour(self.w, self.b, self.Z)
        plt.scatter(self.W, self.B, c = 'r', marker = 'x')
        plt.title('Total Loss Surface Contour Iteration' + str(self.n))
        plt.xlabel('w')
        plt.ylabel('b')
        plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1968409.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Mallet:一款针对任意协议的安全拦截代理工具

关于Mallet Mallet是一款功能强大的协议安全分析工具,该工具支持针对任意协议创建用于安全审计的拦截代理,该工具本质上与我们所熟悉的拦截Web代理类似,只是通用性更强。 工具运行机制 Mallet建立在Netty框架之上,并且依赖于Net…

文案人的梦工场,网易入职指南!

网易云对于咱们一些有点文艺的文案策划来说,简直就是梦中情司。 在这里工作锻炼机会很多,也很开拓眼界,能获得相当于在别处3倍能力的成长速度,福利待遇也是很好的。 要进入网易云音乐做文案策划,你可以按照以下步骤进…

数据结构的基本概念与算法2

线性表 : 线性表 是具有相同数据类型的 n (n > 0) 个数据元素的有限序列,其中 n 为表长,当 n 0 时线性表是一个空表;若用 L 命名线性表,则其一般表示为:L (a1 , a2 , ... ai , ai1 , ... an) 上述中&a…

月木学途开发 2.项目架构

1.项目介绍 月木学途是一款it在线学习网站,项目采用前后端分离架构。前端开发主要使用vue.js,后端使用Spring Cloud Alibaba技术栈。项目包含学习网站的大部分功能,分为管理员端和用户端。管理员端有权限管理、课程管理、网站管理、求职模块管…

Shell函数和Shell 输入/输出重定向

LInux:Shell函数和Shell 输入/输出重定向 Shell函数 参数说明: 可以带function fun() 定义,也可以直接fun() 定义,不带任何参数。参数返回,可以显示加:return 返回,如果不加,将以最后一条命令运…

[Vue warn]: data functions should return an object:

仔细检查你的代码肯定有一个data()内忘记方return{}了

C语言程序设计23

《C程序设计教程(第四版)——谭浩强》 例题2.11 从键盘输入B、O、Y三个字符,然后把他们输出到屏幕上 代码: //《C程序设计教程(第四版)——谭浩强》 //例题2.11 从键盘输入B、O、Y三个字符,然…

RabbitMQ:MQ的可靠性

MQ的可靠性 在默认情况下,RabbitMQ会将接收到的信息保存在内存中以降低消息收发的延迟。这样会导致两个问题: 一旦MQ宕机,内存中的消息会丢失 内存空间有限,当消费者故障或处理过慢时,会导致消息积压,引发MQ阻塞。 …

高效、安全、共享|济南市升级教育城域网,重塑教育网络生态

文/济南市电化教育馆 电教教研室主任 张承强 导语: 近年来,济南市教育局以前瞻性的视野,将教育数字化转型视为推动教育高质量发展的基石,全力加速教育现代化进程。在这一蓝图下,教育城域网的升级改造项目被赋予了基础性、先导性和战略性的重要意义,成为探索教育数字化转型新路…

一键搞定PDF翻译,这四款是职场达人常备翻译工具!!!

作为外贸搬砖人的一份子,虽说外语功底还说地过去,但是每天过目大量pdf文件的翻译,难免还有些吃力,这个时候如果有可以辅助翻译的工具那就再好不过了,今天给大家带来四款非常适合pdf文件翻译的工具,总有一款…

C#中的通信

上位机应用开发-串口通信1、基于C#的串口通信对象:SerialPort 2、字段属性 PortName:获取或设置通信端口 BaudRate:获取或设置串行波特率-DataBits:获取或设置每个字节的标准数据位长度 Parity:获取或设置奇偶校验检查协仪I-StopBits;获取或设置每个字节的标准停止位数 3、…

你需要的Node版本管理神器NVM

在做项目的时候,很多人本地的node都是装一个固定版本,一旦有些项目要下的依赖需要更高版本的node支持的时候,此时需要升级node就得把已经安装的低版本node卸载了,然后再重新下载、安装高版本的node,既费时间又抓狂,特别…

大模型算法面试题(十九)

本系列收纳各种大模型面试题及答案。 1、SFT(有监督微调)、RM(奖励模型)、PPO(强化学习)的数据集格式? SFT(有监督微调)、RM(奖励模型)、PPO&…

网工内推 | 云运维工程师,最高19K,五险一金加补充医疗险

01 云计算运维工程师 🔷岗位职责 1、负责客户云计算解决方案的运维,负责云计算解决方案中云、虚拟化工作; 2、负责客户现场H3C产品的日常问题处理、变更维护、巡检、版本升级等工作,保障客户网络的稳定运行; 3、协调…

yolo数据集格式按照每一个类别的比例划分数据集

写在前面: 写脚本不易,写博客不易,请多点赞关注,谢谢。10多年来,我一直免费给大家毫无保留的分享技术等,不但从来没被打赏过,而且在分享有些模型转化处理的高级脚本中,有些同胞由于自…

pmp证书实用性怎么样,考这个性价比高不高,难度?

要是 PMP 证书没有价值,还会有那么多人愿意去考吗? 我觉得一个原因是因为行业/岗位需求高,还有就是拿证后能不能用得上,看人看公司,很大一部分考证的人都是因为应聘跟投标书要求。 据我了解,PMP 证书目前…

八戒会修特斯拉 特斯拉如何磨合制动器

--------------------------------------------------------------------------------------------------------------------------------- -------------------------------------- 作者: 八戒会修特斯拉 -------------------------…

动态注意力机制新突破!11个最新idea,看了就能发顶会!

在处理复杂数据时,可以通过引入动态注意力机制,让模型根据输入数据的特点动态调整关注点,聚焦最关键的信息,来提高模型的处理能力和效率。 这种比传统方法更高效、灵活的技术足以应对各种复杂任务和挑战,具有强大的适…

C语言程序设计22

《C程序设计教程&#xff08;第四版&#xff09;——谭浩强》 例题2.10 先后输出B、O、Y三个字符 代码&#xff1a; //《C程序设计教程&#xff08;第四版&#xff09;——谭浩强》 //例题2.10 先后输出B、O、Y三个字符#include <stdio.h> int main() {char a B;char …

E22.【C语言】练习:“详解函数递归”文中青蛙跳台阶的答案

点击查看原文 代码实现 jump(n)jump(n-1)jump(n-2) #define _CRT_SECURE_NO_WARNINGS #include <stdio.h> int jump(int n) {if (1 n){return 1;}else if (2 n){return 2;}else{return jump(n - 1) jump(n - 2);} }int main() {int n 0;printf("请输入台阶总数…