Pytorch进阶教学——训练一个图像分类模型(GPU)

news2024/12/24 20:55:48

目录

1、前言 

2、数据集介绍

3、获取数据

4、创建网络

5、训练模型

6、测试模型

6.1、测试整个模型准确率

6.2、测试单张图片


1、前言 

  • 编写一个可以分类蚂蚁和蜜蜂图片的模型,使用数据集对卷积神经网络进行训练。训练后的模型可以对蚂蚁或蜜蜂的图片进行检测。
  • 使用anaconda新建一个虚拟环境,安装好pytorch。后续缺什么包就安装什么包即可。
  • 使用pycharm新建一个项目,配置好环境。

2、数据集介绍

  • 使用的数据集为蚂蚁和蜜蜂的图片,分为训练集和测试集
  • 【注】数据集下载地址。

3、获取数据

  • 代码中获取数据集使用的是txt文件,所以首先需要提取全部图片的地址和标签放入txt文件中。
  • 下述代码为python提取全部图片地址和标签导出为txt文件的脚本。(自行修改)
    • import os  # 导入os模块,用于操作文件路径等操作系统相关功能。
      
      
      def get_file_name(file_path, output_file, type):  # 绝对路径
          path_list = os.listdir(file_path)  # 列出指定路径下的所有文件和文件夹,并将结果存储在path_list中
          with open(output_file, 'a') as file:
              for filename in path_list:
                  all_file_path = os.path.join(file_path, filename)  # 拼接路径
                  file.write(all_file_path + ' ' + type + '\n')
      
      
      if __name__ == '__main__':
          ants_file_path = r"D:\BaiduNetdiskWorkspace\PyTorch\image_recognition\hymenoptera_data\train\ants"
          bees_file_path = r"D:\BaiduNetdiskWorkspace\PyTorch\image_recognition\hymenoptera_data\train\bees"
          output_file = r"D:\BaiduNetdiskWorkspace\PyTorch\image_recognition\hymenoptera_data\train.txt"
          get_file_name(ants_file_path, output_file, 'ants')
          get_file_name(bees_file_path, output_file, 'bees')
    •  
  • 将全部地址修改为相对地址。
    • 使用替换操作实现。例如:
  • 最后txt文件的内容如下:
  • 新建一个dataset.py文件。
    • # 读取数据
      import torch
      import torchvision.transforms as transforms
      from PIL import Image
      
      
      # 读取数据类
      class MyDataset(torch.utils.data.Dataset):  # 继承构建自定义数据集的基类
          def __init__(self, datatxt, datatransform):
              datas = open(datatxt, 'r').readlines()  # 按行读取,每行包含图像路径和标签
              self.images = []
              self.labels = []
              self.transform = datatransform
              for data in datas:
                  item = data.strip().split(' ')  # 去除首尾空格并按空格分割
                  # 分别将图像路径和标签添加到self.images和self.labels列表中
                  self.images.append(item[0])  # 路径
                  self.labels.append(item[1])  # 标签
              return
      
          def __len__(self):
              return len(self.images)
      
          # 获取数据集中的一个样本。接收一个索引item,根据索引获取对应的图像路径和标签
          def __getitem__(self, item):
              imagepath, label = self.images[item], self.labels[item]
              image = Image.open(imagepath)  # 打开图片
              return self.transform(image), label  # 返回转换后的图像和对应的标签
      
      
      # 用于测试
      if __name__ == '__main__':
          # 利用txt文件读取图片信息,txt文件包括图片路径和标签
          traintxt = './hymenoptera_data/train.txt'
          valtxt = './hymenoptera_data/val.txt'
          # 图片转换形式
          traindata_transfomer = transforms.Compose([
              transforms.ToTensor(),  # 转为Tensor格式
              transforms.Resize(60),  # 调整图像大小,调整为高度或宽度为60像素,另一边按比例调整
              transforms.RandomCrop(48),  # 裁剪图片,随机裁剪成高度和宽度均为48像素的部分
              transforms.RandomHorizontalFlip(),  # 随机水平翻转
              transforms.RandomRotation(10),  # 随机旋转
              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 对图像进行归一化处理。对每个通道执行了均值为0.5、标准差为0.5的归一化操作
          ])
          valdata_transfomer = transforms.Compose([
              transforms.ToTensor(),  # 转为Tensor格
              transforms.Resize(48),  # 调整图像大小,调整为高度或宽度为48像素,另一边按比例调整
              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
          ])
          # 加载数据
          traindataset = MyDataset(traintxt, traindata_transfomer)
          valdataset = MyDataset(valtxt, valdata_transfomer)
          print("测试集:" + str(traindataset.__len__()))
          print("训练集:" + str(valdataset.__len__()))
  • 单独运行结果:(只用于测试)

