深度学习记录1(线性回归的实现)

news2025/4/19 13:50:59

1、整体思路

根据线性回归的定义,

,建立线性回归模型,在损失函数的计算上,采用L2 Loss(均方误差)。同时,对于模型的优化采用随机梯度下降。

2、详细代码分析

import random
import torch
from d2l import torch as d2l
def synthetic_data(w, b, num_examples):
    X = torch.normal(0, 1, (num_examples, len(w)))
    y = torch.matmul(X, w) + b
    y += torch.normal(0, 0.01, y.shape)
    return X, y.reshape((-1, 1))
 
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

由于未采用数据集,所以通过自己生成数据来验证模型的正确性。synthetic_data函数为实现的功能为根据相关参数,生成随机数据集,函数输入参数为真实的w,b以及所需要的样本数,通过该函数以及定义的w,b,可以按照

生成相关数据,样本总数为1000。X按照离散正态分布生成,它的每一行为一个样本,每一列为一个特征值。y根据X的值进行生成,为使数据集贴近真实情况,对其利用离散正态分布加上小扰动。

返回值为生成的样本以及对应的标签。(注:在此方法下,)

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(indices[i:min(i+batch_size, num_examples)])
    yield features[batch_indices], labels[batch_indices]
 
batch_size = 10

features,labels本质上来看就是列表,所以为了实现数据的随机读取,只需要将读取列表时的标号打乱即可,同时可以保证完成对所有数据的读取。indices中的数据为0到num_examples-1中的自然数顺序排列,利用random.shuffle()实现对indices进行随机打乱。使用for循环完成对indices的切割,达到每次输出batch_size样本数的目的。(注:该函数使用生成器,在每次调用时读取batch_size的数据,而不是一次性全部输出)

w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)

对待求参数进行初始化,由于待求的w有两个,输出结果一个,因此w的size为(2,1),由于之后需要进行梯度计算,因此requires_grad=True。

def linreg(X, w, b): 
    """线性回归模型。"""
    return torch.matmul(X, w) + b
 
def squared_loss(y_hat, y): 
    """均方损失。"""
    return (y_hat - y.reshape(y_hat.shape))**2 / 2
 
def sgd(params, lr, batch_size):  
    """小批量随机梯度下降。"""
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

定义线性回归模型,均方损失,随机梯度下降函数,在进行损失计算的时候,按照公式,应当除以batch_size,由于该计算是线性的,因此在进行损失计算时或是在更新参数时计算不影响最后结果。由于pytorch中对于梯度是累积的,因此在每一次更新参数后,应该将保存的梯度清零。

lr = 0.03 #学习率
num_epochs = 3
net = linreg
loss = squared_loss
 
for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y)  # `X`和`y`的小批量损失
        # 因为`l`形状是(`batch_size`, 1),而不是一个标量。`l`中的所有元素被加到一起,
        # 并以此计算关于[`w`, `b`]的梯度
        l.sum().backward()
        sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')

此处需要进行l.sum(),是由于在对每一个batch进行损失计算时得到的y为一个列向量,根据反向传播,需要将当前所有样本计算后得到的输出相加并除以样本数的结果进行计算。在该种方法下,更新参数的次数为3*(1000 / 10)。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/698017.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Day7——Web安全基础下

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 回顾前言一、owasp top 10漏洞(了解)(四年一更)1.访问控制崩溃2.敏感数据暴露3.sql注入4.不安全的设计5.安全配置不当…

【单片机】STM32单片机的各个定时器的定时中断程序,标准库

