【PyTorch】课堂测试一:线性回归的求解

news2024/9/21 16:27:49

作者🕵️‍♂️:让机器理解语言か

专栏🎇:PyTorch

描述🎨:PyTorch 是一个基于 Torch 的 Python 开源机器学习库。

寄语💓:🐾没有白走的路,每一步都算数!🐾 

介绍💬

        这个是我们的第一次课堂测试,共有四个挑战,本测试需要你利用前面所学到的 PyTorch 知识,完成线性回归问题的求解,时间为30min。(文末附有参考答案,请大家认真作答再自行校对!)

知识点📜

  • 损失的定义
  • 优化器的定义
  • 模型的训练

线性回归的求解

        首先,让我们来模拟一下,线性回归所需的数据集合:

import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt
%matplotlib inline

X_numpy, y_numpy = datasets.make_regression(
    n_samples=100, n_features=1, noise=20, random_state=4)
plt.plot(X_numpy, y_numpy, 'ro')

        如上,我们初始化了一个数据集合。从图中可以看出,该数据集合大致上呈线性分布。

本挑战的目的就寻找一个良好的函数表达式(又叫做模型),该函数表达式能够很好的描述上面数据点的分布,即对上面数据点进行拟合。

        在使用 PyTorch 求解模型之前,我们需要将上面的数据集转为 PyTorch 认识的张量。

🚩挑战①:将 X_numpy,y_numpy 转为张量 。

📝要求转换后的张量用 X,y 表示。

🔔提示需要利用 tensor.view() 将 y 的维度转 为 2 维。

import torch
import torch.nn as nn

# 编写代码处


# 测试代码
X.size(), y.size()

重要说明

        本课程中,你需要自行补充上方单元格中缺失的代码并运行,如果输出结果和下方的期望输出结果一致,即代表此挑战顺利通过。完成全部内容后,点击「提交检测」即可通过,此说明后续不再出现。

 ✅ 期望输出

(torch.Size([100, 1]), torch.Size([100, 1]))

根据上面图像中数据点的分布情况,我们可以看出,该问题的解决模型应该是一个线性函数模型。接下来让我们使用 PyTorch 来初始化这个线性模型。

🚩 挑战②:线性函数模型的定义 。

📝 要求 :用 model 变量表示线性函数模型。

### 补充代码 ###


# 测试代码
model

期望输出

Linear(in_features=1, out_features=1, bias=True)

定义完模型后,接下来,让我们来定义学习率、损失函数和优化器。

🚩 挑战③:利用 PyTorch 定义学习率、损失函数和优化器 。

📝 要求 :损失采用均方差损失,学习率取 0.01 。

### 补充代码 ###


# 测试代码
optimizer

 期望输出

SGD (
Parameter Group 0
    dampening: 0
    lr: 0.01
    momentum: 0
    nesterov: False
    weight_decay: 0
)

        最后让我们进行模型的训练,即将数据传入模型中,然后利用梯度下降算法不断的迭代,找到最佳的模型。

🚩 挑战④:利用 PyTorch 训练线性模型 。

📝 提示 :可以循环迭代 100 次左右 。

### 补充代码 ###


# 测试代码:将通过模型预测出来的值展示到图像中
# 预测结果并转为 NumPy 的形式
predicted = model(X).detach().numpy()

plt.plot(X_numpy, y_numpy, 'ro')
plt.plot(X_numpy, predicted, 'b')
plt.show()

参考答案

本挑战的参考答案如下:

挑战 1 的参考答案

X = torch.from_numpy(X_numpy.astype(np.float32))
y = torch.from_numpy(y_numpy.astype(np.float32))
y = y.view(y.shape[0], 1)

挑战 2 的参考答案

n_samples, n_features = X.shape
input_size = n_features
output_size = 1
model = nn.Linear(input_size, output_size)

挑战 3 的参考答案

learning_rate = 0.01
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) 

挑战 4 的参考答案

num_epochs = 100
for epoch in range(num_epochs):
    # Forward pass and loss
    y_predicted = model(X)
    loss = criterion(y_predicted, y) 
    # Backward pass and update
    loss.backward()
    optimizer.step()
    # zero grad before new step
    optimizer.zero_grad()
    if (epoch+1) % 10 == 0:
        print(f'epoch: {epoch+1}, loss = {loss.item():.4f}')

