新手小白的pytorch学习第八弹------分类问题模型和简单预测

news2024/11/26 10:34:31

目录

  • 1 启动损失函数和优化器
  • 2 训练模型
    • 创建训练和测试循环
  • 3 预测和评估模型

这篇是接着新手小白的pytorch学习第七弹------分类问题模型这一篇的,代码也是哟~

1 启动损失函数和优化器

对于我们的二分类问题,我们经常使用 binary cross entropy 作为损失函数
可以使用torch.optim.SGD()torch.optim.Adam() 作为优化器

有两个 binary cross entropy 函数

  1. torch.nn.BCELoss()-在label(target)和features(input)之间进行衡量
  2. torch.nn.BCEWithLogitsLoss()-这个和上面这个一样,不过它有一个sigmoid嵌入层(nn.Sigmoid)[之后我们会看这个方式的]

下面我们会创建损失函数和优化器,优化器我们使用SGD,优化器使用模型的参数,学习率为0.1

import torch.nn as nn
import torch.optim as optim
# 创建一个损失函数
loss_fn = nn.BCEWithLogitsLoss() # 嵌入sigmoid()函数

# 创建一个优化器
optimizer = optim.SGD(params=model_0.parameters(),
                      lr=0.1)

我们再引入一个新的东西,评估标准,它也可以像损失函数一样,来衡量你的模型怎么样。毕竟使用多个角度来衡量模型,能够让模型更加的公正客观。Accuracy准确度,可以看出在总样本中正确样本的数量,所以100%是最好的,毕竟我们期望它全部预测对,对吧。

# 创建一个计算准确率的accuracy函数
def accuracy_fn(y_true, y_pred):
    correct = torch.eq(y_true, y_pred).sum().item() # torch.eq()计算两个相同的张量
    acc = (correct/len(y_pred))*100
    return acc

现在我们可以使用这个函数来衡量我们的模型啦。

2 训练模型

这里使用的损失函数是nn.BCEWithLogits(),因此这个损失函数的输入是logits.
什么是logits呢,我的理解就是我们的模型输出的原始值,不经过处理的值,由于这个损失函数是有一个torch.sigmoid()函数的,所以数据的转换有三个步骤:logits -> prediction probability -> prediction labels

# 查看测试数据的前5个输出
with torch.inference_mode():
    y_logits = model_0(X_test.to(device))[:5]
y_logits

tensor([[0.6003],
[0.6430],
[0.5095],
[0.6260],
[0.5431]], device=‘cuda:0’)

因为我们的模型没有被训练,因此这些输出都是随机的。
并且我们模型的原始输出是logits,这些数字难以解释,我们需要能够和真实数据相比较的数据。
我们可以使用torch.sigmoid()激活函数来将数据转换为我们需要的形式.

# 使用 torch.sigmoid() 激活函数
y_pred_probs = torch.sigmoid(y_logits)
y_pred_probs

tensor([[0.6457],
[0.6554],
[0.6247],
[0.6516],
[0.6325]], device=‘cuda:0’)

y_pred_probs 现在是 prediction probability 预测概率的形式,概率就是有多大的可能,有多大的几率。在我们的情况中,我们理想的输出是0或1,所以这些值可以被看做一个决定的边界。比如说值越靠近零,那模型就将这个样本分类为0, 值越接近1,模型就将这个样本分类为1.

更具体地说:
if y_pred_probs >= 0.5, y=1(class 1)
if y_pred_probs < 0.5, y=0(class 0)

将预测概率转变成预测标签,我们四舍五入torch.sigmoid()函数的输出即可

# 将概率转变为标签
y_preds = torch.round(y_pred_probs)

# 将刚才的过程连起来放在一起
y_preds_labels = torch.round(torch.sigmoid(model_0(X_test.to(device))[:5]))

# 查看预测值和标签相等
print(torch.eq(y_preds.squeeze(), y_preds_labels.squeeze()))

# 去掉额外的维度
y_preds.squeeze()

tensor([True, True, True, True, True], device=‘cuda:0’)
tensor([1., 1., 1., 1., 1.], device=‘cuda:0’)

y_test[:5]

y_test[:5]

创建训练和测试循环

# 设置随机种子,有利于代码的复现
torch.manual_seed(42)

epochs = 100

