PyTorch深度学习模型训练流程:(一、分类)

news2024/12/25 0:59:39

自己写了个封装PyTorch深度学习训练流程的函数,实现了根据输入参数训练模型并可视化训练过程的功能,可以方便快捷地检验一个模型的效果,有助于提高选择模型架构、优化超参数等工作的效率。发出来供大家参考,如有不足之处,欢迎批评讨论。

分类是人工智能的一个非常重要的应用,这篇文章分享的函数适用于实现分类的深度学习模型,包括以下功能:

  1. 根据输入的数据集、模型、优化器、损失函数等参数训练一个分类模型;
  2. 使用visdom可视化训练过程,实时输出精确度曲线、损失曲线、混淆矩阵和ROC曲线;
  3. 支持二分类和多分类;
  4. 输入数据集支持形如(X,y)的np.ndarray类型,及形如(train_data,test_data)的torch.utils.data.Dataset类型,可以方便灵活地调用torch内置数据集或自定义数据集;
  5. 支持使用GPU加速深度学习模型的训练。

废话不多说,先来看下输出效果:

二分类
多分类

 深度学习的完整流程通常包括以下几个步骤:

  1. 收集数据
  2. 数据预处理
  3. 选择模型
  4. 训练模型
  5. 评估模型
  6. 超参数调优
  7. 测试模型

本函数封装了训练模型和评估模型的步骤,包括:

  1. 若数据集为(X,y)形式则分离训练集和测试集(测试集占20%),数据标准化,封装训练集和测试集;
  2. 将训练集和测试集设置为加载器;
  3. 遍历训练集加载器,计算每一批次的输出和损失,并反向传播更新神经网络参数;
  4. 每迭代100次评估一下模型,用测试集数据计算并画出精确度曲线、损失曲线、混淆矩阵和ROC曲线。

代码如下:

from functools import partial
import numpy as np
import pandas as pd
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, r2_score

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Dataset
from visdom import Visdom

from typing import Union, Optional
from sklearn.base import TransformerMixin
from torch.optim.optimizer import Optimizer


