Pytorch深度学习笔记(六)用pytorch实现线性回归

news2024/9/20 5:50:14

目录

1.数据准备

2.设计模型

3.构造损失函数和优化器

4.训练周期(前馈—>反馈—>更新)


课程推荐:05.用PyTorch实现线性回归_哔哩哔哩_bilibili

线性通常是指变量之间保持等比例的关系,从图形上来看,变量之间的形状为直线,斜率是常数。

当要预测的变量 y 输出集合是无限且连续,我们称之为回归。比如,天气预报预测明天是否下雨,是一个二分类问题;预测明天的降雨量多少,就是一个回归问题。

1.数据准备

在pytorch中,计算图使用mini batch方式绘制,所以x和y是n*1的向量张量。

mini batch详解,如图,线性模型的x,y都是3*1的向量张量。

2.设计模型

model = LinearModel(x_data),用于自定义计算模块,复写forward(),返回预测值

torch.nn.Linear(in_features,out_features,bias=True),in_features,out_features为输入样本和输出样本的数量

 设计模型代码:

import torch

# 生成两个3*1的向量张量
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

# 继承torch.nn.Module,定义自己的计算模块
class LinearModel(torch.nn.Module):
    # 构造函数
    def __init__(self):
        # 调用父类构造
        super(LinearModel, self).__init__()
        # Linear对象包含两个成员张量weight和bias
        # 定义输入样本和输出样本的维度
        self.linear = torch.nn.Linear(1, 1)

    # 前馈函数
    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred

# 实例化自定义模型
model = LinearModel()

3.构造损失函数和优化器

torch.nn.MSELoss(size_average=True, reduce=True),size_average是否要求均值,reduce结果是否要降维

torch.optim.SGD(params, lr=required, momentum=0, dampening=0,weight_decay=0, nesterov=False),params参数,lr学习率

# 实例化损失函数,返回损失值
criterion = torch.nn.MSELoss(size_average=False)
# 实例化优化器,优化权重w
# model.parameters(),取出模型中的参数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

 不同的优化器:

4.训练周期(前馈—>反馈—>更新)

一轮训练:

①获得预测值
②获得损失值
③梯度归零
④反向传播
⑤更新权重w

完整代码:

import torch

# 生成两个3*1的向量张量
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

# 继承torch.nn.Module,定义自己的计算模块
class LinearModel(torch.nn.Module):
    # 构造函数
    def __init__(self):
        # 调用父类构造
        super(LinearModel, self).__init__()
        # 实例化的Linear对象包含两个成员张量weight和bias
        # 定义输入样本和输出样本的维度
        self.linear = torch.nn.Linear(1, 1)

    # 前馈函数
    def forward(self, x):
        # 返回x线性计算后的预测值
        y_pred = self.linear(x)
        return y_pred

# 实例化自定义模型,返回预测值
model = LinearModel()
# 实例化损失函数,返回损失值
criterion = torch.nn.MSELoss(size_average=False)
# 实例化优化器,优化权重w
# model.parameters(),取出模型中的参数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

for epoch in range(100):
    # 获得预测值
    y_pred = model(x_data)
    # 获得损失值
    loss = criterion(y_pred, y_data)
    # 不会产生计算图,因为__str()__
    print(epoch, loss)
    # 梯度归零
    optimizer.zero_grad()
    # 反向传播
    loss.backward()
    # 更新权重w
    optimizer.step()

# 打印权重和偏值
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]])
# 测试模型
y_test = model(x_test)
print('y_pred = ', y_test.data)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/448447.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

为什么要学习微服务?

文章目录 1.认识微服务1.1微服务由来1.2为什么需要微服务? 2.两种架构2.1.单体架构2.2.分布式架构 3.微服务的特点4.SpringCloud5.总结最后说一句 1.认识微服务 随着互联网行业的发展,对服务的要求也越来越高,服务架构也从单体架构逐渐演变为…

类和对象(上篇)

类和对象----上篇 🔆面向过程和面向对象的初步认识🔆类的引入🔆类的定义🔆类的访问限定符及封装访问限定符封装 🔆类的作用域🔆类的实例化🔆类的对象大小的计算如何计算一个类的大小结构体内存对…

15天学习MySQL计划(多表联查)第四天

15天学习MySQL计划(多表联查)第四天 1.多表查询 1.1概述 ​ 指从多张表中查询数据 ​ 在项目开发中,在进行数据库表结构设计时,会根据业务需求及业务模块之间的关系,分析并设计表结构,由于业务之间相互…

【HCIP】Huawei设备下IPV4IPV6共存实验

目录 方法一、普通的GRE将V6基于V4通讯 方法二、6to4的tunnel 方法三、双栈 方法一、普通的GRE将V6基于V4通讯 //方法一和方法二的前提,搭个简单的V4网络就行 [r1]int g0/0/0 [r1-GigabitEthernet0/0/0]ip address 12.1.1.1 24 [r1]router id 1.1.1.1 [r1-Gigabi…

Spring Security 05 密码加密

