深度学习模型部署全流程-模型训练

news2025/4/17 17:29:22

文章目录

  • 前言
  • 模型训练全流程
    • 1.数据准备
    • 2.数据加载
    • 3.搭建神经网络
    • 4.设置损失函数,优化器
    • 5.训练网络模型
    • 6.模型测试
    • 7.完整代码
    • 9.训练结果
  • 小结

前言

该系列文章会介绍神经网络模型从训练到部署的全流程,对于已经参加工作的人可以快速的了解如何使用深度学习技术满足项目需求;对于学生群体可以实际使用算法,获得入门的成就感,有助于后续对深度学习的理论研究!
重点强调:本系列没有关于深度学习的详细理论介绍,关于理论部分推荐去看吴恩达,李沐等大佬的视频!!!
首先你要具备以下知识:

  • 深度学习理论基础(不懂的话去B站搜吴恩达)
  • pytorch框架使用(不懂的话去B站搜李沐)

模型训练全流程

以图像分类任务为例!

1.数据准备

深度学习技术的一切基础就是数据!数据!数据!小公司被大公司按在地上摩擦的主要原因就是数据层面完全比不过大公司,本片文章使用的是花类数据集,包括:daisy(雏菊)、dandelion(蒲公英)、roses(玫瑰)、sunflowers(向日葵)和 tulips(郁金香)5个类别,下载连接
在这里插入图片描述
下载完数据后还需要对数据进行处理,包括:

  • 制作标签
  • 对数据进行切分,训练集(用于训练模型),验证集(用于验证训练后的模型效果)

执行以下脚本即可得到训练集,验证集

import os
import random

# 根据自己数据的路径对应修改
root = "./flower_photos/"

file_name = ["daisy", "dandelion", "roses", "sunflowers", "tulips"]

for i in range(5):
    file_path = os.path.join(root, file_name[i])
    img_name_list = os.listdir(file_path)
    
    num = len(img_name_list)
    train_num = int(num * 0.8)

    train_id = random.sample(range(0, num), train_num)
    print(train_id)

    with open("./train.txt", "a+") as f:
        for ID in train_id:
            img_path = os.path.join(file_path, img_name_list[ID])
            data = img_path + " " + str(i) + "\n"
            f.write(data)
            print(data)

    with open("./val.txt", "a+") as f:
        for ID in range(num):
            if ID in train_id:
                continue
            else:
                img_path = os.path.join(file_path, img_name_list[ID])
                data = img_path + " " + str(i) + "\n"
                f.write(data)
                print(data)

最终会得到两个txt文件,其中包含了图像路径以及每张图像对应的标签(每行末尾处的0代表第0类daisy雏菊),到此数据准备完毕!
在这里插入图片描述

2.数据加载

这一部分确实不知道该怎么去讲解,因为pytorch已经把加载数据的API完全制作好了,我们只需要按照固定的步骤即可加载数据,挑几个关键部分介绍下吧
PS:个人并不太推荐花费大量的时间研究这类开源API,更加推荐学习下如何使用(看几个实例,搞清楚数据流的输入与输出就懂了),除非你需要自己实现一个类似的功能函数,再去仔细研究别人怎么写的!

  • 图像预处理,把读入的图像进行resize,归一化等操作,并转化为Tensor
'''
Resize:将入读的任意图像转化为固定分辨率
ToTensor:转化为适用于pytorch的tensor数据类型
Normalize:归一化操作,该参数由一些著名实验室实验得出
'''
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}
  • getitem’,定义具体加载数据的方式
class FlowersClsDataset(torch.utils.data.Dataset):

    def __init__(self, list_path, img_transform=None):
        super(FlowersClsDataset, self).__init__()
        self.img_transform = img_transform
        with open(list_path, 'r') as f:
            self.list = f.readlines()

    def __getitem__(self, index):
    	# 指定路径
        name = self.list[index].split()[0]
        img_path = name
        # 读入图像
        img = loader_func(img_path)

        if self.img_transform is not None:
            img = self.img_transform(img)

        # 读入标签
        label = int(self.list[index].split()[-1])
        
        return img, label

    def __len__(self):
        return len(self.list)

