【pytorch 入门系列】02 手把手多分类从0到1

news2024/9/23 19:26:30

温故而知新,通过手把手写一个多分类任务来复习之前所学过的知识。

  1. 前置知识
factorize的妙用:把文本数据枚举化
labels, uniques = pd.factorize(['b', 'b', 'a', 'c', 'b'])
labels,uniques

(array([0, 0, 1, 2, 0]), array([‘b’, ‘a’, ‘c’], dtype=object))

  1. 数据集读取以及处理
    鸢尾花数据集相比大家都已经很熟悉了。
data = pd.read_csv("dataset/iris.csv")
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 150 entries, 0 to 149
Data columns (total 6 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0   Unnamed: 0    150 non-null    int64  
 1   Sepal.Length  150 non-null    float64
 2   Sepal.Width   150 non-null    float64
 3   Petal.Length  150 non-null    float64
 4   Petal.Width   150 non-null    float64
 5   Species       150 non-null    object 
dtypes: float64(4), int64(1), object(1)
memory usage: 7.2+ KB

unnamed列是序号列,没用
species时分类列,
一共150条数据,数据是初探
在这里插入图片描述
有三类鸢尾花,他们是类别,
但是因为torch只能处理数字,文本需要转换成数字类型 1前置就用到了

data.Species.unique()
array(['setosa', 'versicolor', 'virginica'], dtype=object)
labels,uniques = pd.factorize(data.Species.values)
data['Species'] = labels
data

在这里插入图片描述
Unnamed 列是用不到的,是序号列去掉,这样以来前几列是训练集,最后一列是标签 .values 返回的是numpy数据

# Unnamed 列是用不到的,是序号列去掉,这样以来前几列是训练集,最后一列是标签
X = data.iloc[:,1:-1].values
Y = data.iloc[:,-1].values
  1. test,train 数据划分

借助sklearn 完成划分,并转成torch格式, y必须为torch.int64 或者 torch.Long 类型,否则训练过程报错。因为只能计算Long类型的

train_x, test_x, train_y,test_y = train_test_split(X,Y)
# 切分数据集后,转成torch格式
train_x = torch.from_numpy(train_x).type(torch.float32)
train_y = torch.from_numpy(train_y).type(torch.int64)
test_x = torch.from_numpy(test_x).type(torch.float32)
test_y = torch.from_numpy(test_y).type(torch.int64)

转成 dataset 和 dataloader,这样转的原因我已经在模板那篇文章写清楚了,核心:1. train 数据集需要 shuffle 2.自动实现切片功能

batch_size=8
train_ds = TensorDataset(train_x, train_y)
train_dl = DataLoader(train_ds,batch_size=batch_size,shuffle=True)
test_ds = TensorDataset(test_x, test_y)
test_dl = DataLoader(test_ds,batch_size=batch_size)
  1. 设计网络和损失函数
    不需要多解释,但为啥不在这边就用了 self.softmax = nn.Softmax(3)呢
    是因为在损失函数中已经包含了这一部分,torch是这样的,tensorflow应该不是。
import torch.nn.functional as F
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(4,32)
        self.lin2 = nn.Linear(32,32)
        self.lin3 = nn.Linear(32,3)
        # self.softmax = nn.Softmax(3)
    
    def forward(self,x):
        x = self.lin1(x)
        x = F.relu(x)
        x = self.lin2(x)
        x = F.relu(x)
        x = self.lin3(x)
        return x
model = Model()
model

看一下它的结构:
在这里插入图片描述
损失函数:多分类当然是交叉熵损失了

# 定义损失函数,多分类当让是交叉熵损失了
loss_fn = nn.CrossEntropyLoss()

简单测试一下,这一步很重要!

input_batch, label_batch = next(iter(train_dl))
y_pred = model(input_batch)
torch.argmax(y_pred,dim=1)

在这里插入图片描述
5. 计算正确率&目标函数

	# 定义目标函数
   def accuracy(y_pred, y_true):
        y_pred = torch.argmax(y_pred,dim=1)
        acc = (y_pred ==y_true).float().mean()
        return acc
	optim = torch.optim.Adam(model.parameters(),lr=0.0001)

在这里插入图片描述
6. 训练

# 万事俱备,只差训练
train_loss =[]
train_acc = []
test_loss = []
test_acc= []
epochs = 200
for epoch in range(epochs):
    for x,y in train_dl:
        y_pred = model(x)
        loss = loss_fn(y_pred,y)
        optim.zero_grad()
        loss.backward()
        optim.step()
    with torch.no_grad():
        epoch_acc_train = accuracy(model(train_x),train_y)
        epoch_loss_train = loss_fn(model(train_x), train_y).data
        epoch_acc_test = accuracy(model(test_x),test_y)
        epoch_loss_test = loss_fn(model(test_x), test_y).data
        print('epoch: ', epoch, 'loss: ', round(epoch_loss_train.item(), 3),
                                'accuracy:', round(epoch_acc_train.item(), 3),
                                'test_loss: ', round(epoch_loss_test.item(), 3),
                                'test_accuracy:', round(epoch_acc_test.item(), 3)
             )
        
        train_loss.append(epoch_loss_train)
        train_acc.append(epoch_acc_train)
        test_loss.append(epoch_loss_test)
        test_acc.append(epoch_acc_test)

损失情况
在这里插入图片描述
7. 图的方式展示
···
import matplotlib.pyplot as plt
plt.plot(range(1, epochs+1), train_loss, label=‘train_loss’)
plt.plot(range(1, epochs+1), test_loss, label=‘est_loss’)
plt.legend()
···
在这里插入图片描述
···
plt.plot(range(1, epochs+1), train_acc, label=‘train_acc’)
plt.plot(range(1, epochs+1), test_acc, label=‘test_acc’)
plt.legend()
···
在这里插入图片描述
8. 需要整一个训练的核心函数 fit函数

模板代码

  1. 创建输入(dataloader)
  2. 创建模型(model)
  3. 创建损失函数
def fit(epoch, model, trainloader, testloader):
    correct = 0
    total = 0
    running_loss = 0
    for x, y in trainloader:
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optim.zero_grad()
        loss.backward()
        optim.step()
        with torch.no_grad():
            y_pred = torch.argmax(y_pred, dim=1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
            running_loss += loss.item()
        
    epoch_loss = running_loss / len(trainloader.dataset)
    epoch_acc = correct / total
        
        
    test_correct = 0
    test_total = 0
    test_running_loss = 0 
    
    with torch.no_grad():
        for x, y in testloader:
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            y_pred = torch.argmax(y_pred, dim=1)
            test_correct += (y_pred == y).sum().item()
            test_total += y.size(0)
            test_running_loss += loss.item()
    
    epoch_test_loss = test_running_loss / len(testloader.dataset)
    epoch_test_acc = test_correct / test_total
    
        
    print('epoch: ', epoch, 
          'loss: ', round(epoch_loss, 3),
          'accuracy:', round(epoch_acc, 3),
          'test_loss: ', round(epoch_test_loss, 3),
          'test_accuracy:', round(epoch_test_acc, 3)
             )
        
    return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc
model = Model()
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
epochs = 20
train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,
                                                                 model,
                                                                 train_dl,
                                                                 test_dl)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)

