视频与AI,与进程交互(二) pytorch 极简训练自己的数据集并识别

news2024/9/21 16:44:17

目标学习任务

检测出已经分割出的图像的分类

2 使用pytorch

pytorch 非常简单就可以做到训练和加载

2.1 准备数据

在这里插入图片描述
如上图所示,用来训练的文件放在了train中,验证的文件放在val中,train.txt 和 val.txt 分别放文件名称和分类类别,然后我们在代码中写名字就行

里面我就为了做一个例子,放了两种文件,1 是 卡宴保时捷,2 是工程车,如下图所示
在这里插入图片描述
train.txt 如下图所示
在这里插入图片描述
val.txt 也是同样如此

3 show me the code

3.1 装载数据类

新增一个loaddata.py 文件

import torch
import random
from PIL import Image
class LoadData(torch.utils.data.Dataset):
    def __init__(self, root, datatxt, transform=None, target_transform=None):
        super(LoadData, self).__init__()
        file_txt = open(datatxt,'r')
        imgs = []
        for line in file_txt:
            line = line.rstrip()
            words = line.split('|')
            imgs.append((words[0], words[1]))

        self.imgs = imgs
        self.root = root
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        random.shuffle(self.imgs)
        name, label = self.imgs[index]
        img = Image.open(self.root + name).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        label = int(label)
        return img, label

    def __len__(self):
        return len(self.imgs)

LoadData 类是从torch.util.data.Dataset上继承下来的,需要一个transform类输入,实际上就是转化大小

3.2 网络类