实验总结

        通过对线性问题的求解,我想你已经了解了如何利用 PyTorch 训练模型的整个过程。当然,本次挑战只是利用梯度下降算法进行了简单的线性回归。在下一个实验中,我们会尝试使用该算法进行非线性问题的求解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/440766.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何在自定义数据集上训练YOLOv8的各个模型

YOLOv8效果图(可以应用到图片和视频): 四个模式命令 yolo taskdetect modepredict modelmodel/yolov8n.pt sourceinput/test.mp4 showTrueyolo tasksegment modepredict modelmodel/yolov8x-seg.pt sourceinput/zidane.jpg showTrueyolo tas…

JavaSE-part2

文章目录 Day07 IO流1.IO流1.1背景介绍1.2File类1.2.1常用方法 1.3IO流原理1.4IO流的分类1.4.1InputStream 字节输入流1.4.1.1FileInputStream1.4.1.2FileOutPutStream1.4.1.3练习 1.4.2Reader and Writer1.4.2.1FileReader1.4.2.2FileWriter 1.4.3节点流和处理流1.4.3.1处理流…

MSNet网络结构与代码搭建深入解读

模型结构 1、首先,将多光谱遥感图像的波段分为可见光和不可见光两组,然后进行分组同步特征提取; 代码 先看总体结构,主代码 __init__定义了声明MSNet模型有哪些类,MSNet的forward方法规定数据如何在层之间流动。 1、首先是获得图片的输入尺寸input_size = (rgbnnd.size(…

Python数据结构与算法-动态规划(钢条切割问题)

一、动态规划(DP)介绍 1、从斐波那契数列看动态规划 (1)问题 斐波那契数列递推式: 练习:使用递归和非递归的方法来求解斐波那契数列的第n项 (2)递归方法的代码实现 import time # 递…

Spark----RDD(弹性分布式数据集)

RDD 文章目录 RDDRDD是什么?为什么需要RDD?RDD的五大属性WordCount中的RDD的五大属性如何创建RDD?RDD的操作两种基本算子/操作/方法/API分区操作重分区操作聚合操作四个有key函数的区别 关联操作排序操作 RDD的缓存/持久化cache和persistchec…

Java学习-MySQL-DQL数据查询-联表查询JOIN

Java学习-MySQL-DQL数据查询-联表查询JOIN 1.分析需求,查找那些字段 2.分析查询的字段来自哪些表 3.确定使用哪种连接查询 4.确定交叉点 5.确定判断条件 操作描述inner join返回左右表的交集left join返回左表,即使右表没有right join返回右表&#xf…

iptables深度总结--基础篇

iptables 五表五链 链:INPUT OUTPUT FORWARD PREROUTING POSTROUTING 表:filter、nat、mangle、raw、security 数据报文刚进网卡,还没有到路由表的时候,先进行了prerouting,进入到路由表,通过目标地址判…

FFMPEG 关于smaple_fmts的理解及ffplay播放PCM

问题 当我将一个aac的音频文件解码为原始的PCM数据后,使用ffplay播放测试是否成功时,需要提供给ffplay 采样率,通道数,PCM的格式类型 3个参数,否则无法播放! 所以使用ffprobe 查看原来的aac文件信息&…

Python手写板 画图板 签名工具

程序示例精选 Python手写板 画图板 签名工具 如需安装运行环境或远程调试&#xff0c;见文章底部个人QQ名片&#xff0c;由专业技术人员远程协助&#xff01; 前言 这篇博客针对<<Python手写板 画图板 签名工具>>编写代码&#xff0c;代码整洁&#xff0c;规则&am…

别再回答面试官,toFixed采用的是四舍五入啦!

四舍五入大家都知道&#xff0c;但你知道银行家舍入法么&#xff1f;你知道JS里的toFixed实现用的是哪种吗&#xff1f; 前两天我写了篇《0.1 0.2 不等于 0.3&#xff1f;原来是因为这个》&#xff0c;大概就是说&#xff0c;0.1 0.2不等于0.3是因为浮点数精度问题。 结果在…

LinkedList 的特点及优缺点

现在来讲 LinkedList LinkedList 是链表集合&#xff0c;基于链表去存储数据&#xff0c;每一个数据视作一个节点 private static class Node<E> {// 存放的数据E item;// 下一个节点Node<E> next;// 上一个节点Node<E> prev;Node(Node<E> prev, E ele…

【unity实战】2D横版实现人物移动跳跃2——用对象池设计制作冲锋残影的效果(包含源码)

基于上一篇人物移动二段跳进一步优化完善 先看看最终效果 什么是对象池? 在Unity中,对象池是一种重复使用游戏对象的技术。使用对象池的好处是可以减少游戏对象的创建和销毁,从而提高游戏的性能。如果不使用对象池,每次需要创建游戏对象时,都需要调用Unity的Instantiate函…

国内几大技术网站,你最爱和哪个玩耍?

所谓“物以类聚&#xff0c;人以群分” 所谓“士为知己者死&#xff0c;女为悦己者容” 所谓“世上的乌鸦都一般黑&#xff0c;鸽子却各有各的白” CSDN&#xff0c;掘金&#xff0c;博客园等&#xff0c;说起来都是“技术”社区&#xff0c;每个却都有着不同的姿色和用处。至于…

初识Spring——IoC及DI详解

目录 一&#xff0c;什么是Spring Spring设计核心 Spring核心定义 Spring官网 二&#xff0c;什么是IoC IoC思想 控制权的反转 三&#xff0c;什么是DI DI的定义 DI和IoC的关系 一&#xff0c;什么是Spring Spring设计核心 我们常说的Spring其实指的是Spring Framewo…

ABP vNext电商项目落地实战(一)——项目搭建

一、落地条件&#xff1a; 1. .NET5版本 2. DDD 3. ABP vNext 4.ABP CLI &#xff08;ABP的命令行工具&#xff0c;包括ABP的各种模板&#xff09; 5.SQL Server 写在前面&#xff1a;我觉得这个框架的文件分层很凌乱&#xff0c;在企业的实际业务场景中&#xff0c;一般…

vscode+git浅尝

git 安装git以后初始化仓库分支重命名合并分支连接远程仓库推送项目 安装git以后 第一次使用git需要配置用户名和邮箱 任意处打开git终端&#xff0c;譬如鼠标右击点击git bash here 命令分别为&#xff1a; 设置用户名和邮箱 git config --global user.name “username” …

【QA】Python代码调试之解决Segmentation fault (core dumped)问题

Python代码调试之解决Segmentation fault 问题 问题描述排查过程1. 定位错误&#xff0c;2. 解决办法 参考资料 问题描述 Python3执行某一个程序时&#xff0c;报Segmentation fault (core dumped)错&#xff0c;且没有其他任何提示&#xff0c;无法查问题。 Segmentation fa…

jenkins gitlab asp.net core持续集成

什么是jenkins Jenkins直接取自其官方文档&#xff0c;是一个独立的开源自动化服务器&#xff0c;您可以使用它来自动执行与构建、测试、交付或部署软件相关的各种任务。 jenkins可以干什么 Jenkins 通过自动执行某些脚本来生成部署所需的文件来工作。这些脚本称为JenkinsFi…

叶酸聚乙二醇羟基FA-PEG-OH;了解高分子试剂 Folate-PEG-OH

FA-PEG-OH&#xff0c;叶酸-聚乙二醇-羟基 中文名称&#xff1a;叶酸聚乙二醇羟基 英文名称&#xff1a;FA-PEG-OH HO-PEG-FA Folate-PEG-OH 性状&#xff1a;黄色液体或固体&#xff0c;取决于分子量 溶剂&#xff1a;溶于水&#xff0c;DMSO、DMF等常规性有机溶剂 活性基…

【NestJs】使用连接mysql企业级开发规范

本篇将介绍如何建立 NestJs 的数据库连接、并使用数据库联表查询。 简介 Nest 与数据库无关&#xff0c;允许您轻松地与任何 SQL 或 NoSQL 数据库集成。根据您的偏好&#xff0c;您有许多可用的选项。一般来说&#xff0c;将 Nest 连接到数据库只需为数据库加载一个适当的 No…