深度学习 线性神经网络(线性回归 从零开始实现)

news2025/1/16 20:59:46

介绍:

在线性神经网络中,线性回归是一种常见的任务,用于预测一个连续的数值输出。其目标是根据输入特征来拟合一个线性函数,使得预测值与真实值之间的误差最小化。

线性回归的数学表达式为:
y = w1x1 + w2x2 + ... + wnxn + b

其中,y表示预测的输出值,x1, x2, ..., xn表示输入特征,w1, w2, ..., wn表示特征的权重,b表示偏置项。

训练线性回归模型的目标是找到最优的权重和偏置项,使得模型预测的输出与真实值之间的平方差(即损失函数)最小化。这一最优化问题可以通过梯度下降等优化算法来解决。

线性回归在深度学习中也被广泛应用,特别是在浅层神经网络中。在深度学习中,通过将多个线性回归模型组合在一起,可以构建更复杂的神经网络结构,以解决更复杂的问题。

 手动生成数据集:

%matplotlib inline
import torch
from d2l import torch as d2l
import random

#"""生成y=Xw+b+噪声"""
def synthetic_data(w, b, num_examples):  #生成num_examples个样本
    X = d2l.normal(0, 1, (num_examples, len(w)))#随机x,长度为特征个数,权重个数
    y = d2l.matmul(X, w) + b#y的函数
    y += d2l.normal(0, 0.01, y.shape)#加上0~0.001的随机噪音
    return X, d2l.reshape(y, (-1, 1))#返回

true_w = d2l.tensor([2, -3.4])#初始化真实w
true_b = 4.2#初始化真实b

features, labels = synthetic_data(true_w, true_b, 1000)#随机一些数据
print(features)
print(labels)

显示数据集:

print('features:', features[0],'\nlabel:', labels[0])

'''
features: tensor([ 2.1714, -0.6891]) 
label: tensor([10.8673])
'''

d2l.set_figsize()
d2l.plt.scatter(d2l.numpy(features[:, 1]), d2l.numpy(labels), 1);

读取小批量数据集:

#每次抽取一批量样本
def data_iter(batch_size, features, labels):#步长、特征、标签
    num_examples = len(features)#特征个数
    indices = list(range(num_examples))
    
    random.shuffle(indices)# 这些样本是随机读取的,没有特定的顺序,打乱顺序
    for i in range(0, num_examples, batch_size):#随机访问,步长为batch_size
        batch_indices = d2l.tensor(
            indices[i: min(i + batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]
        

定义模型:

#定义模型
def linreg(X, w, b):  
    """线性回归模型"""
    return d2l.matmul(X, w) + b

定义损失函数:

#定义损失和函数
def squared_loss(y_hat, y):  #@save
    """均方损失"""
    return (y_hat - d2l.reshape(y, y_hat.shape)) ** 2 / 2

定义优化算法(小批量随机梯度下降):

#定义优化算法  """小批量随机梯度下降"""
def sgd(params, lr, batch_size):  #参数、lr学习率、
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

模型训练:

#训练
lr = 0.03#学习率
num_epochs = 3#数据扫三遍
net = linreg#模型
loss = squared_loss#损失函数
#初始化模型参数
w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)#权重
b = torch.zeros(1, requires_grad=True)#b全赋为0


for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):#拿出一批量x,y
        l = loss(net(X, w, b), y)  # X和y的小批量损失,实际的和预测的
        
        # 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,
        # 并以此计算关于[w,b]的梯度
        l.sum().backward()
        sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数
        
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
'''
epoch 1, loss 0.037302
epoch 2, loss 0.000140
epoch 3, loss 0.000048
'''


print(f'w的估计误差: {true_w - d2l.reshape(w, true_w.shape)}')
print(f'b的估计误差: {true_b - b}')
'''
w的估计误差: tensor([0.0006, 0.0001], grad_fn=<SubBackward0>)
b的估计误差: tensor([-0.0003], grad_fn=<RsubBackward1>)
'''

print(w)
'''
tensor([[ 1.9994],
        [-3.4001]], requires_grad=True)
'''

