线性回归例子

news2024/10/6 12:21:08

转自:https://www.cnblogs.com/BlairGrowing/p/15061912.html

刚开始接触深度学习和机器学习,由于是非全日制,也没有方向感,缺乏学习氛围、圈子,全靠自己业余时间瞎琢磨,犹如黑夜中的摸索着过河。

只是顺着原作者的思路捋一下,代码部分纯粹照搬原作者的源码。

希望自己有一天也能在黑夜中,摸着石头,跟着前行者的微弱光芒,在狂风暴雨中,坚定信息和祈祷,努力前行,也能过去人生中的大河。



import torch
from IPython import display
from matplotlib import pyplot as plt #matplotlib包可用于作图,用来显示生成的数据的二维图。
import numpy as np
import random


feature_size = 2

example_count = 10000

true_w = [8.88888888, 8.88888888]

true_b = 3.14159265

#生成特征,生成均值为0,方差为1 的特征矩阵
features = torch.tensor(np.random.normal(0, 1, (example_count, feature_size)), dtype=torch.float)

#输出特征矩阵的维度
print("特征矩阵的维度=",list(features.shape))

#根据线性方程得出特征对应的labels
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
#print(labels)

# 添加随机噪声
labels += torch.tensor(np.random.normal(0, 1, size=labels.size()), dtype=torch.float)
#print(labels)



def use_svg_display(): 
    # 用矢量图显示 
    display.set_matplotlib_formats('svg') 
    
def set_figsize(figsize=(10, 5)): 
    use_svg_display() 
    # 设置图的尺寸 
    plt.rcParams['figure.figsize'] = figsize 
    
#绘制散点图
set_figsize() 
plt.scatter(features[:, 1].numpy(), labels.numpy(), 1);



def data_iter(batch_size, features, labels): 
    num_examples = len(features) 
    indices = list(range(num_examples)) 
    
    #print(indices)
    
    # 样本的读取顺序是随机的 
    random.shuffle(indices)  
    for i in range(0, num_examples, batch_size): 
        # 最后一次可能不足一个batch 
        j = torch.LongTensor(indices[i: min(i + batch_size, num_examples)]) 
        yield  features.index_select(0, j), labels.index_select(0, j)
        
        

batch_sizes = 10

    
w = torch.tensor(np.random.normal(0, 1, (feature_size, 1)), dtype=torch.float32) 
b = torch.zeros(1, dtype=torch.float32)

w.requires_grad_(requires_grad=True) 
b.requires_grad_(requires_grad=True)


def lineRegression(X, w, b): 
    return torch.mm(X, w) + b


def squared_loss(y_hat, y):  
    return (y_hat - y.view(y_hat.size())) ** 2 / 2



def sgd(params, lr, batch_size): 
    for param in params: 
        param.data -= lr * param.grad / batch_size # 注意这里更改param时用的param.data
        
        
lr = 0.01

num_epochs = 10

net = lineRegression 

for epoch in range(num_epochs):       # 训练模型一共需要num_epochs个迭代周期 
    
    # 在每一个迭代周期中,会使用训练数据集中所有样本一次 
    for X, y in data_iter(batch_size, features, labels):       # x和y分别是小批量样本的特征和标签 
        l = squared_loss(net(X, w, b), y).sum()     # l是有关小批量X和y的损失 
        l.backward()     # 小批量的损失对模型参数求梯度 
        sgd([w, b], lr, batch_size)     # 使用小批量随机梯度下降迭代模型参数 
        w.grad.data.zero_()    # 梯度清零 
        b.grad.data.zero_() 
        
    train_l = squared_loss(net(features, w, b), labels) 
    print('epoch %d, loss %f, w %f b % f' % (epoch + 1, train_l.mean().item(), w.sum().mean(),b.sum().mean()))    
    


运行结果如下:

在这里插入图片描述

注意:

  1. backward函数会计算,参与本参数运算的(包括本参数在内)其他参数的梯度。
  2. 最小二乘法可以用来计算损失,它的最小值就是梯度优化的目标位置,因此,利用梯度下降算法逐步逼近该位置,就可以拟合w和b参数,同时用它来画图表示损失,从而实现回归算法。这部分是整个回归算法的核心,明白了此处,才能真正理解回归算法的本质。
  3. 从结果可以看到,最后拟合出来的w(17.769148)和b(3.177670)跟labels中实际值w1(8.88888888) + w2 (8.88888888) = 17.77777776和3.14159265两个数字的值非常接近了,这也显示了机器学习能力的强大和魅力。而且,这还是在29行添加了误差的情况下。若是删除第29行人为添加的正态分布误差,真实数据的拟合结果如下图,误差在万分之一量级。

