手动实现线性回归例子

news2025/1/10 10:50:11

转自:https://www.cnblogs.com/BlairGrowing/p/15061912.html

刚开始接触深度学习和机器学习,由于是非全日制,也没有方向感,缺乏学习氛围、圈子,全靠自己业余时间瞎琢磨,犹如黑夜中摸索着石头过河。

本文只是顺着原作者的思路捋一下,说一下自己对代码的看法和了解,代码部分纯粹照搬原作者的源码。

希望自己也能在黑夜中,摸着石头,跟着前行者的微弱光芒,在狂风暴雨中,坚定信念和祈祷,努力前行,胜利的趟过人生之大河。





import torch
from IPython import display
from matplotlib import pyplot as plt #matplotlib包可用于作图,用来显示生成的数据的二维图。
import numpy as np
import random


feature_size = 2

example_count = 1000

true_w = [8.88888888, 8.88888888]

true_b = 3.14159265

#生成特征,生成均值为0,方差为1 的特征矩阵
features = torch.tensor(np.random.normal(0, 1, (example_count, feature_size)), dtype=torch.float)

#输出特征矩阵的维度
print("特征矩阵的维度=",list(features.shape))

#根据线性方程得出特征对应的labels
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
#print(labels)

# 添加随机噪声
labels += torch.tensor(np.random.normal(0, 1, size=labels.size()), dtype=torch.float)
#print(labels)



def use_svg_display(): 
    # 用矢量图显示 
    display.set_matplotlib_formats('svg') 
    
def set_figsize(figsize=(10, 5)): 
    use_svg_display() 
    # 设置图的尺寸 
    plt.rcParams['figure.figsize'] = figsize 
    
#绘制散点图
set_figsize() 
plt.scatter(features[:, 1].numpy(), labels.numpy(), 1);



def data_iter(batch_size, features, labels): 
    num_examples = len(features) 
    indices = list(range(num_examples)) 
    
    #print(indices)
    
    # 样本的读取顺序是随机的 
    random.shuffle(indices)  
    for i in range(0, num_examples, batch_size): 
        # 最后一次可能不足一个batch 
        j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) 
        yield  features.index_select(0, j), labels.index_select(0, j)
        
        

batch_sizes = 10

    
w = torch.tensor(np.random.normal(0, 1, (feature_size, 1)), dtype=torch.float) 
b = torch.zeros(1, dtype=torch.float64)

w.requires_grad_(requires_grad=True) 
b.requires_grad_(requires_grad=True)


def lineRegression(X, w, b): 
    return torch.mm(X, w) + b


def squared_loss(y_hat, y):  
    return (y_hat - y.view(y_hat.size())) ** 2 / 2



def sgd(params, lr, batch_size): 
    for param in params: 
        #print(param.grad);
        param.data -= lr * param.grad / batch_size # 注意这里更改param时用的param.data
        
        
lr = 0.01

num_epochs = 10

net = lineRegression 

for epoch in range(num_epochs):       # 训练模型一共需要num_epochs个迭代周期 
    
    # 在每一个迭代周期中,会使用训练数据集中所有样本一次 
    for X, y in data_iter(batch_size, features, labels):       # x和y分别是小批量样本的特征和标签 
        l = squared_loss(net(X, w, b), y).sum()     # l是有关小批量X和y的损失 
        l.backward()     # 小批量的损失对模型参数求梯度 
        sgd([w, b], lr, batch_size)     # 使用小批量随机梯度下降迭代模型参数 
        w.grad.data.zero_()    # 梯度清零 
        b.grad.data.zero_() 
        
    train_l = squared_loss(net(features, w, b), labels) 
    print('epoch %d, loss %f, w %f b % f' % (epoch + 1, train_l.mean().item(), w.sum().mean(),b.sum().mean()))
    
    


运行结果如下:

在这里插入图片描述