在这里插入图片描述
在这里插入图片描述
至此fit就可以对付所有的多分类问题了,您只需要修改model的网络结构即可

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/383028.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C++】-- 特殊类设计

对于类的思维境界提升&#xff0c;没有太大的实际意义&#xff0c;但是锻炼思想。 目录 单例模式 饿汉模式 懒汉模式 #&#xff1a;请设计一个类&#xff0c;不能被拷贝。 拷贝只会发生在两个场景中&#xff1a;拷贝构造函数赋值运算符重载因此想要让一个类禁止拷贝&#xf…

对称锥规划:对称锥的增广拉格朗日乘子法(Semi-Smooth Newton Method解无约束优化子问题)

文章目录对称锥规划&#xff1a;对称锥的增广拉格朗日乘子法&#xff08;Semi-Smooth Newton Method解无约束优化子问题&#xff09;对称锥的增广拉格朗日函数Semi-Smooth Newton Method半光滑牛顿法广义雅可比半光滑性半光滑牛顿算法参考文献对称锥规划&#xff1a;对称锥的增…

2023年最新阿里云服务器价格表出炉(精准收费标准及配置价格表)

阿里云在全球率先宣布了基于 Intel Ice Lake 处理器的第七代云服务器ECS&#xff0c;性能提升的同时降低了报价&#xff0c;性价比更高了。进入2023年阿里云服务器价格依然是大家关心的问题&#xff0c;事实上阿里云服务器租用价格和最新收费标准都可以通过官方云服务器计算器来…

