线性回归从0到1实践

news2025/1/7 19:44:17

导入需要的包

from idlelib.configdialog import tracers
%matplotlib inline
import random
import torch
from d2l import torch as d2l

根据有噪声的线性模型构造一个人造数据集。我们使用线性模型参数 w = [ 2 , − 3 , 4 ] T w = [2,-3,4]^T w=[2,3,4]T、b=4.2 和噪声 ϵ \epsilon ϵ 生成数据集及标签 y = X w + b + ϵ y=Xw + b + \epsilon y=Xw+b+ϵ

def synthetic_data(w,b,num_examples):
    """生成 y = Xw + b + 噪声"""
    X = torch.normal(0,1,(num_examples,len(w)))
    y = torch.matmul(X,w) + b
    y += torch.normal(0,0.01,y.shape)
    return X,y.reshape(-1,1)

true_w = torch.tensor([2,-3.4])
true_b = 4.2
features, labels = synthetic_data(true_w,true_b,1000)

y.reshape(-1,1) 解释:

  • -1: 让库根据原始数据的大小自动推断这一维的大小,确保数据的总数保持不变。
  • 1: 强制将数组的列数设置为 1。
print("features:",features[0],'\nlabels:',labels[0])
features: tensor([-0.6632, -0.1771]) 
labels: tensor([3.4713])
# 这个不要理会
d2l.set_figsize()
d2l.plt.scatter(features[:,(1)].detach().numpy(),labels.detach().numpy(),1)

在这里插入图片描述

