深度学习——线性神经网络(三、线性回归的简洁实现)

news2024/11/25 13:50:50

目录

  • 3.1 生成数据集
  • 3.2 读取数据集
  • 3.3 定义模型
  • 3.4 初始化模型参数
  • 3.5 定义损失函数
  • 3.6 定义优化算法
  • 3.7 训练

  在上一节中,我们通过张量来自定义式地进行数据存储和线性代数运算,并通过自动微分来计算梯度。实际上,由于数据迭代器、损失函数、优化器和神经网络层很常用,现代深度学习框架已经为我们实现了这些组件,只需要调用即可。

3.1 生成数据集

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b,1000)
# 可以打印出来看一下
print(features,labels)

在这里插入图片描述

3.2 读取数据集

  我们可以通过调用框架中现有的API来读取数据,将features和labels作为API的参数传递,并通过数据迭代器指定batch_size,此外,布尔值is_train表示是否希望数据迭代器对象在每轮内打乱数据。

def load_array(data_arrays, batch_size, is_train=True):
    """构造一个Python数据迭代器"""
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)


batch_size = 10
data_iter = load_array((features,labels), batch_size)

  提到的data.TensorDataset(*data_arrays)中,*号的用法与函数定义中的类似,它表示TensorDataset可以接受任意数量的参数。这些参数通常是torch.Tensor对象,其中最后一个参数默认被视为标签,其余的参数被视为特征。

  使用iter函数构造Python迭代器,并使用next函数从迭代器中获取第一项。

print(next(iter(data_iter)))

  我是用pycharm写的代码,和jupyter中有些不一样,jupyter中直接写next(iter(data_iter))就可以打印出来了,pycharm中必须要加上print

在这里插入图片描述

  因为布尔值shuffle=is_train表示数据迭代器对象在每轮内打乱数据,所以next函数取出来的第一批量10项数据,并不直接是生成的数据集中的前10项数据。这点大家可以注意一下!

3.3 定义模型

  对于标准深度学习模型,我们可以使用框架已经预定义好的层,这使我们只需关注使用哪些层来构造模型,而不必关注层的实现细节。
  我们先定义一个模型变量net,它是一个Sequential类的实例。Sequential类将多个层串联在一起,当给定输入数据时,Sequential实例将数据传入第一层,然后将第一层的输出作为第二层的输入,以此类推。
  在线性神经网络中,模型只包含一个层,因此实际上不需要Sequential,但是由于以后几乎所有的模型都是多个层的,在这里使用Sequential类更方便理解“标准的流水线”。
在这里插入图片描述
  在单层网络架构中,这一单层称为“全连接层”,因为它的每个输入都通过矩阵-向量乘法得到它的每个输出。
  在pytorch中,全连接层在Linear类中定义,我们将两个参数传递到nn.Linear中,第一个参数指定输入特征的形状,即2;第二个参数指定输出特征形状,输出特征形状为单个标量,因此为1。

# nn是神经网络的缩写
from torch import nn

net = nn.Sequential(nn.Linear(2,1))

3.4 初始化模型参数

  在使用net之前,我们需要初始化模型参数,如在线性回归模型中的权重和偏置。深度学习框架通常由预定义的方法来初始化参数。
  在这里,我们指定每个权重系数应该从均值为0,标准差为0.01的正态分布中随机抽样,偏置参数将初始化为0.
  我们在构造nn.Linear时指定了输入和输出的尺寸,现在可以直接访问参数以设定它们的初始值。通过net[0]选择网络中的第一层,然后使用weight.data和bias.data方法访问函数。我们还可以使用替换方法normal_和fill_来重写参数值。

# 重写参数值之前的对比
print(net[0].weight.data)
print(net[0].bias.data)

net[0].weight.data.normal_(0,0.01)
net[0].bias.data.fill_(0)
print(net[0].weight.data)
print(net[0].bias.data)

  下面是重写参数值之前的对比在这里插入图片描述

3.5 定义损失函数

  计算均方误差使用的是MSELoss类,也称为平方 L 2 L_2 L2范数。默认情况下,它返回所有样本损失的平均值

loss = nn.MSELoss()

3.6 定义优化算法

  小批量随机梯度下降算法是一种优化神经网络的标准工具,Pytorch在optim模块中实现了该算法的许多变体。当我们实例化一个SGD实例时,我们要指定优化的参数(可以通过net.parameters()从我们的模型中获得)以及优化算法所需的超参数字典。小批量随机梯度下降只需要设置lr的值,这里设置为0.03.

trainer = torch.optim.SGD(net.parameters(), lr=0.03)

3.7 训练

  在每轮里,我们将完整遍历一次数据集(train_data),不断地从中获取一个小批量的输入和相应的标签。对于每个小批量,将执行以下步骤:

  • 通过调用net(X)生成预测并计算损失l(前向传播)
  • 通过反向传播来计算梯度
  • 通过调用优化器来更新模型参数

  为了 更好地度量训练效果,我们计算每轮后的损失,并打印出来监控训练过程。