【IoT】智能烟雾报警器

设计简介 硬件设计由AT89C51单片机、DS18B20温度传感器、4位共阳数码管、电源模块、报警模块、按键模块、MQ-2烟雾检测模块和ADC0832模数转换模块组成。 烟雾传感器MQ-2检测空气中的烟雾气体&#xff0c;通过ADC0832进行数据转换&#xff0c;经过单片机的运算处理后在数码管上…

【WEB前端进阶之路】 HTML 全路线学习知识点梳理(上)

前言 HTML 是一切Web开发的基础&#xff0c;本文专门为小白整理&#xff0c;针对前端零基础的朋友们&#xff0c;手把手教你学习HTML&#xff0c;让你轻松迈入WEB开发的行列。 首先&#xff0c;感谢 橙子_ 在HTML学习以及本文编写过程中对我的帮助。 文章目录前言一.HTML简介1.…

Java使用不同方式获取两个集合List的交集、补集、并集(相加)、差集(相减)

1 明确概念首先知道几个单词的意思&#xff1a;并集 union交集 intersection补集 complement析取 disjunction减去 subtract1.1 并集对于两个给定集合A、B&#xff0c;由两个集合所有元素构成的集合&#xff0c;叫做A和B的并集。记作&#xff1a;AUB 读作“A并B”例&#…

微纳制造技术——基础知识

1.什么是直接带隙半导体和间接带隙半导体 导带底和价带顶处以同一K值&#xff0c;称为直接带隙半导体 导带底和价带顶不处在同一K值&#xff0c;称为间接带隙半导体 2.扩散和漂移的公式 3.三五族半导体的性质 1.high mobility 2.wide bandgap 3.direct bandgap 4.三五族…

SWM181 串口功能使用介绍

SWM181 串口功能使用介绍&#x1f4cc;SDK固件包&#xff1a;https://www.synwit.cn/kuhanshu_amp_licheng/✨注意新手谨慎选择作为入门单片机学习。&#x1f33c;开发板如下图&#xff1a; &#x1f4cb;SWM181描述上写了有4个串口&#xff0c;在数据手册上&#xff0c;将引脚…

教你如何搭建设备-巡检管理系统,demo可分享

1、简介1.1、案例简介本文将介绍&#xff0c;如何搭建设备-巡检管理。1.2、应用场景设备管理员进行制定设备巡检时间/内容计划、记录设备巡检信息、可以查看今日待巡检设备。2、设置方法2.1、表单搭建1&#xff09;新建表单【设备档案-履历表】&#xff0c;字段设置如下&#x…

activiti7执行流程详解

什么是工作流&#xff1f; 官方定义&#xff1a;工作流是将一组任务组织起来以完成某个经营过程&#xff1a;定义了任务的触发顺序和触发条件&#xff0c;每个任务可以由一个或多个软件系统完成&#xff0c;也可以由一个或一组人完成&#xff0c;还可以由一个或多个人与软件系统…

阿赵的MaxScript学习笔记分享八《文件操作》