在这里插入图片描述

  1. np.random.normal函数当均值和方差不是0和1时,容易发生nan错误。原因未知。搞不懂,pytorch这么强大的框架,为何正态分布下,均值和方差到10以上,就会发生溢出错误。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/871345.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

易服客工作室:WordPress是什么?初学者的解释

目录 什么是WordPress? WordPress可以制作什么类型的网站? 谁制作了WordPress?它已经存在多久了? 谁使用 WordPress? 白宫网站 微软 滚石乐队 为什么要使用 WordPress? WordPress 是免费且…

【不支持发行版本 5】错误解决

说明&#xff1a;启动项目报下面的错误&#xff0c;不支持发行版本 5 解决&#xff1a;在pom文件中添加下面这两行配置&#xff0c;修改成你自己安装的jdk版本 <properties><maven.compiler.source>11</maven.compiler.source><maven.compiler.target&g…

TienChin 新建业务菜单

首先是移动菜单&#xff0c;参考下图将菜单移动到下图结构&#xff1a; 我这里将系统监控&#xff0c;系统工具都移动到了系统管理下面&#xff0c;并且排了个序&#xff0c;将多级菜单放在了一起&#xff0c;这样看起来更加的清晰。 修改一下系统管理(100)与TienChin健身官网(…

Blazor:组件生命周期和刷新机制详解

文章目录 前言生命周期子组件设置事件刷新为什么传入非基础元素&#xff0c;会强制刷新 自动刷新逻辑如何解决委托事件强制刷新问题 前言 对于组件化来说&#xff0c;生命周期是必须掌握的知识&#xff0c;只有掌握了生命周期才能更好的去设置数据的变化。 Blazor 生命周期 微…

99. for循环练习题-3种方式输出0-9

【目录】 文章目录 99. for循环练习题-3种方式输出0-91. for循环和while循环的区别2. 输出 0~(n-1)的数字2.1 基础代码2.2 自定义函数代码2.3 异常处理语句代码 【正文】 99. for循环练习题-3种方式输出0-9 1. for循环和while循环的区别 for循环和while循环都用于重复执行特定…

AI Infra工具关键能力解析:数据准备、模型训练、模型部署与整合

在预训练大模型时代,我们可以从应用落地过程里提炼出标准化的工作流,AI Infra的投资机会得以演绎。传统ML时代AI模型通用性较低,项目落地停留在“手工作坊”阶段,流程难以统一规范。而大规模预训练模型统一了“从0到1”的技术路径,具备解决问题的泛化能力,能够赋能“从1到…

WiFi小工具homedale,可以切换同名WiFi节点

有一个很小众的需求&#xff0c;就是多个路由器组网时候&#xff0c;PC有时不会自动切换同名WiFi&#xff0c;homedale这个工具可以满足手动切换需求 这个界面可以看到所有节点列表&#xff0c;可以看到有很多同名的 可以选择自己想要的那个&#xff0c;比如信道/信号强度&am…

avd(emulator)设置代理以及与pc互访

默认pc127.0.0.1是还回ip&#xff0c;模拟器使用127.0.0.1指向了自己&#xff0c;模拟器使用10.0.2.2指代pc地址&#xff0c;这点在官方文档有说明可以查看,所以想要挂代理抓包就需要为模拟器设置代理为10.0.2.2 安卓模拟器设置代理 前提&#xff1a;本机开启了代理如&#xf…

nodejs+vue+elementui健康饮食美食菜谱分享网站系统

本系统采用了nodejs语言的vue框架&#xff0c;数据采用MySQL数据库进行存储。结合B/S结构进行开发设计&#xff0c;功能强大&#xff0c;界面化操作便于上手。本系统具有良好的易用性和安全性&#xff0c;系统功能齐全&#xff0c;可以满足饮食分享管理的相关工作。 语言 node.…

05 mysql innodb page