# 将数据放到指定的设备上
X_train, y_train = X_train.to(device), y_train.to(device)
X_test, y_test = X_test.to(device), y_test.to(device)

# 创建训练和测试循环
for epoch in range(epochs):
    # 进入训练模式
    model_0.train()
    
    # 预测
    y_logits = model_0(X_train).squeeze()
    y_pred_prob = torch.sigmoid(y_logits)
    y_pred = torch.round(y_pred_prob)
    
    # 计算损失函数和准确率
    loss = loss_fn(y_logits, 
                   y_train)
    acc = accuracy_fn(y_true = y_train, 
                      y_pred = y_pred)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # 测试
    model_0.eval()
    with torch.inference_mode():
        test_logits = model_0(X_test).squeeze()
        test_pred = torch.round(torch.sigmoid(test_logits))
        test_loss = loss_fn(test_logits, 
                            y_test)
        test_acc = accuracy_fn(y_true = y_test,
                               y_pred = test_pred)
    
    # 打印出内容
    if epoch % 10 == 0:
        print(f"Epoch:{epoch} | Loss:{loss:.5f} | Accuracy:{acc:.2f}% | Test loss:{test_loss:.2f} | Test accuracy:{test_acc:.2f}%")

Epoch:0 | Loss:0.69313 | Accuracy:51.75% | Test loss:0.69 | Test accuracy:48.50%
Epoch:10 | Loss:0.69310 | Accuracy:51.75% | Test loss:0.69 | Test accuracy:48.00%
Epoch:20 | Loss:0.69308 | Accuracy:51.25% | Test loss:0.69 | Test accuracy:49.00%
Epoch:30 | Loss:0.69307 | Accuracy:50.75% | Test loss:0.69 | Test accuracy:48.00%
Epoch:40 | Loss:0.69306 | Accuracy:50.38% | Test loss:0.69 | Test accuracy:48.00%
Epoch:50 | Loss:0.69305 | Accuracy:51.12% | Test loss:0.69 | Test accuracy:47.50%
Epoch:60 | Loss:0.69304 | Accuracy:51.12% | Test loss:0.69 | Test accuracy:48.00%
Epoch:70 | Loss:0.69303 | Accuracy:50.75% | Test loss:0.69 | Test accuracy:47.50%
Epoch:80 | Loss:0.69303 | Accuracy:50.75% | Test loss:0.69 | Test accuracy:47.00%
Epoch:90 | Loss:0.69303 | Accuracy:50.38% | Test loss:0.69 | Test accuracy:46.50%

通过上面的数据,损失函数几乎没变化,精确度50%左右,感觉模型啥也没有学到,这就意味着它分类是随机的

3 预测和评估模型

从上面的数据,感觉我们的模型好像是随机猜测,我们来可视化一下看看究竟是怎么个事儿。

我们接着会写代码下载并导入helper_functions.py script来自Learn PyTorch for Deep Learning repo.

在这里有一个叫做 plot_decision_boundary() 的函数,它来可视化我们模型的分类的不同的点

我们也会导入我们在 01 中自己写的 plot_predictions()

import requests
from pathlib import Path
# 从仓库下载文档
if Path("helper_functions.py").is_file():
    print("helper_functions.py already exists, skipping download")
else:
    print("Downloading helper_functions.py")
    request = requests.get("https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/helper_functions.py")
    with open("helper_functions.py", "wb") as f:
        f.write(request.content)

from helper_functions import plot_predictions, plot_decision_boundary

Downloading helper_functions.py

这里可能需要科学上网,我把文件helper_functions.py的代码放到文末了,可以自己创建一个.py文件粘进去。

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("Train")
plot_decision_boundary(model_0, X_train, y_train)
plt.subplot(1, 2, 2)
plt.title("Test")
plot_decision_boundary(model_0, X_test, y_test)

在这里插入图片描述
看中间这条白色的线,模型是通过这条直线来区分红色和蓝色的点,所以是50%的准确率,明显是不对的,因为我们的数据明显是圈圈。

从机器学习的方面看,我们的模型欠拟合(underfitting),即没有从数据中学习到数据的模式.

那我们如何改善呢?请听下回分解

