pyTorch入门(五)——训练自己的数据集

news2025/1/19 10:39:01

学更好的别人,

做更好的自己。

——《微卡智享》

419c279068cd36a2704020162893d3c7.jpeg

本文长度为1749,预计阅读5分钟

前言

前面四篇将Minist数据集的训练及OpenCV的推理都介绍完了,在实际应用项目中,往往需要用自己的数据集进行训练,所以本篇就专门介绍一下pyTorch怎么训练自己的数据集。

57242f16008f15a727e298f868261a8b.png

微卡智享

生成自己的训练图片

上一篇《pyTorch入门(四)——导出Minist模型,C++ OpenCV DNN进行识别》中使用VS Studio实现了OpenCV的推理,介绍过在推理前需要将图片进行预处理,包括灰度、二值化,查找及排序轮廓都已经处理了,所以只要对上面的代码进行改造一下,将提取的信息保存出来,就是我们想要训练的数据了。先上源码:

#pragma once
#include<iostream>
#include<chrono>
#include<time.h>
#include<opencv2/opencv.hpp>
#include<opencv2/dnn/dnn.hpp>


using namespace cv;
using namespace std;


//参数iType  0-提取图片保存   1-使用DNN推理
int iType = 1;


dnn::Net net;


//排序矩形
void SortRect(vector<Rect>& inputrects) {
  for (int i = 0; i < inputrects.size(); ++i) {
    for (int j = i; j < inputrects.size(); ++j) {
      //说明顺序在上方,这里不用变
      if (inputrects[i].y + inputrects[i].height < inputrects[i].y) {


      }
      //同一排
      else if (inputrects[i].y <= inputrects[j].y + inputrects[j].height) {
        if (inputrects[i].x > inputrects[j].x) {
          swap(inputrects[i], inputrects[j]);
        }
      }
      //下一排
      else if (inputrects[i].y > inputrects[j].y + inputrects[j].height) {
        swap(inputrects[i], inputrects[j]);
      }
    }
  }
}


//处理DNN检测的MINIST图像,防止长方形图像直接转为28*28扁了
void DealInputMat(Mat& src, int row = 28, int col = 28, int tmppadding = 5) {
  int w = src.cols;
  int h = src.rows;
  //看图像的宽高对比,进行处理,先用padding填充黑色,保证图像接近正方形,这样缩放28*28比例不会失衡
  if (w > h) {
    int tmptopbottompadding = (w - h) / 2 + tmppadding;
    copyMakeBorder(src, src, tmptopbottompadding, tmptopbottompadding, tmppadding, tmppadding,
      BORDER_CONSTANT, Scalar(0));
  }
  else {
    int tmpleftrightpadding = (h - w) / 2 + tmppadding;
    copyMakeBorder(src, src, tmppadding, tmppadding, tmpleftrightpadding, tmpleftrightpadding,
      BORDER_CONSTANT, Scalar(0));


  }
  resize(src, src, Size(row, col));
}


// 获取当时系统时间
const string GetCurrentSystemTime()
{
  auto t = chrono::system_clock::to_time_t(std::chrono::system_clock::now());
  struct tm ptm { 60, 59, 23, 31, 11, 1900, 6, 365, -1 };
  _localtime64_s(&ptm, &t);
  char date[60] = { 0 };
  sprintf_s(date, "%d%02d%02d%02d%02d%02d",
    (int)ptm.tm_year + 1900, (int)ptm.tm_mon + 1, (int)ptm.tm_mday,
    (int)ptm.tm_hour, (int)ptm.tm_min, (int)ptm.tm_sec);
  return move(std::string(date));
}


