【PyTorch学习4】《PyTorch深度学习实践》——线性回归(Linear Regression)

news2024/11/17 23:42:14

目录

    • 一、实现框架
    • 二、程序实现
    • 三、代码讲解
      • 1.`self.linear = torch.nn.Linear(1, 1)`
      • 2.`model(x_data)`
      • 3.`criterion = torch.nn.MSELoss(reduction='sum'),loss = criterion(y_pred, y_data)`

一、实现框架

1、Prepare dataset
2、Design model using Class (inherit from nn.Module)
3、Construct loss and optimizer (using PyTorch API)
loss是为了计算损失,optimizer是为了优化参数
4、Training cycle (forward,backward,update)

二、程序实现

import torch

# prepare dataset
# x,y是矩阵,3行1列 也就是说总共有3个数据,每个数据只有1个特征
x_data = torch.tensor([[1.0], [2.0], [3.0]])
y_data = torch.tensor([[2.0], [4.0], [6.0]])

# design model using class
"""
our model class should be inherit from nn.Module, which is base class for all neural network modules.
member methods __init__() and forward() have to be implemented
class nn.linear contain two member Tensors: weight and bias
class nn.Linear has implemented the magic method __call__(),which enable the instance of the class can
be called just like a function.Normally the forward() will be called 
"""

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        # (1,1)是指输入x和输出y的特征维度,这里数据集中的x和y的特征都是1维的
        # 该线性层需要学习的参数是w和b  获取w/b的方式分别是~linear.weight/linear.bias
        self.linear = torch.nn.Linear(1, 1)

    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred

model = LinearModel()
print(model)

# construct loss and optimizer
# criterion = torch.nn.MSELoss(size_average = False)
criterion = torch.nn.MSELoss(reduction='sum') # 误差平方和
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# training cycle forward, backward, update
for epoch in range(100):
    y_pred = model(x_data)  # forward:predict
    loss = criterion(y_pred, y_data)  # forward: loss
    print(epoch, loss.item())

    optimizer.zero_grad()  # 梯度清零
    loss.backward()  # backward: autograd,自动计算梯度
    optimizer.step()  # update参数,即更新w和b的值

print('w = ', model.linear.weight.item()) #将w以数值打印,只能是一个数的张量可以用.item()
print('b = ', model.linear.bias.item())

x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

模型:
在这里插入图片描述
输出(部分截图):
在这里插入图片描述

三、代码讲解

1.self.linear = torch.nn.Linear(1, 1)

torch.nn.Linear是一个类,(1,1)分别表示该线性层的输入和输出维度
来看下面这个例子(输入size是128x20,经过一个20x30的线性层,输出为128x30):

import torch.nn as nn
import torch

m = nn.Linear(20, 30)

print(type(m))
print(type(nn.Linear(20, 30)))

input = torch.randn(128, 20)
output = m(input)
print(output.size())

在这里插入图片描述
更多理解看这里:https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear

2.model(x_data)

       先讲一下torch.nn.Module,Base class for all neural network modules。Your models should also subclass this class.所有神经网络的基类,你的模型应该将其子类化。就是一般我们会继承这个类,主要完成初始化模型,然后前向传播等一些东西。
       来看一段重要代码和输出:

import torch
x_data = torch.tensor([1.0])

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1, bias=False)

    def forward(self, x):
        n = x
        y_pred = self.linear(x)
        return y_pred,x

model = LinearModel()
print(type(model))
print(type(LinearModel()))

print(model.forward(x_data))
print(model(x_data))
print(model.__call__(x_data))

在这里插入图片描述
       当你定义完LinearModel这个类后,你需要将其实例化为对象(当然对象也是类哈),才能进行调用。这个类需要一个参数,意思是,model = LinearModel(),model(x_data)才可以使用;而不可以直接 LinearModel(x_data)
       另外,model.forward(x_data),model(x_data),model.__call__(x_data)这三个输出一模一样,这和torch.nn.Module的内部封装有关系,具体可以看官方源代码:https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module

