(深度学习快速入门)第三章第二节:通过一个二分类任务介绍完整的深度学习项目

news2024/11/15 19:40:33

文章目录

  • 一:数据集介绍
  • 二:一个完整的深度学习项目必备文件
  • 三:项目代码
    • (1)config.py——超参数文件
    • (2)preprocess——数据预处理文件
    • (3)dataloader——数据集封装
    • (4)model——网络模型
    • (5)trainer——训练脚本
    • (6)inference——推理脚本

之前的波士顿房价预测案例非常简单,所以我们使用了几十行代码便完成了需求。但是在实际项目编写中可没有这么简单,如果把所有的内容都写入到一个文件中,那势必会让人十分混乱,因此本节通过一个十分简单的二分类任务来介绍完整的深度学习项目,本节中所涉及的损失函数、优化器、标准化等深度学习“组件”将会在下一节一一介绍,如果你现在对其中有些组件感觉不熟悉,那么可以线性略过,只关注整体逻辑即可

一:数据集介绍

anknote Dataset(chao票数据集):这是从zhi币鉴别过程中的图像里提取的数据,用来预测chao票的真假的数据集 。本数据集所给数据并非是原始图像数据,而是经过小波变化后的等价数据

  • uci链接:banknote authentication Data Set

在这里插入图片描述

3.6216,8.6661,-2.8073,-0.44699,0
4.5459,8.1674,-2.4586,-1.4621,0
3.866,-2.6383,1.9242,0.10645,0
3.4566,9.5228,-4.0112,-3.5944,0
0.32924,-4.4552,4.5718,-0.9888,0
4.3684,9.6718,-3.9606,-3.1625,0
3.5912,3.0129,0.72888,0.56421,0
2.0922,-6.81,8.4636,-0.60216,0
3.2032,5.7588,-0.75345,-0.61251,0
1.5356,9.1772,-2.2718,-0.73535,0
...

该数据集中含有1372个样本,每个样本由5个数值型变量构成,4个输入变量和1个输出变量这是一个二元分类问题

  • 第一列:图像经小波变换后的方差(variance)(连续值),用于描述分布的离散的程度

  • 第二列:图像经小波变换后的偏度(skewness)(连续值),用于描述分布的偏移中心的程度

  • 第三列:图像经小波变换后的峰度(curtosis)(连续值),用于描述概率密度分布曲线在平均值处峰值高低的特征数

  • 第四列:图像的(entropy)(连续值)

  • 第五列:chao票所属的类别(整数,0或1)

二:一个完整的深度学习项目必备文件

一个完整的项目由以下文件构成(假设项目文件夹叫做"test"),具体实施按需选择

  • README.md:项目说明
  • config.py(必有):配置文件(模型配置、数据集配置、参数配置等等)
  • data:数据集文件夹
  • dataset_loader.py:数据集的dataloader脚本,pytorch中专门用于管理数据的加载等操作
  • inference:推理脚本,模型训练完毕后运行进行测试,交给测试团队调用
  • log(必有):存放日志(Pytorch中是TensorBoardX)
  • loss.py:损失函数的设计
  • model.py(必有):模型的设计
  • model_save(必有):模型检查点保存(模型训练时间一般会很长,所以要进场保存,以防意外事件发生还需要重复训练)
  • preprocess.py(必有):数据集预处理,例如数据集的划分工作
  • tranier.py(必有):训练
  • utils.py:工具

三:项目代码

整体结构如下

在这里插入图片描述

(1)config.py——超参数文件

该文件下需要配置项目所使用到的超参数,例如devicedata_path等等,这些参数需要预先设定。主要目的是为了方便统一修改

class Parameters:
    ########################################## 数据 #############################################
    device = 'cpu'  # 指定设备(如果有GPU则为'cuda')
    data_dir = r'./data/'  # 所有数据所在文件夹
    data_path = r'./data/data_banknote_authentication.txt'  # 源数据路径
    trainset_path = r'./data/train.txt'  # 训练数据集路径
    valset_path = r'./data/val.txt'  # 验证集路径
    testset_path = r'./data/test.txt'  # 测试集路径



    in_features = 4  # 输入数据的特征数
    out_dim = 2  # 输出结果,由于是二分类问题,所以输出为2

    seed = 1234  # 随机种子


    ########################################## 网络结构 #############################################
    layer_list = [in_features, 64, 128, 64, out_dim]  # 层次顺序:输入层:三个隐藏层:输出层

    ########################################## 环境 #############################################
    batch_size = 64  # batch_size大小
    init_lr = 1e-3  # 初始学习率
    epochs = 100  # 训练轮数
    verbose_step = 10  # 每10步打印一次
    save_step = 200  # 每200步保存一次