定义一个网络类,只有两个输出

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.pool = nn.MaxPool2d((2, 2))
        self.pool1 = nn.MaxPool2d((2, 2))
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.fc1 = nn.Linear(36*36*32, 120)
        self.fc2 = nn.Linear(120, 60)
        self.fc3 = nn.Linear(60, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool1(F.relu(self.conv2(x)))
        x = x.view(-1, 36*36*32)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

3.3 主要流程

import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim
from loaddata import LoadData
from modelnet import Net

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


classes = ['工程车','卡宴']
transform = transforms.Compose(
   [transforms.Resize((152, 152)),transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_data=LoadData(root ='./data/train/',
                 datatxt='./data/'+'train.txt',
                 transform=transform)
test_data=LoadData(root ='./data/val/',
                datatxt='./data/'+'val.txt',
                transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=2, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=2)

def imshow(img):
   img = img / 2 + 0.5     # unnormalize
   npimg = img.numpy()
   plt.imshow(np.transpose(npimg, (1, 2, 0)))
   plt.show()


net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(10):
   running_loss = 0.0
   for i, data in enumerate(train_loader, 0):
       inputs, labels = data
       optimizer.zero_grad()
       outputs = net(inputs)
       loss = criterion(outputs, labels)
       loss.backward()
       optimizer.step()

       running_loss += loss.item()
       if i % 200 == 0:
           print('[%d, %5d] loss: %.3f' %
                 (epoch + 1, i + 1, running_loss / 200))
           running_loss = 0.0

print('Finished Training')

PATH = './test.pth'
torch.save(net.state_dict(), PATH)

net = Net()
net.load_state_dict(torch.load(PATH))

correct = 0
total = 0
with torch.no_grad():
   for data in test_loader:
       images, labels = data
       outputs = net(images)
       _, predicted = torch.max(outputs.data, 1)
       total += labels.size(0)
       correct += (predicted == labels).sum().item()

print('Accuracy of the network on the test images: %d %%' % (
   100 * correct / total))

在这里插入图片描述
如上图所示,epoch为5时精确度为80%,为10时精确度为100%,各位不要当真,这这是训练集里面的数据集做识别,并不是真的精确度。

3.4 识别代码

import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from modelnet import Net

PATH = './test.pth'
transform = transforms.Compose(
    [transforms.Resize((152, 152)),transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])



net = Net()
net.load_state_dict(torch.load(PATH))

img = Image.open("./data/val/102.jpg").convert('RGB')
img = transform(img)
with torch.no_grad():
    outputs = net(img)
    _, predicted = torch.max(outputs.data, 1)
    print("the 102 img lable is ",predicted)

如下图所示,102 为卡宴识别为1 正确
在这里插入图片描述

后记

后面我们准备是从视频中传递过来图像进行分类,同时使用我们的工具VT解码视频后进行内存共享来生成图像,而不是从磁盘加载。要用到我们的c++ 解码工具,和pytorch进行交互
以下是第一篇文章:视频与AI,与进程交互(一)
VT 工具准备开源,端午节节后开出来

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/675416.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android之 弹框总结

一 简介 1.1 弹框即浮与页面之上的窗口,如键盘弹框,吐司弹框,确认弹框,下拉选择框,应用悬浮框等 1.2 弹框控件也很多,比如常用的Spinner,Dialog,Toast,PopWindow等&…

小主机折腾记14

1.m72e主机,3240t-2390t-3470t测试; 2390t官方参数 在m72e上 全核3.08Ghz 单核3.28-3.31Ghz 核显2帧 评分 3470t官方参数 在m72e上 全核睿频3.28 单核最高3.44 核显1.2帧???还不如那啥HD2000 最后评分 进入…

chatgpt赋能python:Python求累加的方法及其应用

Python求累加的方法及其应用 在Python编程中,经常需要对一系列数字进行求和或累加的操作。那么在Python中,我们可以通过哪些方法来实现这个功能呢?本文将为大家介绍Python求累加的方法及其应用。 1. Python中的for循环 首先,我…

05-事件循环

事件循环 以下知识点都涉及到事件循环 计时器,promise,ajax,node 明白此知识点,是前端的分水岭,可以提高效率,js中奇怪的东西都可以得到解决,整个过程是根据W3C和谷歌源码进行 浏览器的进程…

一文理解cast转换

目录 写在前边 1. what?又报错: 2. 靠,难道是这样? 3. 小试牛刀 4. 实际中的“坑” 写在后边 写在前边 关于$cast转换的结论无外乎以下四条: 如果将子类句柄复制给父类句柄,可以实现父类句柄的向下转换…

翻筋斗觅食策略改进灰狼算法

目录 一、动态扰动因子策略 二、翻筋斗觅食策略 三、改进灰狼算法收敛曲线图 灰狼优化算法(grey wolf optimization,GWO)存在收敛的不合理性等缺陷,目前对GWO算法的收敛性改进方式较少,除此之外,当GWO迭代至后期,所有灰狼个体…

企业版:Select.PDF Library for .NET

HTML 到 PDF API SelectPdf提供了一个REST API,可用于通过我们的专用云服务将html转换为任何语言的pdf。 另存为 PDF 链接 以非常简单的方式将“转换为PDF”功能添加到您的网站或博客。只需添加一个指向您的网页的链接,您就完成了。 适用于 .NET 的 PD…

Redis 通用命令

通用命令介绍 Redis 通用命令是一些 Redis 下可以作用在常用数据结构上的常用命令和一些基础的命令,比如删除键、对键进行改名、判断键是否存在等。简单说,就是 keys 分类的命令,如下图。 上图中圈中的部分,就是所谓的通用的命令…

chatgpt赋能python:Python对于SEO的重要性:浏览网页的技术分析

Python对于SEO的重要性:浏览网页的技术分析 越来越多的网站需要搜索引擎优化(SEO),以便他们的网站上的内容能够被更多人浏览与访问。这就要求我们使用一些工具和技术,例如Python,来帮助我们分析网页的技术…

通过调整图像hue值并结合ImageEnhance库以实现色调增强

前言 PIL库中的ImageEnhance类可用于图像增强,可以调节图像的亮度、对比度、色度和锐度。 通过RGB到HSV的变换加调整可以对图像的色调进行调整。 两种方法结合可以达到更大程度的图像色调增强。 调整hue值 __author__ TracelessLe __website__ https://blog…

linux 下查看 USB 设备

文章目录 前言目录内容详解usb11-0:1.01-1.1:1.0 结构图设备信息bDeviceClassversionbusnum & devnumdevbMaxPoweridVendor & idProductproductmanufacturerbcdDevicespeedueventbmAttributesdrivers_autoprobe 前言 在 sysfs 文件系统下,查看 USB 设备&am…

PaddleOCR #使用PaddleOCR进行光学字符识别(PP-OCR文本检测识别)

引言: PaddleOCR 是一个 OCR 框架或工具包,它提供多语言实用的 OCR 工具,帮助用户在几行代码中应用和训练不同的模型。PaddleOCR 提供了一系列高质量的预训练模型。这包含三种类型的模型,使 OCR 高度准确并接近商业产品。它提供文…

【Unity 2D AABB碰撞检测】铸梦之路

作者介绍:铸梦xy。IT公司技术合伙人,IT高级讲师,资深Unity架构师,铸梦之路系列课程创始人。 目录1.AABB 碰撞介绍2.常用2D碰撞盒3.为什么要学习如何编写碰撞检测4.2D BOX & BOX 碰撞检测原理和代码5.2D BOX &Shpere 碰撞检…

Linux信号编程、signal函数范例详解( 4 ) -【Linux通信架构系列 】

系列文章目录 C技能系列 Linux通信架构系列 C高性能优化编程系列 深入理解软件架构设计系列 高级C并发线程编程 期待你的关注哦!!! 现在的一切都是为将来的梦想编织翅膀,让梦想在现实中展翅高飞。 Now everything is for the…

chatgpt赋能python:Python求1是什么?Python求1在SEO中的应用

Python求1是什么?Python求1在SEO中的应用 介绍Python求1 Python求1,也叫做1-bit计数器,是一种用来统计网页浏览量的技术。在Web开发中,我们需要记录网页的浏览次数,以便了解网站的流量和用户的使用情况。传统的做法是…

chatgpt赋能python:Python求绝对值:从初学者到高级工程师的必备知识

Python求绝对值:从初学者到高级工程师的必备知识 Python是一种有趣且功能强大的编程语言。它非常易于学习,同时又具有广泛的应用领域,比如Web开发、数据分析、机器学习和人工智能等。在Python的数学运算中,求绝对值是一个常见的需…

chatgpt赋能python:Python浮点型的两种表示方法

Python浮点型的两种表示方法 Python是一种解释型的动态语言,可以处理多种数据类型。其中,浮点型是其中一种数据类型,它包括十进制和科学计数法两种表示方法。 十进制表示法 十进制浮点数是Python的基本浮点类型,可以表示实数。…

2023 hnust 湖南科技大学 大数据技术与应用 期末考试 复习资料

前言 感谢:lqx(主要内容来源),hqh 有自己的理解和魔改 可以参考的资料 课后题答案我爬取的老师布置的学习通课后题往年资料csdn里面找到的:1、2老师ppt上课划重点录音 不提供pdf文件,方便修改&#xff0…

探索技术极致,未来因你出‘粽’

🌷🍁 博主 libin9iOak带您 Go to New World.✨🍁 🦄 个人主页——libin9iOak的博客🎐 🐳 《面试题大全》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~&#x1f33…

chatgpt赋能python:Python游戏——为什么它成为未来最热门的游戏开发工具

Python游戏——为什么它成为未来最热门的游戏开发工具 在游戏开发中,Python一直是非常强大和受欢迎的语言。Python具有很多吸引人的特点和实用功能,它为游戏开发者提供了多种可能,我们在本文中将介绍Python游戏以及为什么它越来越受欢迎。 …