线性回归网络

news2025/1/10 23:32:25

  李沐大神的《动手学深度学习》,是我入门机器学习的首课,因此在这里记录一下学习的过程。

线性回归的从零开始实现

  线性回归是理解机器学习的基础,它经常用来表示输入和输出之间的关系。
  线性回归基于几个简单的假设: 首先,假设自变量X和因变量y之间的关系是线性的, 即y可以表示为X中元素的加权和,这里通常允许包含观测值的一些噪声。

  下面基于房屋价格price和房屋的面积area与年龄age之间的关系来构造出了一个线性模型
在这里插入图片描述
   w a r e a w_{area} warea w a g e w_{age} wage称为权重(weight),权重决定了每个特征对我们预测值的影响。 b b b称为偏置(bias)、偏移量(offset)或截距(intercept)。 偏置是指当所有特征都取值为0时,预测值应该为多少。

  基于如上的线性模型从零开始实现一个线性回归模型的示例代码
代码如下:

import random
import torch
from d2l import torch as d2l

# 生成合成数据集
def synthetic_data(w, b, num_examples):  #@save
    """生成y=Xw+b+噪声"""
    # 生成均值为0,标准差为1的随机数,形状为num_examples行len(w)列
    X = torch.normal(0, 1, (num_examples, len(w)))
    # matmul为两个张量的矩阵乘积
    y = torch.matmul(X, w) + b
    # 加上随机噪音
    y += torch.normal(0, 0.01, y.shape)
    # y作为列向量返回,x
    return X, y.reshape((-1, 1))

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
print('查看第一个特征和标签:')
print('features:', features[0],'\nlabel:', labels[0])

d2l.set_figsize()
d2l.plt.scatter(features[:, 1].detach().numpy(), labels.detach().numpy(), 1)
d2l.plt.show()

# 读取数据集
def data_iter(batch_size, features, labels):
    num_examples = len(features)
    # 生成每个样本的下标
    indices = list(range(num_examples))
    # 这些样本是随机读取的,没有特定的顺序(将下标打乱)
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        # 拿出batch_size个下标
        batch_indices = torch.tensor(
            indices[i: min(i + batch_size, num_examples)])
        # 拿出batch_size个随机的特征和标签
        yield features[batch_indices], labels[batch_indices]

batch_size = 10
print('查看1次拿出的10个样本标签:')
for X, y in data_iter(batch_size, features, labels):
    print(X, '\n', y)
    break

# 生成均值为0,标准差为0.01的随机数,形状为2行1列,需要计算梯度
w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
# 返回一个全为标量 0 的张量,形状由可变参数sizes 定义,需要计算梯度
b = torch.zeros(1, requires_grad=True)

# 定义模型
def linreg(X, w, b):  #@save
    """线性回归模型"""
    return torch.matmul(X, w) + b

# 定义损失函数
# y_hat为预测值,y为真实值
def squared_loss(y_hat, y):  #@save
    """均方损失"""
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

# 定义优化算法
def sgd(params, lr, batch_size):  #@save
    """小批量随机梯度下降"""
    # 更新的时候不需要反向传播(进行梯度计算)
    with torch.no_grad():
        for param in params:
            # 此处/batch_size是因为下面计算了l.sum()
            param -= lr * param.grad / batch_size
            # 将梯度设为0
            param.grad.zero_()

# 学习率
lr = 0.03
# 数据扫描次数
num_epochs = 3
# 模型
net = linreg
# 损失函数
loss = squared_loss

for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y)  # X和y的小批量损失
        # 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,
        # 并以此计算关于[w,b]的梯度
        l.sum().backward()
        sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数
    # 数据扫完一遍后,评价一下模型进度
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print('数据扫完一遍后,模型进度:')
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')

print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')

生成的随机数据如下:
在这里插入图片描述
代码运行结果如下:
在这里插入图片描述

线性回归的简洁实现

  简洁实现相对来说只是可以使用一些封装好的函数,底层逻辑还是一样的。
代码如下:

import torch
from torch.utils import data
from d2l import torch as d2l


# 生成数据集
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

# 读取数据集
# 该函数使用 PyTorch 的 data.TensorDataset 类将输入数据和标签数据封装为一个矩阵。然后使用 PyTorch 的 data.DataLoader 类将 TensorDataset 对象作为数据源,并设置批次大小和是否打乱数据。
def load_array(data_arrays, batch_size, is_train=True):  #@save
    """构造一个PyTorch数据迭代器"""
    # 将数据和标签封装为一个 TensorDataset 对象
    dataset = data.TensorDataset(*data_arrays)
    # 返回一个 DataLoader 对象,使用 DataLoader 对象返回一个数据加载器,可以通过它对数据进行迭代。最后返回一个数据迭代器。
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