parameters = Parameters()

(2)preprocess——数据预处理文件

该文件下需要对数据进行预处理,比如在这个例子中我把原始数据集划分为三个部分(训练、验证和测试)

import numpy as np
from config import parameters
import os

# 训练集、验证集和测试集划分比例
trainset_ratio = 0.7
valset_ratio = 0.2
test_set = 0.1

# 设置随机种子,读取源数据并打乱
np.random.seed(parameters.seed)
dataset = np.loadtxt(parameters.data_path, delimiter=',')
np.random.shuffle(dataset)

# 样本数量
n_items = np.shape(dataset)[0]

# 划分数据

trainset = dataset[:int(trainset_ratio*n_items), ]
valset = dataset[int(trainset_ratio*n_items):int(trainset_ratio*n_items)+int(valset_ratio*n_items), :]
testset = dataset[int(trainset_ratio*n_items)+int(valset_ratio*n_items):, ]


# 存储
np.savetxt(os.path.join(parameters.data_dir, 'train.txt'), trainset, delimiter=',')
np.savetxt(os.path.join(parameters.data_dir, 'val.txt'), valset, delimiter=',')
np.savetxt(os.path.join(parameters.data_dir, 'test.txt'), testset, delimiter=',')

(3)dataloader——数据集封装

Pytorch中使用DataLoader加载数据集时,需要统一继承torch.utils.data.Dataset类,所以这里书写如下

import torch
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from config import parameters
import numpy as np

class BankfakeDataset(torch.utils.data.Dataset):
    def __init__(self, data_path):
        self.dataset = np.loadtxt(data_path, delimiter=',')

    #  返回其X和y(输入数据和标签)
    #  至少要重写__getitem__和__len__方法
    def __getitem__(self, idx):
        item = self.dataset[idx]
        X, y = item[:parameters.in_features], item[parameters.in_features:]
        """
            to(parameters.device):会把数据送到CPU或GPU
            squeeze():会把维度为1的那个维度去掉
        """
        return torch.Tensor(X).float().to(parameters.device), \
               torch.Tensor(y).squeeze().long().to(parameters.device)
    def __len__(self):
        return np.shape(self.dataset)[0]

(4)model——网络模型

该文件用于存放你的模型设计,具体问题有具体的模型结构。对于本案例,是一个简单的二分类问题,所以搭建几个全连接层就ok了

import torch
from config import parameters
from torch import nn


class BankfakeModel(nn.Module):
    def __init__(self, ):
        super(BankfakeModel, self).__init__()

        self.linear_layer = nn.ModuleList([
            nn.Linear(in_features=in_dim, out_features=out_dim)
            for in_dim, out_dim in zip(parameters.layer_list[:-1], parameters.layer_list[1:])
        ])

    def forward(self, input_x):
        for layer in self.linear_layer:
            input_x = layer(input_x)
            input_x = nn.functional.relu(input_x)

        return input_x


# 测试
if __name__ == '__main__':
    model = BankfakeModel()
    X = torch.randn(size=(16, parameters.in_features)).to(parameters.device)
    y_pred = model(X)
    print(y_pred)
    print(y_pred.size())


(5)trainer——训练脚本

该文件用于对模型进行训练和验证,对于大多数问题来说,该文件的写法比较固定,有的也是些许的改动

import os
import torch
import random
import numpy as np
from tensorboardX import SummaryWriter
from argparse import ArgumentParser
from torch.utils.data import DataLoader


from config import parameters
from dataset_loader import BankfakeDataset
from model import BankfakeModel

# 日志记录
logger = SummaryWriter('./log')

# 随机种子
torch.manual_seed(parameters.seed)  # CPU随机种子
#  torch.cuda.manual_seed(parameters.seed)  # GPU随机种子(若有)
random.seed(parameters.seed)  # random随机种子
np.random.seed(parameters.seed)  # numpy随机种子