print(b)
'''
tensor([4.2003], requires_grad=True)
'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1539280.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

openGauss学习笔记-249 openGauss性能调优-使用Plan Hint进行调优-Join顺序的Hint

文章目录 openGauss学习笔记-249 openGauss性能调优-使用Plan Hint进行调优-Join顺序的Hint249.1 功能描述249.2 语法格式249.3 参数说明249.4 示例 openGauss学习笔记-249 openGauss性能调优-使用Plan Hint进行调优-Join顺序的Hint 249.1 功能描述 指明join的顺序&#xff0…

内存卡的照片怎么突然就找不到了,内存卡照片突然找不到如何恢复

最近,我遇到了一个令人困惑的问题,就是我的内存卡中的照片突然间找不到了。作为一个热爱摄影的人,我经常使用内存卡来存储我的珍贵照片。然而,最近我发现,无论我如何搜索和浏览,这些照片似乎就像消失了一样。内存卡照片突然找不到如何恢复?虽然挺沮丧的,但幸好遇上了以…

五分钟快速搭建个人游戏网站(1Panel)

五分钟快速搭建个人游戏网站&#xff08;1Panel&#xff09; 环境要求&#xff1a;主流 Linux 发行版本&#xff08;基于 Debian / RedHat&#xff0c;包括国产操作系统&#xff09;&#xff1b; 如果是Windows OS的可以通过WSL来实现安装。 1 介绍 1Panel 是一个基于 Web 的 L…

OpenCV+OpenCV-Contrib源码编译

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、OpenCV是什么&#xff1f;二、OpenCV 源码编译1.前期准备1.1 源码下载1.2 cmake安装1.3 vscode 安装1.4 git 安装1.5 mingw安装 2.源码编译2.1 打开cmake2.…

[ESP32]:基于HTTP实现百度AI识图

[ESP32]&#xff1a;基于HTTP实现百度AI识图 测试环境&#xff1a; esp32-s3esp idf 5.1 首先&#xff0c;先配置sdk&#xff0c;可以写入到sdkconfig.defaults CONFIG_IDF_TARGET"esp32s3" CONFIG_IDF_TARGET_ESP32S3yCONFIG_PARTITION_TABLE_CUSTOMy CONFIG_PA…

vue.config.js配置项

vue.config.js配置项 vue-cli3 脚手架搭建完成后&#xff0c;项目目录中没有 vue.config.js 文件&#xff0c;需要手动创建 创建vue.config.js vue.config.js(相当于之前的webpack.config.js) 是一个可选的配置文件&#xff0c;如果项目的 (和 package.json 同级的) 根目录中存…

数组不初始化带来的问题及解决、动态分配

C中数组不初始化会输出什么结果 在C中&#xff0c;如果你声明了一个数组但没有对其进行初始化&#xff0c;数组的元素将具有未定义的值。这意味着数组元素的值是不确定的&#xff0c;可能是垃圾值。 当你访问未初始化的数组元素时&#xff0c;可能会得到以下结果&#xff1a;…

力扣---全排列---回溯

思路&#xff1a; 递归做法&#xff0c;一般会有visit数组来判断第 i 位是否被考虑了。我们先考虑第0位&#xff0c;再考虑第1位&#xff0c;再考虑第2位...dfs函数中还是老套路&#xff0c;先判定特殊条件&#xff0c;再从当下的角度&#xff08;决定第 j 位是哪个元素&#x…

Docker 【通过Dockerfile构建镜像】【docker容器与镜像的关系】

文章目录 前言一、前期的准备工作二、上手构建一个简单的镜像三、DcokerFile1 指令总览2 指令详情 四、Dockerfile文件规范五、docker运行build时发生了什么?六、调试手段1. 修改镜像打包后&#xff0c;如何验证新内容已更新至镜像 七、Dockerfile优化方案 前言 docker构建镜…

优化选址问题 | 基于鹈鹕算法求解基站选址问题含Matlab源码

目录 问题代码问题 鹈鹕算法(Pelican Optimization Algorithm, POA)是一种相对较新的启发式优化算法,模拟了鹈鹕鸟觅食的行为。这种算法通常用于解决复杂的优化问题,如函数优化、路径规划、调度问题等。基站选址问题通常是一个复杂的优化问题,需要考虑覆盖范围、干扰、成…

AI修复老照片的一些参数设置

很久没更新CSDN文章了&#xff0c;这次给粉丝带来老照片修复流程 1>用ps修图 图章工具 笔刷 画笔修复 2>高清放大 3>lineattile 重绘 4>上色 具体可参考我的B站视频。 下面是一些笔记。 best quality,masterpiece,photorealistic,8k,ultra high res,solo,ext…

Fabric Measurement

Fabric Measurement 布料测量

通信网络设计[PTA]

文章目录 题目描述思路AC代码 题目描述 输入样例1 6 10 1 2 6 1 5 10 1 6 12 2 4 5 2 6 8 2 3 3 3 4 7 4 5 9 4 6 11 5 6 16 输出样例1 YES! Total cost:31输入样例2 5 4 1 2 3 1 3 11 2 3 8 4 5 9 输出样例2 NO! 1 part:1 2 3 2 part:4 5思路 最小生成树–kruskal算法 具体做法…

RbbitMQ使用教程

RabbitMQ 是一个可靠且成熟的消息传递和流代理&#xff0c;易于部署在云环境、本地和本地计算机上。它目前被全球数百万人使用。通过确认消息传递和跨集群复制消息的功能&#xff0c;您可以使用 RabbitMQ 确保您的消息是安全的。 RabbitMQ安装 推荐使用docker安装部署&#xf…

VMware Workstation Pro 17虚拟机超级详细搭建(含redis,nacos,docker)(一)

今天从零搭建一下虚拟机的环境&#xff0c;把nacos&#xff0c;redis等微服务组件还有数据库搭建到里面&#xff0c;首先看到的是我们最开始下载VMware Workstation Pro 17 之后的样子&#xff0c;总共一起应该有三部分因为篇幅太长了 下载地址 : VMware - Delivering a Digit…

pytest全局配置+前后只固件配置

pytest全局配置前后只固件配置 通过读取pytest.ini配置文件运行通过读取pytest.ini配置文件运行无条件跳过pytest.initest_mashang.pyrun.py 有条件跳过test_mashang.py pytest框架实现的一些前后置&#xff08;固件、夹具&#xff09;处理方法一&#xff08;封装&#xff09;方…

VUE3.0(一):vue3.0简介

Vue 3 入门指南 什么是vue Vue (发音为 /vjuː/&#xff0c;类似 view) 是一款用于构建用户界面的 JavaScript 框架。它基于标准 HTML、CSS 和JavaScript 构建&#xff0c;并提供了一套声明式的、组件化的编程模型&#xff0c;帮助你高效地开发用户界面。无论是简单还是复杂的界…

[ C++ ] STL---仿函数与priority_queue

目录 仿函数 示例一&#xff1a; 示例二 : 常见的仿函数 priority_queue简介 priority_queue的常用接口 priority_queue的模拟实现 基础接口 push() 堆的向上调整算法 堆的插入 pop() 堆的向下调整算法 堆的删除 priority_queue最终实现 仿函数 仿函数&#xff…

数据结构 - 二叉树非递归遍历

文章目录 前言一、前序二、中序三、后序 前言 本文实现二叉树的前中后的非递归遍历&#xff0c;使用栈来模拟递归。 文字有点简略&#xff0c;需要看图和代码理解 树节点&#xff1a; typedef char DATA; //树节点 typedef struct Node {DATA data; //数据struct Node* left…

MySQL进阶-----Linux系统安装MySQL

目录 前言 一、准备工作 1. 准备一台Linux服务器 2. 下载Linux版MySQL安装包 3. 上传MySQL安装包 二、安装操作指令 1. 创建目录,并解压 2.安装mysql的安装包 三、开启mysql与密码修改 1.启动MySQL服务 2. 查询自动生成的root用户密码 3.修改root用户密码 四、创…