深度学习3.3 线性回归的简洁实现

news2025/4/24 7:52:40
步骤操作作用
前向计算net(X)计算预测值 y_hat = Xw + b
损失计算loss(y_hat, y)量化预测误差,驱动参数更新
反向传播l.backward()计算参数梯度
参数更新trainer.step()根据梯度调整参数,逼近最优解
梯度清零trainer.zero_grad()防止梯度累积(必须放在 backward() 之后,step() 之前)
训练监控loss(net(features), labels)评估模型整体性能,避免过拟合或欠拟合

3.3.1 生成数据集

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

3.3.2 读取数据集

def load_array(data_arrays, batch_size, is_train=True):
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

batch_size = 10
data_iter = load_array((features, labels), batch_size)
next(iter(data_iter))

数据加载器 (DataLoader)
‌数据集封装‌:TensorDataset 将特征和标签包装为 PyTorch 数据集。‌
批量加载‌:DataLoader 按 batch_size=10 加载数据,训练时打乱数据 (shuffle=True)。

3.3.3 定义模型

from torch import nn
net = nn.Sequential(nn.Linear(2, 1))

3.3.4 初始化模型参数

net[0].weight.data.normal_(0, 0.01) # 权重初始化
net[0].bias.data.fill_(0) # 偏置初始化

3.3.5 定义损失函数

loss = nn.MSELoss() # 均方误差损失

3.3.6 定义优化算法

trainer = torch.optim.SGD(net.parameters(), lr=0.03)  # 随机梯度下降

3.3.7 训练

num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        l = loss(net(X), y)     # 前向计算损失
        trainer.zero_grad()      # 清零梯度
        l.backward()            # 反向传播
        trainer.step()          # 参数更新
    
    # 计算并输出整个训练集的损失
    l = loss(net(features), labels)
    print(f'epoch{epoch + 1}, loss{l:f}')


epoch1, loss0.000205
epoch2, loss0.000094
epoch3, loss0.000094

# 输出参数估计误差
w = net[0].weight.data
print(f'w的估计误差:{true_w - w.reshape(true_w.shape)}')
b = net[0].bias.data
print(f'b的估计误差:{true_b - b}')

w的估计误差:tensor([5.9402e-04, 4.6015e-05])
b的估计误差:tensor([0.0001])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2341289.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

APP动态交互原型实例|墨刀变量控制+条件判断教程

引言 不同行业的产品经理在绘制原型图时,拥有不同的呈现方式。对于第三方软件技术服务公司的产品经理来说,高保真动态交互原型不仅可以在开发前验证交互逻辑,还能为甲方客户带来更直观、真实的体验。 本文第三部分将分享一个实战案例&#…

色谱图QCPColorMap

一、QCPColorMap 概述 QCPColorMap 是 QCustomPlot 中用于绘制二维颜色图的类,可以将矩阵数据可视化为颜色图(热力图),支持自定义色标和插值方式。 二、主要属性 属性类型描述dataQCPColorMapData存储颜色图数据的对象interpol…

最新扣子(Coze)案例教程:飞书多维表格按条件筛选记录 + 读取分页Coze工作流,无限循环使用方法,手把手教学,完全免费教程

大家好,我是斜杠君。 👨‍💻 星球群里有同学想学习一下飞书多维表格的使用方法,关于如何通过按条件筛选飞书多维表格中的记录,以及如何使用分页解决最多一次只能读取500条的限制问题。 斜杠君今天就带大家一起搭建一…

Spring AI Alibaba-02-多轮对话记忆、持久化消息记录

Spring AI Alibaba-02-多轮对话记忆、持久化消息记录 Lison <dreamlison163.com>, v1.0.0, 2025.04.19 文章目录 Spring AI Alibaba-02-多轮对话记忆、持久化消息记录多轮对话对话持久-Redis 本次主要聚焦于多轮对话功能的实现&#xff0c;后续会逐步增加更多实用内容&…

联邦元学习实现个性化物联网的框架

随着数据安全和隐私保护相关法律法规的出台&#xff0c;需要直接在中央服务器上收集和处理数据的集中式解决方案&#xff0c;对于个性化物联网而言&#xff0c;训练各种特定领域场景的人工智能模型已变得不切实际。基于此&#xff0c;中山大学&#xff0c;南洋理工大学&#xf…

实验1 温度转换与输入输出强化

知识点&#xff1a;input()/print()、分支语句、字符串处理&#xff08;教材2.1-2.2&#xff09; 实验任务&#xff1a; 1. 实现摄氏温度与华氏温度互转&#xff08;保留两位小数&#xff09; 2. 扩展功能&#xff1a;输入错误处理&#xff08;如非数字输入提示重新输入&#x…

【AI】SpringAI 第五弹:接入千帆大模型

1. 添加依赖 <dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-starter-model-qianfan</artifactId> </dependency> 2. 编写 yml 配置文件 spring:ai:qianfan:api-key: 你的api-keysecret-key: 你的secr…

[Godot] C#2D平台游戏基础移动和进阶跳跃代码