# 使用验证集对模型进行评估
def evaluate(model, val_loader, loss_func):
    #  进入eval模式
    model.eval()
    sum_loss = 0.
    #  with torch.no_grad()含义:https://blog.csdn.net/qq_42251157/article/details/124101436
    with torch.no_grad():
        for batch in val_loader:
            X, y = batch
            pred = model(X)
            loss = loss_func(pred, y)
            sum_loss += loss.item()
    #  特别注意返回train模式
    model.train()
    return sum_loss / len(val_loader)

# 保存模型
def save_checkpoint(model, epoch, optimizer, checkpoint_path):
    save_dict = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }

    torch.save(save_dict, checkpoint_path)



# 训练函数
def train():
    # 有关argparse.ArgumentParser用法:https://blog.csdn.net/u011913417/article/details/109047850
    # 其作用是解析命令行参数,目的是在终端窗口(ubuntu是终端窗口,windows是命令行窗口)输入训练的参数和选项
    parser = ArgumentParser(description='Model Training')
    parser.add_argument(
        '--c',
        # 当模型再次训练时选择从头开始还是从上次停止的地方开始
        default=None,  # 当参数未在命令行中出现时使用的值
        type=str,  # 参数类型
        help='from head or last checkpoint?'  # 参数说明
    )
    args = parser.parse_args()

    #  模型实例
    model = BankfakeModel()
    model = model.to(parameters.device)

    # 损失函数(这里比较简单所以直接定义,否则需要新建文件loss.py存放)
    loss_func = torch.nn.CrossEntropyLoss()  # 交叉熵损失函数

    # 优化器
    optimizer = torch.optim.Adam(model.parameters(), parameters.init_lr)

    # 训练数据加载
    trainset = BankfakeDataset(parameters.trainset_path)
    train_loader = DataLoader(trainset, batch_size=parameters.batch_size, shuffle=True, drop_last=True)
    # 验证数据加载(在evaluation函数中进行评估)
    valset = BankfakeDataset(parameters.valset_path)
    val_loader = DataLoader(valset, batch_size=parameters.batch_size, shuffle=True, drop_last=False)

    # 起始训练轮数, 步数
    start_epoch, step = 0, 0

    # 判断参数,是否需要从检查点开始训练
    # 主要针对大型数据,可能会训练几个小时或几天,所以容易出现问题
    if args.c:
        checkpoint = torch.load(args.c)  # 加载模型
        #  加载参数(权重系数、偏置值、梯度等等)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print("参数加载成功")

    else:
        print("从头开始训练")

    # 关于model.train()说明:https://blog.csdn.net/weixin_44211968/article/details/123774649
    model.train()  # 启用 batch normalization 和 dropout

    # 训练过程
    for epoch in range(start_epoch, parameters.epochs):
        print("-----------当前epoch:{}-----------".format(epoch))
        for i, batch in enumerate(train_loader):
            print("-----------当前batch:{}/{}-----------".format(i, len(trainset)//(parameters.batch_size)))
            X, y = batch
            pred = model(X)
            loss = loss_func(pred, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            logger.add_scalar('loss/train', loss, step)

            # 每10步进行验证集评估并保存
            if not step % parameters.verbose_step:
                eval_loss = evaluate(model, val_loader, loss_func)
                logger.add_scalar('loss/val', eval_loss, step)
            if not step % parameters.save_step:
                model_path = "epoch-{}_step-{}.pth".format(epoch, step)
                save_checkpoint(model, epoch, optimizer, os.path.join('movel_save', model_path))

            step += 1
            logger.flush()
            print("当前step:{};当前train_loss:{:.5f};当前val_loss:{:.5f}".format(step, loss.item(), eval_loss))
    logger.close()

if __name__ == '__main__':
    train()


(6)inference——推理脚本

模型训练完毕,然后选择出最佳模型后,就可以在该文件中加载模型对测试集上的数据做出预测

import torch
from torch.utils.data import DataLoader
from dataset_loader import BankfakeDataset
from model import BankfakeModel
from config import parameters

#  网络实例
model = BankfakeModel()
#  加载模型:观察tensorboard可知,迭代600次时模型收敛
checkpoint = torch.load('./movel_save/epoch-40_step-600.pth')
#  加载模型参数
model.load_state_dict(checkpoint['model_state_dict'])

#  加载测试数据
testset = BankfakeDataset(parameters.testset_path)
test_loader = DataLoader(testset, batch_size=parameters.batch_size, shuffle=True, drop_last=False)

#  预测时,进入eval模式
model.eval()

#  分别表示总的数据个数和预测正确的个数
total_num = 0
correct_num = 0

with torch.no_grad():
    for batch in test_loader:
        X, y = batch
        pred = model(X)
        total_num += pred.size(0)
        """
            pred是一个有2个元素的列表,分别表示当前纸币真假的概率,所以
            我们只需选择最大概率即可,这里选择索引后正好就和真实标签中的
            0/1对应了
        """
        correct_num += (torch.argmax(pred, 1) == y).sum()

print("测试数据{}个,正确预测{}个,预测准确率:{}%".format(total_num, correct_num, (correct_num / total_num) * 100))


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/138447.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

后端人眼中的Vue(一)

一、简介 1.1、Vue简介 ​ Vue是渐进式 JavaScript 框架,啥叫渐进式?渐进式意味着你可以将Vue作为你应用的一部分嵌入其中,或者如果你希望将更多的业务逻辑使用Vue实现,那么Vue的核心库以及其生态系统。比如CoreVue-routerVuexax…

Homekit智能家居DIY之智能灯泡

一、什么是智能灯 传统的灯泡是通过手动打开和关闭开关来工作。有时,它们可以通过声控、触控、红外等方式进行控制,或者带有调光开关,让用户调暗或调亮灯光。 智能灯泡内置有芯片和通信模块,可与手机、家庭智能助手、或其他智能…

RabbitMQ、Kafka、RocketMQ消息中间件对比总结

文章目录前言侧重点架构模型消息通讯其他对比总结参考文档前言 不论Kafka还是RabbitMQ和RocketMQ,作为消息中间件,其作用为应用解耦、异步通讯、流量削峰填谷等。 拿我之前参加的一个电商项目来说,订单消息通过MQ从订单系统到支付系统、库存…

ORB-SLAM2 --- KeyFrame::UpdateConnections 函数

目录 一、函数作用 二、函数流程 三、code 四、函数解析 一、函数作用 更新关键帧之间的连接图。 更新变量 mConnectedKeyFrameWeights:当前关键帧的共视信息,记录当前关键帧共视关键帧的信息(哪一帧和当前关键帧有共视,共视…

用C++实现十大经典排序算法

作者:billy 版权声明:著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处 简介 排序算法可以分为内部排序和外部排序,内部排序是数据记录在内存中进行排序,而外部排序是因排序的数据很大…

喜报|知道创宇连续两年获评北京市企业创新信用领跑企业!

近日,2022年度北京市企业创新信用领跑名单正式发布。知道创宇凭借过硬的技术实力、创新能力及良好的企业信用记录成功入选2022年度北京市企业创新信用领跑企业。值得一提的是,这是知道创宇继2021年以来,连续两年获得此项殊荣。连续两年蝉联双…

CPU是如何执行程序的?

CPU是如何执行程序的?1、硬件结构介绍1.1、CPU1.2、内存1.3、总线1.4、输入/输出设备2、程序执行的基本过程3、a11执行的详细过程现代计算机的基本结构为五个部分:CPU、内存、总线、输入/输出设备。或许你了解了这些概念,但是你知道a11在计算…

【Kubernetes | Pod 系列】Pod 的镜像下载策略和 Pod 的生命周期 Ⅰ—— 理论

目录4. 镜像下载策略5. Pod 的生命周期5.1 Pod 生命期与特性说明5.2 Pod Phase 阶段说明备注5.3 容器状态说明(1)Waiting (等待)(2)Running(运行中)(3)Termin…

【回答问题】ChatGPT上线了!给我推荐20个比较流行的nlp预训练模型

目录给我推荐20个比较流行的nlp预训练模型给我推荐20个比较流行的nlp预训练模型源码给我推荐20个比较流行的nlp预训练模型 BERT (谷歌) GPT-2 (OpenAI) RoBERTa (Facebook) ALBERT (谷歌) ELECTRA (谷歌) XLNet (谷歌/纽约大学) T5 (OpenAI) Transformer-XL (谷歌/香港中文大学…

Qt音视频开发09-ffmpeg内核音视频同步

一、前言 用ffmpeg来做音视频同步,个人认为这个是ffmpeg基础处理中最难的一个,无数人就卡在这里,怎么也不准,本人也是尝试过网上各种demo,基本上都是渣渣,要么仅仅支持极其少量的视频文件比如收到的数据包…

【EdgeBox_tx1_tx2_E100】 PyTorch v1.8.0 torchvision v0.9.0 环境部署

简介:介绍PyTorch 环境 在 EHub_tx1_tx2_E100载板,TX1核心模块环境(Ubuntu18.04)下如何实现部署和测试,准备安装的环境是(PyTorch v1.8.0 torchvision v0.9.0)。 关于测试硬件EHub_tx1_tx2_E1…

文献学习04_Deep contextualized word representations 深度语境化的单词表示_20230102

论文信息 Subjects: Computation and Language (cs.CL) (1)题目:Deep contextualized word representations (深度语境化的单词表示) (2)文章下载地址: https://doi.org/10.48550/…

Telemetry网络监控技术讲解

目录 Telemetry基本概念 设备监测数据的数据类型 为么要提出Telemetry Telemetry网络模型 广义Telemetry 狭义Telemetry 狭义Telemetry框架 数据源(Yang) 数据生成(GPB) 数据订阅(gRPC、UDP) 数…

跟着开源项目学java7-从操作日志排除敏感字段的提交看基于注解的日志记录实现

这次 commit 主要解决日志信息中可能存在 password 等敏感字段,需要在保存前排除掉 主要涉及两个类的修改,添加实现了一个 PropertyPreExcludeFilter,集成 fastjson2 的 SimplePropertyPreFilter 实现 /*** 排除JSON敏感属性* * author ruo…

两种方法设置Word文档的“只读模式”

防止Word文档被意外更改,我们可以将Word设置成“只读模式”来保护文档。根据需要,还可以将Word可以设置成无密码和有密码的“只读模式”,下面来说说具体方法。 方法一:无密码的“只读模式” 打开Word文档后,点击菜单…

C进阶_C语言_大小端_C语言大小端

现在调试以下代码&#xff0c;并对变量a和b进行监视&#xff1a; #include <stdio.h> int main() {int a 20;int b -10;return 0; } 右键&#xff0c;勾选十六进制显示&#xff1a; 可以看到&#xff0c;变量a和变量b的十六进制值分别为0x00000014和0xfffffff6。 那么…

MySQL之数据库设计范式

数据库设计范式&#xff1a; 第一范式&#xff1a; 要求任何一张表必须有主键&#xff0c;每一个字段原子性不可再分&#xff0c;第一范式是最核心&#xff0c;最重要的范式&#xff0c;所有的表的设计都需要满足 举例&#xff1a; 第二范式&#xff1a; 建立在第一范式的基…

一款基于SSH的反向Shell工具

一款基于SSH的反向Shell工具。 Reverse_SSH上一款基于SSH的反向Shell工具&#xff0c;在该工具的帮助下&#xff0c;广大研究人员可以使用SSH来实现反向Shell&#xff0c;并同时拥有下列功能&#xff1a; 1、使用原生SSH语句管理和连接反向Shell&#xff1b; 2、动态、本地和…

<UDP网络编程>——《计算机网络》

目录 1. 网络基础知识 1.1 理解源IP地址和目的IP地址 1.2 认识端口号 1.3 理解 "端口号" 和 "进程ID" 1.3.1 理解源端口号和目的端口号 1.4 认识TCP协议 1.5 认识UDP协议 1.6 网络字节序 2. socket编程接口 2.1 socket 常见API 2.2 sockaddr结构 2.3 socka…

广告刷屏世界杯,联想Filez助力海信全球营销运营

相信每个世界杯球迷在看球的同时也被世界杯球场上不断滚动的“Hisense&#xff0c;世界第二&#xff0c;中国第一”的广告牌吸引目光。在这28天&#xff0c;64场比赛中&#xff0c;卡塔尔的比赛场地不仅随处可见海信的围栏广告&#xff0c;同时场外也随处可见海信的身影。从备受…