RK3568笔记三:基于ResNet18的Cifar-10分类识别训练部署

news2025/1/19 23:21:09

若该文为原创文章,转载请注明原文出处。

本篇文章参考的是野火-lubancat的rk3568教程,本篇记录了在正点原子的ATK-DLK3568部署。

一、介绍

        ResNet18 是一种卷积神经网络,它有 18 层深度,其中包括带有权重的卷积层和全连接层。它是ResNet 系列网络的一个变体,使用了残差连接(residual connection)来解决深度网络的退化问题。 

       ResNet(Residual Neural Network)由微软研究院的 Kaiming He 等人在 2015 年提出,ResNet 的结 构可以极快的加速神经网络的训练,模型的准确率也有比较大的提升。 ResNet 是一种残差网络,可以把它理解为一个子网络,这个子网络经过堆叠可以构成一个很深的 网络。ResNet 系列有多种变体,如 ResNet18,ResNet34,ResNet50,ResNet101 和 ResNet152,其 网络结构如下:

这里我们主要看下 ResNet18,ResNet18 基本含义是网络的基本架构是 ResNet,网络的深度是 18层,是带有权重的 18 层,不包括 BN 层,池化层。ResNet18 使用的基本残差单元,每个单元由两 个 3x3 卷积层组成,中间有一个 BN 层和一个 ReLU 激活函数。

PyTorch 中的 ResNet18 源码实现:https://github.com/pytorch/vision/blob/main/torchvision/models/ resnet.py

二、环境安装

环境分为两部分:一是训练的环境;二是rknn环境;

rknn环境前面有介绍,自行安装;训练的环境是windows电脑无gpu,使用的是CPU

1、创建虚拟环境

conda create -n ResNet18_env python=3.8 -y

2、激活环境

conda activate ResNet18_env

3、安装环境

pip install torchvision
pip install onnxruntime

三、训练

        自定义一个 ResNet18 网络结构,并使用 CIFAR-10 数据集进行简单测试。CIFAR-10 数据集由 10 个类别的 60000 张 32x32 彩色图像组成,每个类别有 6000 张图像,总共分为 50000 张训练图像和 10000 张测试图像。

resnet18.py

import os
import torchvision
import torch
import torch.nn as nn

#device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device=torch.device("cpu")

# Transform configuration and data augmentation
transform_train=torchvision.transforms.Compose([
        torchvision.transforms.Pad(4),
        torchvision.transforms.RandomHorizontalFlip(), #图像一半的概率翻转,一半的概率不翻转
        torchvision.transforms.RandomCrop(32), #图像随机裁剪成32*32
        # torchvision.transforms.RandomVerticalFlip(),
        # torchvision.transforms.RandomRotation(15),
        torchvision.transforms.ToTensor(), #转为Tensor ,归一化
        torchvision.transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
        #torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010))
        ])
transform_test=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
        #torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
# epoch时才对数据集进行以上数据增强操作

num_classes=10
batch_size=128
learning_rate=0.001
num_epoches=100
classes = ("plane","car","bird","cat","deer","dog","frog","horse","ship","truck")

# load downloaded dataset
train_dataset = torchvision.datasets.CIFAR10('./data', download=True, train=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10('./data', download=True, train=False, transform=transform_test)

# Data loader 
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define 3*3 convolutional neural network
def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
    def forward(self, x):
        residual=x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if(self.downsample):
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out
 
# 自定义一个神经网络,使用nn.model,,通过__init__初始化每一层神经网络。
# 使用forward连接数据
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes):
        super(ResNet, self).__init__()
        self.in_channels = 16
        self.conv = conv3x3(3, 16)
        self.bn = torch.nn.BatchNorm2d(16)
        self.relu = torch.nn.ReLU(inplace=True)
        self.layer1 = self._make_layers(block, 16, layers[0])
        self.layer2 = self._make_layers(block, 32, layers[1], 2)
        self.layer3 = self._make_layers(block, 64, layers[2], 2)
        self.layer4 = self._make_layers(block, 128, layers[3], 2)
        self.avg_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.fc = torch.nn.Linear(128, num_classes)
        
    def _make_layers(self, block, out_channels, blocks, stride=1):
        downsample = None
        if (stride != 1) or (self.in_channels != out_channels):
            downsample = torch.nn.Sequential(
                conv3x3(self.in_channels, out_channels, stride=stride),
                torch.nn.BatchNorm2d(out_channels))
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for i in range(1, blocks):
            layers.append(block(out_channels, out_channels))
        return torch.nn.Sequential(*layers)
    
    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