注意:

  1. backward函数会计算,参与本参数运算的(包括本参数在内)其他参数的梯度。
  2. 最小二乘法squared_loss函数,这里必须要着重阐述一下自己的理解,这部分是整个回归算法的核心,明白了此处,才能真正理解回归算法的本质。该函数一方面用来计算损失,但是它的本质作用,是因为它的最小值就是梯度优化的目标位置,反向求导数,利用梯度下降算法逐步逼近该位置,就可以完成对参数w和b的拟合,最后,它还是损失的图像化表示,通过该函数,最终实现了回归算法理论的闭环。
  3. 从结果可以看到,最后拟合出来的w(17.769148)和b(3.177670)跟labels中实际值w1(8.88888888) + w2 (8.88888888) = 17.77777776和3.14159265两个数字的值非常接近了,这也显示了机器学习能力的强大和魅力。而且,这还是在29行添加了误差的情况下。若是删除第29行人为添加的正态分布误差,真实数据的拟合结果如下图,误差在万分之一量级。

在这里插入图片描述

  1. np.random.normal函数当均值和方差不是0和1时,容易发生nan错误。原因未知。搞不懂,pytorch这么强大的框架,为何正态分布下,均值和方差到10以上,就会发生溢出错误。
  2. 另外,在sgd函数中,梯度计算时要除以batch_size。这是为什么呢?百度了一下,发现此文对此做了完美的解释:为什么梯度值要除以batch_size?,另外还可以关注此处的解释说明:https://github.com/ShusenTang/Dive-into-DL-PyTorch/issues/75
  3. 最后一点,加入example_count改为10000,那么loss就降为0,为何变化速率这么快呢?是正常的学习结果还是有什么问题呢?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/876502.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

探讨uniapp的数据缓存问题

异步就是不管保没保存成功,程序都会继续往下执行。同步是等保存成功了,才会执行下面的代码。使用异步,性能会更好;而使用同步,数据会更安全。 1 uni.setStorage(OBJECT) 将数据存储在本地缓存中指定的 key 中&#x…

Oracle切割字符串的方法,SQL语句完成。

Oracle用正则的方式循环切割字符串 需求:有一个这样子的 Str “‘CNJ-520-180500000001|CNJ-520-181200000001|CNJ-520-190300000001|CNJ-520-190100000001|CNJ-520-181200000002’” ,然后我需要拿到每一个单号,每一个单号都要走一遍固定的…

基于K8S环境部署Dolphinscheduler及简单应用

一、Dolphinscheduler简介 Apache DolphinScheduler 是一个分布式易扩展的可视化DAG工作流任务调度开源系统。适用于企业级场景,提供了一个可视化操作任务、工作流和全生命周期数据处理过程的解决方案。 Apache DolphinScheduler 旨在解决复杂的大数据任务依赖关系,并为应用…

SOPC之NIOS Ⅱ遇到的问题

记录NIOS Ⅱ中遇到的报错 一、NIOS II中Eclipse头文件未找到 问题:Unresolved inclusion: "system.h"等 原因:编译器无法找到头文件所在路径 解决方法: 在文件夹中找到要添加的头文件,并记录下其路径,如…

8.14 作业

1. .text .globl _start_start:mov r0,#0x9mov r1,#0xfbl loop loop:cmp r0,r1beq stopsubhi r0,r1subls r1,r0mov pc,lr stop:b stop 2.实现1-100的和 .text .globl _start_start:mov r0,#0x1bl loop loop:cmp r0,#0x64bhi stopaddls r1,r0addls r0,#0x1mov pc,lr stop:b sto…

Android app专项测试之耗电量测试

前言 耗电量指标 待机时间成关注目标 提升用户体验 通过不同的测试场景,找出app高耗电的场景并解决 01、需要的环境准备 1、python2.7(必须是2.7,3.X版本是不支持的) 2、golang语言的开发环境 3、Android SDK 此三个的环境搭建这里就不详细说了&am…

C++之map的emplace与pair插入键值对用法(一百七十四)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…

【openwrt学习笔记】dnsmasq源码阅读

目录 一、DHCP(Dynamic Host Configuration Protocol)1.1 前置知识1.2 参考链接1.3 IP地址分配代码分析rfc2131.cdhcp-common.cdhcp.c 1.4 几个小问题1.4.1 连续IP模式(sequential_ip)1.4.2 重新连接使用IP地址1.4.3 续约租期1.4.4 不同的MAC地址分配到相…

Power Automate:筛选查找列表中的项并删除

目的:筛选出列表中所有符合条件的项并删除,如果一项都没有,就发邮件通知自己 1、首先获取多个项,添加筛选条件,比如设备列为1的项,无需加引号。也可以添加筛选条件 2、接下来不能直接循环刚得到的多个项并…