文章目录 定时器1_定时中断定时器2_定时中断定时器3_定时中断定时器4_定时中断定时器5_定时中断 高级定时器和普通定时器的区别(https://zhuanlan.zhihu.com/p/557896041): 定时器1_定时中断 TIM1是高级定时器,使用的时钟总线是R…

使用Megascans,Blender和Substance 3D画家创建渔人旅馆(p2)

今天云渲染小编接着Polina Tarakanova分享的Fishermans Inn项目上篇分享,下篇主要是纹理和材料、组装场景、照明等方面的分享。 纹理和材料 随着酒馆的模块化建设完成,是时候进入贴图阶段了。我使用Substance 3D Painter进行了所有的贴图工作。在我的场…

【网站创建】网络杂谈(6)之web网站的创建

涉及知识点 如何创建web网站,web网站创建的步骤,手把手教你如何搭建web网站,web网站创建的过程,深入了解web网站创建。 原创于:CSDN博主-《拄杖盲学轻声码》,更多内容可去其主页关注下哈,不胜感…

基于Java+SpringBoot+Vue的计算机类考研交流平台设计与实现

博主介绍:擅长Java、微信小程序、Python、Android等,专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇🏻 不然下次找不到哟 Java项目精品实战案例…

React-View-UI组件库封装Loading加载中源码

目录 组件介绍Loading API能力组件源码组件测试源码组件库线上地址 组件介绍 Loading组件是日常开发用的很多的组件,这次封装主要包含两种状态的Loading,旋转、省略号,话不多说先看一下组件的文档页面吧: 正在上传…重新上传取…

掌握imgproc组件:opencv-图像变换

图像变换 1. 基于OpenCV的边缘检测1.1 边缘检测的一般步骤1.2 canny算子1.2.1 Canny边缘检测步骤:1.2.2 Canny边缘检测:Canny()函数1.2.3 Canny边缘检测案例 1.3 sobel算子1.3.1 sobel算子的计算过程1.3.2 使用Sobel算子:Sobel()函数1.3.3 示…

模拟高并发下RabbitMQ的削峰作用

在并发量很高的时候,服务端处理不过来客户端发的请求,这个时候可以使用消息队列,实现削峰。原理就是请求先打到队列上,服务端从队列里取出消息进行处理,处理不过来的消息就堆积在消息队列里等待。 可以模拟一下这个过…

生态+公链:中创面向未来的区块链建设!

未来的区块链市场,一定属于能够将区块链技术与应用完美结合在一起的产品。从互联网的发展历程来看,最后的竞争往往会集中到生态与兼容性。 如何将区块链的落地和应用更加有机地结合在一起,从而让区块链的功能和作用得到最大程度的发挥&#…

机器学习8:特征组合-Feature Crosses

特征组合也称特征交叉(Feature Crosses),即不同类型或者不同维度特征之间的交叉组合,其主要目的是提高对复杂关系的拟合能力。在特征工程中,通常会把一阶离散特征两两组合,构成高阶组合特征。可以进行组合的…

css:去除input和textarea默认边框样式并美化

input input默认样式和focus样式 参考element-ui的css&#xff0c;可以实现如下效果 实现代码 <style>/* 去除默认样式 */input {border: none;outline: none;padding: 0;margin: 0;-webkit-appearance: none;-moz-appearance: none;appearance: none;background-im…

ElasticSearch 8.0+ 版本Windows系统启动

下载地址&#xff1a;https://www.elastic.co/cn/downloads/past-releases/winlogbeat-8-8-1 解压\elasticsearch\elasticsearch-8.5.1 进入bin目录&#xff0c;启动elasticsearch.bat 问题1&#xff1a; warning: ignoring JAVA_HOMED:\jdk1.8.0_271; using bundled JDK J…

使用凌鲨连接SSH服务器

SSH&#xff08;Secure Shell&#xff09;是一种加密的网络协议&#xff0c;用于安全地连接远程服务器。它提供了一种安全的通信方式&#xff0c;使得用户可以在不受干扰的情况下远程访问服务器。SSH协议的加密技术可以保护用户的登录信息和数据传输过程中的安全性。 SSH对于服…

伦敦银同业拆借利率查询

伦敦银同业拆借利率&#xff08;London InterBank Offered rate&#xff09;简称Libor&#xff0c;它是伦敦银业之间在货币市场的无担保借贷利率&#xff0c;主要报价有五种币别&#xff1a;美元、欧元、英镑、日圆、瑞士法郎&#xff0c;分别有隔夜、一周、一个月、两个月、三…

密码学—Vigenere破解Python程序

文章目录 概要预备知识点学习整体流程技术名词解释技术细节小结代码 概要 破解Vigenere需要Kasiski测试法与重合指数法的理论基础 具体知识点细节看下面这两篇文章 预备知识点学习 下面两个是结合起来使用猜测密钥长度的&#xff0c;只有确认了密钥长度之后才可以进行破解。 …

Jupyter Notebook左侧大纲目录设置

在 Jupyter Notebook 中&#xff0c;可以通过安装jupyter_contrib_nbextensions插件来实现在页面左边显示大纲的功能。 1. 安装插件 pip install jupyter_contrib_nbextensions 1.1 如何安装 windows cmd小黑裙窗口&#xff1b; 1.查看目前安装了哪些库 conda list 2. 使用…

【Oracle】springboot连接Oracle写入blob类型图片数据

目录 一、表结构二、mapper 接口和sql三、实体类四、controller五、插入成功后的效果 springboot连接Oracle写入blob类型图片数据 一、表结构 -- 创建表: student_info 属主: scott (默认当前用户) create table scott.student_info (sno number(10) constraint pk_si…

Vue3 完整项目搭建 Vue3+Pinia+Vant3/ElementPlus+typerscript

❤ Vue3 项目 1、Vue3+Pinia+Vant3/ElementPlus+typerscript环境搭建 1、安装 Vue-cli 3.0 脚手架工具 npm install -g @vue/cli2、安装vite环境 npm init @vitejs/app报错 使用: yarn create @vitejs/app依然报错 转而使用推荐的: npm c

Redisson分布式锁原理

1、Redisson简介 一个基于Redis实现的分布式工具&#xff0c;有基本分布式对象和高级又抽象的分布式服务&#xff0c;为每个试图再造分布式轮子的程序员带来了大部分分布式问题的解决办法。 2、使用方法 引入依赖 <dependency><groupId>org.springframework.bo…

基于Python所写的Word助手设计

点击以下链接获取源码资源&#xff1a; https://download.csdn.net/download/qq_64505944/87959100?spm1001.2014.3001.5503 《Word助手》程序使用说明 在PyCharm中运行《Word助手》即可进入如图1所示的系统主界面。在该界面中&#xff0c;通过顶部的工具栏可以选择所要进行的…