pytorch 实现线性回归(深度学习)

news2024/12/29 10:29:16

一 查看原始函数

        y=2x+4.2

初始化

%matplotlib inline
import random
import torch
from d2l import torch as d2l

1.1 生成原始数据

def synthetic_data(w, b, num_examples):
    x = torch.normal(0, 1, (num_examples, len(w)))
    y = torch.matmul(x, w) + b
    print('x:', x)
    print('y:', y)
    y += torch.normal(0, 0.01, y.shape)  # 噪声
    return x, y.reshape((-1 , 1))
true_w = torch.tensor([2.])
true_b = 4.2
print(f'true_w: {true_w}, true_b: {true_b}')

features, labels = synthetic_data(true_w, true_b, 10)

1.2 数据转换

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]

batch_size = 10
for x, y in data_iter(batch_size, features, labels):
    print(f'x: {x}, \ny: {y}')

1.3 初始化权重

随机初始化,w使用 均值0,方差 0.01 的随机值, b 初始化为1

w = torch.normal(0, 0.01, size = (1,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b

二 执行训练

查看训练过程中的 参数变化:

print(f'true_w: {true_w}, true_b: {true_b}')

def squared_loss(y_hat, y):
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

def linreg(x, w, b):
    return torch.matmul(x, w) + b

def sgd(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            # print('param:', param, 'param.grad:', param.grad)
            param -= lr * param.grad / batch_size
            param.grad.zero_()

lr = 0.03
num_epochs = 1000
for epoch in range(num_epochs):
    for x, y in data_iter(batch_size, features, labels):
        l = squared_loss(linreg(x, w, b), y)   # 计算总损失
        print('w:', w, 'b:', b)  # l:', l, '\n
        l.sum().backward()
        sgd([w, b], lr, batch_size)

 


三 测试梯度更新

初始化数据

%matplotlib inline
import random
import torch
from d2l import torch as d2l

def synthetic_data(w, b, num_examples):
    x = torch.normal(0, 1, (num_examples, len(w)))
    y = torch.matmul(x, w) + b
    print('x:', x)
    print('y:', y)
    y += torch.normal(0, 0.01, y.shape)  # 噪声
    return x, y.reshape((-1 , 1))

true_w = torch.tensor([2.])
true_b = 4.2
print(f'true_w: {true_w}, true_b: {true_b}')

features, labels = synthetic_data(true_w, true_b, 10)

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]

batch_size = 10
for x, y in data_iter(batch_size, features, labels):
    print(f'x: {x}, \ny: {y}')
    
w = torch.normal(0, 0.01, size = (1,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
w, b

3.1 测试更新

print(f'true_w: {true_w}, true_b: {true_b}')

def squared_loss(y_hat, y):
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

def linreg(x, w, b):
    return torch.matmul(x, w) + b

def sgd(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            print('param:', param, 'param.grad:', param.grad)
#             param -= lr * param.grad / batch_size
#             param.grad.zero_()

lr = 0.03
num_epochs = 2
for epoch in range(num_epochs):
    for x, y in data_iter(batch_size, features, labels):
        l = squared_loss(linreg(x, w, b), y)   # 计算总损失
        print(f'\nepoch: {epoch},w:', w, 'b:', b)  # l:', l, '\n
        l.sum().backward()  # 计算更新梯度
        sgd([w, b], lr, batch_size)

使用 l.sum().backward()  # 计算更新梯度:

不使用更新时:

print(f'true_w: {true_w}, true_b: {true_b}')

def squared_loss(y_hat, y):
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

def linreg(x, w, b):
    return torch.matmul(x, w) + b

def sgd(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            print('param:', param, 'param.grad:', param.grad)
#             param -= lr * param.grad / batch_size
#             param.grad.zero_()

lr = 0.03
num_epochs = 2
for epoch in range(num_epochs):
    for x, y in data_iter(batch_size, features, labels):
        l = squared_loss(linreg(x, w, b), y)   # 计算总损失
        print(f'\nepoch: {epoch},w:', w, 'b:', b)  # l:', l, '\n
        # l.sum().backward()  # 计算更新梯度
        sgd([w, b], lr, batch_size)
        
#     break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1452637.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Mysql第二关之存储引擎

简介 所有关于Mysql数据库优化的介绍仿佛都有存储引擎的身影。本文介绍Mysql常用的有MyISAM存储引擎和Innodb存储引擎,还有常见的索引。 Mysql有两种常见的存储引擎,MyISAM和Innodb,它们各有优劣,经过多次优化和迭代,…

【STM32 CubeMX】SPI HAL库编程

文章目录 前言一、CubeMX配置SPI Flash二、SPI HAL编程2.1 查询方式函数2.2 使用中断方式2.3 DMA方式 总结 前言 STM32 CubeMX 是一款由 STMicroelectronics 提供的图形化配置工具,用于生成 STM32 微控制器的初始化代码和项目框架。在 STM32 开发中,使用…

JDBC查询操作

目录 加载驱动获取连接创建会话发送SQL处理结果关闭资源测试 加载驱动 // 加载驱动Class.forName("com.mysql.cj.jdbc.Driver");获取连接 // 获取连接String url "jdbc:mysql://127.0.0.1:3306/book";String username "root" …

2024全新领域,适合新手发展的渠道,年后不愁资金问题!

我是电商珠珠 如今年已经过完了,不少人还在迷茫自己开工后要做些什么,部分人还在想着去做一些不用吃力就能赚钱的工作,或是一份能兼顾自己日常生活的兼职。 其实,任何赚钱的工作要么动脑要么费力。 费力的工作有很多&#xff0…

敦煌网怎么提升流量的?如何进行自养号测评提升转化率?

敦煌网作为中国领先的跨境电商平台,对于商家而言,提升流量是增加曝光和销售的重要手段。以下将介绍敦煌网提升流量的几种方法。 一、敦煌网怎么提升流量的? 首先,通过合理的商品定位和市场调研,选择有潜力和竞争优势的商品进行…

android获取sha1

1.cmd在控制台获取 切换到Android Studio\jre\bin目录下执行keytool -list -v -keystore 签名文件路径例如: 2.也可以在android studio中获取 在Terminal中输入命令:keytool -list -v -keystore 签名文件路径获取 获取到的sha1如下:

linux系统zabbix工具监控web页面

web页面监控 内建key介绍浏览器配置浏览器页面查看方式 监控指定的站点的资源下载速度,及页面响应时间,还有响应代码; web Scenario: web场景(站点)web page :web页面,一个场景有多…

【LeetCode: 103. 二叉树的锯齿形层序遍历 + BFS】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

OpenAl 视频生成模型 —— Sora技术报告解读

这里是陌小北,一个正在研究硅基生命的碳基生命。正在努力成为写代码的里面背诗最多的,背诗的里面最会写段子的,写段子的里面代码写得最好的…厨子。 写在前面 早上醒来,就看到OpenAl推出的视频模型Sora炸锅了,感觉所…

WordPress站点成功升级后的介绍页地址是什么?

我们一般在WordPress站点后台 >> 仪表盘 >> 更新中成功升级WordPress的话,最后打开的就是升级之后的版本介绍页。比如boke112百科前两天升级到WordPress 6.4.2后显示的介绍页如下图所示: 该介绍除了介绍当前版本修复了多少个问题及修补了多少…

BUUCTF第十九、二十题解题思路

目录 第十九题rome 第二十题rsa 第十九题rome 解压、查壳。 无壳,用32位IDA打开,检索字符串,找到一个字符串“You are correct!”,与flag相关,对其交叉引用找到函数,查看伪代码。 int func() {int resul…

二叉树前序中序后序遍历(非递归)

大家好,又和大家见面啦!今天我们一起去看一下二叉树的前序中序后序的遍历,相信这个对大家来说是信手拈来,但是,今天我们并不是使用常见的递归方式来解题,我们采用迭代方式解答。我们先看第一道前序遍历 1…

LabVIEW开发DUP实时监控系统

LabVIEW开发DUP实时监控系统 该项目采用虚拟仪器设计理念,以LabVIEW作为核心技术平台,开发了一套磁控溅射过程的实时监控系统。实现过程中关键参数的全面数据采集与处理,建立完整的历史数据库,以支持涂层技术的改进和系统向模糊控…

英文论文(sci)解读复现【NO.20】TPH-YOLOv5++:增强捕获无人机的目标检测跨层不对称变压器的场景

此前出了目标检测算法改进专栏,但是对于应用于什么场景,需要什么改进方法对应与自己的应用场景有效果,并且多少改进点能发什么水平的文章,为解决大家的困惑,此系列文章旨在给大家解读发表高水平学术期刊中的 SCI论文&a…

输入捕获模式PWM输入模式(PWMI)

一、概念介绍 输出比较: 比较电路输入的CNT、CCR大小关系 ,在通道引脚输出高低电平 二、频率知识、测量方法补充 N/fc得到标准频率的时长,也就是待测频率的周期 测频法代码实现:修改对射式红外传感器计次(上升沿计…

【Linux】管道文件 打包压缩 文本编辑器nano 进度条

目录 什么是管道文件? 打包和压缩 文本编辑器 nano的安装 nano的使用 退出nano编辑,ctrlx 普通用户无法sudo,该怎么解决 Linux小程序-进度条 预备知识 1.回车换行 2.缓冲区 准备工作 代码实现 1.processBar.h代码编写 2.main.c代…

优秀的电机驱动MCU:MM32SPIN360C

DC-DC电源布局注意点: 电源模块布局布线可提前下载芯片的datasheet(数据表),按照推荐的布局和布线进行设计。 1) 芯片电源接近原则: 对于为芯片提供电压的开关电源,应确保它尽量靠近芯片放置。这样可以避…

一周学会Django5 Python Web开发-项目配置settings.py文件-资源文件配置

锋哥原创的Python Web开发 Django5视频教程: 2024版 Django5 Python web开发 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili2024版 Django5 Python web开发 视频教程(无废话版) 玩命更新中~共计17条视频,包括:2024版 Django5 Python we…

开关电源电路主要元器件基础知识详解

在学习电子电路过程中,电源我们无法绕开的一个重要部分,很多时候,故障就出现在电源部分,特别是开关电源。开关电源电路主要是由熔断器、热敏电阻器、互感滤波器、桥式整流电路、滤波电容器、开关振荡集成电路、开关变压器、光耦合…

使用 RAG 创建 LLM 应用程序

如果您考虑为您的文件或网站制作一个能够回应您的个性化机器人,那么您来对地方了。我可以帮助您使用Langchain和RAG策略来创建这样一个机器人。 了解ChatGPT的局限性和LLMs ChatGPT和其他大型语言模型(LLMs)经过广泛训练,以理解…