batch_size = 10
data_iter = load_array((features, labels), batch_size)

print('查看第一个特征和标签:')
# 使用iter构造Python迭代器,并使用next从迭代器中获取第一项
print(next(iter(data_iter)))

# nn是神经网络的缩写
from torch import nn
# 定义模型
# nn.Linear(2, 1) 表示创建一个线性层,输入特征维度为 2,输出特征维度为 1。输入维度为 2,意味着在训练模型时需要提供一个 2 维的特征向量作为输入。它将输入的 2 维特征通过线性变换映射到一个 1 维输出值。
net = nn.Sequential(nn.Linear(2, 1))

# 初始化模型参数
# 通过net[0]选择网络中的第一个图层, 然后使用weight.data和bias.data方法访问参数
# weight.data 表示该层的权重参数,bias.data 表示该层的偏置参数。
# 使用 normal_() 方法对权重参数进行随机初始化,参数分布服从均值为 0、标准差为 0.01 的正态分布。
# 使用 fill_() 方法对偏置参数进行初始化,将偏置值设置为 0。
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)

# 定义损失函数
# 计算均方误差使用的是MSELoss类,也称为平方L2范数。 默认情况下,它返回所有样本损失的平均值。
loss = nn.MSELoss()

# 定义优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.03)

# 训练
num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        # 计算平均损失值
        l = loss(net(X) ,y)
        # 清零模型的梯度,以确保每次迭代前梯度的正确计算
        trainer.zero_grad()
        # 计算损失函数关于模型参数的梯度
        l.backward()
        # 更新模型的参数
        trainer.step()
    # 计算整个训练集的损失值l,并打印出当前轮数和对应的损失值
    l = loss(net(features), labels)
    print('数据扫完一遍后,模型进度:')
    print(f'epoch {epoch + 1}, loss {l:f}')

w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)

代码运行结果如下:
在这里插入图片描述

总结:
  通过这个代码也算了解了线性回归模型算法的基本步骤,目前这个代码是懂了,但是还无法真正自己实现一个,先慢慢理解,多回顾。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1017413.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【计算机视觉】Vision and Language Pre-Trained Models算法介绍合集(一)

文章目录 一、ALIGN二、Contrastive Language-Image Pre-training(CLIP)三、Learning Cross-Modality Encoder Representations from Transformers(LXMERT)四、BLIP: Bootstrapping Language-Image Pre-training五、Vision-and-La…

Json-Jackson和FastJson

狂神: 测试Jackson 纯Java解决日期格式化 设置ObjectMapper FastJson: 知乎:Jackson使用指南 1、常见配置 方式一:yml配置 spring.jackson.date-format指定日期格式,比如yyyy-MM-dd HH:mm:ss,或者具体的…

机器学习 day35(决策树)

决策树 上图的数据集是一个特征值X采用分类值,即只取几个离散值,同时也是一个二元分类任务,即标签Y只有两个值 上图为之前数据集对应的决策树,最顶层的节点称为根节点,椭圆形节点称为决策节点,矩形节点称…

ffplay源码解析-FrameQueue队列