NuGet包离线安装方法

在某项情况下,我们的计算机是无法直接连接外网的,这个时候就只能用离线安装的方法了。 一、直接区NUGET.org网页下载: 二、先下载nuget.exe工具,然后用这个工具下载 把下载的nuget.exe放在任意目录下,然后在此目录用…

MongoDB升级经历(4.0.23至5.0.19)

MongoDB从4.0.23至5.0.19升级经历 引子:为了解决MongoDB的两个漏洞决定把MongoDB升级至最新版本,期间也踩了不少坑,在这里分享出来供大家学习与避坑~ 1、MongoDB的两个漏洞 漏洞1:MongoDB Server 安全漏洞(CVE-2021-20330) 漏洞2…

如何在Stream流中分组统计

上面是今天碰到需求,之前就做过类似的分组统计,这个相对来说比较简单,统计的也少,序号和总预约人数这两部分交给前端了,不需要由后端统计,后端统计一下预约日期和检查项目和预约人数就行; Overridepublic List<ItemStatisticsVo> statistics(ItemStatisticsModel itemSta…

checkbox post参数接收

checkbox 定义 <div class"check-box"> <label for"ck1">batchInsert:</label><input type"checkbox" id"ck1" checkedname"ckFn" value"batchInsert" > </div> <div class&qu…

Python爬虫获取美女头像并保存本地(观山篇一)

Python爬虫获取美女头像并保存本地&#xff08;观山篇一&#xff09; 前言步骤一步骤二步骤三步骤四步骤五最终效果完整代码结言 前言 最近某短视频平台上经常刷到&#xff0c;人生四大雅事&#xff1a;“品茗、抚琴、观山、听雨”。那么今天我们就利用python观山所看到的美景给…

如何做好一名网络工程师?具备的技能有哪些?

支持属于网络工程师的工作范围的企业网络&#xff0c;此网络与支持它的铜或光纤基础架构一样性能良好。网络工程师及其布线厂区需要为支持最新网络技术做好准备。网络工程师作为任何性能问题的解决者&#xff0c;需要拥有必要的工具来确定问题所在 — 在网络中还是在其他地方。…

在众多单片机市场中STM32系列为何能脱颖而出?

回顾单片机市场&#xff0c;除了传统的51系列外&#xff0c;早些年主要有PIC、TI、Nxp、ATMEL、Freescale等厂商。然而&#xff0c;这些厂商普遍存在一些问题&#xff1a;资料难以获取&#xff0c;文档数量有限且大多是英文的&#xff0c;开发板价格昂贵&#xff0c;调试器成本…

Mac 新手10个小窍门

即便你是 Mac 新手&#xff0c;也会发现它易学好用。你可以点按程序坞上的访达&#xff0c;快速查看到所有文件&#xff1b;把你喜爱的文件夹拖入边栏&#xff1b;你可以自定义查看文件的方式&#xff0c;甚至可以按下空格键&#xff0c;就能一键预览文档&#xff1b;还能在多台…

Linux 复制进程fork

一、父进程和子进程 当前的一个进程在fork的时候可以复制当前的进程产生一个进程&#xff0c;这时产生出来的这个进程就是子进程&#xff0c;被复制的进程叫做父进程。子进程会将环境变量从父进程继承过来&#xff0c;或者说被拷贝过来。父进程也会有它的父进程&#xff0c;一…

【动态规划基础】数字三角形(IOI1994)

题目描述 数字三角形 输入输出样例 输入样例#1&#xff1a; 5 7 3 8 8 1 0 2 7 4 4 4 5 2 6 5输出样例#1&#xff1a; 30思路&#xff1a; 这题可能看到的第一眼——直接贪心然后一层一层判断呀&#xff01;&#xff01;&#xff01;不过很快又会发现&#xff0c;额___好…

小白如何轻松制作产品帮助中心页面?

产品帮助中心是每个网站/产品必不可少的页面&#xff0c;产品帮助中心页面成为了企业提供客户支持和解决方案的重要组成部分。对于初次接触建立帮助中心页面的小白来说&#xff0c;也许会感到一些困惑和无从下手。本文将为小白介绍如何轻松制作产品帮助中心页面&#xff0c;帮助…