目录 DelegatingPasswordEncoder 使用 PasswordEncoder 密码加密实战 密码自动升级 实际密码比较是由PasswordEncoder完成的,因此只需要使用PasswordEncoder 不同实现就可以实现不同方式加密。 public interface PasswordEncoder {// 进行明文加密String encod…

如何搭建自己的博客网站(手把手教你搭建免费个人博客网站)

没有前言直接开始正文,搭建一个博客需要服务器,域名,博客程序。 博客程序常用的有wordpress,z-blog,typecho等等,其中wordpress和z-blog最为简单,typecho需要一定的技术含量,这里暂…

使用NPOI做Excel简单报表

文章目录 前言初版表格,单元格的合并进阶表格,单元格美化小结 前言 前面介绍了NPOI单元格样式的使用、单元格合并,以及NPOI提供的颜色。现在用上前面的一些知识点,做一个测试结果表格。 1、 介绍NPOI 的颜色卡、名称以及索引 ht…

图片去摩尔纹简述实现python代码示例

这篇文章主要为大家介绍了图片去摩尔纹简述实现的python代码示例,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪 1、前言 当感光元件像素的空间频率与影像中条纹的空间频率接近时,可能产生…

分布式系统需要关注的技术点和面试经常问的点

1、分布式系统概述 关于什么是分布式系统,有很多文章介绍,其实这个并不难理解,大白话讲就是:工厂活多了一个人撑不住,那就多找些工人一起干,要让这么多人为了一个目标干得快干得好,就需要一些规…

自主可控智能网联汽车操作系统

开发自主可控智能网联汽车操作系统的必要性 当下,传统汽车操作系统行业的核心技术几乎由国外的黑莓、谷歌、风河、Vector、ETAS等垄断。操作系统已成为我国智能网联汽车发展过程中的关键卡脖子技术,开发自主可控的智能网联汽车操作系统势在必行。 操作…

CVPR 2023 | 达摩院REALY头部重建榜单冠军模型HRN解读

团队模型、论文、博文、直播合集,点击此处浏览 前言 高保真 3D 头部重建在许多场景中都有广泛的应用,例如 AR/VR、医疗、电影制作等。尽管大量的工作已经使用 LightStage 等专业硬件实现了出色的重建效果,从单一或稀疏视角的单目图像估计高精…

微服务架构设计与实践

随着互联网的发展,软件开发已经成为各种企业发展的重要手段。然而,单体应用在长时间的维护中会变得复杂、难以扩展、难以修改。因此,为了满足业务需求,微服务架构应运而生。本篇文章将深入探讨微服务架构的设计与实践。 一、微服务…

C++中的类与对象

类与对象 我们在C语言中自定义的struct 叫做结构体,而在C中我们把struct升级为了类,并且还加入了一个class,也称为类,那么我们今天就来看一下结构体和类的不同和相同 1.结构体与类 我们在C语言中的结构体是struct,而…

QT学习笔记(持续更新)

QT 一、按钮 1.效果 2.代码 #include<QPushButton>//头文件myWidget::myWidget(QWidget *parent): QWidget(parent) {//方法1QPushButton *btnnew QPushButton;//btn->show();//以顶层方式显示btn->setParent(this);//在myWidget窗口中btn->setText("按钮…

JS编程中的API hook

JavaScript奇技淫巧&#xff1a;Hook与反Hook 作者&#xff1a;专注于JS混淆加密的 JShaman API HOOK技术&#xff0c;在PC时代曾盛行&#xff0c;是高端的技术。在JavaScript编程中&#xff0c;也可以应用API Hook技术实现不寻常的效果。 例&#xff0c;eval hook&#xff1a…

Kotlin 基础 笔记

这里写目录标题 变量函数条件语句if/else 语句when 语句if/else 表达式 和 when 表达式 Kotlin 中的null使用 ?: Elvis 运算符 类和对象构造函数类之间的关系可见性修饰符定义属性委托 变量 变量是存储单项数据的容器&#xff0c;必须先声明变量&#xff0c;才可以使用。 常见…

Centos7.5 如何安装Bacula 11.05详细教程

环境: 本地华为桌面云服务器环境 Centos 7.5 Bacula 11.05 问题描述: Centos7.5 如何安装Bacula 解决方案: 一、官网下载Bacula 1.下载Source Files11.0.5 2.先安装C 和 C++ 编译器 root@localhost ~]# yum install -y gcc gcc-c++ 1 已安装: 2 gcc.x86_64 0:…

梦想云图Node.JS服务(2023.4.19)

说明 后台提供梦想Node.JS服务&#xff0c;方便调用控件后台功能&#xff0c;Windows服务程序所在目录:Bin\MxDrawServer\Windows&#xff0c;Linux服务程序所在目录:Bin\Linux\MxDrawServer 启动服务 Windows:进入Bin\MxDrawServer\Windows目录&#xff0c;运行start.bat启动…

redis原理及进化之路

Redis 的主从复制经历了多次演进&#xff0c;本文将从最基本的原理和实现讲起&#xff0c;并层层递进&#xff0c;逐步呈现 Redis 主从复制的演进历史。大家将了解到 Redis 主从复制的原理&#xff0c;以及各个改进版本解决了什么问题&#xff0c;并最终看清 Redis 7.0 主从复制…

vue+uniapp疫苗预约接种系统 微信小程序

统计分析&#xff1a;查看用户&#xff0c;疫苗&#xff0c;订单数量&#xff1b;统计近7日&#xff0c;30日订单趋势 用户管理&#xff1a;查看注册用户信息&#xff0c;及删除&#xff08;数据库mysql) 疫苗管理&#xff1a;疫苗增删改查以及上下架 接种点管理&#xff1a;接…