def classify(
        data: tuple[Union[np.ndarray, Dataset], Union[np.ndarray, Dataset]],
        model: nn.Module,
        optimizer: Optimizer,
        criterion: nn.Module,
        scaler: Optional[TransformerMixin] = None,
        batch_size: int = 64,
        epochs: int = 10,
        device: Optional[torch.device] = None
) -> nn.Module:
    """
    分类任务的训练函数。
    :param data: 形如(X,y)的np.ndarray类型,及形如(train_data,test_data)的torch.utils.data.Dataset类型
    :param model: 分类模型
    :param optimizer: 优化器
    :param criterion: 损失函数
    :param scaler: 数据标准化器
    :param batch_size: 批大小
    :param epochs: 训练轮数
    :param device: 训练设备
    :return: 训练好的分类模型
    """
    if isinstance(data[0], np.ndarray):
        X, y = data
        # 处理类别
        classes = np.unique(y)
        classes_str = [str(i) for i in classes]
        num_classes = len(classes)
        # 分离训练集和测试集,指定随机种子以便复现
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        # 数据标准化
        if scaler is not None:
            X_train = scaler.fit_transform(X_train)
            X_test = scaler.transform(X_test)
        # 转换为tensor
        X_train = torch.from_numpy(X_train.astype(np.float32))
        X_test = torch.from_numpy(X_test.astype(np.float32))
        y_train = torch.from_numpy(y_train.astype(np.int64))
        y_test = torch.from_numpy(y_test.astype(np.int64))
        # 将X和y封装成TensorDataset
        train_dataset = TensorDataset(X_train, y_train)
        test_dataset = TensorDataset(X_test, y_test)

    elif isinstance(data[0], Dataset):
        train_dataset, test_dataset = data
        classes = list(train_dataset.class_to_idx.values())
        classes_str = train_dataset.classes
        num_classes = len(classes)
    else:
        raise ValueError('Unsupported data type')

    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
    )
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
    )

    model.to(device)
    vis = Visdom()
    # 训练模型
    for epoch in range(epochs):
        for step, (batch_x_train, batch_y_train) in enumerate(train_loader):
            batch_x_train = batch_x_train.to(device)
            batch_y_train = batch_y_train.to(device)
            # 前向传播
            output = model(batch_x_train)
            loss = criterion(output, batch_y_train)
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            niter = epoch * len(train_loader) + step + 1  # 计算迭代次数
            if niter % 100 == 0:
                # 评估模型
                model.eval()
                with torch.no_grad():
                    eval_dict = {
                        'test_loss': [],
                        'test_acc': [],
                        'test_cm': [],
                        'test_roc': [],
                    }
                    for batch_x_test, batch_y_test in test_loader:
                        batch_x_test = batch_x_test.to(device)
                        batch_y_test = batch_y_test.to(device)
                        test_output = model(batch_x_test)
                        predicted = torch.argmax(test_output, 1)
                        test_predicted_tuple = (batch_y_test.numpy(), predicted.numpy())
                        # 计算并记录损失、精确度、混淆矩阵、ROC曲线
                        eval_dict['test_loss'].append(criterion(test_output, batch_y_test))
                        eval_dict['test_acc'].append(accuracy_score(*test_predicted_tuple))
                        eval_dict['test_cm'].append(confusion_matrix(*test_predicted_tuple, labels=classes))
                        if num_classes == 2:
                            # eval_dict['test_roc']形状为(len,(fpr,tpr),3)
                            eval_dict['test_roc'].append(roc_curve(*test_predicted_tuple)[:2])
                        else:
                            # 多分类ROC曲线需要one-hot编码
                            y_test_one_hot, predicted_one_hot = map(partial(label_binarize, classes=classes), test_predicted_tuple)
                            fpr_list = []
                            tpr_list = []
                            for i in range(num_classes):
                                fpr, tpr, _ = roc_curve(y_test_one_hot[:, i], predicted_one_hot[:, i])
                                # 无(fpr,tpr)数据点时,插值填充(0,0)数据点
                                if len(fpr) != 3:
                                    fpr = np.insert(fpr, 0, 0)
                                    tpr = np.insert(tpr, 0, 0)
                                fpr_list.append(fpr)
                                tpr_list.append(tpr)
                            # eval_dict['test_roc']形状为(len,(fpr,tpr),num_classes,3)
                            eval_dict['test_roc'].append((fpr_list, tpr_list))

                    # 画出损失曲线
                    vis.line(
                        X=torch.ones((1, 2)) * (niter // 100),
                        Y=torch.stack((loss, torch.mean(torch.tensor(eval_dict['test_loss'])))).unsqueeze(0),
                        win='loss',
                        update='append',
                        opts=dict(title='Loss', legend=['train_loss', 'test_loss']),
                    )
                    # 画出精确度曲线
                    train_acc = accuracy_score(batch_y_train.numpy(), torch.argmax(output, 1).numpy())
                    vis.line(
                        X=torch.ones((1, 2)) * (niter // 100),
                        Y=torch.tensor((train_acc, np.mean(eval_dict['test_acc']))).unsqueeze(0),
                        win='accuracy',
                        update='append',
                        opts=dict(title='Accuracy', legend=['train_acc', 'test_acc'], ytickmin=0, ytickmax=1),
                    )
                    # 画出混淆矩阵
                    vis.heatmap(
                        X=np.add.reduce(eval_dict['test_cm']),
                        win='confusion_matrix',
                        opts=dict(title='Confusion Matrix', columnnames=classes_str, rownames=classes_str),
                    )
                    # 画出ROC曲线
                    test_roc_arr = np.array(eval_dict['test_roc'])
                    zeros_df = pd.DataFrame({'fpr': [0], 'tpr': [0]})  # 用于填充的(0,0)数据点
                    ones_df = pd.DataFrame({'fpr': [1], 'tpr': [1]})  # 用于填充的(1,1)数据点
                    if num_classes == 2:
                        plot_arr = test_roc_arr[:, :, 1]  # 提取(fpr,tpr)数据点,形状为(len,(fpr,tpr))
                        cats = pd.qcut(plot_arr[:, 0], q=10, labels=False, duplicates='drop')  # 按fpr大小分成10个数据一样多的区间
                        groups = pd.DataFrame(plot_arr, columns=['fpr', 'tpr']).groupby(cats).mean()  # 计算每个区间的平均值,形状为(10,(fpr,tpr))
                        plot_df = pd.concat([zeros_df, groups, ones_df])  # 头添加(0,0),尾添加(1,1)数据点,形状为(12,(fpr,tpr))
                        x = plot_df['fpr']
                        Y = plot_df['tpr']
                    else:
                        plot_df_list = []
                        plot_arr = test_roc_arr[:, :, :, 1].swapaxes(1, 2)  # 提取(fpr,tpr)数据点并换轴,形状为(len,num_classes,(fpr,tpr))
                        for i in range(num_classes):
                            cats = pd.qcut(plot_arr[:, i, 0], q=10, labels=False, duplicates='drop')
                            groups = pd.DataFrame(plot_arr[:, i, :], columns=['fpr', 'tpr']).groupby(cats).mean()  # 形状为(10,(fpr,tpr))
                            plot_df = pd.concat([zeros_df, groups, ones_df])  # 形状为(12,(fpr,tpr))
                            add_num = 12 - len(plot_df)
                            # 长度不足12时,插值填充(0,0)数据点
                            if add_num > 0:
                                plot_df = pd.concat([zeros_df] * add_num + [plot_df])
                            plot_df_list.append(plot_df)  # 形状为(num_classes,12,(fpr,tpr))
                        plot_arr_sum = np.stack(plot_df_list, axis=1)  # 形状为(12,num_classes,(fpr,tpr))
                        x = plot_arr_sum[:, :, 0]
                        Y = plot_arr_sum[:, :, 1]
                    vis.line(
                        X=x,
                        Y=Y,
                        win='ROC',
                        opts=dict(title='ROC', legend=classes_str),
                    )
    return model

注意:代码运行前要先在命令行输入python -m visdom.server,在浏览器中打开提供的链接:

 成功运行的效果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2079970.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

iPhone 16要发布了,iPhone 13 咋办啊

iPhone 16要发布了,iPhone 13 咋办啊? Phone 16的屏幕尺寸和分辨率是多少? iPhone 16采用了6.1英寸的超视网膜XDR显示屏,与iPhone 15相同。屏幕分辨率达到了25561179像素,像素密度为460ppi,为用户提供了清…

一、菜单扩展

一、创建文件夹 创建一个名为Editor的文件夹。unity会默认这个名字为工程文件夹 二、创建代码 实现点击unity菜单,对应代码的方法 引用命名空间;使用这个menuitem 注:必须有一个子路径,不然会报错 这里是这个方法的参数 每一个…

并发式服务器

并发式服务器是一种设计用来同时处理多个客户端请求的服务器。这种服务器能够提高资源利用率和响应速度,适用于需要服务大量用户的网络应用。以下是并发式服务器的一些关键特点: 多任务处理:并发式服务器能够同时处理多个任务或请求&#xff…

【Python】成功解决 ModuleNotFoundError: No module named ‘lpips‘

【Python】成功解决 ModuleNotFoundError: No module named ‘lpips’ 下滑即可查看博客内容 🌈 欢迎莅临我的个人主页 👈这里是我静心耕耘深度学习领域、真诚分享知识与智慧的小天地!🎇 🎓 博主简介:98…

LDR ,DTR 指令详解 (17)

arm 的 大致的架构。 LDR{条件} 目的寄存器&#xff0c; <存储器地址> 功能&#xff1a; 将存储器的一个32的数据&#xff0c;保存到寄存器中。 但是这条指令还有另外一个作用。 如果 目的寄存器是PC的话&#xff0c;而从内存中读到的数据是一块内存的地址&#xff0…

正则表达式模块re及其应用

正则表达式是一种强大的文本处理工具&#xff0c;能够用来匹配、查找、替换复杂的文本模式。Python中的正则表达式由re模块提供。 以下是一些常用的方法及示例&#xff1a; 一. 常用方法 re.match() 从头开始匹配re.search() 搜索第一个匹配串re.findall() 查找所有匹配项re…

代码随想录Day 27|贪心算法,题目:455.分发饼干、376.摆动序列、53.最大子序和

提示&#xff1a;DDU&#xff0c;供自己复习使用。欢迎大家前来讨论~ 文章目录 贪心算法Part01一、理论基础1.1 什么是贪心贪心算法解法&#xff1a;动态规划解法&#xff1a; 1.2 贪心一般解题步骤 二、题目题目一&#xff1a;455.分发饼干解题思路&#xff1a;其他思路 题目…

【Datawhale AI夏令营第五期】 CV方向 Task02学习笔记 精读Baseline 建模方案解读与进阶

【Datawhale AI夏令营第五期】 CV方向 Task02学习笔记 精读Baseline 建模方案解读与进阶 教程&#xff1a; 链接&#xff1a; https://linklearner.com/activity/16/16/68 传送门 之前我看原画课的时候&#xff0c;造型的部分就跟我们说&#xff0c;让我们日常观察事物的时候…

海运系统:海运拼箱 小批量货物的海运奥秘

在国际海运运输的广阔领域中&#xff0c;海运拼箱作为一种灵活且经济的运输方式&#xff0c;尤其适用于那些货物量不大或体积不足以单独填满一个标准集装箱的场景。这种运输模式不仅促进了国际贸易的便捷性&#xff0c;还通过资源共享的方式&#xff0c;有效降低了物流成本&…

p10 容器的基本命令

首先先拉取一个centos的镜像 命令&#xff1a;docker pull centos 新建容器并且启动 这里直接参考老师的命令吧 接下来是启动并且进入到容器当中去输入docker run -it centos /bin/bash这里是以交互的方式进入到容器中可以看到接下来的ls命令输出的东西就是一个Linux系统最…

Ansys Speos | 挡风玻璃光学畸变分析

附件下载 联系工作人员获取附件 此示例介绍了基于 TL 957 标准和43号法规&#xff08;ECE R43&#xff09;的挡风玻璃光学畸变分析的工作流程&#xff0c;以及 GitHub Ansys 光学自动化中提供的分析自动化工具。 如果您从未使用过任何 GitHub 仓库&#xff0c;可以根据光学自…

数据结构(邓俊辉)学习笔记】串 07——KMP算法:分摊分析

文章目录 1.失之粗糙2.精准估计 1.失之粗糙 以下&#xff0c;就来对 KMP 算法的性能做一分析。我们知道 KMP 算法的计算过程可以根据对齐位置相应的分为若干个阶段&#xff0c;然而每一个阶段所对应的计算量是有很大区别的。很快就会看到&#xff0c;如果只是简单地从最坏的角…

K8S的持久化存储

文章目录 一、持久化存储emptyDir实际操作 hostPath建立过程 NFS存储NFS 存储的优点NFS 存储的缺点具体操作 pv和pvcPersistent Volume (PV)使用场景 Persistent Volume Claim (PVC)使用场景 使用 PV 和 PVC 的场景实际操作 StorageClassStorageClass 概述应用场景实际应用 一、…

实用攻略:亲身试用,高效数据恢复软件推荐!

今天要跟大家分享一下我使用几款数据恢复软件的经历。如果你曾经丢失过重要的文件&#xff0c;那除了注意备份外&#xff0c;也可以尝试一下这些非常棒的免费数据恢复软件&#xff01; 第一款&#xff1a;福昕数据恢复 链接&#xff1a;www.pdf365.cn/foxit-restore/ 首先聊…

Nginx+ModSecurity(3.0.x)安装教程及配置WAF规则文件

本文主要介绍ModSecurity v3.0.x在Nginx环境下的安装、WAF规则文件配置、以及防御效果的验证&#xff0c;因此对于Nginx仅进行简单化安装。 服务器操作系统&#xff1a;linux 位最小化安装 一、安装相关依赖工具 Bash yum install -y git wget epel-release yum install -y g…

大模型企业应用落地系列二》基于大模型的对话式推荐系统》核心技术架构设计图

注&#xff1a;此文章内容均节选自充电了么创始人&#xff0c;CEO兼CTO陈敬雷老师的新书《自然语言处理原理与实战》&#xff08;人工智能科学与技术丛书&#xff09;【陈敬雷编著】【清华大学出版社】 文章目录 大模型企业应用落地系列二基于大模型的对话式推荐系统》心技术架…

【精品】计算机毕业设计之:springboot游戏分享网站(源码+文档+辅导)

博主介绍&#xff1a; ✌我是阿龙&#xff0c;一名专注于Java技术领域的程序员&#xff0c;全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师&#xff0c;我在计算机毕业设计开发方面积累了丰富的经验。同时&#xff0c;我也是掘金、华为云、阿里云、InfoQ等平台…

win11,vscode上用docker环境跑项目

1.首先用dockerfile创建docker镜像 以下是dockerfile文件的内容&#xff1a; FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-devel LABEL Service"SparseInstanceActivation"ENV TZEurope/Moscow ENV DETECTRON_TAGv0.6 ARG DEBIAN_FRONTENDnoninteractiveRUN apt-…

JavaScript:js;知识回顾;笔记分享

一&#xff0c;js前奏 1&#xff0c;js简介&#xff1a; Javascript是一种由Netscape(网景)的LiveScript发展而来的原型化继承的面向对象的动态类型的区分大小写的客户端脚本语言&#xff0c;主要目的是为了解决服务器端语言&#xff0c;比如Perl&#xff0c;遗留的速度问题&a…

数据结构与算法(1)

抽象数据类型定义 算法的效率 时间效率 一个算法的运行时间是指一个算法在计算机上运行所耗费的时间 大致可以等于计算机执行一种简单的操作(如赋值、比较、移动等) 所需的时间与算法中进行的简单操作次数乘积。 比较时间复杂度&#xff08;看数量级&#xff09; 空…