定义一个data_iter函数,该函数接受批量大小、特征矩阵和标签向量作为输入,生成大小为batch_size的小批量

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)
    for i in range(0,num_examples,batch_size):
        batch_indices = torch.tensor(indices[i:min(i + batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]
batch_size = 10

for X,y in data_iter(batch_size,features,labels):
    print(X,"\n",y)
    break
tensor([[-0.7705, -0.1793],
        [ 0.6317,  1.4700],
        [-0.1015, -2.5528],
        [ 2.7295,  0.4477],
        [-0.0854,  1.0438],
        [-0.9627, -0.0421],
        [-2.6444,  0.5648],
        [ 0.2786,  1.0552],
        [-1.2454, -1.7555],
        [-0.8601, -0.8605]]) 
 tensor([[ 3.2769],
        [ 0.4713],
        [12.6764],
        [ 8.1384],
        [ 0.4931],
        [ 2.4080],
        [-3.0041],
        [ 1.1742],
        [ 7.6635],
        [ 5.4063]])
# 定义初始化模型参数
w = torch.normal(0,0.01,size=(2,1),requires_grad=True)
b = torch.zeros(1,requires_grad=True)
# 定义模型
def linreg(X,w,b):
    """线性回归模型"""
    return torch.matmul(X,w) + b
# 定义损失函数
def squared_loss(y_hat,y):
    """均方损失"""
    return (y_hat - y.reshape(y_hat.shape))**2 / 2
def sgd(params,lr,batch_size):
    """小批量随机梯度下降"""
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()
# 训练过程
lr = 0.03
num_epochs = 3
net = linreg
loss = squared_loss

for epoch in range(num_epochs):
    for X,y in data_iter(batch_size,features,labels):
        l = loss(net(X,w,b),y)
        l.sum().backward()
        sgd([w,b],lr,batch_size)
    with torch.no_grad():
        train_1 = loss(net(features,w,b),labels)
        print(f'epoch {epoch+1}, loss {float(train_1.mean()):f}')
epoch 1, loss 0.046590
epoch 2, loss 0.000186
epoch 3, loss 0.000050
# 比较真实参数和通过训练学到的参数来评估训练的成功程度
print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差:{true_b - b}」')
w的估计误差: tensor([ 2.8789e-04, -6.0797e-05], grad_fn=<SubBackward0>)
b的估计误差:tensor([0.0005], grad_fn=<RsubBackward1>)」

Summary

总结一下这一个总的线性回归的一个过程

  1. 准备了一下初始的数据,在这个实验的过程中也就是真实的线性回归的 w,和 b
  2. 然后根据真实的w和b,依靠正态分布生成人造数据,用于进行模型的一个训练
  3. 为数据创造迭代起,就是给数据分成一批一批的,因为如何数据很大,不分成一批一批的话就会很耗费资源(计算梯度的时候,是最耗费资源的,这里也涉及到一个超参数,批大小batch_size)
  4. 定义初始化模型参数,这里为什么要定义成形状是(2,1)的呢,要进行矩阵的乘法进行线性回归模拟,因为数据是 x 是(1000,2)所以 w 得是 (2,1)
  5. 定义模型,就是一个线性回归嘛 y = a x + b y = ax+b y=ax+b 这里的x可以是 x = [ x 1 , x 2 , x 3 … … ] x=[x_1,x_2,x_3……] x=[x1,x2,x3……]
  6. 然后定义了损失函数
  7. 定义了优化函数,也就是随机梯度下降函数,这个函数就是寻找较好的拟合的解,通过求导达到优解。(这里会涉及一个超参数lr学习率,也就是往梯度下降的方向前进的步长,这个不能太大,也不能太小,需要合适的选择)
  8. 然后就是进行多个轮次的训练
  9. 最后得到结果,进行模型结果的评价

模型训练流程:

  1. 数据如何读取
  2. 模型的定义
  3. 参数的初始化
  4. 损失函数
  5. 训练模块

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2271914.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从摩托罗拉手机打印短信的简单方法

昨天我试图从摩托罗拉智能手机上打印短信&#xff0c;但当我通过USB将手机连接到电脑时&#xff0c;我在电脑上找不到它们。由于我的手机内存已达到限制&#xff0c;并且我想保留短信的纸质版本&#xff0c;您能帮我将短信从摩托罗拉手机导出到计算机吗&#xff1f; 如您所知&…

elementui table 表格 分页多选,保持选中状态

elementui多选时分页&#xff0c;解决选中状态无法保留选中项问题&#xff1a; 在el-table标签中加入row-key&#xff0c;row-key的值取当前数据里的唯一key在el-table-column selection 项中加入以下:reserve-selection“true” 完成后&#xff0c;将需要清空的地方 ( 如返回…

《掌握 C/C++ 动态内存管理,让编程更高效灵活》

这里写目录标题 一、回顾C/C内存分布1. 三道基础的练习题2. 内存区域划分图 二、C 语言中动态内存的管理方式&#xff08;malloc/calloc/realloc/free&#xff09;1. malloc() 和 calloc() 的区别和注意事项2. realloc() 的用法和注意事项 三、C 中的动态内存管理方式&#xff…

网络安全抓包

#知识点&#xff1a; 1、抓包技术应用意义 //有些应用或者目标是看不到的&#xff0c;这时候就要进行抓包 2、抓包技术应用对象 //app,小程序 3、抓包技术应用协议 //http&#xff0c;socket 4、抓包技术应用支持 5、封包技术应用意义 总结点&#xff1a;学会不同对象采用…

今日头条ip属地根据什么显示?不准确怎么办

在今日头条这样的社交媒体平台上&#xff0c;用户的IP属地信息对于维护网络环境的健康与秩序至关重要。然而&#xff0c;不少用户发现自己的IP属地显示与实际位置不符&#xff0c;这引发了广泛的关注和讨论。本文将深入探讨今日头条IP属地的显示依据&#xff0c;并提供解决IP属…

CSS3——3. 书写格式二

<!DOCTYPE html> <html><head><meta charset"UTF-8"><title></title></head><body><!--css书写&#xff1a;--><!--1. 属性名:属性值--><!--2.属性值是对属性的相关描述--><!--3.属性名必须是…

C# OpenCV机器视觉:双目视觉-深度估计

在一个阳光欢快得仿佛要蹦迪的日子里&#xff0c;阿强像个即将踏上神秘星际旅行的宇航员&#xff0c;雄赳赳气昂昂地坐在实验室那张堆满奇奇怪怪小玩意儿的桌子前。桌上&#xff0c;两台摄像头宛如两个严阵以待的机甲战士&#xff0c;镜头闪烁着冷峻的光&#xff0c;仿佛在向阿…

网络IP协议

IP&#xff08;Internet Protocol&#xff0c;网际协议&#xff09;是TCP/IP协议族中重要的协议&#xff0c;主要负责将数据包发送给目标主机。IP相当于OSI&#xff08;图1&#xff09;的第三层网络层。网络层的主要作用是失陷终端节点之间的通信。这种终端节点之间的通信也叫点…

springboot566健美操评分系统(论文+源码)_kaic

摘 要 健美操评分系统采用B/S架构&#xff0c;数据库是MySQL。系统的搭建与开发采用了先进的JAVA进行编写&#xff0c;使用了springboot框架。该系统从三个对象&#xff1a;由管理员、裁判员和用户来对系统进行设计构建。主要功能包括首页&#xff0c;个人中心&#xff0c;裁…

【深度学习之空洞卷积】空洞卷积和普通卷积的比较包括哪些优势?从感受野、计算复杂度方面分析。

【深度学习之空洞卷积】空洞卷积和普通卷积的比较包括哪些优势&#xff1f;从感受野、计算复杂度方面分析。 【深度学习之空洞卷积】空洞卷积和普通卷积的比较包括哪些优势&#xff1f;从感受野、计算复杂度方面分析。 文章目录 【深度学习之空洞卷积】空洞卷积和普通卷积的比…

【机器遗忘之UNSIR算法】2023年IEEE Trans期刊论文:Fast yet effective machine unlearning

1 介绍 年份&#xff1a;2023 期刊&#xff1a;IEEE Transactions on Neural Networks and Learning Systems 引用量&#xff1a;170 Tarun A K, Chundawat V S, Mandal M, et al. Fast yet effective machine unlearning[J]. IEEE Transactions on Neural Networks and Le…

VSCode 在Windows下开发时使用Cmake Tools时输出Log乱码以及CPP文件乱码的终极解决方案

在Windows11上使用VSCode开发C程序的时候&#xff0c;由于使用到了Cmake Tools插件&#xff0c;在编译运行的时候&#xff0c;会出现输出日志乱码的情况&#xff0c;那么如何解决呢&#xff1f; 这里提供了解决方案&#xff1a; 当Settings里的Cmake: Output Log Encoding里设…

程序的环境(预处理详解)

一.程序的翻译环境和执行环境 在ANSI C&#xff08;标准c&#xff09;的任何一种实现中&#xff0c;存在两个不同的环境。 计算机是能够执行二进制指令的&#xff0c;但是我们写出的c语言代码是文本信息&#xff0c;计算机不能直接理解 第1种是翻译环境&#xff0c;在这个环境…

Kafka 消费者专题

目录 消费者消费者组消费方式消费规则独立消费主题代码示例&#xff08;极简&#xff09;代码示例&#xff08;独立消费分区&#xff09; offset自动提交代码示例&#xff08;自动提交&#xff09;手动提交代码示例&#xff08;同步&#xff09;代码示例&#xff08;异步&#…

解决 :VS code右键没有go to definition选项(转到定义选项)

问题背景&#xff1a; VScode 右键没有“go to definition”选项了&#xff0c;情况如图所示&#xff1a; 问题解决办法&#xff1a; 第一步&#xff1a;先检查没有先安装C/C插件&#xff0c;没有安装就先安装下。 第二步&#xff1a; 打开VS CODE设置界面&#xff1a;文件->…

网络安全的学习与实践经验(附资料合集)

学习资源 在线学习平台&#xff1a; Hack This Site&#xff1a;提供从初学者到高级难度的挑战任务&#xff0c;适合练习各种网络安全技术。XCTF_OJ&#xff1a;由XCTF组委会开发的免费在线网络安全网站&#xff0c;提供丰富的培训材料和资源。SecurityTube&#xff1a;提供丰…

《Rust权威指南》学习笔记(五)

高级特性 1.在Rust中&#xff0c;unsafe是一种允许绕过Rust的安全性保证的机制&#xff0c;用于执行一些Rust默认情况下不允许的操作。unsafe存在的原因是&#xff1a;unsafe 允许执行某些可能被 Rust 的安全性检查阻止的操作&#xff0c;从而可以进行性能优化&#xff0c;如手…

使用R语言绘制标准的中国地图和世界地图

在日常的学习和生活中&#xff0c;有时我们常常需要制作带有国界线的地图。这个时候绘制标准的国家地图就显得很重要。目前国家标准地图服务系统向全社会公布的标准中国地图数据&#xff0c;是最权威的地图数据。 今天介绍的R包“ggmapcn”&#xff0c;就是基于最新公布的地图…

Flutter踩坑记-第三方SDK不兼容Gradle 8.0,需适配namespace

最近需要集成Flutter作为Module&#xff0c;Flutter依赖了第三方库&#xff0c;Gradle是8.0版本。 编译报错&#xff1a; 解决办法是在.android根目录下的build.gradle下新增一行代码&#xff1a; buildscript {ext.kotlin_version "1.8.22"repositories {google()…

golang 编程规范 - 项目目录结构

原文&#xff1a;https://makeoptim.com/golang/standards/project-layout 目录结构 Go 目录 cmdinternalpkgvendor 服务端应用程序目录 api Web 应用程序目录 web 通用应用程序目录 buildconfigsdeploymentsinitscriptstest 其他目录 assetsdocsexamplesgithooksthird_par…