3.criterion = torch.nn.MSELoss(reduction='sum'),loss = criterion(y_pred, y_data)

这里只需注意一点,就是sum还是mean,sum是预测值与真实值平方和,而mean只不过求了个平均值。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/354397.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Leedcode】顺序表必备的三道面试题(附图解)

顺序表必备的三道面试题(附图解) 文章目录顺序表必备的三道面试题(附图解)前言一、第一题1.题目2.思路图解3.源码二、第二题1.题目2.思路图解3.源码三、第三题1.题目2.思路图解3.源码总结前言 本文给大家介绍三道顺序表学习过程中…

【项目精选】javaEE土地档案管理系统(源码+论文+视频)

技术:java、jsp、struts、spring、hibernate 数据库:oracle 集成开发工具:eclipse 点击下载源码 本土地项目管理系统在可行性研究的基础上,是为了进一步明确土地项目管理系统的软件需求,以便安排项目规划和进度&#x…

ARM+LINUX嵌入式学习路线

嵌入式学习是一个循序渐进的过程,如果是希望向嵌入式软件方向发展的话,目前最常见的是嵌入式Linux方向,关注这个方向,大概分3个阶段: 1、嵌入式linux上层应用,包括QT的GUI开发 2、嵌入式linux系统开发 3、…

Python|每日一练|字符串|递归|链表|单选记录:字符串转换整数 (atoi)|阶乘后的零| K 个一组翻转链表

1、字符串转换整数 (atoi)(字符串) 请你来实现一个 myAtoi(string s) 函数,使其能将字符串转换成一个 32 位有符号整数(类似 C/C 中的 atoi 函数)。 函数 myAtoi(string s) 的算法如下: 读入字符串并丢弃…

Golang 方法笔记

Golang中的方法是作用在指定的数据类型上的,因此自定义类型都可以有方法。方法定义func (recevier type) methodName (参数列表) (返回值列表) {方法体return 返回值}基本申明和调用type A struct {Num int}func (a A) test() {fmt.Println(a.Num)}说明:…

Python快速上手系列--邮件发送--详解篇

本章就来一起学习一下跑完自动化脚本后如何自动的发送邮件到指定的邮箱。zmail操作:1. 导包 import zmail2. 邮件内容,包含:主题(subject)、正文(content_text)、附件(attachments)3. 发件人信息,包含:发件人账号&…

什么牌的运动耳机比较好、运动耳机排行榜10强

现在运动健身的潮流持续不下,而且人们长期坐于办公室办公,严重影响身体的健康,这时不论是去健身房锻炼,还是户外跑步都是非常必要的了,而蓝牙耳机作为运动必备的一款数码产品,更是受到了大家的青睐&#xf…

2023/02/18 ES6对象属性的解读

1 属性的可枚举性和遍历 <script>const obj {userName: zhaoshuai-lc,userAge: 26,userSex: male}let res Object.getOwnPropertyDescriptors(obj);console.log(res); </script>描述对象的 [ enumerable ] 属性, 称为“可枚举性”, 如果该属性为 [ false ], 就表…

如何实现外网访问内网ip?公网端口映射或内网映射来解决

本地搭建服务器应用&#xff0c;在局域网内可以访问&#xff0c;但在外网不能访问。如何实现外网访问内网ip&#xff1f;主要有两种方案&#xff1a;路由器端口映射和快解析内网映射。根据自己本地网络环境&#xff0c;结合是否有公网IP&#xff0c;是否有路由权限&#xff0c;…

0.4如何使用cmake来管理项目

如何使用cmake来管理项目 【opencv源码解析0.1】VS如何优雅的配置opencv环境 【opencv源码解析0.2】如何编译opencv库源码 【opencv源码解析0.3】调试opencv源码以及使用cmake来管理项目 前面几篇文章我们都是围绕Visual Studio 2019这个IDE来展开的&#xff0c;IDE为我们做了…

【OJ】小熊猫的肉质品