大家好&#xff0c;我是阿赵。继续分享MaxScript学习笔记第八篇 。这一篇主要讲文件操作&#xff0c;包括文件的I/O和导入导出。 1、获得3DsMax指定的一些目录路径 如果在电脑上安装了3DsMax软件&#xff0c;那么在文档里面会有一个3dsMax的文件夹&#xff0c;里面有一些3dsMa…

《C++ Primer Plus》(第6版)第8章编程练习

《C Primer Plus》&#xff08;第6版&#xff09;第8章编程练习《C Primer Plus》&#xff08;第6版&#xff09;第8章编程练习1. 打印字符串2. CandyBar3. 将string对象的内容转换为大写4. 设置并打印字符串5. max5()6. maxn()7. SumArray()《C Primer Plus》&#xff08;第6版…

【C++】C++11 异常

目录 1. C语言传统的处理错误的方式 2. C异常概念 3. 异常的使用 3.1. 异常的抛出和捕获 3.2. 在函数调用链中异常栈展开匹配原则 3.3. 异常的重新抛出 3.4. 异常安全 3.5. 异常规范 4.自定义异常体系 5. C标准库的异常体系 6. 异常的优缺点 6.1. C异常的优点&…

Spark性能优化二 Shuffle机制分析

&#xff08;一&#xff09; 什么情况下发生shuffle 在MapReduce框架中&#xff0c;Shuffle是连接Map和Reduce之间的桥梁&#xff0c;Map阶段通过shuffle读取数据并输出到对应的Reduce&#xff1b;而Reduce阶段负责从Map端拉取数据并进行计算。在整个shuffle过程中&#xff0c…

Linux 学习整理(使用 iftop 查看网络带宽使用情况 《端口显示》)

一、命令简介 iftop 是实时流量监控工具&#xff0c;可以用来监控网卡的实时流量&#xff08;可以指定网段&#xff09;、反向解析IP、显示端口信息等。 二、命令安装 yum install -y iftop 三、命令相关参数及说明 3.1、相关参数说明 -i&#xff1a;设定监测的网卡&#…

python未来应用前景怎么样

Python近段时间一直涨势迅猛&#xff0c;在各大编程排行榜中崭露头角&#xff0c;得益于它多功能性和简单易上手的特性&#xff0c;让它可以在很多不同的工作中发挥重大作用。 正因如此&#xff0c;目前几乎所有大中型互联网企业都在使用 Python 完成各种各样的工作&#xff0…

CAD中如何将图形对象转换为三维实体?

有些小伙伴在CAD绘制完图纸后&#xff0c;想要将图纸中的某些图形对象转换成三维实体&#xff0c;但却不知道该如何操作&#xff0c;其实很简单&#xff0c;本节CAD绘图教程就和小编一起来了解一下浩辰CAD软件中将符合条件的对象转换为三维实体的相关操作步骤吧&#xff01; 将…

HID协议详解 - Report Descriptor报告描述符构建与解析

USB相关基础知识简述 报告描述符是HID协议里比较复杂的一部分&#xff0c;在理解报告描述符之前&#xff0c;可以对USB协议数据传输的一些基础知识做一些了解&#xff0c;更方便理解后续内容。 报告是USB协议里数据传输&#xff08;Data Transfer&#xff09;的一种&#xff…

Android正确使用资源res文件

观看此文注意首先有的UI改颜色&#xff0c;没用&#xff0c;发现无法更改按钮背景颜色。我的AS下载的是最新版本&#xff0c;Button按钮的背景颜色一直都是亮紫色&#xff0c;无法更改。为什么呢&#xff1f;首先在你的清单文件中看你应用的是哪个主题。我现在用的是这个可能你…

PYthon组合数据类型的简单使用

Python的数据类型有两种&#xff0c;基本数据类型和组合数据类型&#xff0c;组合数据类型在Python的使用中特别重要。 1.组合数据类型的分类&#xff1a; 2.序列类型 序列类型中元素存在顺序关系&#xff0c;可以存在数值相同但位置不同的元素。序列类型支持成员关系操作符&…