4、创建网络

  • 新建一个net.py文件。
    • 其中创建了一个简单的三层卷积神经网络。
    • # 三层卷积神经网络
      import torch
      
      
      # 卷积神经网络类
      class SimpleConv3(torch.nn.Module):  # 继承创建神经网络的基类
          def __init__(self, classes):
              super(SimpleConv3, self).__init__()
              # 卷积层
              self.conv1 = torch.nn.Conv2d(3, 16, 3, 2, 1)  # 输入通道3,输出通道16,3*3的卷积核,步长2,边缘填充1
              self.conv2 = torch.nn.Conv2d(16, 32, 3, 2, 1)  # 输入通道16,输出通道32,3*3的卷积核,步长2,边缘填充1
              self.conv3 = torch.nn.Conv2d(32, 64, 3, 2, 1)  # 输入通道32,输出通道64,3*3的卷积核,步长2,边缘填充1
              # 全连接层
              self.fc1 = torch.nn.Linear(2304, 100)
              self.fc2 = torch.nn.Linear(100, classes)
      
          def forward(self, x):
              # 第一次卷积
              x = torch.nn.functional.relu(self.conv1(x))  # relu为激活函数
              # 第二次卷积
              x = torch.nn.functional.relu(self.conv2(x))
              # 第三次卷积
              x = torch.nn.functional.relu(self.conv3(x))
              # 展开成一维向量
              x = x.view(x.size(0), -1)
              x = torch.nn.functional.relu(self.fc1(x))
              x = self.fc2(x)
              return x
      
      
      # 用于测试
      if __name__ == '__main__':
          inputs = torch.rand((1, 3, 48, 48))  # 生成一个随机的3通道、48x48大小的张量作为输入
          net = SimpleConv3(2)  # 二分类
          output = net(inputs)
          print(output)
  • 单独运行结果:(只用于测试)

