动手学深度学习(pytorch)学习记录7-线性回归的从零开始实现[学习记录]

news2024/9/23 15:24:10

注:本代码在jupyter notebook上运行
封面图片来源

1、生成数据集

%matplotlib inline
import random
import torch
from d2l import torch as d2l

构造数据集:生成一个包含1000个样本的数据集, 每个样本包含从标准正态分布中采样的2个特征。

def synthetic_data(w, b, num_examples):  #@save
    """生成y=Xw+b+噪声"""
    X = torch.normal(0, 1, (num_examples, len(w))) # [1000,2]
    y = torch.matmul(X, w) + b # [1000,2]*[2, 1]=[1000, 1]
    y += torch.normal(0, 0.01, y.shape)
    # return X, y.reshape((-1, 1)) # 不用reshape应该也行
    return X, y
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)
print('features:', features[0],'\nlabel:', labels[0])

在这里插入图片描述
可视化一下

d2l.set_figsize()
d2l.plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1);# x,y,点的大小

在这里插入图片描述
总是用沐神的d2l终归不太好,动手实现一下吧

import matplotlib.pyplot as plt  
import matplotlib


 # 设置Matplotlib的字体为支持中文的字体  
# 这里以'SimHei'为例,但具体可用的字体取决于你的系统  
matplotlib.rcParams['font.sans-serif'] = ['SimHei']  # 指定默认字体为黑体  
matplotlib.rcParams['axes.unicode_minus'] = False  # 解决保存图像时负号'-'显示为方块的问题 
 
 
  
# 准备数据  
x = features[:, 1].detach().numpy()  
y = labels.detach().numpy()  
  
# 绘制散点图  
plt.scatter(x, y, s=1)  
  
# 添加标题和坐标轴标签  
plt.title('散点图')  
plt.xlabel('features')  
plt.ylabel('y')  
  
# 显示图表(在Jupyter Notebook中通常是可选的)  
plt.show()

在这里插入图片描述

2、读取数据集

训练模型时要对数据集进行遍历,每次抽取一小批量样本,并使用它们来更新模型。 由于这个过程是训练机器学习算法的基础,所以有必要定义一个函数, 该函数能打乱数据集中的样本并以小批量方式获取数据。

在下面的代码中定义了一个data_iter函数, 该函数接收批量大小、特征矩阵和标签向量作为输入,生成大小为batch_size的小批量。 每个小批量包含一组特征和标签。

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    # 这些样本是随机读取的,没有特定的顺序
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(
            indices[i: min(i + batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]
        # yield使得这个函数成为一个生成器,可以按需生成数据批次,而不是一次性将所有数据加载到内存中。

读取一个小批量数据样本并打印

batch_size = 50

for X, y in data_iter(batch_size, features, labels):
    print(X, '\n', y)
    break

在这里插入图片描述

3、初始化模型参数

从均值为0、标准差为0.01的正态分布中采样随机数来初始化权重, 并将偏置初始化为0。

w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)

4、定义模型

def linreg(X, w, b):  #@save
    """线性回归模型"""
    return torch.matmul(X, w) + b

5、定义损失函数

def squared_loss(y_hat, y):  #@save
    """均方损失"""
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
# 6、定义优化算法
def sgd(params, lr, batch_size):  #@save
    """小批量随机梯度下降"""
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

7、训练

在每个迭代周期(epoch)中,我们使用data_iter函数遍历整个数据集, 并将训练数据集中所有样本都使用一次(假设样本数能够被批量大小整除)。 这里的迭代周期个数num_epochs和学习率lr都是超参数,分别设为3和0.03。

lr = 0.03 # 学习率
num_epochs = 3 # 训练轮数
net = linreg # 调用的上面模型
loss = squared_loss # 损失函数
for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y)  # X和y的小批量损失
        # 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,
        # 并以此计算关于[w,b]的梯度
        # print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
        l.sum().backward()
        sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数
    with torch.no_grad(): # 输出训练完一个epoch的损失
        train_l = loss(net(features, w, b), labels)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')