通过调试查看读入的数据是否正确,可以看到图像数据已经转化为tensor类型了,大功告成!
在这里插入图片描述

3.搭建神经网络

深度学习中最重要的部分,这一部分的可解释性较低,通常由著名实验室通过大量的实验得出,这里给出一份网络模型集合网址,大家可以根据自己的项目需求,硬件条件自行选择,并且该网址也配备了每个网络模型的论文,代码实现,非常良心!
网络模型链接
PS:该网站集合了几乎所有网络模型结构,并且也包含各种最新的具体算法,如目标检测,语义分割,图像分类等。画重点:基本上都有具体代码链接!!!
在这里插入图片描述
在这里插入图片描述
本文只用于模型训练流程演示,因此网络模型随便搭建,并没有参考某个具体的网络模型结构,只是简单的卷积+BN+Relu层的堆叠,具体代码入下

在这里插入代码片# 整合卷积,bn,relu操作
class conv_bn_relu(torch.nn.Module):
    def __init__(self,in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(conv_bn_relu,self).__init__()
        self.conv = torch.nn.Conv2d(in_channels,out_channels, kernel_size, stride = stride, padding = padding)
        self.bn = torch.nn.BatchNorm2d(out_channels)
        self.relu = torch.nn.ReLU()

    def forward(self,x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

# 定义网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layer1 = conv_bn_relu(3, 8, 3, 1, 1)
        self.layer2 = conv_bn_relu(8, 16, 3, 1, 1)
        self.layer3 = conv_bn_relu(16, 32, 3, 1, 1)
        self.layer4 = conv_bn_relu(32, 64, 3, 1, 1)
        self.layer5 = conv_bn_relu(64, 96, 3, 1, 1)

        self.fc1 = nn.Linear(7 * 7 * 96, 1024)
        self.fc2 = nn.Linear(1024, 128)
        self.fc3 = nn.Linear(128, 5)

        self.maxpool = nn.MaxPool2d(2, 2)
        self.softmax = nn.Softmax(dim=-1)
 
    def forward(self, x):
        x = self.layer1(x)
        x = self.maxpool(x)

        x = self.layer2(x)
        x = self.maxpool(x)

        x = self.layer3(x)
        x = self.maxpool(x)

        x = self.layer4(x)
        x = self.maxpool(x)

        x = self.layer5(x)
        x = self.maxpool(x).view(-1, 7*7*96)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        x = self.softmax(x)
 
        return x
 
net = Net()
net.to(device)

模型搭建完毕后,最重要的是验证其输入与输出的维度是否和预期的一致,以本片为例输入的数据维度(N,3,224,224),输出数据为(N,5)
此时模型输入为(10,3,224,224),10是通过batch_size得出
在这里插入图片描述
此时模型输出为(10,5),与我们所需数据维度一致
在这里插入图片描述

4.设置损失函数,优化器

# 定义损失函数,分类损失
class ClsLoss(nn.Module):
    def __init__(self):
        super(ClsLoss, self).__init__()
        self.nll = nn.NLLLoss()

    def forward(self, pre, labels):
        pre = torch.log(pre)
        loss = self.nll(pre, labels)
        return loss

# 损失函数实例化
loss_func = ClsLoss()

# 网络模型实例化
net = Net()
# 模型加载到GPU中
net.to(device)

# 提取网络模型参数
training_params = filter(lambda p: p.requires_grad, net.parameters())
# 定义优化器
optimizer = torch.optim.Adam(training_params, lr=0.0003, weight_decay=0.0001)

5.训练网络模型

# 具体执行训练过程
for epoch in range(31):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        # 获取图像数据和标签
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        # 计算损失
        loss = loss_func(outputs, labels)
        # 优化模型参数
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 200 == 199:
            print('[%d %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 200))
            running_loss = 0.0
    # 保存网络模型
    torch.save(net.state_dict(), "./model/" + str(epoch) + ".pth")
print('finished training!')

6.模型测试

#测试
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))

7.完整代码

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image

# 指定具体显卡设备
device = torch.device('cuda:0')

# 图像数据预处理步骤
'''
Resize:将入读的任意图像转化为固定分辨率
ToTensor:转化为适用于pytorch的tensor数据类型
Normalize:归一化操作,该参数由一些著名实验室实验得出
'''
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

def loader_func(path):
    return Image.open(path).convert('RGB')

# 数据加载器,__getitem__模块最为重要,完成数据读入与标签读入
class FlowersClsDataset(torch.utils.data.Dataset):

    def __init__(self, list_path, img_transform=None):
        super(FlowersClsDataset, self).__init__()
        self.img_transform = img_transform
        with open(list_path, 'r') as f:
            self.list = f.readlines()

    def __getitem__(self, index):
        name = self.list[index].split()[0]
        img_path = name
        # 读入图像
        img = loader_func(img_path)

        if self.img_transform is not None:
            img = self.img_transform(img)

        # 读入标签
        label = int(self.list[index].split()[-1])
        
        return img, label

    def __len__(self):
        return len(self.list)

# 完成数据加载器实例化
train_dataset = FlowersClsDataset('train.txt', img_transform=data_transforms['train'])
test_dataset = FlowersClsDataset('val.txt', img_transform=data_transforms['val'])

# 制作DataLoader,设置batch_size
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=True)

