从0开始深度学习(5)——线性回归的逐步实现

news2024/9/22 3:56:30

将从零开始实现整个方法, 包括数据流水线、模型、损失函数和小批量随机梯度下降优化器,但现代的深度学习框架几乎可以自动化地进行所有这些工作,但从零开始实现可以确保我们真正知道自己在做什么。

下一章会使用框架简洁的实现线性回归

# 提前导入的库
import random
import torch
import matplotlib.pyplot as plt

1 生成数据集

我们将根据带有噪声的线性模型构造一个人造数据集。
在这里插入图片描述

def synthetic_data(w, b, num_examples):  #@save
    """生成y=Xw+b+噪声"""
    X = torch.normal(0, 1, (num_examples, len(w)))
    y = torch.matmul(X, w) + b
    y += torch.normal(0, 0.01, y.shape)
    return X, y.reshape((-1, 1))

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)# 生成1000个点


# 绘制散点图,我们可以简单看下每个点趋近于哪条线
plt.scatter(features[:, 0].numpy(), labels.numpy(), 1.0)
plt.xlabel('Feature')
plt.ylabel('Label')
plt.title('Scatter Plot of Generated Data')
plt.show()

在这里插入图片描述

2 读取数据集

每次抽取一小批量样本,并使用它们来更新我们的模型。

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))  # 创建一个包含所有样本索引的列表
    random.shuffle(indices)  # 随机打乱索引,以确保每次迭代时数据的顺序都是随机的
    
    # 遍历索引列表,步长为batch_size
    for i in range(0, num_examples, batch_size):
        # 根据当前的索引i和batch_size,计算出当前小批量的索引范围
        batch_indices = torch.tensor(
            indices[i: min(i + batch_size, num_examples)])
        # 使用当前小批量的索引从特征和标签中抽取对应的数据
        yield features[batch_indices], labels[batch_indices]

# 设置小批量的大小
batch_size = 10

for X, y in data_iter(batch_size, features, labels):
    # 打印第一个小批量的特征和标签
    print(X, '\n', y)
    break

第一个下批量的特征和标签
在这里插入图片描述
PS:在深度学习框架中实现的内置迭代器效率要高得多

3 定义模型

.matmul 函数是 PyTorch 中的一个方法,用于执行矩阵乘法。
定义一个简单的线性模型,即一个特征矩阵X和向量w进行矩阵-向量相乘后,再加上一个偏置参数b

def linreg(X, w, b):  #@save
    """线性回归模型"""
    return torch.matmul(X, w) + b

4 定义损失函数

因为需要计算损失函数的梯度,所以我们应该先定义损失函数。这里使用均方误差(MSE)

def squared_loss(y_hat, y):  #@save
    """均方损失"""
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

5 定义优化算法

线性模型有解析解,但是为了模拟其他没有解析解的模型,这里使用梯度下降,即在每一步中,使用从数据集中随机抽取的一个小批量,然后根据参数计算损失的梯度,接下来,朝着减少损失的方向更新我们的参数。每一步更新的大小由学习速率lr决定。

def sgd(params, lr, batch_size):  #@save
    """小批量随机梯度下降"""
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

6 初始化模型参数

初始化参数

import torch

# 初始化权重w,使用正态分布,均值为0,标准差为0.01,形状为(2, 1)
w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
# 初始化偏置b,值为0,形状为(1,),这里使用reshape或者直接创建一个标量
b = torch.zeros(1, requires_grad=True)

7 训练

# 设置模型参数
lr = 0.03 # 学习率
num_epochs = 3 # 迭代周期(迭代几次)
net = linreg # 线性回归模型
loss = squared_loss # 平方损失函数

for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels): # 计算当前小批量的损失
        l = loss(net(X, w, b), y)  # 使用net函数计算预测值,然后计算与真实值y的损失
        l.sum().backward() # 将损失相加(因为损失是按小批量计算的),然后对所有参数求梯度
        sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}') # 打印每个epoch的损失