前言 最近看到了 何登成 大佬的 "深入MySQL源码 -- Step By Step" 的 pdf 呵呵 似乎是找到了一些 方向 之前对于 mysql 方面的东西, 更多的仅仅是简单的使用[业务中的各种增删改查], 以及一些面试题的背诵 这里会参照 MySQL Internals Manual 来大致的看一下 i…

06_Hudi案例实战

本文来自"黑马程序员"hudi课程 6.第六章 Hudi案例实战 6.1 案例架构 6.2 业务数据 6.2.1 消息数据格式 6.2.2 数据生成 6.3 七陌数据采集 6.3.1 Apache Flume 是什么 6.3.2 Apache Flume 运行机制 6.3.3 Apache Flume 安装部署 6.3.4 Apache Flume 入门程序 6.3.5 七…

springboot项目重启的shell命令

大家好&#xff0c;我是雄雄&#xff0c;微信公众号&#xff1a;雄雄的小课堂&#xff0c;欢迎关注。 前言 我们都知道&#xff0c;springboot项目启动的时候&#xff0c;需要如下过程&#xff1a; 查找 服务的进程id杀掉该进程启动服务 并且每一步都有对应的shell命令&…

torch.cat() stack()函数使用说明,含实例及运行结果

torch.cat和stack函数使用说明&#xff0c;含实例及运行结果 torch.cat() 函数torch.cat() 函数定义参数及功能二维数据实例解释参数dim0参数dim1参数dim-1 torch.stack() 函数torch.stack() 函数定义参数及功能二维数据实例解释参数dim0参数dim1参数dim2参数dim-1 参考博文及感…

关于Neo4j的使用及其基本命令

关于Neo4j的使用 文章目录 关于Neo4j的使用1、启动方式2、创建新节点&#xff0c;节点内有属性3、创建关系4、查询节点5、查询关系6、删除两个节点的关系7、删除节点8、删除某个标签的全部关系9、某个节点添加属性10、删除节点某个属性 1、启动方式 进入bin目录&#xff1a; …

成人自考-英语二-大纲要求及考试题型及分值详细介绍

感谢内容提供者&#xff1a;金牛区吴迪软件开发工作室 文章目录 一、大纲要求二、考试题型及分值1. 总览2. 样卷【2015年】(1) 阅读判断(2)阅读选择(3)概括段落大意(4)补全句子(5)填句补文(6)填词补文(7)完形补文(8)短文写作 一、大纲要求 二、考试题型及分值 1. 总览 2. 样卷…

视野狭窄--程序员的解决之道

为什么会发生这种情况&#xff1f; 这是我学到的最艰难的一课&#xff1a;辛勤工作和意图并不等同于商业影响力。我太专注于对给定问题的出色解决&#xff0c;而没有停下来考虑我是否在解决正确的问题。我在工程师身上投入的所有时间并没有使我们解决的问题变得更重要。你的主…

射频入门知识-混频器-1

5.4混频电路-视频_哔哩哔哩_bilibili ​​​​​​​

图·c++

数据结构&#xff1a; 邻接矩阵&#xff0c;邻接表 1.图的存储方式&#xff1a;邻接矩阵&#xff0c;邻接表 1.稀疏图和稠密图 2.无向图&#xff1a; n n n 个点&#xff0c;最多 n ( n − 1 ) / 2 n(n-1)/2 n(n−1)/2 条边&#xff0c; 当 m m m 接近 n ( n − 1 ) / 2 …

【面试题】1、总结面试题1

1、Java语言有哪些特点&#xff1f;❀ &#xff08;1&#xff09;【面向对象】Java是一种面向对象的语言&#xff0c;支持封装、继承和多态等面向对象的特性。Java特别强调类和对象的关系&#xff0c;要求所有代码都必须位于类中。和Java一样很流行的Python也是面向对象的语言…

NanoPi NEO移植LVGL8.3.5到1.69寸ST7789V屏幕

移植前准备 移植好fbtft屏幕驱动 参考链接&#xff1a;友善之臂NanoPi NEO利用fbtft驱动点亮1.69寸ST7789V2屏幕 获取源码 名称地址描述lvglhttps://github.com/lvgl/lvgl.gitlvgl-8.3.5lv_drivershttps://github.com/lvgl/lv_drivers.gitlv_drivers-6.1.1 创建工程目录 创…