本文章给大家分享一下如何实现基本的移动和进阶的跳跃&#xff08;跳跃缓冲、可变跳跃、土狼时间&#xff09;以及相对应的重力代码&#xff0c;大家可以根据自己的需要自行修改 实现效果 场景搭建 因为Godot不像Unity&#xff0c;一个节点只能绑定一个脚本&#xff0c;所以我…

【Unity笔记】Unity + OpenXR项目无法启动SteamVR的排查与解决全指南

图片为AI生成 一、前言 随着Unity在XR领域全面转向OpenXR标准&#xff0c;越来越多的开发者选择使用OpenXR来构建跨平台的VR应用。但在项目实际部署中发现&#xff1a;打包成的EXE程序无法正常启动SteamVR&#xff0c;或者SteamVR未能识别到该应用。本文将以“Unity OpenXR …

使用 rebase 轻松管理主干分支

前言 最近遇到一个技术团队的 dev 环境分支错乱&#xff0c;因为是多人合作大家各自提交信息&#xff0c;导致出现很多交叉合并记录&#xff0c;让对应 log 看起来非常混乱&#xff0c;难以阅读。 举例说明 假设我们有一个项目&#xff0c;最初develop分支有 3 个提交记录&a…

【愚公系列】《Python网络爬虫从入门到精通》063-项目实战电商数据侦探(主窗体的数据展示)

&#x1f31f;【技术大咖愚公搬代码&#xff1a;全栈专家的成长之路&#xff0c;你关注的宝藏博主在这里&#xff01;】&#x1f31f; &#x1f4e3;开发者圈持续输出高质量干货的"愚公精神"践行者——全网百万开发者都在追更的顶级技术博主&#xff01; &#x1f…

HttpSessionListener 的用法笔记250417

HttpSessionListener 的用法笔记250417 以下是关于 HttpSessionListener 的用法详解&#xff0c;涵盖核心方法、实现步骤、典型应用场景及注意事项&#xff0c;帮助您全面掌握会话&#xff08;Session&#xff09;生命周期的监听与管理&#xff1a; 1. 核心功能 HttpSessionLi…

火山RTC 5 转推CDN 布局合成规则

实时音视频房间&#xff0c;转推CDN&#xff0c;文档&#xff1a; 转推直播--实时音视频-火山引擎 一、转推CDN 0、前提 * 在调用该接口前&#xff0c;你需要在[控制台](https://console.volcengine.com/rtc/workplaceRTC)开启转推直播功能。<br> * 调…

Spark两种运行模式与部署

1. Spark 的运行模式 部署Spark集群就两种方式&#xff0c;单机模式与集群模式 单机模式就是为了方便开发者调试框架的运行环境。但是生产环境中&#xff0c;一般都是集群部署。 现在Spark目前支持的部署模式&#xff1a; &#xff08;1&#xff09;Local模式&#xff1a;在本地…

qt画一朵花

希望大家的生活都更加美好&#xff0c;画一朵花送给大家 效果图 void FloatingArrowPubshButton::paintEvent(QPaintEvent *event) {QPainter painter(this);painter.setRenderHints(QPainter::Antialiasing);QPen pen;pen.setColor("green");pen.setWidth(5);QBrush…

服务器上安装maven

1.安装 下载安装包 https://maven.apache.org/download.cgi 解压安装包 cd /opt/software tar -xzvf apache-maven-3.9.9-bin.tar.gz 安装目录(/opt/maven/) mv /opt/software/apache-maven-3.9.9 /opt/ 3.权限设置 把/opt/software/apache-maven-3.9.9 文件夹重命名为ma…

UOS+N 卡 + CUDA 环境下 X86 架构 DeepSeek 基于 vLLM 部署与 Dify 平台搭建指南

一、文档说明 本文档是一份关于 DeepSeek 在X86架构下通vLLM工具部署的操作指南&#xff0c;主要面向需要在UOSN卡CUDA环境中部署DeepSeek的技术人员&#xff0c;旨在指导文档使用者完成从 Python 环境升级、vLLM 库安装、模型部署到 Dify 平台搭建的全流程操作。 二、安装Pyt…

MySQL终章(8)JDBC

目录 1.前言 2.正文 2.1JDBC概念 2.2三种编码方式 2.2.1第一种 2.2.2第二种&#xff08;优化版&#xff09; 2.2.3第三种&#xff08;更优化版&#xff09; 3.小结 1.前言 哈喽大家好吖&#xff0c;今天来给大家带来Java中的JDBC的讲解&#xff0c;之前学习的都是操作…

Python 爬虫如何伪装 Referer?从随机生成到动态匹配

一、Referer 的作用与重要性 Referer 是 HTTP 请求头中的一个字段&#xff0c;用于标识请求的来源页面。它在网站的正常运行中扮演着重要角色&#xff0c;例如用于统计流量来源、防止恶意链接等。然而&#xff0c;对于爬虫来说&#xff0c;Referer 也可能成为被识别为爬虫的关…

【MySQL】表的约束(主键、唯一键、外键等约束类型详解)、表的设计

目录 1.数据库约束 1.1 约束类型 1.2 null约束 — not null 1.3 unique — 唯一约束 1.4 default — 设置默认值 1.5 primary key — 主键约束 自增主键 自增主键的局限性&#xff1a;经典面试问题&#xff08;进阶问题&#xff09; 1.6 foreign key — 外键约束 1.7…