int main(int argc, char** argv) {
  //定义onnx文件
  string onnxfile = "D:/Business/DemoTEST/CPP/OpenCVMinistDNN/torchminist/ResNet.onnx";


  //测试图片文件
  string testfile = "D:/Business/DemoTEST/CPP/OpenCVMinistDNN/testpic/test3.png";


  //提取的图片保存位置
  string savefile = "D:/Business/DemoTEST/CPP/OpenCVMinistDNN/findcontoursMat";


  if (iType == 1) {
    net = dnn::readNetFromONNX(onnxfile);
    if (net.empty()) {
      cout << "加载Onnx文件失败!" << endl;
      return -1;
    }
  }


  //读取图片,灰度,高斯模糊
  Mat src = imread(testfile);
  //备份源图
  Mat backsrc;
  src.copyTo(backsrc);
  cvtColor(src, src, COLOR_BGR2GRAY);
  GaussianBlur(src, src, Size(3, 3), 0.5, 0.5);
  //二值化图片,注意用THRESH_BINARY_INV改为黑底白字,对应MINIST
  threshold(src, src, 0, 255, THRESH_BINARY_INV | THRESH_OTSU);


  //做彭账处理,防止手写的数字没有连起来,这里做了3次膨胀处理
  Mat kernel = getStructuringElement(MORPH_RECT, Size(3, 3));
  //加入开运算先去燥点
  morphologyEx(src, src, MORPH_OPEN, kernel, Point(-1, -1));
  morphologyEx(src, src, MORPH_DILATE, kernel, Point(-1, -1), 3);
  imshow("src", src);


  vector<vector<Point>> contours;
  vector<Vec4i> hierarchy;
  vector<Rect> rects;


  //查找轮廓
  findContours(src, contours, hierarchy, RETR_EXTERNAL, CHAIN_APPROX_NONE);
  for (int i = 0; i < contours.size(); ++i) {
    RotatedRect rect = minAreaRect(contours[i]);
    Rect outrect = rect.boundingRect();
    //插入到矩形列表中
    rects.push_back(outrect);
  }


  //按从左到右,从上到下排序
  SortRect(rects);
  //要输出的图像参数
  for (int i = 0; i < rects.size(); ++i) {
    Mat tmpsrc = src(rects[i]);
    DealInputMat(tmpsrc);


    if (iType == 1) {
      //Mat inputBlob = dnn::blobFromImage(tmpsrc, 0.3081, Size(28, 28), Scalar(0.1307), false, false);
      Mat inputBlob = dnn::blobFromImage(tmpsrc, 1, Size(28, 28), Scalar(), false, false);


      //输入参数值
      net.setInput(inputBlob, "input");
      //预测结果 
      Mat output = net.forward("output");


      //查找出结果中推理的最大值
      Point maxLoc;
      minMaxLoc(output, NULL, NULL, NULL, &maxLoc);


      cout << "预测值:" << maxLoc.x << endl;


      //画出截取图像位置,并显示识别的数字
      rectangle(backsrc, rects[i], Scalar(255, 0, 255));
      putText(backsrc, to_string(maxLoc.x), Point(rects[i].x, rects[i].y), FONT_HERSHEY_PLAIN, 5, Scalar(255, 0, 255), 1, -1);
    }
    else {
      string filename = savefile + "/" + GetCurrentSystemTime() + "-" + to_string(i) + ".jpg";
      cout << filename << endl;
      imwrite(filename, tmpsrc);
    }
  }


  imshow("backsrc", backsrc);




  waitKey(0);
  return 0;
}

划重点

696b35749fc2d08092aaea17c12bcf3e.png

加了一个参数,设置的时候0为提取保存的图片,1是上一篇的推理。

9b73b9524203bf40b821f0230a283166.png

增加了一个获取当前时间的函数,主要作用就是保存图片的时候在文件名加上时间。

a06bdbdd0839f168cd0eb8e90323b9e3.png

增加了一个保存图片的位置

634a67e4e50c2621215c694fa2b13718.png

根据上面的参数,设置为1时还是原来的DNN推理,0时通过imwrite将图片进行保存。

832dc656540202bd25a5f88deea31459.png

接下来我们自己做点数据集,用画图工具在上面写上数字,将0--9的数字分别做了10张图出来。

9663be5306a98fa4547d066a5315b660.png

6e3194a29168ced8937551dfed883c7a.png

79f8857319913e02716ae435ad402ca6.png

c6e6bd5458548371d1f7c98cbd0e86bd.png

运行的效果如下:

fa0b279295cfed140faa2d03a12cecfb.png