num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        l = loss(net(X), y)
        trainer.zero_grad()
        l.backward()
        trainer.step()
    l = loss(net(features),labels)
    print(f'epoch{epoch + 1}, loss {l:f}')

在这里插入图片描述

几点注意:
l = loss(net(X), y)
loss函数中已经有了sum()操作,省略了原来实现过程中的 l.sum() 这一步骤
net(X)
net()本身就带了模型中的参数,就不需要把W,b写进去了
trainer.zero_grad()
优化器需要先把梯度清零
trainer.step()
调用step()函数进行模型更新
l = loss(net(features),labels)
模型参数更新完之后,再计算一遍均方误差

  下面比较一下生成数据集的真实参数和通过有限数据训练获得的模型参数。要访问参数,我们首先从net访问所需的层,然后读取该层的权重和偏置。如下所示,我们估计得到的参数与生成数据集的真实参数非常接近。

w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)

在这里插入图片描述

小结:

  • 我们可以使用Pytorch中的高级API更简洁地实现模型;
  • 在Pytorch中,data模块提供了数据处理工具,nn 模块定义了大量的神经网络层和常见的损失函数;
  • 我们可以通过以"_"结尾的方法将参数替换,从而自定义初始化参数。

以下是完整代码:

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
# nn是神经网络的缩写
from torch import nn

true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b,1000)
# 可以打印出来看一下
# print(features,labels)


def load_array(data_arrays, batch_size, is_train=True):
    """构造一个Python数据迭代器"""
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)


batch_size = 10
data_iter = load_array((features,labels), batch_size)
# print(next(iter(data_iter)))

net = nn.Sequential(nn.Linear(2,1))

# 重写参数值之前的对比
# print(net[0].weight.data)
# print(net[0].bias.data)

net[0].weight.data.normal_(0,0.01)
net[0].bias.data.fill_(0)
# print(net[0].weight.data)
# print(net[0].bias.data)
loss = nn.MSELoss()

trainer = torch.optim.SGD(net.parameters(), lr=0.03)

num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        l = loss(net(X), y)
        trainer.zero_grad()
        l.backward()
        trainer.step()
    l = loss(net(features),labels)
    # print(f'epoch{epoch + 1}, loss {l:f}')

w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2212137.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于深度学习的西红柿成熟度检测系统

简介: 基于深度学习技术的西红柿成熟度检测系统是一种利用人工智能算法对西红柿成熟程度进行自动识别和分类的智能系统。该系统通过采集西红柿的图像数据,运用深度学习模型对图像中的西红柿进行特征提取和分析,从而实现对西红柿成熟度的准确判…

【C】printf()与scanf()详介以及如何在VS中使用scanf(保姆级详细版)

printf() 基本用法 printf()的作用是将参数文本输出到屏幕,它名字里面的f 代表 format(格式化)&#xff0c;表示可以定制输出文本的格式。 1 #include <stdio.h>//standard input output标准输入输出操作 2 int main() 3{ 4 printf("Hello World"); 5 retu…

DVWA CSRF 漏洞实践报告

1. 漏洞简介 CSRF&#xff08;跨站请求伪造&#xff09;是一种攻击&#xff0c;使得攻击者能够以受害者的身份执行非预期的操作。在靶场DVWA中&#xff0c;我将尝试通过CSRF漏洞更改管理员密码。 2. 实验环境 DVWA版本&#xff1a;DVWA-old浏览器&#xff1a;火狐默认管理员账…

QtModel

QModelIndex index1 model->index(row,column,QModelIndex());QModelIndex index2 model->index(row.column,index2); QSqlQuery::size() 仅在使用了 QSqlQuery::exec() 后并且查询结果集的所有行都被读取时才有效。如果结果集很大或在使用游标的情况下&#xff0c;返回…

Linux 内核态,用户态,以及如何从内核态到用户态,交互方式有哪些

一、Linux 内核态&#xff0c;用户态 Linux 内核态&#xff0c;用户态&#xff0c;以及如何从内核态到用户态&#xff0c;我来说下我的理解 很多面试官&#xff0c;面试也是照搬照套&#xff0c;网上找的八股文面试题&#xff0c;面试的人也是背八股文&#xff0c;刚好背到了&…

全面讲解C++

数据类型 1.1 基本数据类型 1.1.1 整型&#xff08;Integer Types&#xff09; 整型用于表示整数值&#xff0c;分为以下几种类型&#xff1a; int&#xff1a;标准整数类型&#xff0c;通常为4字节&#xff08;32位&#xff09;。short&#xff1a;短整型&#xff0c;通常…

被装物联网系统|DW-S305系统是一套成熟系统

东识被装仓库管理系统&#xff08;智被装DW-S305&#xff09;作业管理软件系统包括收发管理、库房管理、库存统计、环境监测、预警管理、数据展示、系统管理等功能&#xff0c;主要功能如下&#xff1a; 收发管理&#xff1a;对库房收发物资进行管理&#xff0c;支持收发物单据…

通信工程学习:什么是TCP/IP(传输控制协议/互联网议)