在这里插入图片描述
因为我们使用的是自己合成的数据集,知道真正的参数是什么。 因此,可以通过比较真实参数和通过训练学到的参数来评估训练的成功程度。 真实参数和通过训练学到的参数确实非常接近。

print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')
print(f'b的估计误差: {true_b - b}')

欢迎点击我的主页查看更多文章。
本人学习地址https://zh-v2.d2l.ai/
恳请大佬批评指正。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2041715.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【JavaEE】线程池和定时器

🔥个人主页: 中草药 🔥专栏:【Java】登神长阶 史诗般的Java成神之路 ✏️一.线程池 在Java中,线程池(Thread Pool)是一种用于管理并发线程的机制,它提供了一种创建、复用和管理一组…

【C++】一文掌握C++的四种类型转换 --- static_cast、reinterpret_cast、const_cast、dynamic_cast

当面对两个选择时,抛硬币总能奏效。 并不是因为它总能给出对的答案, 而是在你把它抛在空中的那一秒里。 你突然就知道,你希望的结果是什么了。 --- 曾小贤 《爱情公寓》--- 一文掌握C的四种类型转换 1 C中的类型2 类型转换3 四种类型转换…

一次caffeine引起的CPU飙升问题

背景 背景是上游服务接入了博主团队提供的sdk,已经长达3年,运行稳定无异常,随着最近冲业绩,流量越来越大,直至某一天,其中一个接入方(流量很大)告知CPU在慢慢上升且没有回落的迹象&…

Godot《躲避小兵》实战之创建玩家场景

项目设置完之后,我们就可以开始处理玩家控制的角色。 这里我们将玩家放在一个单独的场景当中,这样做的好处是在游戏的其他部分做出来之前,我们就可以对其进行单独测试。 节点结构 场景是一个节点树结构,因此一个场景需要有一个…

设计模式六大原则之:依赖倒置原则

1. 依赖倒置原则简介 依赖倒置原则(Dependency Inversion Principle, DIP) 是面向对象设计的核心原则之一,由罗伯特马丁(Robert C. Martin)提出,旨在降低类间的依赖度,使之更易于维护和扩展。该原则主张高层模块不应该依赖于底层模块&#x…

江科大/江协科技 STM32学习笔记P23

文章目录 DMA直接存储器存取DMA简介存储器映像DMA框图DMA基本结构存储器到存储器的数据转运ADC扫描模式和DMA配合使用流程 DMA直接存储器存取 DMA简介 DMA进行存储器到存储器的数据转运,比如Flash里的一批数据转运到SRAM里,需要软件触发,使用…

JQuery实现的时间插件源码附注释

HTML页面代码 <!DOCTYPE HTML> <html> <meta http-equiv="Content-Type" content="text/html; charset=utf-8" /> <meta content="width=device-width, initial-scale=1, maximum-scale=1,user-scalable=no;" name="…

【方案】SRM系统整体设计方案(解决方案+实现源码)

一、项目理解 二、总体解决方案概述 三、业务解决方案详述 四、端到端的采购流程管理 1. 采购计划 2. 采购寻源(招投标、询报价、竞价) 3. 合同管理 4. 订单执行与供应商协同 5. 采购分析与评估 五、支撑流程 1. 物资主数据管理 2. 供应商管理 3. 目录管理 4. 评标专家库管理 5…

grom接入Prometheus,grafana

在同级目录下分别创建 docker-compose.yml&#xff0c;与prometheus.yml 配置文件 version: 3.8services:prometheus:image: prom/prometheuscontainer_name: prometheusports:- "9090:9090" # Prometheus Web UI 端口volumes:- ./prometheus.yml:/etc/prometheus…

opencv-python实战项目八:根据颜色抠出图片中感兴趣区域

文章目录 一&#xff0c;简介二&#xff0c;实现方案三、算法实现步骤3.2 处理颜色蒙版&#xff1a;3.3 取出图片中蒙版对应区域 四&#xff0c;整体代码五&#xff0c;效果&#xff1a; 一&#xff0c;简介 本项目旨在开发一个基于OpenCV的图像处理工具&#xff0c;实现根据颜…