可以看出上图中我们将数字9的图片分开截取并保存到指定的目录了。

db8689d46c2755d651e916d14e92e4b1.png

同时在Dataset下创建mydata目录,并创建出train训练的目录,在目录下创建了0-9的文件夹,这样做的目录是在pyTorch调用时会直接根据train下不同的文件夹目录设置对应的label标签了,不用我们在每个进行对照,相应的,提取出的数字图片也要放到对应的目录中

4ce412583b43388e6f667e7fe32ea462.png

将刚才生成的数字9的图片都剪切到9的文件夹下,其余的数字也是用同样方式。

5cc7864cdaeb7777d1deca672abf98ab.png

test测试集也用相同的方式处理,只不过我们拷过来后删了一大部分,就做别的处理。做完这些,提取图片的准备工作就完成了,接下来就是通过pyTorch训练。

3e1769d38670b88898e73c7cf0744cb1.png

微卡智享

pyTorch训练自己数据集

420c8516d113837d9c52dde07fee8ab4.png

新建了一个trainmydata.py的文件,训练的流程其实和原来差不多,只不过我们是在原来的基础上进行再训练,所以这些的模型是先加载原来的训练模型后,再进行训练,还是先上代码

import torch
import time
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.optim as optim
import matplotlib.pyplot as plt
from pylab import mpl
import trainModel as tm


##训练轮数
epoch_times = 15


##设置初始预测率,用于判断高于当前预测率的保存模型
toppredicted = 0.0


##设置学习率
learnrate = 0.01 
##设置动量值,如果上一次的momentnum与本次梯度方向是相同的,梯度下降幅度会拉大,起到加速迭代的作用
momentnum = 0.5


##自己训练的模型前面加个my
savemodel_name = "my" + tm.savemodel_name


##生成图用的数组
##预测值
predict_list = []
##训练轮次值
epoch_list = []
##loss值
loss_list = []


transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.1307,), std=(0.3081,))
]) ##Normalize 里面两个值0.1307是均值mean, 0.3081是标准差std,计算好的直接用了


##训练数据集位置
train_mydata = datasets.ImageFolder(
    root = '../datasets/mydata/train', 
    transform = transform
)
train_mydataloader = DataLoader(train_mydata, batch_size=64, shuffle=True, num_workers=0)


##测试数据集位置
test_mydata = datasets.ImageFolder(
    root = '../datasets/mydata/test', 
    transform = transform
)
test_mydataloader = DataLoader(test_mydata, batch_size=1, shuffle=True, num_workers=0)




##加载已经训练好的模型
model = tm.Net(tm.train_name)
model.load_state_dict(torch.load(tm.savemodel_name))


##加入判断是CPU训练还是GPU训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)


##优化器 
optimizer = optim.SGD(model.parameters(), lr= learnrate, momentum= momentnum)


##训练函数
def train(epoch):
    model.train()
    for batch_idx, data in enumerate(train_mydataloader, 0):
        inputs, target = data
        ##加入CPU和GPU选择
        inputs, target = inputs.to(device), target.to(device)


        optimizer.zero_grad()


        #前馈,反向传播,更新
        outputs = model(inputs)
        loss = model.criterion(outputs, target)
        loss.backward()
        optimizer.step()


    loss_list.append(loss.item())
    print("progress:", epoch, 'loss=', loss.item())






def test():
    correct = 0 
    total = 0
    model.eval()
    ##with这里标记是不再计算梯度
    with torch.no_grad():
        for data in test_mydataloader:
            inputs, labels = data
            ##加入CPU和GPU选择
            inputs, labels = inputs.to(device), labels.to(device)




            outputs = model(inputs)
            ##预测返回的是两列,第一列是下标就是0-9的值,第二列为预测值,下面的dim=1就是找维度1(第二列)最大值输出
            _, predicted = torch.max(outputs.data, dim=1)


            total += labels.size(0)
            correct += (predicted == labels).sum().item()


    currentpredicted = (100 * correct / total)
    ##用global声明toppredicted,用于在函数内部修改在函数外部声明的全局变量,否则报错
    global toppredicted
    ##当预测率大于原来的保存模型
    if currentpredicted > toppredicted:
        toppredicted = currentpredicted
        torch.save(model.state_dict(), savemodel_name)
        print(savemodel_name+" saved, currentpredicted:%d %%" % currentpredicted)


    predict_list.append(currentpredicted)    
    print('Accuracy on test set: %d %%' % currentpredicted)        