TCP/IP&#xff1a;传输控制协议/互联网议 TCP/IP&#xff08;Transmission Control Protocol/Internet Protocol&#xff0c;传输控制协议/互联网协议&#xff09;是互联网的基本协议&#xff0c;也是国际互联网络的基础。它不仅仅是一个协议&#xff0c;而是一个协议族&#…

Github 2024-10-13php开源项目日报 Top10

根据Github Trendings的统计,今日(2024-10-13统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量PHP项目10Vue项目2JavaScript项目1TypeScript项目1Blade项目1Coolify: 开源自助云平台 创建周期:1112 天开发语言:PHP, Blade协议类型:Apach…

算法题总结(十四)——贪心算法(上)

贪心算法 什么是贪心 贪心的本质是选择每一阶段的局部最优&#xff0c;从而达到全局最优。 贪心的套路&#xff08;什么时候用贪心&#xff09; 刷题或者面试的时候&#xff0c;手动模拟一下感觉可以局部最优推出整体最优&#xff0c;而且想不到反例&#xff0c;那么就试一试…

Vscode+Pycharm+Vue.js+WEUI+django火锅(五)Django的API

如果只是嫌弃Djanggo的前台不好&#xff0c;用vue替换&#xff0c;只要在Djanggo项目里面创建一个Vue项目文件夹&#xff0c;然后 1.修改urls.py 修改路由 2.修改settings.py中&#xff0c;增加templates内容指向vue文件夹 3.静态文件staticfile_dir中也添加vue文件夹 但因为我…

深圳大学-Java程序设计-选实验3 包及继承应用

实验目的与要求&#xff1a; 实验目的&#xff1a;熟悉面向对象编程中package,import等语句的使用。 实验要求&#xff1a; (1).编写一个计算机与软件学院类CSSE、一个研究所/中心类Institute和一个教学系类Department。CSSE类中包含有多个Institute类的实例和多个Department…

信息技术 04 WPS文字处理 图书订购单

信息技术 04 WPS文字处理 图书订购单 素材下载 信息技术 04 WPS文字处理 图书订购单链接&#xff1a;https://pan.baidu.com/s/1_S9HMfmiC6JJcjk4nO-tKg?pwdi304 提取码&#xff1a;i304 成品样图 题目 任务实现具体要求如下&#xff1a; ① 根据设计好的表格的结构&#…

基于 PyQt5 和 Matplotlib 的医学图像处理应用开发

1. 引言 在医学领域&#xff0c;图像处理是一项非常重要的技术&#xff0c;特别是在医学成像&#xff08;如MRI、CT扫描等&#xff09;的数据处理上&#xff0c;可以帮助医生更加准确地进行诊断。本项目基于 Python 的 PyQt5 图形用户界面框架与 Matplotlib 数据可视化库&…

Variational Auto-Encoder(VAE)缺少数学推导未完结版

VAE是Diffusion的基础&#xff0c;在其中将输入的图片数据编码到潜在空间后再解码出来。 略显复杂&#xff0c;博主结合李宏毅视频、网上一些讲解以及自己的理解将其总结如下&#xff1a; 一、什么是VAE VAE&#xff08;变量自编码器&#xff09;最早在以上两篇文章被提出。 …

yakit使用教程(四,信息收集)

本文仅作为学习参考使用&#xff0c;本文作者对任何使用本文进行渗透攻击破坏不负任何责任。 前言&#xff1a;yakit下载安装教程。 一&#xff0c;基础爬虫。 在新建项目或新建临时项目后&#xff0c;点击安全工具&#xff0c;点击基础爬虫。 此工具并不是为了爬取网站上的一…

【零散技术】MAC 安装多版本node

时间是我们最宝贵的财富,珍惜手上的每个时分 不同前端项目运行的node版本不一致&#xff0c;会导致无法运行&#xff0c;就像Odoo也需要依据版本使用对应的python环境。python 可以用 conda随时切换版本&#xff0c;那么Node可以吗&#xff1f;答案是肯定的。 1、安装 n&#x…

k8s-资源管理、实战入门

资源管理 一、资源管理介绍 在kubernetes中&#xff0c;所有的内容都抽象为资源&#xff0c;用户需要通过操作资源来管理kubernetes。 &#xff08;1&#xff09;kubernetes的本质上就是一个集群系统&#xff0c;用户可以在集群中部署各种服务&#xff0c;所谓的部署服务&…

SpringBoot高校学科竞赛平台:性能优化与实践

3系统分析 3.1可行性分析 通过对本高校学科竞赛平台实行的目的初步调查和分析&#xff0c;提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本高校学科竞赛平台采用SSM框架&#xff0c;JAVA作为开发语…

详细分析Redisson分布式锁中的renewExpiration()方法

目录 一、Redisson分布式锁的续期 整体分析 具体步骤和逻辑分析 为什么需要递归调用&#xff1f; 定时任务的生命周期&#xff1f; 一、Redisson分布式锁的续期 Redisson是一个基于Redis的Java分布式锁实现。它允许多个进程或线程之间安全地共享资源。为了实现这一点&…