终于把今天的学习整理出来了,BB啊,今天中午吃了个超级物美价廉的套餐,里面的土豆炖牛腩和我平常吃的不一样,它这个带汤,尊嘟很好吃,熏过的香干一定要尝尝啊,皮蛋也很八错。还喝了一杯瑞幸的美式,一般般吧,室友说苦,我就喜欢喝这种苦苦的,嘻嘻嘻。

师姐通过一个电话说我喜欢一个男孩子,就说我喜欢他,哈哈哈哈,不知道咋听出来的,乌龙可是闹大了呢,话说,我们见面次数确实不多,但听到她的声音,莫名有点想她了,晚上就是多愁善感啊,别管我!

BB啊,今天就到这吧,不敢想象明天的学习有多开心,终于到了要改善模型啦~

如果文章对您有帮助的话,记得给俺点个呐!

靴靴,谢谢~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1938577.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

秋招突击——7/19——队列同步器AQS学习

文章目录 引言基础知识简介AQS接口和示例第一类&#xff1a;访问和修改同步状态的方法第二类&#xff0c;5个重写的方法第三类&#xff0c;9个模板方法 队列同步器实现原理同步队列独占式同步获取和释放共享式同步获取和释放独占式同步获取和释放 相关面试题怎么理解Lock和AQS的…

如何解决热插拔时的电压过冲

摘要 热插拔是指将上电电压源连接到电子器件的输入电源或电池连接器。热插拔产生的电压瞬态尖峰会损坏器件内部的集成电路。该文解释了此类电压瞬变的根本原因&#xff0c;并提供了防止这些瞬变损坏电子产品中的集成电路(IC) 的可能设计。 1 引言 当将高于 5V 的 USB 适配器…

达梦数据库的系统视图v$dmwatcher

达梦数据库的系统视图v$dmwatcher 查询当前登录实例所对应的守护进程信息&#xff0c;注意一个守护进程可以同时守护多个组的实例&#xff0c;因此查询结果中部分字段&#xff08;N_GROUP、SWITCH_COUNT&#xff09;为守护进程的全局信息&#xff0c;并不是当前登录实例自身的…

BUUCTF - Web - 1

文章目录 1. [极客大挑战 2019]EasySQL 1【SQL注入-万能密码】2. [极客大挑战 2019]Havefun 1【前端代码审计-注释泄漏】3. [HCTF 2018]WarmUp 1【PHP代码审计】4. [ACTF2020 新生赛]Include 1【PHP伪协议】5. [ACTF2020 新生赛]Exec 1【命令注入-基础】6. [GXYCTF2019]Ping Pi…

张量网络碎碎念:CGC

在本系列 上一篇文章 中&#xff0c;我介绍了张量网络的一些基础概念。其中很大一部分来自 github 上一个教程。事实上&#xff0c;该教程的大部分内容来自 e3nn 官网。 除了上篇文章介绍的一些可视化技巧&#xff0c;官网还提供了其他一些可视化模块。使用这些功能能使我们更深…

windows USB 设备驱动开发-开发Type C接口的驱动程序(三)

编写 USB Type C 端口控制器驱动程序 如果 USB Type-C 硬件实现 USB Type-C 或电源传送 (PD) 物理层&#xff0c;但未实现供电所需的状态机&#xff0c;则需要编写 USB Type-C 端口控制器驱动程序。 在 Windows 10 版本 1703 中&#xff0c;USB Type-C 体系结构已得到改进&am…

云监控(华为) | 实训学习day5(10)

Gaussdb安装和连接idea GaussDB的安装 首先关闭防火墙 systemctl disable firewalld.service 永久关闭防火墙&#xff08;发生在下次启动&#xff09; systemctl stop firewalld.service 关闭本次防火墙 查看防火墙状态systemctl status firewalld.service 查询的状态是Dead表…

【算法】百钱买百鸡问题算法详解及多语言实现

问题描述 百钱买百鸡问题是一个经典的数学问题&#xff0c;题目要求用100文钱买100只鸡&#xff0c;公鸡5文钱一只&#xff0c;母鸡3文钱一只&#xff0c;小鸡3只一文钱&#xff0c;问公鸡、母鸡、小鸡各买多少只&#xff1f; 目录 问题描述​编辑 解决方案 Python实现 Ja…

选择Maya进行3D动画制作与渲染的理由