商贸城小程序系统开发制作方案

商贸城作为集批发、零售、展示、交流于一体的综合性商业体。通过商贸城小程序系统促进商家与消费者之间的互动&#xff0c;实现线上线下流量的无缝对接。一、用户需求分析 1、顾客需求&#xff1a; 快速查找店铺信息&#xff1b; 在线浏览商品和服务&#xff1b; 实现线上预约、…

【网站项目】SpringBoot675学生心理压力咨询评判

&#x1f64a;作者简介&#xff1a;拥有多年开发工作经验&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。&#x1f339;赠送计算机毕业设计600个选题excel文件&#xff0c;帮助大学选题。赠送开题报告模板&#xff…

Vue启动时报异常 ‘error:03000086:digital envelope routines::initialization error‘

问题描述 启动Vue项目时&#xff0c;突发报如下异常&#xff1a; opensslErrorStack: [error:03000086:digital envelope routines::initialization error,error:0308010C:digital envelope routines::unsupported],library: digital envelope routines,reason: unsupported,…

CentOS7下载与安装 即配置网卡

CentOS 7是什么? CentOS7是基于RHEL的企业级Linux操作系统&#xff0c;引入了Systemd、XFS文件系统和Docker支持。它提供了新的软件包、工具和性能调优选项&#xff0c;同时加强了系统安全和稳定性。总的来说&#xff0c;CentOS7是一个稳定、安全、长期支持的操作系统&#xf…

【wiki知识库】09.欢迎页面展示(浏览量统计)SpringBoot部分

&#x1f34a; 编程有易不绕弯&#xff0c;成长之路不孤单&#xff01; 大家好&#xff0c;我是熊哈哈&#xff0c;这个项目从我接手到现在有了两个多月的时间了吧&#xff0c;其实本来我在七月初就做完的吧&#xff0c;但是六月份的时候生病了&#xff0c;在家里休息了一个月的…

【和可被 K 整除的子数组】python刷题记录

R4-前缀和专题 class Solution:def subarraysDivByK(self, nums: List[int], k: int) -> int:ret0# 存储当前位置的上一个位置的前缀和的余数加上当前位置的值对K的余数pre_mod0dictdefaultdict(int)dict[0]1for i in range(len(nums)):pre_mod(pre_modnums[i])%k# 如果能在…

Kubernetes 基础概念介绍

1. 应用部署方式 传统部署&#xff1a;互联网早期&#xff0c;会直接将应用程序部署在物理机上 优点&#xff1a;简单&#xff0c;不需要其它技术的参与 缺点&#xff1a;不能为应用程序定义资源使用边界&#xff0c;很难合理地分配计算资源&#xff0c;而且程序之间容易产生影…

投票小程序App功能开发源码技术实现

随着移动互联网的快速发展&#xff0c;小程序作为一种轻量级应用&#xff0c;因其无需安装、即用即走的特点&#xff0c;在各类应用场景中迅速普及。投票小程序作为其中的一种&#xff0c;因其便捷性和实用性&#xff0c;广泛应用于各类活动、问卷调查及意见收集中。本文将围绕…

Linux-Shell入门-05

1.Shell的概念 1.1 什么是shell Shell脚本语言是实现Linux/UNIX系统管理及自W动化运维所必备的重要工具&#xff0c; Linux/UNIX系统的底层及基础应用软件的核心大都涉及Shell脚本的内容。Shell是一种编程语言, 它像其它编程语言如: C, Java, Python等一样也有变量/函数/运算符…

上海“创投十九条”明确政府引导基金带动作用,银行如何挖掘投贷联动增长潜力?

发展创业投资是促进科技、产业、金融良性循环的重要举措。为促进创业投资高质量发展&#xff0c;近几个月来&#xff0c;“国创投十七条”“上海创投十九条”等政策陆续发布&#xff0c;明确指出发挥政府引导基金带动作用&#xff0c;进一步加大对战略性新兴产业和未来产业的支…