# 定义损失函数,分类损失
class ClsLoss(nn.Module):
    def __init__(self):
        super(ClsLoss, self).__init__()
        self.nll = nn.NLLLoss()

    def forward(self, pre, labels):
        pre = torch.log(pre)
        loss = self.nll(pre, labels)
        return loss

# 损失函数实例化
loss_func = ClsLoss()

# 整合卷积,bn,relu操作
class conv_bn_relu(torch.nn.Module):
    def __init__(self,in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(conv_bn_relu,self).__init__()
        self.conv = torch.nn.Conv2d(in_channels,out_channels, kernel_size, stride = stride, padding = padding)
        self.bn = torch.nn.BatchNorm2d(out_channels)
        self.relu = torch.nn.ReLU()

    def forward(self,x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

# 定义网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layer1 = conv_bn_relu(3, 8, 3, 1, 1)
        self.layer2 = conv_bn_relu(8, 16, 3, 1, 1)
        self.layer3 = conv_bn_relu(16, 32, 3, 1, 1)
        self.layer4 = conv_bn_relu(32, 64, 3, 1, 1)
        self.layer5 = conv_bn_relu(64, 96, 3, 1, 1)

        self.fc1 = nn.Linear(7 * 7 * 96, 1024)
        self.fc2 = nn.Linear(1024, 128)
        self.fc3 = nn.Linear(128, 5)

        self.maxpool = nn.MaxPool2d(2, 2)
        self.softmax = nn.Softmax(dim=-1)
 
    def forward(self, x):
        x = self.layer1(x)
        x = self.maxpool(x)

        x = self.layer2(x)
        x = self.maxpool(x)

        x = self.layer3(x)
        x = self.maxpool(x)

        x = self.layer4(x)
        x = self.maxpool(x)

        x = self.layer5(x)
        x = self.maxpool(x).view(-1, 7*7*96)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        x = self.softmax(x)
 
        return x

# 网络模型实例化
net = Net()
# 模型加载到GPU中
net.to(device)

# 提取网络模型参数
training_params = filter(lambda p: p.requires_grad, net.parameters())
# 定义优化器
optimizer = torch.optim.Adam(training_params, lr=0.0003, weight_decay=0.0001)

# 具体执行训练过程
for epoch in range(31):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        # 获取图像数据和标签
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        # 计算损失
        loss = loss_func(outputs, labels)
        # 优化模型参数
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 200 == 199:
            print('[%d %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 200))
            running_loss = 0.0
    torch.save(net.state_dict(), "./model/" + str(epoch) + ".pth")
print('finished training!')
 
#测试
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))

9.训练结果

只训练了10论,可以看到该模型的精度为67%
在这里插入图片描述

小结

主要介绍了深度学习模型训练的全流程,其中最重要的是pytorch框架的熟练程度,这一部分多用几次,多看看官方文档就熟悉了;更为重要的是理论部分,需要看大量的论文,并且结合多次实验(前提是你有N卡有电费并且还有数据)才能有非常大的提升!

如果文章对你有用,请点个赞呗!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2922.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android Studio入门之文本内容、大小、颜色的讲解及实战(附源码 超详细必看)

运行有问题或需要源码请点赞关注收藏后评论区留言或私信博主 一、设置文本的内容 1:在XML文件中通过属性android:text设置文本 <TextViewandroid:layout_width"wrap_content"android:layout_height"wrap_content"android:text"Hello World!"…

nordic 52832中添加RTT打印

JlinkRTT RTT是基于Jlink调试器的实时传输技术,可以代替串口打印一些调试信息,不需要额外接线。 nordic 52832官方例程中,会将RTT打印函数做进一步封装,下面就讲一下怎么开启52832中的RTT打印。 第一步 增加RTT代码 RTT源代码可以在segger官方网站下载,也可以在nordic 5…

使用 stream buffer 传递数据

使用 stream buffer 传递数据 概述 如前所述&#xff0c;队列虽然提供了任务之间传递数据的功能&#xff0c;但没有对通知机制进行优化&#xff0c;即不方便实现多次采集不同长度的数据&#xff0c;然后触发一次通知接收的机制。 特性概述 Streambuffer 的中文含意是“流式…

Chapter5.5:频率响应法

此系列属于胡寿松《自动控制原理题海与考研指导》(第三版)习题精选&#xff0c;仅包含部分经典习题&#xff0c;需要完整版习题答案请自行查找&#xff0c;本系列属于知识点巩固部分&#xff0c;搭配如下几个系列进行学习&#xff0c;可用于期末考试和考研复习。 自动控制原理(…

Hive与Hbase的区别与联系

一、概念 1&#xff0c;Hive hive是基于Hadoop的一个数据仓库工具&#xff0c;用来进行数据提取、转化、加载&#xff0c;这是一种可以存储、查询和分析存储在Hadoop中的大规模数据的机制。hive数据仓库工具能将结构化的数据文件映射为一张数据库表&#xff0c;并提供SQL查询…

网站中的经典,分享那些我用过的宝藏网站

前言 本篇将会具体分享我在最开始学习编程时了解到的网站&#xff0c;并分享自己使用这些网站的感受&#xff0c;当然&#xff0c;如果我有说的不正确的或者需要补充的&#xff0c;欢迎评论区补充纠正。还有各位来自优秀学校的伙伴们&#xff0c;或许其中一些资源在你们的学校…

安卓开发Android studio学习笔记15:关于如何使用Okhttp框架的网络请求(调用API接口)

Android studio一、安卓基于HTTP网络编程(一)、两种请求方式(二&#xff09;、安卓基于HTTP网络编程的两种方式1、使用HttpURLConnection访问网络资源**2、利用HttpClient访问网络资源**&#xff08;1&#xff09;HttpGet&#xff08;2&#xff09;HttpPost二、基础Okhttp的网络…

修改 echarts 默认样式记录

1、修改折线图上的数据标记点 showSymbol:false , 表示不展示数据点&#xff0c;只有鼠标 hover 时&#xff0c; tooltip 展示。 series: [{name: 进场, // 名称&#xff0c;图例和 tooltip 中展示showSymbol: false, // 不展示数据标记点type: line, // 类型color: #0091FF…

大学解惑10 - CSS中的content怎么换行,以及使用before伪类的优点

大学解惑09 - 单独用HTML javascript CSS 实现三版99乘法表&#xff0c;你就是班里最靓的仔https://blog.csdn.net/xingyu_qie/article/details/127631612 ☆ 上一篇文章用前端HTML CSS JS基础写了3版99乘法表&#xff0c;有同学说终于把99乘法表写透了&#xff0c;但是紧接着就…

Linux关于JDK、Tomcat以及MySQL安装

目录 一、JDK安装 1、 上传jdk、tomcat安装包 2、解压两个工具包 3、配置环境 4、在配置文件中加入java环境变量&#xff1a; 5、保存&#xff0c;让新设置的环境变量生效 二、Tomcat安装 1、将tomcat解压到/opt下 2、配置环境变量 3、启动tomcat 4、创建启动脚本 三…

入门学习XSS漏洞,这一篇就够了

入门学习XSS漏洞&#xff0c;这一篇就够了1.XSS简介2.XSS的类型反射型XSS存储型XSSDOM型XSS1.XSS简介 XSS攻击&#xff0c;通常指黑客通过“HTML注入”篡改了网页&#xff0c;插入了恶意的脚本&#xff0c;从而在用户浏览网页时&#xff0c;控制用户浏览器的一种攻击。在一开始…

【沐风老师】怎么在3DMAX中使用MAXScript脚本动画编程?

大家可能对3dmax都抱有很浓厚的兴趣,但如果你接触到max脚本(MAXScript),你会觉得它比max本身更让人着迷,因为它更能拓展我们的想象力,或者帮助我们更好的提高工作效率。不过,MAXScript是解释语言,不适合编写过于复杂的功能,因为这将大大影响执行的速度。 言归正传,就…

jmeter模拟多IP访问

1. 前言&#xff1a; 今天一同事在压测时提到怎么用jmeter里虚拟多个ip来发送请求&#xff0c;我想了一下以前用LR时用过虚拟ip地址&#xff0c;jmeter还没有使用过。想着原理应该是相通的&#xff0c;既然LR都能支持的话&#xff0c;那Jmeter应该也是支持&#xff0c;于是就有…

ARM pwn 入门 (1)

最近笔者刚刚加入了一个项目组&#xff0c;需要用到ARM架构的东西&#xff0c;和ARM pwn也有一定关系&#xff0c;因此一不做二不休&#xff0c;决定开始学习ARM pwn&#xff0c;顺便熟悉项目前置知识&#xff0c;一举两得。 ARM与x86分属不同架构&#xff0c;指令集不同&…

用frp搞个内网穿透

使用场景&#xff1a; 在公司用电脑敲代码&#xff0c;环境都是localhost&#xff0c;有时候你要接第三方接口比如支付、或者企业微信的事件回调等&#xff0c;都需要一个公网地址&#xff0c;因为这时候是开发阶段&#xff0c;你即想要公司电脑上运行的环境又想要回调能找到你…

2022年首家民营征信机构浙江同信获企业征信备案公示

2022年首家民营征信机构浙江同信获企业征信备案公示 2022年11月1日&#xff0c;中国人民银行杭州中心支行公示了浙江同信企业征信服务有限公司企业征信机构备案&#xff0c;该机构为浙江省进行备案公示的第九家机构。其他八家分别为芝麻信用管理有限公司、浙江有数数智科技有限…

Transform介绍(1)

文章目录1. transform 方法2. transform 增量模式3. 注册 Transform使用Transform的常见场景有埋点统计、耗时监控、方法替换 通过上图以我们了解下transform的作用&#xff0c;transform在 class 到 dex 之间工作&#xff0c;处理包括 javac 编译后的字节码文件&#xff0c;每…

【Linux内核系列】进程调度

目录 一、为什么要调度 二、调度均衡 三、进程调度框架 3.1 调度队列 3.2 进程唤醒 3.3 调度时机 主动调度&#xff1a; 被动调度&#xff1a; 四、调度算法 4.1 先来先服务调度算法 4.2 最短作业优先调度算法 4.3 高响应比优先调度算法 4.4 时间片轮转调度算法 …

洛谷千题详解 | P1007 独木桥【C++、Pascal语言】

博主主页&#xff1a;Yu仙笙 专栏地址&#xff1a;洛谷千题详解 目录 题目背景 题目描述 输入格式 输出格式 输入输出样例 解析&#xff1a; C源码&#xff1a; Pascal源码&#xff1a; ------------------------------------------------------------------------------------…

NFT 推荐|辛迪加黑市系列第一弹

由 Planet Rift 呈现&#xff01; 塞巴星球的辛迪加已经洗劫了政府&#xff0c;现在是时候揭开补给品的神秘面纱了&#xff01; 辛迪加黑市系列的第一弹包括 30 个由 Planet Rift 宇宙设计的资产。其中首次发售的包含 4 套未来风格的盔甲、3 台彩色自动售货机和其他装备。 别忘…