5、训练模型

  • 新建一个train.py文件。
    • 其中可自行设置的参数都有标出。 
    • # 训练模型
      import matplotlib
      
      matplotlib.use('TkAgg')
      import matplotlib.pyplot as plt
      from dataset import MyDataset
      from net import SimpleConv3
      import torch
      import torchvision.transforms as transforms
      from torch.optim import SGD  # 优化相关
      from torch.optim.lr_scheduler import StepLR  # 优化相关
      from sklearn import preprocessing  # 处理label
      
      # 图片转换形式
      traindata_transfomer = transforms.Compose([
          transforms.ToTensor(),  # 转为Tensor格式
          transforms.Resize(60, antialias=True),  # 调整图像大小,调整为高度或宽度为60像素,另一边按比例调整,antialias=True启用了抗锯齿功能
          transforms.RandomCrop(48),  # 裁剪图片,随机裁剪成高度和宽度均为48像素的部分
          transforms.RandomHorizontalFlip(),  # 随机水平翻转
          transforms.RandomRotation(10),  # 随机旋转
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 对图像进行归一化处理。对每个通道执行了均值为0.5、标准差为0.5的归一化操作
      ])
      
      if __name__ == '__main__':
          traintxt = './hymenoptera_data/train.txt'
          valtxt = './hymenoptera_data/val.txt'
      
          # 加载数据
          traindataset = MyDataset(traintxt, traindata_transfomer)
      
          # 创建卷积神经网络
          net = SimpleConv3(2)  # 二分类
          # 使用GPU
          device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
          net.to(device)
          # 测试GPU是否能使用
          # print("The device is gpu later?:", next(net.parameters()).is_cuda)
          # print("The device is gpu,", next(net.parameters()).device)
      
          # 将数据提供给模型使用
          traindataloader = torch.utils.data.DataLoader(traindataset, batch_size=128, shuffle=True,
                                                        num_workers=1)  # batch_size可以自行调节
          # 优化器
          optim = SGD(net.parameters(), lr=0.1, momentum=0.9)  # 使用随机梯度下降(SGD)作为优化器,学习率0.1,动量0.9,加速梯度下降过程,lr可自行调节
          criterion = torch.nn.CrossEntropyLoss()  # 使用交叉熵损失作为损失函数
          lr_step = StepLR(optim, step_size=200, gamma=0.1)  # 学习率调度器,动态调整学习率,每200个epoch调整一次,每次调整缩小为原来的0.1倍,step_size可自行调节
          epochs = 5  # 训练次数
          accs = []
          losss = []
          # 训练循环
          for epoch in range(0, epochs):
              batch = 0
              running_acc = 0.0  # 精度
              running_loss = 0.0  # 损失
              for data in traindataloader:
                  batch += 1
                  imputs, labels = data
                  # 将标签从元组转换为tensor类型
                  labels = preprocessing.LabelEncoder().fit_transform(labels)
                  labels = torch.as_tensor(labels)
                  # 利用GPU训练模型
                  imputs = imputs.to(device)
                  labels = labels.to(device)
                  # 将数据输入至网络
                  output = net(imputs)
                  # 计算损失
                  loss = criterion(output, labels)
                  # 平均准确率
                  acc = float(torch.sum(labels == torch.argmax(output, 1))) / len(imputs)
                  # 累加损失和准确率,后面会除以batch
                  running_acc += acc
                  running_loss += loss.data.item()
      
                  optim.zero_grad()  # 清空梯度
                  loss.backward()  # 反向传播
                  optim.step()  # 更新参数
      
              lr_step.step()  # 更新优化器的学习率
              # 一次训练的精度和损失
              running_acc = running_acc / batch
              running_loss = running_loss / batch
              accs.append(running_acc)
              losss.append(running_loss)
              print('epoch=' + str(epoch) + ' loss=' + str(running_loss) + ' acc=' + str(running_acc))
      
          # 保存模型
          torch.save(net, 'model.pth')  # 保存模型的权重和结构
          x = torch.randn(1, 3, 48, 48).to(device)  # # 生成一个随机的3通道、48x48大小的张量作为输入,新建的张量也要送到GPU中
          net = torch.load('model.pth')  # 从保存的.pth文件中加载模型
          net.train(False)  # 设置模型为推理模式,意味着不会进行梯度计算或反向传播
          torch.onnx.export(net, x, 'model.onnx')  # 使用ONNX格式导出模型
          # 接受模型net、示例输入x和导出的文件名model.onnx作为参数
      
          # 可视化结果
          fig = plt.figure()
          plot1, = plt.plot(range(len(accs)), accs)  # 创建一个图形对象plot1,绘制accs列表中的数据
          plot2, = plt.plot(range(len(losss)), losss)  # 创建另一个图形对象plot2,绘制losss列表中的数据
          plt.ylabel('epoch')  # 设置y轴的标签为epoch
          plt.legend(handles=[plot1, plot2], labels=['acc', 'loss'])  # 创建图例,指定图表中不同曲线的标签
          plt.show()  # 展示所绘制的图表
  • 【注】本项目使用的是GPU训练模型。如果GPU可以获得,但是无法使用,可能是pytorch的版本不对,需要重新安装。
  • 运行结果:
  • 保存后的模型如下:

6、测试模型

6.1、测试整个模型准确率

  • 利用测试集,测试整个模型的准确率。
  • 新建一个test.py文件。
    • # 测试整个模型的准确率
      import torch
      import torchvision.transforms as transforms
      from dataset import MyDataset  # 您的数据集类
      from sklearn import preprocessing  # 处理label
      
      # 定义测试集的数据转换形式
      valdata_transfomer = transforms.Compose([
          transforms.ToTensor(),  # 转为Tensor格式
          transforms.Resize(60, antialias=True),  # 调整图像大小,调整为高度或宽度为60像素,另一边按比例调整,antialias=True启用了抗锯齿功能
          transforms.CenterCrop(48),  # 中心裁剪图片,裁剪成高度和宽度均为48像素的部分
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 对图像进行归一化处理。对每个通道执行了均值为0.5、标准差为0.5的归一化操作
      ])
      
      if __name__ == '__main__':
          valtxt = './hymenoptera_data/val.txt'  # 测试集数据路径
      
          # 加载测试集数据
          valdataset = MyDataset(valtxt, valdata_transfomer)
      
          # 加载已训练好的模型,利用GPU进行测试
          device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
          net = torch.load('model.pth').to(device)
          net.eval()  # 将模型设置为评估模式,意味着不会进行梯度计算或反向传播
      
          # 使用 DataLoader 加载测试集数据
          valdataloader = torch.utils.data.DataLoader(valdataset, batch_size=1, shuffle=False)
      
          correct = 0  # 被正确预测的样本数
          total = 0  # 测试样本数
      
          # 测试模型
          with torch.no_grad():
              for data in valdataloader:
                  images, labels = data
                  # 将标签从元组转换为tensor类型
                  labels = preprocessing.LabelEncoder().fit_transform(labels)
                  labels = torch.as_tensor(labels)
                  # 利用GPU训练模型
                  images, labels = images.to(device), labels.to(device)
                  outputs = net(images)  # 输入图像并获取模型预测结果
                  _, predicted = torch.max(outputs.data, 1)  # 获取预测值中最大概率的索引
                  total += labels.size(0)  # 累计测试样本数量
                  correct += (predicted == labels).sum().item()  # 计算正确预测的样本数量
      
          # 计算并输出模型在测试集上的准确率
          accuracy = 100 * correct / total
          print('Test Accuracy: {:.2f}%'.format(accuracy))
  • 运行结果:
    • 因为训练模型时只迭代了200次,所以准确率并不高。可以尝试提高训练次数,提高准确率。 

6.2、测试单张图片

  • 使用训练后的模型,对单张图片进行预测。
  • 新建一个testone.py文件。
    • import torch
      from PIL import Image
      import torchvision.transforms as transforms
      
      # 定义图片预处理转换
      image_transforms = transforms.Compose([
          transforms.Resize(60, antialias=True),  # 调整图像大小
          transforms.CenterCrop(48),  # 中心裁剪
          transforms.ToTensor(),  # 转为Tensor格式
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化处理
      ])
      
      # 定义类别映射字典
      class_mapping = {
          0: "ant",
          1: "bee"
      }
      
      # 加载已训练好的模型,利用GPU测试
      device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
      net = torch.load('model.pth').to(device)
      net.eval()  # 将模型设置为评估模式,意味着不会进行梯度计算或反向传播
      
      # 加载要测试的图片
      image_path = './hymenoptera_data/val/bees/26589803_5ba7000313.jpg'  # 图片路径
      input_image = Image.open(image_path)  # 加载图片
      input_tensor = image_transforms(input_image).unsqueeze(0)  # 对图片进行预处理转换,并增加 batch 维度
      
      # 将输入数据移动到GPU上
      input_tensor = input_tensor.to(device)
      
      # 使用模型进行预测
      with torch.no_grad():
          output = net(input_tensor)
          _, predicted = torch.max(output, 1)  # 在张量中沿指定维度找到最大值及其对应的索引
      
      # 输出预测结果
      predicted_class = predicted.item()  # 得到预测的标签
      predicted_label = class_mapping[predicted_class]  # 将标签转换为文字
      print(f"The predicted class for the image is: {predicted_label}")
  • 运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1271704.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