运行结果
在这里插入图片描述
比较真实参数和通过训练学到的参数来评估训练的成功程度。

print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2108159.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

孩子成长的黄金期,做好这件事,给孩子培养一个“超级大脑“!

唤醒我们的大脑能力,开发大脑的无限功能,可以提高自己的学习和工作效率,帮我们实现更好的生活状态。 而孩子在6-12岁这个阶段,正是具体想象思维向抽象思维过渡的关键时期,所以这个阶段正是训练孩子逻辑思维能力的好时…

如何在本地服务器部署SeaFile自托管文件共享服务结合内网穿透打造私有云盘?

文章目录 1. 前言2. SeaFile云盘设置2.1 Owncould的安装环境设置2.2 SeaFile下载安装2.3 SeaFile的配置 3. cpolar内网穿透3.1 下载安装3.2 Cpolar注册3.3 Cpolar云端设置3.4 Cpolar本地设置 4.公网访问测试5.结语 1. 前言 本文主要为大家介绍,如何使用两个简单软件…

钢铁百科:Q420DR力学性能、Q420DR执行标准、Q420DR低温容器钢板

Q420DR钢板是一种专为低温压力容器设计的优质钢材,其材质特性、执行标准、化学成分、力学性能、交货状态、应用范围以及常用规格等方面都具有显著的特点。 一、Q420DR钢板材质 Q420DR钢板的命名方式体现了其材质特性。其中,“Q”代表屈服强度&#xff…

米壳AI:做跨境电商欧美市场必备工具--AI图片翻译!

在竞争激烈的欧美跨境电商领域,如何脱颖而出? 对于欧美市场的跨境电商从业者来说,语言和文化的差异常常是一大挑战。 但有了米壳 AI 这个强大的工具,问题便迎刃而解。是一个无需下载安装的网站,打开就能用&#xff0c…

关于cookie和session的直观讲解(二)

前言 上一章,讲解了Cookie,本章介绍Session. 概念:服务器端会话技术,在一次会话的多次请求间共享数据,将数据保存在服务器端的对象中HttpSession。 Session 基础 获取HttpSession对象: HttpSession ses…

High-Resolution Image Synthesis with Latent Diffusion Models论文学习

想要研究多模态数据生成就要研究diffusion架构。今天通过这个论文来学习一下。 在此之前先通过博文扩散模型 (Diffusion Model) 之最全详解图解-CSDN博客学习一下。 GAN就是生成对抗模型,让两个模型打擂台,一个造假一个打假。 而diffusion模型则是从原…

第十五届蓝桥杯图形化省赛题目及解析

第十五届蓝桥杯图形化省赛题目及解析 一. 单选题 1. 运行以下程序&#xff0c;角色会说( )? A、29 B、31 C、33 D、35 正确答案&#xff1a;C 答案解析&#xff1a; 重复执行直到m>n不成立&#xff0c;即重复执行直到m<n。所有当m小于或者 等于n时&…

数据结构(14)——哈希表(1)

欢迎来到博主的专栏&#xff1a;数据结构 博主ID&#xff1a;代码小豪 文章目录 哈希表的思想映射方法&#xff08;哈希函数&#xff09;除留余数法 哈希表insert闭散列负载因子扩容find和erase 哈希表的思想 在以往的线性表中&#xff0c;查找速度取决于线性表是否有序&#…

Origin 2024下载安装教程(中文版软件包) 百度网盘分享链接地址

Origin是什么软件&#xff1f; origin主要是绘图、数据分析、数据导入导出的功能。Origin 广泛应用于科学研究、工程技术、数据分析等领域&#xff0c;Origin 是一款功能强大、易于使用的科学绘图和数据分析软件&#xff0c;能够帮助你高效地处理和可视化数据&#xff0c;为你…

C程序设计——函数0