&#x1f4da;Description: 自从可爱的小熊猫来到浙商大后便再也不想吃那些膳食纤维了&#xff0c;比如&#xff1a;竹子。所以&#xff0c;骞哥&#xfeff;不得不帮助国宝寻找一些肉类来维持能量&#xff0c;使得小熊猫不至于饿死在全球某工商。但是&#xff0c;你知道的淘…

Spring Cloud Alibaba--seata微服务详解之分布式事务(三)

上篇讲述gateway的部署和使用&#xff0c;gateway统一管理和转发了HTTP请求&#xff0c;在互联网中大型项目一定存在复杂的业务关系&#xff0c;尤其在商城类软件中如淘宝、PDD等商城&#xff0c;尤其在秒杀场景中&#xff0c;并发量可以到达千万级别&#xff0c;此时数据库就会…

第五十六章 树状数组(一)

第五十六章 树状数组一、前缀和的缺陷二、树状数组1、作用2、算法分析3、算法实现&#xff08;1&#xff09;lowbits()&#xff08;2&#xff09;插入&#xff08;3&#xff09;查询三、例题1、问题题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1提示2、代码一、前缀和…

【SCL】1200应用项目: 四节传送带机械手模拟控制

使用SCL编写小应用:四节传送带模拟和机械手模拟控制 文章目录 目录 文章目录 前言 一、四节传送带模拟控制 1.控制要求 2.参考图 3.接线图和I/0分配 4.编写程序 1.逆序启动 2.顺序停止 3.故障输入 4.调试 5.完整代码 二、机械手控制 1.控制要求 2. 接线图和I/0分配 3.编写程序 …

JVM内存模型深度剖析与优化

1. Java语言的跨平台特性 2. JVM整体结构及内存模型 堆存放着对象信息每个线程都会分配一块属于自己的内存空间&#xff08;栈空间&#xff09; 每个方法都会分配一块内存空间&#xff08;栈桢&#xff09;&#xff0c;上图 compute()方法 和 main()方法 都会分配到各自的栈桢空…

git 学习笔记

Git 是 Linus Torvalds 为了帮助管理 Linux 内核开发而开发的一个开放源码的版本控制软件,可以敏捷高效地处理任何或小或大的项目。Git 与常用的版本控制工具 CVS, Subversion 等不同&#xff0c;它采用了分布式版本库的方式&#xff0c;不需要服务器端软件支持。 一、安装配置…

线性神经网络(sotfmax回归)

sotfmax回归定义网络架构softmax运算softmax回归实现&#xff08;MNIST数据集&#xff09;数据集的处理读取数据集查看形状数据可视化读取小批量整合所有组件神经网络的搭建加载数据集初始化模型参数定义softmax函数定义模型定义损失函数&#xff08;难点&#xff09;分类精度训…

有序表之跳表

文章目录1、前言2、跳表简介3、理解“跳表”4、用跳表查询到底有多快5、跳表是不是很浪费内存6、高效的动态插入和删除7、跳表索引动态更新8、跳表代码实现1、前言 在开始讲解跳表之前&#xff0c;先来说一说积压结构。 何为积压结构&#xff1f;就是当数据达到了一定程度&am…

【ROS2实践】Vmware17下安装ubuntu22.04和ros2-humble

一、简介 ROS2-foxy已经不再维护&#xff0c;ROS2-humble成为主角&#xff0c;因而该转变一下开发场景了。如何安装&#xff1f;官方文档没有错&#xff0c;然而&#xff0c;照着做却无法进行。实超中遇到的需要变通的地方&#xff0c;官网是不给你提供解决的&#xff0c;本文给…

宽刈幅干涉高度计SWOT(Surface Water and Ocean Topography)卫星进展(待完善)

以下信息搬运自SWOT官方网站等部分文献资料&#xff0c;如有侵权请联系&#xff1a;sunmingzhismz163.com 排版、参考文献、部分章节待完善 概况 2022年12月16日地表水与海洋地形卫星SWOT (Surface Water and Ocean Topography)在加利福尼亚州范登堡航天基地由SpaceX猎鹰9号(Sp…