# Make model,使用cpu
model=ResNet(ResidualBlock, [2,2,2,2], num_classes).to(device=device)

# 打印model结构
# print(f"Model structure: {model}\n\n")

# 优化器和损失函数
criterion = nn.CrossEntropyLoss()  #交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) #优化器随机梯度下降

if __name__ == "__main__":
    # Train the model
    total_step = len(train_loader)
    for epoch in range(0,num_epoches):
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device=device)
            labels = labels.to(device=device)
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            # Backward and optimize
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 更新参数
            optimizer.step()
            #sum_loss += loss.item()
            #_, predicted = torch.max(outputs.data, dim=1)
            #total += labels.size(0)
            #correct += predicted.eq(labels.data).cpu().sum()
            if (i+1) % total_step == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                      .format(epoch+1, num_epoches, i+1, total_step, loss.item()))
    print("Finished Tranining")
    # 保存权重文件
    #torch.save(model.state_dict(), 'model_weights.pth')
    #torch.save(model, 'model.pt')

    print('\nTest the model')
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device=device)
            labels = labels.to(device=device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('在10000张测试集图片上的准确率:{:.4f} %'.format(100 * correct / total))

    # 导出onnx模型
    x = torch.randn((1, 3, 32, 32))
    torch.onnx.export(model, x, './resnet18.onnx', opset_version=12, input_names=['input'], output_names=['output'])

这里有个要注意的,数据集已经提前下载好了,所以没有在线下载

数据集下载是通过下面代码,数据集放在data目录下:

# load downloaded dataset
train_dataset = torchvision.datasets.CIFAR10('./data', download=True, train=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10('./data', download=True, train=False, transform=transform_test)

类型分类为10类

classes = ("plane","car","bird","cat","deer","dog","frog","horse","ship","truck")

等待大概1小时,训练结束后会在当前目录下生成resnet18.onnx模型

四、测试onnx模型

测试代码如下:

test_resnet18_onnx.py

import os, sys

sys.path.append(os.getcwd())
import onnxruntime
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# 自定义的数据增强
def get_test_transform(): 
    return transforms.Compose([
        transforms.Resize([32, 32]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

# 推理的图片路径
image = Image.open('./horse5.jpg').convert('RGB')

img = get_test_transform()(image)
img = img.unsqueeze_(0)  # -> NCHW, 1,3,224,224
# 模型加载
onnx_model_path = "resnet18.onnx"
resnet_session = onnxruntime.InferenceSession(onnx_model_path)
inputs = {resnet_session.get_inputs()[0].name: to_numpy(img)}
outs = resnet_session.run(None, inputs)[0]

print("onnx weights", outs)
print("onnx prediction", outs.argmax(axis=1)[0])

测试预测结果7,对应的是马。

五、RKNN模型转换

打开ATK搭建好的虚拟机,进入环境rknn2_env,RKNN环境要确保安装好。

转换代码在rknn-toolkit2目录下的example的pytorch里也有,参考代码如下:

rknn_transfer.py

import numpy as np
import cv2
from rknn.api import RKNN
import torchvision.models as models
import torch
import os

def softmax(x):
    return np.exp(x)/sum(np.exp(x))

def torch_version():
    import torch
    torch_ver = torch.__version__.split('.')
    torch_ver[2] = torch_ver[2].split('+')[0]
    return [int(v) for v in torch_ver]

if __name__ == '__main__':

    if torch_version() < [1, 9, 0]:
        import torch
        print("Your torch version is '{}', in order to better support the Quantization Aware Training (QAT) model,\n"
              "Please update the torch version to '1.9.0' or higher!".format(torch.__version__))
        exit(0)

    MODEL = './resnet18.onnx'

    # Create RKNN object
    rknn = RKNN(verbose=True)

    # Pre-process config
    print('--> Config model')
    rknn.config(mean_values=[127.5, 127.5, 127.5], std_values=[255, 255, 255], target_platform='rk3568')
    #rknn.config(mean_values=[123.675, 116.28, 103.53], std_values=[58.395, 58.395, 58.395], target_platform='rk3568')
    #rknn.config(mean_values=[125.307, 122.961, 113.8575], std_values=[51.5865, 50.847, 51.255], target_platform='rk3568')
    print('done')

    # Load model
    print('--> Loading model')
    #ret = rknn.load_pytorch(model=model, input_size_list=input_size_list)
    ret = rknn.load_onnx(model=MODEL)
    if ret != 0:
        print('Load model failed!')
        exit(ret)
    print('done')

    # Build model
    print('--> Building model')
    ret = rknn.build(do_quantization=False)
    if ret != 0:
        print('Build model failed!')
        exit(ret)
    print('done')

    # Export rknn model
    print('--> Export rknn model')
    ret = rknn.export_rknn('./resnet_18_100.rknn')
    if ret != 0:
        print('Export rknn model failed!')
        exit(ret)
    print('done')

    # Set inputs
    img = cv2.imread('./0_125.jpg')
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img,(32,32))
    #img = np.expand_dims(img, 0)
    
    # Init runtime environment
    print('--> Init runtime environment')
    ret = rknn.init_runtime()
    if ret != 0:
        print('Init runtime environment failed!')
        exit(ret)
    print('done')

    # Inference
    print('--> Running model')
    outputs = rknn.inference(inputs=[img])
    np.save('./pytorch_resnet18_qat_0.npy', outputs[0])
    #show_outputs(softmax(np.array(outputs[0][0])))
    print(outputs)
    print('done')

    rknn.release()

运行python rknn_transfer.py,正常生成rknn文件。

会在当前目录下生成rknn模型

六、部署

导出rknn后,把rknn和测试图片通过adb上传到开发板。

rknnlite_inference0.py

import numpy as np
import cv2
import os
from rknnlite.api import RKNNLite

IMG_PATH = '2_67.jpg'
RKNN_MODEL = './resnet18.rknn'
img_height = 32
img_width = 32
class_names = ["plane","car","bird","cat","deer","dog","frog","horse","ship","truck"]

# Create RKNN object
rknn_lite = RKNNLite()

# load RKNN model
print('--> Load RKNN model')
ret = rknn_lite.load_rknn(RKNN_MODEL)
if ret != 0:
    print('Load RKNN model failed')
    exit(ret)
print('done')

# Init runtime environment
print('--> Init runtime environment')
ret = rknn_lite.init_runtime()
if ret != 0:
    print('Init runtime environment failed!')
    exit(ret)
print('done')

# load image
img = cv2.imread(IMG_PATH)
img = cv2.resize(img,(32,32))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.expand_dims(img, 0)


# runing model
print('--> Running model')
outputs = rknn_lite.inference(inputs=[img])
print("result: ", outputs)
print(
    "This image most likely belongs to {}."
    .format(class_names[np.argmax(outputs)])
)

rknn_lite.release()

如有侵权,或需要完整代码,请及时联系博主。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1114271.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Mysql架构解析,InnoDB架构概述。

MySQL架构解析 Mysql整体架构 MySQL整体架构如下图所示&#xff1a; MySQL逻辑系统架构分为4层: 应用层MySQL服务层存储引擎层系统文件层 下面将对各层的功能和组件进行介绍&#xff0c;并探讨一条语句的执行过程。 应用层 应用层是MySQL体系架构的最上层&#xff0c;它…

Docker——【部署项目的最优解】使用DockerCompose部署项目

目录 前言 1、安装docker-compose 2、为什么使用docker-compose&#xff1f; 3、如何使用DockerCompose 3.1、创建docker-compose文件 3.2、docker-compose相关命令&#xff1a; 前言 对Docker常规操作部署项目不了解的伙伴&#xff0c;可以先看看这篇文章&#xff1a;h…

多商户进驻小程序商城的作用是什么

多商户进驻商城简单来说就是在一个商城里&#xff0c;由经营者邀请同行、异业商家进驻到商城里&#xff08;子商户&#xff09;&#xff0c;可丰富商城经营业态&#xff0c;满足客户多方购物需求&#xff0c;打造购物商圈及经营者获得更多收益等。 通过【雨科】平台的多商户进驻…

Anaconda安装第三方库

一定要使用国内镜像源来进行下载&#xff0c;否则会非常慢&#xff01; 有兴趣的可以看看下面的文章^~^ 新版PyCharm安装第三方库更换国内下载镜像地址 OK!安装完成&#xff01;

Oracle数据中如何在 where in() 条件传参

一、问题场景描述 在sql 条件中&#xff0c;如何在 where in()中想传入参数&#xff0c;如果直接 where in(:seqList)&#xff0c;当传入单个值&#xff0c;seqList: ‘80’ 是没问题的&#xff0c;但是初入多个值时&#xff0c;seqList: ‘80,90’ &#xff0c;因缺少单引号&…

Windows重启开机在不登录系统情况下自启指定程序

问题前言&#xff1a; 项目开发完成后需要部署上线&#xff0c;首次肯定是手动部署跑项目&#xff0c;后期如果出现断电或其他原因导致服务器关机需要重启服务器的情况&#xff0c;这个时候再远程过去手动跑项目是很鸡肋的&#xff0c;通常会设置程序开机自启动&#xff0c;这…

eNSP-OSPF协议其他区域不与骨干区域相连解决方法3

virtual-link技术 AR1 [ar1]int g0/0/0 [ar1-GigabitEthernet0/0/0]ip add 192.168.1.1 24 [ar1-GigabitEthernet0/0/0]quit [ar1]ospf [ar1-ospf-1]area 0 [ar1-ospf-1-area-0.0.0.0]net 192.168.1.0 0.0.0.255 [ar1-ospf-1-area-0.0.0.0]quit AR2 [ar2]int g0/0/0 [ar2-Gig…

C语言的输入输出和条件判断

目录 数据类型、运算符与表达式 1.数据类型 基本数据类型包括 取值范围 2.常量和变量 常量 变量 定义变量 变量的分类 为什么要用变量 3.输入与输出 格式输出函数printf&#xff08;&#xff09; 打印时的输出类型 格式输入函数scanf&#xff08;&#xff09; 4…

C++设计模式_10_ Prototype 原型模式(小模式,不太常用)

Prototype 原型模式仍然属于“对象创建模式”模式的一种。前面两篇介绍的工厂方法模式和抽象工厂模式的流行程度要远大于Prototype 原型模式和builder构建器模式&#xff0c;后两种由于较为简单&#xff0c;介绍篇幅也会少一些。 文章目录 1. 动机 (Motivation)2. 代码演示Prot…

解决方案 | 法大大电子签助力融资租赁突围数字化

融资租赁作为我国非银金融市场的重要组成部分&#xff0c;具有融资和融物两方面功能&#xff0c;不仅能够拓宽市场主体的融资渠道&#xff0c;而且也是促进先进制造业、战略性新兴产业、绿色产业等领域高质量发展的重要助力。 2023年以来&#xff0c;多地相继出台了一系列鼓励…

众和策略:股票中总量和现量是什么意思?

股票商场是出资者最常用的一种出资办法之一&#xff0c;股票的价格动摇与供需联系有很大的联系。而供需联系中&#xff0c;总量和现量被广泛关注&#xff0c;它们别离指的是某一时期内的股票发行总量和现有交易量。在本文中&#xff0c;咱们将从多个角度分析股票中总量和现量的…

[每周一更]-(第68期):Excel常用函数及常用操作

日常工作&#xff0c;偶尔也会存在excel表格入库的情况&#xff0c;针对复杂的入库情况&#xff0c;一般都是代码编号&#xff0c;读文件-写db形式&#xff1b;但是有些简单就直接操作&#xff0c;但是 这些简单的入库不仅仅是直接入库&#xff0c;而是内容中有部分需要进行映射…

接口测试(jmeter和postman 接口使用)

接口测试基础知识 接口测试主要用于检测外部系统与系统之间以及内部各个子系统之间的交互点。把前端&#xff08;client&#xff09;和后端&#xff08;server&#xff09;联系起来&#xff0c;测试的重点是要检查数据的交换&#xff0c;传递和控制管理过程&#xff0c;以及系…

Keeplived安装部署(单机双机)

Keeplived官网&#xff1a;https://www.keepalived.org/download.html 一 单机安装配置: 1.上传keepalived安装包并且安装 [rootmaster1 local]# tar -zxvf keepalived-2.2.8.tar.gz [rootmaster1 local]# mv keepalived-2.2.8 keepalived [rootmaster1 local]# chown root:r…

docker安装es分词插件ik详情步骤

1.下载ik查询 根据es版本去下载对应的版本&#xff0c;游览器中输入下面下载链接 https://github.com/medcl/elasticsearch-analysis-ik/releases/download/v8.7.1/elasticsearch-analysis-ik-8.7.1.zip 2.2.若有对应版本跳过&#xff0c;若没有对应版本&#xff08;比如我需…

将语义分割的标注mask转为目标检测的bbox

1. 语义分割标签 1.1 labelme工具 语义分割的标签是利用labelme工具进行标注的,标注的样式如下: 1.2 语义分割的标签样式 2. 转换语义分割的标注到目标检测的bbox 实现步骤 (1) 利用标注的json文件生成mask图片(2) 在mask图片中找到目标的bbox矩形框的左上角点和右下角点(…

TCP通信-使用线程池优化

下面的通信架构存在问题&#xff1a; 客户端与服务端的线程模型是&#xff1a; N-N的关系&#xff0c;客户端并发越多&#xff0c;系统瘫痪的越快。 引入线程池处理多个客户端消息 代码实现 public class ClientDemo1 {public static void main(String[] args) {try {Syste…

C++是不是最容易产生猪队友的编程语言之一?

C是不是最容易产生猪队友的编程语言之一&#xff1f; 猪队友不是什么编程语言产生的&#xff0c;而是其做派本身就是猪队友&#xff0c;比如说自己一知半解的东西用得飞 起&#xff0c;而且不愿意深层次去学;再比如说不愿意写单元测试&#xff0c;甚至普通的测试都懒得做。最近…

在chrom浏览器安装Vue.js devtools插件,遇到恶意扩展程序字样,附百度网盘下载链接

遇到的问题 拖拽下载好的 Vue.js devtools 插件到谷歌扩展程序&#xff0c; 百度网盘下载地址 链接&#xff1a;https://pan.baidu.com/s/1FeK6pwc2UzRUUlMFN3rW5w?pwdw361 提取码&#xff1a;w361 提示&#xff1a; 解决办法 将Vue.js devtools 插件的后缀从.crx改为.zi…

C# 文件 校验:MD5、SHA1、SHA256、SHA384、SHA512、CRC32、CRC64

文件 校验 算法:MD5、SHA1、SHA256、SHA384、SHA512、CRC32、CRC64 编程语言:C# 文件属性内容 校验算法:MD5、SHA1、SHA256、SHA384、SHA512、CRC32、CRC64。 核心代码: using System; using System.Collections.Generic; using System; using System.Text; using Syst…