帧队列架构位置 结构体源码 FrameQueue结构体 /* 这是一个循环队列,windex是指其中的首元素,rindex是指其中的尾部元素. */ typedef struct FrameQueue {Frame queue[FRAME_QUEUE_SIZE]; // FRAME_QUEUE_SIZE 最大size, 数字太大时会占用大量的…

DPDK环境搭建

(1)虚拟环境:VMware Workstation 16 Pro 网上随便下载一个也行 (2)操作系统:ubuntu-22.04-beta-desktop-amd64.iso 下载地址:oldubuntu-releases-releases-22.04安装包下载_开源镜像站-阿里云…

Thymeleaf语法详解

目录 一、Thymeleaf介绍 (1)依赖 (2)视图 (3)控制层 二、变量输出 三、操作字符串 四、操作时间 五、条件判断 六、遍历集合 (1)迭代遍历 (2)将遍…

Java————数组

1 、数组 数组可以看成是相同类型元素的一个集合, 在内存中是一段连续的空间。 每个空间有自己的编号,其实位置的编号为0,即数组的下标。 数组是引用类型。 1. 数组的创建 T[] 数组名 new T[N];T:表示数组中存放元素的类型 …

Kakfa - Producer机制原理与调优

Producer是Kakfa模型中生产者组件,也就是Kafka架构中数据的生产来源,虽然其整体是比较简单的组件,但依然有很多细节需要细品一番。比如Kafka的Producer实现原理是什么,怎么发送的消息?IO通讯模型是什么?在实…

对Docker的认识和总结

Docker简介 Docker 是一个开源的应用容器引擎,让开发者可以打包他们的应用以及依赖包到一个可移植的镜像中,然后发布到任何流行的 Linux或Windows操作系统的机器上,也可以实现虚拟化。容器是完全使用沙箱机制,相互之间不会有任何接…

数据结构入门 — 二叉树的概念、性质及结构

本文属于数据结构专栏文章,适合数据结构入门者学习,涵盖数据结构基础的知识和内容体系,文章在介绍数据结构时会配合上动图演示,方便初学者在学习数据结构时理解和学习,了解数据结构系列专栏点击下方链接。 博客主页&am…

学习记忆——英语——字母编码

字母编码表 A:苹果 ; B:一支笔或者小男孩boy ; C:月亮或者镰刀 ; D:笛子或者弟弟或者狗dog ; E:大白鹅 ; F:斧头 ; G:鸽子…

Python:安装Flask web框架hello world示例

安装easy_install pip install distribute 安装pip easy_install pip 安装 virtualenv pip install virtualenv 激活Flask pip install Flask 创建web页面demo.py from flask import Flask app Flask(__name__)app.route(/) def hello_world():return Hello World! 2023if _…

Spring注解家族介绍: @RequestMapping

前言: 今天我们来介绍RequestMapping这个注解,这个注解的内容相对来讲比较少,篇幅会比较短。 目录 前言: RequestMapping 应用场景: 总结: RequestMapping RequestMapping 是一个用于映射 HTTP 请求…

[Linux打怪升级之路]-缓冲区

前言 作者:小蜗牛向前冲 名言:我可以接受失败,但我不能接受放弃 如果觉的博主的文章还不错的话,还请点赞,收藏,关注👀支持博主。如果发现有问题的地方欢迎❀大家在评论区指正 本期学习目标&…

SpringCloud Ribbon--负载均衡 原理及应用实例

😀前言 本篇博文是关于SpringCloud Ribbon的基本介绍,希望你能够喜欢 🏠个人主页:晨犀主页 🧑个人简介:大家好,我是晨犀,希望我的文章可以帮助到大家,您的满意是我的动力…

深入理解线程安全

引言: 在多线程编程中,线程安全是一个至关重要的概念。线程安全可能到导致数据不一致,应用程序崩溃和其他不可预测的后果。本文将深入探讨线程安全问题的根本原因,并通过Java代码示例演示如何解决这些问题。 线程安全的根本原因 …

element plus Infinite Scroll 无限滚动

欢迎关注我的公众号:夜说猫,让一个贫穷的程序员不靠打代码也能吃饭~ element plus官网中,Infinite Scroll示例使用的是数字,在实际项目运用中,我们更多的是使用json数组进行渲染,所以我们改写v-infinite-sc…

Visual Studio2019报错

1- Visual Studio2019报错 错误 MSB8036 找不到 Windows SDK 版本 10.0.19041.0的解决方法 小伙伴们在更新到Visual Studio2019后编译项目时可能遇到过这个错误:“ 错误 MSB8036 找不到 Windows SDK 版本 10.0.19041.0的解决方法”,但是我们明明安装了该…

网络安全攻防对抗之隐藏通信隧道技术整理

完成内网信息收集工作后,渗透测试人员需要判断流量是否出得去、进得来。隐藏通信隧道技术常用于在访问受限的网络环境中追踪数据流向和在非受信任的网络中实现安全的数据传输。 一、隐藏通信隧道基础知识 (一)隐藏通信隧道概述 一般的网络通…

Python图像融合处理和 ROI 区域绘制基础

文章目录 一、图像融合二、图像 ROI 区域定位三、图像属性3.1 shape3.2 size3.3 dtype四、图像通道分离及合并4.1、split()函数4.2 merge()函数五、图像类型转换一、图像融合 图像融合通常是指多张图像的信息进行融合,从而获得信息更丰富的结果,能够帮助人们观察或计算机处理…