如果你对3D动画充满热情并追求成为专业3D动画师的梦想&#xff0c;你一定听说过Maya——近年来3D动画的行业标准。Maya被3D艺术家广泛使用&#xff0c;你是否想知道为什么Maya总是他们的首选&#xff1f;下面一起来了解下。 一、什么是Maya&#xff1f; 由Autodesk开发的Maya是…

wxid转微信号

7.21由于微信的再一次调整&#xff0c;能够转出微信号的接口已经和谐&#xff0c;根据客户要求琢磨了几个小时 发现新的接口也是可以批量转换的

springcolud学习06Hystrix

Hystrix Hystrix是Netflix开发的一个用于处理分布式系统中延迟和容错问题的库。它主要用于防止分布式系统中的雪崩效应,通过在服务之间添加延迟容错和故障处理机制来增强系统的弹性。 服务熔断 类似于电路中的断路器,当失败率超过阈值时,Hystrix 可以自动地开启断路器,停…

c++习题12-开关灯

目录 一&#xff0c;题目 二&#xff0c;思路 三&#xff0c;代码 一&#xff0c;题目 用例输入 1 10 10 用例输出 1 1,4,9 二&#xff0c;思路 创建可以存放路灯亮灭情况的数组&#xff0c;路灯的编号从1开始&#xff0c;因此在使用for循环去初始化数组时&#xff…

初识模板【C++】

P. S.&#xff1a;以下代码均在VS2022环境下测试&#xff0c;不代表所有编译器均可通过。 P. S.&#xff1a;测试代码均未展示头文件stdio.h的声明&#xff0c;使用时请自行添加。 博主主页&#xff1a;LiUEEEEE                        …

编写小程序用什么软件

编写小程序时&#xff0c;可以使用多种软件或工具&#xff0c;这些工具通常提供了丰富的开发功能和组件&#xff0c;方便开发者进行小程序的创建、开发和调试。以下是一些常用的编写小程序的软件和工具&#xff1a; DIY官网可视化工具 可视化拖拽开发神器|无须编程 零代码基础…

HashMap原理详解,HashMap源码解析

HashMap是一个数组链表和红黑树的结合体 HashMap的第一层表现是数组&#xff0c;HashMap默认创建一个长度为十六的数组来储存数据&#xff0c;但不同的是&#xff0c;它并非是先放在第0个索引&#xff0c;然后第一个索引那么放置&#xff0c;而是通过key获取对应的32位hash值&a…

OAuth2.0 or Spring Session or 单点登录流程

1.社交登录 2.微博社交登录 第三方登录 1.登录微博 2.点击网站接入 3.填写完信息&#xff0c;到这里&#xff0c;写入成功回调 和 失败回调 是重定向&#xff0c;所以可以写本地的地址 3.认证 分布式Session spring-session 域名不一样 发的 jSessionId 就不同&#xff0c…

uniapp,vue3上传图片组件封装

首先创建一个 components 文件在里面进行组件的创建 下面是 vip组件的封装 也就是图片上传组件 只是我的命名是随便起的 <template><!--图片 --><view class"up-page"><!--图片--><view class"show-box" v-for"(item,ind…

STM32的串口(RS485)数据收发

一、前言 我们的单片机串口一般常用RS232、RS485、TTL这几种通讯方式&#xff0c;日常调试可能RS232、TTL比较多&#xff0c;真正和其它厂家数据交互的时候&#xff0c;还是RS485用的比较多&#xff0c;因为它是差分信号等电气属性&#xff0c;所以比较稳定&#xff0c;传输距…

Matlab演示三维坐标系旋转

function showTwo3DCoordinateSystemsWithAngleDifference() clear all close all % 第一个三维坐标系 origin1 [0 0 0]; x_axis1 [1 0 0]; y_axis1 [0 1 0]; z_axis1 [0 0 1];% 绕 x 轴旋转 30 度的旋转矩阵 theta_x 30 * pi / 180; rotation_matrix_x [1 0 0; 0 cos(th…

SpringBoot使用本地缓存——Caffeine

SpringBoot使用本地缓存——Caffeine 缓存&#xff0c;想必大家都用过&#xff0c;将常用的数据存储在缓存上能在一定程度上提升数据存取的速度。这正是局部性原理的应用。之前用的缓存大多是分布式的&#xff0c;比如Redis。使用Redis作为缓存虽然是大多数系统的选择&#xf…