##开始训练
timestart = time.time()
for epoch in range(epoch_times):
    train(epoch)
    test()
timeend = time.time() - timestart
print("use time: {:.0f}m {:.0f}s".format(timeend // 60, timeend % 60))






##设置画布显示中文字体
mpl.rcParams["font.sans-serif"] = ["SimHei"]
##设置正常显示符号
mpl.rcParams["axes.unicode_minus"] = False


##创建画布
fig, (axloss, axpredict) = plt.subplots(nrows=1, ncols=2, figsize=(8,6))


#loss画布
axloss.plot(range(epoch_times), loss_list, label = 'loss', color='r')
##设置刻度
axloss.set_xticks(range(epoch_times)[::1])
axloss.set_xticklabels(range(epoch_times)[::1])


axloss.set_xlabel('训练轮数')
axloss.set_ylabel('数值')
axloss.set_title(tm.train_name+' 损失值')
#添加图例
axloss.legend(loc = 0)


#predict画布
axpredict.plot(range(epoch_times), predict_list, label = 'predict', color='g')
##设置刻度
axpredict.set_xticks(range(epoch_times)[::1])
axpredict.set_xticklabels(range(epoch_times)[::1])
# axpredict.set_yticks(range(100)[::5])
# axpredict.set_yticklabels(range(100)[::5])


axpredict.set_xlabel('训练轮数')
axpredict.set_ylabel('预测值')
axpredict.set_title(tm.train_name+' 预测值')
#添加图例
axpredict.legend(loc = 0)


#显示图像
plt.show()

划重点

22cb244fdcf4fa43784ef000cdc6e357.png

自己训练的模型文件前面加上一个my,用于不覆盖原来的训练模型。

加载训练集和测试集

e9f3aebe4640a7592188864523459b18.png

在transform中,增加了一行transforms.Grayscale(num_output_channels=1),主要原因是在OpenCV中使用imwrite保存的文件,虽然是二值化的图片,但是是3通道的,而在pyTorch我们的训练数据都是1X28X28,即是单通道的图像,所以这里加上这一句是将读取的图片设置为单通道。

使用datasets.ImageFolder直接读取train目录下的数据,自动将图像及对应的标签加载进来了。

加载已训练的模型

71ae4d6479603740bce6094f709f95c2.png

这里的model模型直接通过load_state_dict加载进来,然后再训练自己的数据,下面的训练方式和原来train都一样了。

e00c31aa1abff71781568e86146dd6ac.png

4aa80b9bb402d525f1e20c6a8e52c462.png

因为我这边保存的数据很少,而且测试集的图片和训练集的一样,只训练了15轮,所以训练到第3轮的时候已经就到100%了。简单的训练自己的数据集就完成了。

f7d7a0741f48d3512b9cf069c4706a1d.png

922088850941b926acc627227ca65eee.png

往期精彩回顾

 

d8f73635e4ed1490e4abaca17a1c26c2.jpeg

pyTorch入门(四)——导出Minist模型,C++ OpenCV DNN进行识别

 

 

8a9bd2f215a3d2935dbf77425f29f973.jpeg

pyTorch入门(三)——GoogleNet和ResNet训练

 

 

1ed72c8a20a8e0f7295b23e5f0478060.jpeg

pyTorch入门(二)——常用网络层函数及卷积神经网络训练

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/116752.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

UCOS-III任务堆栈溢出检测及统计任务堆栈使用量的方法

1、说在前 在操作系统任务设计的时候&#xff0c;通常会遇到一个比较麻烦的问题&#xff0c;也就是任务堆栈大小设定的问题&#xff0c;为此我们我需要知道一些问题&#xff1a; 1.1. 任务堆栈一但溢出&#xff0c;意味着系统的崩溃&#xff0c;在有MMU或者MPU的系统中&#xf…

linux centons安装cpolar内网穿透

Linux操作系统在个人电脑上并不多见&#xff0c;但在需要集中资源处理信息交互的服务器上&#xff0c;Linux系统却几乎是唯一的存在。而cpolar凭借极极低的资源占用和便捷操作&#xff0c;十分适合在linux系统上使用。今天&#xff0c;我们就为大家介绍&#xff0c;如何在linux…

判断变量是否为数组及通用判断数据类型方法

判断变量是不是数组类型 function fn() {console.log(Array.isArray(arguments)); //false; 因为arguments是类数组&#xff0c;但不是数组console.log(Array.isArray([1,2,3,4])); //trueconsole.log(arguments instanceof Array); //fasleconsole.log([1,2,3,4] instance…

python中利用tkinter和ImageTK进行圣诞快乐图片的显示

一、前言 python中使用tkinter加载“Merry Christmas“ 图片。 二、用python显示Merry Christmas图片 1. python中&#xff0c;tkinter中可以进行图形界面编程。tkinter库提供了各种控件&#xff0c;其中&#xff0c;可以使用PhotoImage和Label组合&#xff0c;进行“Merry Chr…

【树莓派不吃灰】网络篇 Tcpdump iptables

目录1、一台主机上只能保持最多 65535 个 TCP 连接吗&#xff1f;2、tcpdump3、iptables❤️ 博客主页 单片机菜鸟哥&#xff0c;一个野生非专业硬件IOT爱好者 ❤️❤️ 本篇创建记录 2022-12-26 ❤️❤️ 本篇更新记录 2022-12-26 ❤️&#x1f389; 欢迎关注 &#x1f50e;点…

RT-Thread 学习笔记(十四)--- 开启基于RTGUI的LCD显示功能(4)<demo组件的按键响应和焦点支持>

软件环境&#xff1a;Win7&#xff0c;Keil MDK 4.72a, IAR EWARM 7.2, GCC 4.2&#xff0c;Python 2.7 ,SCons 2.3.2 硬件环境&#xff1a;Armfly STM32F103ZE-EK v3.0开发板 参考文章&#xff1a;RT-Thread编程指南 RT-Thread_1.2.0lwiprtgui0.8.0 移植心得 RT-Thread RT…

2022/12/26 请你谈谈数据库事务机制?

1 事务四大特征 一般来说&#xff0c;事务是必须满足4个条件&#xff08;ACID&#xff09;&#xff1a;原子性&#xff08;Atomicity&#xff09;、一致性&#xff08;Consistency&#xff09;、隔离性&#xff08;Isolation&#xff09;、持久性&#xff08;Durability&#…

软件测试工程职场发展细谈

前言 今天几个测试圈子的大佬约了饭局&#xff0c;席间彼此交流了很多关于职场工作上测试相关的话题&#xff0c;听了他们的一些观点很有启发&#xff0c;我自己对于聊的话题也做了一些描述和实际的案例说明。下面是聊的一些关键话题&#xff0c;我将交流的内容和个人观点整理…

(二)JavaScript

JavaScript 是一门跨平台、面向对象的脚本语言。JavaScript 是用来控制网页行为的&#xff0c;它能使网页可交互。 一、JavaScript 引入方式&#xff08;P71&#xff09; &#xff08;1&#xff09;内部脚本&#xff1a;将JS代码定义在HTML页面中 &#xff08;2&#xff09;外部…

ActiveMQ集群模式

目录 一、面试题 二、多节点集群是什么 三、zookeeperreplicated-leveldb-store的主从集群 四、官网集群原理图 五、部署规划和步骤 六、集群可用性测试 一、面试题 引入消息队列之后该如何保证其高可用性 二、多节点集群是什么 基于ZooKeeper和LevelDB搭建ActiveMQ 集…

API签名鉴权设计

鉴权作用 在实际的业务中&#xff0c;必然会存在和其他平台系统进行数据传输。这个时候出于对数据的保密要求&#xff0c;都会对接口&#xff08;API&#xff09;添加鉴权机制&#xff0c;识别调用方的真实身份&#xff0c;对未通过鉴权的请求不做任何业务处理&#xff0c;以帮…

国科大模式识别导论作业3:神经网络

目录题目代码data.pyutils.pynetwork.pymain.py结果整理一下近期作业中的编程题&#xff0c;仅供交流学习题目 本题使用的数据如下&#xff1a; 第一类 10 个样本&#xff08;三维空间&#xff09;&#xff1a; [ 1.58, 2.32, -5.8], [ 0.67, 1.58, -4.78], [ 1.04, 1.01, -3…

OpenCV 图像旋转、平移、缩放

本文是 OpenCV图像视觉入门之路的第7篇文章&#xff0c;本文详细的进行了图像的缩放 cv2.resize()、旋转 cv2.flip()、平移 cv2.warpAffine()等操作。 OpenCV 图像旋转、平移、缩放目录 1 缩放图片 2 翻转图片 2.1 垂直翻转 2.2 水平翻转 2.3 水平垂直翻转 ​编辑 3 平移…

百度离线人脸识别SDK

1&#xff0c;采坑备忘 &#xff08;1&#xff09;8.1版本的SDK在spring-boot接口访问第一次正常&#xff0c;第二次之后JVM会奔溃&#xff0c;可能是java gc 处理C开出的内存有问题。 换6.1.3版本的SDK。 javaWindows百度离线人脸识别SDK6.1.3-Java文档类资源-CSDN下载javaW…

Harmony/OpenHarmony应用开发-转场动画页面间转场

在全局pageTransition方法内配置页面入场和页面退场时的自定义转场动效。 说明&#xff1a;从API Version 7开始支持。开发语言ets. 名称 参数 参数描述 PageTransitionEnter { type: RouteType, duration: number, curve: Curve | string, delay: number } 设置当前页面…

1998-2014年企业绿色发展数据库

1998-2014年工业企业的排放排污和环境治理等信息数据 1、时间&#xff1a;1998-2014年 2、数据来源&#xff1a;原环保部。 3、统计字段&#xff1a;主要有企业基本信息、生产信息、水环境、大气环境&#xff0c;内容涵盖了资源利用类指标&#xff08;工业用水量、煤炭消费量…

YGG 与 Thirdverse 达成合作,将《足球小将》IP 带入 Web3

YGG 与 Thirdverse 建立了合作关系&#xff0c;Thirdverse 是一家专注于多人 VR 和 Web3 游戏的游戏工作室&#xff0c;在日本和美国分别设有办事处。 YGG 通过购买未来股权的简单协议&#xff08;SAFE&#xff09;参与了 Thirdverse 近期的 1500 万美元融资。这种合作关系将使…

FastAPI从入门到实战(16)——依赖项

依赖注入是耳熟能详的一个词了&#xff0c;听起来很复杂&#xff0c;实际上并没那么复杂&#xff0c;正常的访问需要接受各种参数来构造一个对象&#xff0c;依赖注入就变成了只接收一个实例化对象&#xff0c;主要用于共享业务逻辑、共享数据库连接、实现安全、验证、权限等相…

原油投资怎么样赚钱?原油投资赚钱技巧有哪些?

以前没有交易过原油的投资者&#xff0c;看到其他投资者从中获得了较好的盈利&#xff0c;也想通过原油投资来赚钱。那么原油投资到底能不能赚钱&#xff0c;是很多新手投资者比较想了解的问题。其实原油投资想盈利并不能全部依靠运气&#xff0c;只有掌握了原油投资赚钱技巧&a…

【Java基础】Java日志—什么是日志级别?如何配置数据源到不同的位置?配置文件内容都是什么含义?

目录 一、log4j1详情&#xff1a;记录器和日志级别 二、 log4j1详情&#xff1a;输出源【输出到不同的位置】 1、ConsoleAppender【将日志输出到控制台】 2、FileAppender【将日志输出到文件】 3、DailyRollingFileAppender【每日输出到一个新文件】 4、JDBCAppender【输…