函数定义 前面说过C语言是结构化的程序设计语言&#xff0c;他把所有问题抽象为数据和对数据的操作&#xff0c;前面讲的变量、常量&#xff0c;都是数据。现在开始讲对数据操作——函数。 C语言的函数&#xff0c;定义方式如下&#xff1a; 返回值类型 函数名(参数列表) {…

论文速读|重新审视奖励设计与评估:用于强健人型机器人站立与行走控制的方法

论文地址&#xff1a;https://arxiv.org/pdf/2404.19173 这篇论文为类人机器人站立和行走&#xff08;SaW&#xff09;控制器的持续可衡量改进奠定了基础。通过引入一套定量实际基准测试方法&#xff0c;作者展示了现有控制器的优缺点&#xff0c;并通过基准测试指导新控制器的…

龙芯L2K0300开发板综合测试

CPU 查看cpu版本信息 cat /proc/cpuinfo可以看到cpu是64位的LoongsonArch架构 stress压力测试结果 RAM 久久派板载512MB DDR4-2666内存&#xff0c;查看内存信息 cat /proc/meminfo可以用memtester进行内存性能测试 memtester <size> <times>memtester测试结果…

Java 工程师转型大数据的优势——别小看自己!

时间&#xff1a;2024年09月05日 作者&#xff1a;小蒋聊技术 邮箱&#xff1a;wei_wei10163.com 微信&#xff1a;wei_wei10 音频地址: https://xima.tv/1_U3suSJ?_sonic0 希望大家帮个忙&#xff01;如果大家有工作机会&#xff0c;希望帮小蒋推荐一下&#xff0c;小蒋希…

2024国赛数学建模A题思路模型

完整的思路模型请查看文末名片

机器学习:opencv图像识别--模版匹配

目录 一、模版匹配的核心概念 1.图片模板匹配是一种用于在图像中查找特定模式或对象的技术。 2.模板图像 3.目标图像 4.滑动窗口 5.相似度度量 6.匹配位置 二、模版匹配的步骤 1.准备图像&#xff1a; 2.预处理&#xff1a; 3.匹配&#xff1a; 4.定位最佳匹配&…

【MySQL】初识MySQL—MySQL是啥,以及如何简单操作???

前言&#xff1a; &#x1f31f;&#x1f31f;本期讲解关于MySQL的简单使用和注意事项&#xff0c;希望能帮到屏幕前的你。 &#x1f308;上期博客在这里&#xff1a;http://t.csdnimg.cn/wwaqe &#x1f308;感兴趣的小伙伴看一看小编主页&#xff1a;GGBondlctrl-CSDN博客 目…

2024数学建模国赛题目A-E题

2024数学建模国赛题目A-E题已经发布 各个赛题题目如下 A题 板凳龙 闹元宵 B题 生产过程中的决策问题 C题 农作物的种植策略 D题 反潜航空深弹命中概率问题 E题 交通流量管控 Csdn在文末&#xff0c;关注云顶数模公众号&#xff0c;或者点击下方名片。

2024年高教社杯数学建模国赛赛题浅析——助攻快速选题

一图流——一张图读懂国赛 总体概述&#xff1a; A题偏几何与运动学模型&#xff0c;适合有几何与物理背景的队伍&#xff0c;数据处理复杂性中等。 B题侧重统计和优化&#xff0c;适合有运筹学和经济学背景的队伍&#xff0c;数据处理较为直接但涉及多步骤的决策优化。 C题…

新手c语言讲解及题目分享(十六)--文件系统专项练习

在我刚开始学习c语言的时候就跳过了这一章节&#xff0c;但在后面慢慢发现这一章节还是比较重要的,如果你报考了计算机二级c语言的话&#xff0c;你应该可以看到后面的三个大题有时会涉及到这章。所以说这章还是非常重要的。 目录 前言 一.打开文件 1.Fopen( )函数返回值 2&…

Keil发现Error: C9555E: Failed to check out a license

遇到这样的问题 解决办法&#xff1a; 换成这个版本 然后重新用keygen生成license