单片机学习12——电容

电容的作用: 1)降压作用: 容抗: Xc 1/2fc 串联分压原理。2100Ω的容量,50Hz的频率,可以得到1.5uF。断电之后,需要串联一个1MΩ的电阻放电。 那是不是可以使用2100欧姆的电阻来代替电容呢&am…

单宁对葡萄酒可饮用性和陈酿潜力会有影响吗?

当在酿酒过程中葡萄酒中的单宁过量时,酿酒师可以使用白蛋白、酪蛋白和明胶等各种细化剂,这些药物可以与单宁分子结合,并将其作为沉淀物沉淀出来。随着葡萄酒的老化,单宁将形成长长的聚合链,氧气可以与单宁分子结合&…

安全技术与防火墙

目录 安全技术 防火墙 按保护范围划分: 按实现方式划分: 按网络协议划分. 数据包 四表五链 规则链 默认包括5种规则链 规则表 默认包括4个规则表 四表 查询 格式: 规则 面试题 NFS常见故障解决方法 安全技术 入侵检测系统 (Intrusion Detection Sy…

高并发架构——网页爬虫设计:如何下载千亿级网页?

Java全能学习面试指南:https://javaxiaobear.cn 在互联网早期,网络爬虫仅仅应用在搜索引擎中。随着大数据时代的到来,数据存储和计算越来越廉价和高效,越来越多的企业开始利用网络爬虫来获取外部数据。例如:获取政府公…

【23真题】快跑,考太偏了这所211!

今天分享的是23年湖南师范997的信号与系统试题及解析。 小马哥Tips: 本套试卷难度分析:22年湖南师范997考研真题,我也发布过,若有需要,戳这里自取!本套试题难度中等,题量适中,但是…

百度推送收录工具-免费的各大搜索引擎推送工具

在互联网时代,网站收录是网站建设的重要一环。百度推送工具作为一种提高网站收录速度的方式备受关注。在这个信息爆炸的时代,对于网站管理员和站长们来说,了解并使用一些百度推送工具是非常重要的。本文将重点分享百度批量域名推送工具和百度…

四、shell - 字符串

目录 1、单引号 2、双引号 3、拼接字符串 3.1 使用双引号拼接 3.2 使用单引号拼接 4、获取字符串长度 ​​​​​​​5、提取子字符串 ​​​​​​​6、查找子字符串 ​​​​​​​字符串是shell编程中最常用最有用的数据类型(除了数字和字符串&#xff0…

Flyway 数据库版本管理 | 专业解决方案

前言 目前很多公司都是通过人工去维护、同步数据库脚本,但经常会遇到疏忽而遗漏的情况,同时也是非常费力耗时 比如说我们在开发环境对某个表新增了一个字段,而提交测试时却忘了提交该 SQL 脚本,导致出现 bug 而测试中断&#xf…

fiddler弱网测试实践

准备工作 1、fiddler安装包 2、一部安卓手机 一、fiddler安装 安装fiddler到电脑上,傻瓜式安装即可 二、fiddler环境配置 三、手机端环境配置 1、获取电脑的IP地址:WindowsR,输入cmd弹出命令窗口,输入命令ipconfig 或者鼠标…

百度地图JavaScript API GL获取经纬度,标记,添加文本标注,点击事件,封装

百度地图JavaScript API GL常用方法封装 引入百度js库 <script type"text/javascript" src"https://api.map.baidu.com/api?v1.0&typewebgl&ak自己的百度应用ak"></script>封装方法 <template><div class"map"&…

【性能测试】性能测试监控关键指标

系统指标 检测性能测试是否有bug的关键指标 1、系统指标——与用户场景及需求直接相关。 并发用户数&#xff1a;某一物理时刻同时向系统提交请求的用户数。平均响应时间&#xff1a;系统处理事务的响应时间的平均值&#xff0c;对于系统快速响应类页面&#xff0c;一般响应…

Yolov8实现瓶盖正反面检测

一、模型介绍 模型基于 yolov8n数据集采用SKU-110k&#xff0c;这数据集太大了十几个 G&#xff0c;所以只训练了 10 轮左右就拿来微调了 基于原木数据微调&#xff1a;训练 200 轮的效果 10 轮SKU-110k 20 轮原木 200 轮瓶盖正反面 微调模型下载地址https://wwxd.lanzouu.co…

北斗卫星助力乡村治理,走进数字化新时代

北斗卫星助力乡村治理&#xff0c;走进数字化新时代 随着国家对乡村治理越来越重视&#xff0c;为了进一步提升乡村治理水平&#xff0c;我国已经启动了全面建设现代化强国的大计划&#xff0c;其中数字化成为了重要的一环。而北斗卫星作为我国自主研制的卫星导航系统&#xff…

【漏洞复现】通达OA inc/package/down.php接口存在未授权访问漏洞 附POC

漏洞描述 通达OA(Office Anywhere网络智能办公系统)是由通达信科科技自主研发的协同办公自动化软件,是与中国企业管理实践相结合形成的综合管理办公平台。通达OA为各行业不同规模的众多用户提供信息化管理能力,包括流程审批、行政办公、日常事务、数据统计分析、即时通讯、…

西班牙Wallapop是什么?原来欧洲版闲鱼也很好用!

说到国内的闲鱼大家肯定不陌生&#xff0c;那国外的二手闲置平台大家知道吗&#xff1f;在西班牙&#xff0c;最受欢迎的移动购物APP是Wallapop和速卖通。Wallapop是西班牙第一大二手商品网站&#xff0c;网站上丰富的性价比高的商品正好满足了西班牙人的需求。今天龙哥就和大家…

VirtualBox上安装CentOS7

基础环境&#xff1a;宿主机是64位Windows10操作系统&#xff0c;通过无线网访问网络。 macOS可以以类似方式进行安装&#xff0c;不同之处见最后补充。 Step1 安装VirtualBox VirtualBox是一款免费、开源、高性能的虚拟机软件&#xff0c;可以跨平台运行&#xff0c;支持Wi…

VMware Linux(Centos)虚拟机扩容根目录磁盘空间

给VMWare虚拟机根目录扩容&#xff0c;简单有效&#xff01;_迷倒万千少女的Csir的博客-CSDN博客 https://blog.csdn.net/m0_64206944/article/details/131453844?spm1001.2014.3001.5506 上述链接融合参考下面文章 VMware Linux(Centos)虚拟机扩容根目录磁盘空间 centosli…

开启新零售时代,引领消费革命

开启新零售时代&#xff0c;引领消费革命 新零售的魅力在于它将线上线下融合&#xff0c;打破了传统零售的界限。以往&#xff0c;消费者需要亲自前往实体店面购物&#xff0c;但如今他们可以通过电子商务平台随时随地进行购物。这种便捷的消费方式不仅节省了时间和精力&#x…

‘tsc‘ 不是内部或外部命令,也不是可运行的程序 或批处理文件。

最近在用nodejs typescript 某游戏服务器在做一些研究 nodejs-tcs 问题描述&#xff1a; 1.使用命令npm install -g typescript安装typescript后&#xff0c;输入 tsc命令&#xff0c;一直报错 tsc 不是内部或外部命令&#xff0c;也不是可运行的程序 或批处理文件。 2.目…