在树莓派上实现numpy的conv2d卷积神经网络做图像分类,加载pytorch的模型参数,推理mnist手写数字识别,并使用多进程加速

news2024/12/23 23:01:46

这几天又在玩树莓派,先是搞了个物联网,又在尝试在树莓派上搞一些简单的神经网络,这次搞得是卷积识别mnist手写数字识别

训练代码在电脑上,cpu就能训练,很快的:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
# 设置随机种子
torch.manual_seed(42)

# 定义数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.1307,), (0.3081,))
])

# 加载训练数据集
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

# 构建卷积神经网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.pool = nn.MaxPool2d(2)
        self.fc = nn.Linear(10 * 12 * 12, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = x.view(-1, 10 * 12 * 12)
        x = self.fc(x)
        return x

model = Net()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

# 训练模型
def train(model, device, train_loader, optimizer, criterion, epochs):
    model.train()
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 100 == 0:
                print(f'Train Epoch: {epoch+1} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                      f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

# 在GPU上训练(如果可用),否则使用CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 训练模型
train(model, device, train_loader, optimizer, criterion, epochs=5)

# 保存模型为NumPy数据
model_state = model.state_dict()
numpy_model_state = {key: value.cpu().numpy() for key, value in model_state.items()}
np.savez('model.npz', **numpy_model_state)
print("Model saved as model.npz")
 

然后需要自己在dataset里导出一些图片:我保存在了mnist_pi文件夹下,“_”后面的是标签,主要是在pc端导出保存到树莓派下

 树莓派推理端的代码,需要numpy手动重新搭建网络,并且需要手动实现conv2d卷积神经网络和maxpool2d最大池化,然后加载那些保存的矩阵参数,做矩阵乘法和加法

复制代码

import numpy as np
import os
from PIL import Image

def conv2d(input, weight, bias, stride=1, padding=0):
    batch_size, in_channels, in_height, in_width = input.shape
    out_channels, in_channels, kernel_size, _ = weight.shape

    # 计算输出特征图的大小
    out_height = (in_height + 2 * padding - kernel_size) // stride + 1
    out_width = (in_width + 2 * padding - kernel_size) // stride + 1

    # 添加padding
    padded_input = np.pad(input, ((0, 0), (0, 0), (padding, padding), (padding, padding)), mode='constant')

    # 初始化输出特征图
    output = np.zeros((batch_size, out_channels, out_height, out_width))

    # 执行卷积操作
    for b in range(batch_size):
        for c_out in range(out_channels):
            for h_out in range(out_height):
                for w_out in range(out_width):
                    h_start = h_out * stride
                    h_end = h_start + kernel_size
                    w_start = w_out * stride
                    w_end = w_start + kernel_size

                    # 提取对应位置的输入图像区域
                    input_region = padded_input[b, :, h_start:h_end, w_start:w_end]

                    # 计算卷积结果
                    x = input_region * weight[c_out]
                    bia = bias[c_out]
                    conv_result = np.sum(x, axis=(0,1, 2)) + bia

                    # 将卷积结果存储到输出特征图中
                    output[b, c_out, h_out, w_out] = conv_resu

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1036415.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

spring-cloud-stream版本升级,告别旧注解@EnableBinding,拥抱函数式编程

spring-cloud-stream中,EnableBinding从3.1开始就被弃用,取而代之的是函数式编程模型 同期被废弃的注解还有下面这些注解 Input Output EnableBinding StreamListener 官方例子:GitHub - spring-cloud/spring-cloud-stream-samples: Sample…

电视访问群晖共享文件失败的设置方式,降低协议版本

控制面板-文件服务-SMB-高级设置,常规及其他里面配置即可。

微信公众号小说系统源码 漫画系统源码 可对接微信公众号 APP打包 对接个人微信收款

源码描述:修复版掌上阅读小说源码_公众号漫画源码可以打包漫画app ■产品介绍 掌上阅读小说源码支持公众号、代理分站支付功能完善强大的小说源码,公众号乙帅读者, 可以对接微信公众号、APP打包。支持对接个人微信收款。 ■产品优势 1新增…

免费好用的Mac电脑磁盘清理工具CleanMyMac

许多刚从Windows系统转向Mac系统怀抱的用户,一开始难免不习惯,因为Mac系统没有像Windows一样的C盘、D盘,分盘分区明显。因此这也带来了一些问题,关于Mac的磁盘的清理问题,怎么进行清理?怎么确保清理的干净&…

3+氧化应激+分型+预后模型

今天给同学们分享一篇3氧化应激分型预后模型的生信文章“An oxidative stress-related signature for predicting the prognosis of liver cancer”,这篇文章于2023年月4日发表在Front Genet 期刊上,影响因子为3.7。 越来越多的证据表明,肿瘤…

【MySQL数据库事务操作、主从复制及Redis数据库读写分离、主从同步的实现机制】

文章目录 MySQL数据库事务操作、主从复制及Redis数据库读写分离、主从同步的实现机制ACID及如何实现事务隔离级别:MVCC 多版本并发控制MySQL数据库主从复制主从同步延迟怎么处理Redis 读写分离1.什么是主从复制2.读写分离的优点 Redis为什么快呢? MySQL数…

【完美世界】天仙书院偷食也就算了,竟然还偷院长的孙女,美滋滋

Hello,小伙伴们,我是小郑继续为大家深度解析完美世界系列。 齐道临从天仙书院劫走石昊,为何天仙书院不仅没去找他麻烦,反而给他一块随意进入渡劫神莲池的令牌?石昊来到上界也是闹出不小的动静,先是在恶魔岛的神碑留名&…

C语言数组和指针笔试题(四)(一定要看)

目录 二维数组例题一例题二例题三例题四例题五例题六例题七例题八例题九例题十例题十一 结果 感谢各位大佬对我的支持,如果我的文章对你有用,欢迎点击以下链接 🐒🐒🐒个人主页 🥸🥸🥸C语言 🐿️…

算法通关村第16关【青铜】| 滑动窗口思想

1. 滑动窗口的基本思想 一句话概括就是两个快慢指针维护的一个会移动的区间 固定大小窗口:求哪个窗口元素最大、最小、平均值、和最大、和最小 可变大小窗口:求一个序列里最大、最小窗口是什么 2. 两个入门题 (1)子数组最大平…

#倍增 #国旗计划

文章目录 题目&#xff1a;题解代码 题目&#xff1a; 国旗计划 题解 三个技巧&#xff1a; 断环成链&#xff1a; 具体而言就是&#xff1a; if(w[i].R < w[i].L) w[i].R m; m是环的长度&#xff1b; 贪心&#xff1a; 选择一个区间i后&#xff0c;下一个区间只能从左端…

(c++)类和对象中篇

目录 1. 类的6个默认成员函数 2. 构造函数 3. 析构函数 4. 拷贝构造函数 5. 赋值运算符重载 6. const成员函数 7. 取地址及const取地址操作符重载 1. 类的 6 个默认成员函数 如果一个类中什么成员都没有&#xff0c;简称为空类。 空类中真的什么都没有吗&#xff1f;并…

线性绘制在NSDT 3D场布中的应用

什么是线性摆放&#xff1f; 线性摆放是指将一系列对象按照直线或者曲线进行排列&#xff0c;形成一条线或者弧线状的布局方式。在3D场布中&#xff0c;线性摆放可以应用于多个领域和场景&#xff0c;如展览设计、景观规划、商业空间布置等。 线性绘制在3D场布中的应用 展览设…

Postman全局配置变量token

Postman全局配置变量token 这里主要是介绍在 Postman 中全局配置token&#xff0c;以及方便以后查阅&#xff01;&#xff01;&#xff01; 一、简介 用户在开发或调试网络程序和网页B/S模式的程序时需要一些方法来跟踪网页请求&#xff0c;可使用一些网络的监视工具如Firebu…

多线程详解(下)

文章目录 常见锁策略乐观锁 vs 悲观锁重量级锁 vs 轻量级锁自旋锁 vs 挂起等待锁读写锁可重入锁 vs 不可重入锁公平锁 vs 非公平锁面试相关题 CAS什么是CASCAS 是怎么实现的CAS 有哪些应用1)实现原子类2)实现自旋锁 CAS的ABA问提什么是ABA问提ABA问提引来的BUG解决方法 相关面试…

基于LLMs构建产业多智能体

前言 随着信息技术的发展以及产业数字化的发展&#xff0c;在产业端&#xff0c;信息系统的建设和应用场景的搭建日渐完善&#xff0c;如何从完备的业务系统中挖掘数据价值以及如何从业务互联走向数据驱动决策成为产业数字化的新发展阶段。目前主要由数据中台承担数据汇聚、数…

Kettle安装初始化问题

1、Kettle启动闪退: 原因&#xff1a;自己的JDK是16 8.0的Kettle适配JDK1.8 【Spoon.bat 双击后闪退】解决办法 - 知乎 2、KettleDB连接中文命名 Unexpected problem reading shared objects from XML file : null Error reading information from input stream Invalid …

解读未知--文档图像大模型的探索与应用

前言&#xff1a; 近日&#xff0c;合合信息在多模态大模型与文档图像智能理解专题论坛上进行了分享。多模态大模型指的是能够处理多种语义信息的一种深度学习模型。文档图像智能理解则是指对文档和图像进行智能化解析和理解的技术。合合信息在这个领域的分享&#xff0c;无疑将…

PHP 变动:PHP 8 版本下字符串与数值的弱比较

文章目录 参考环境声明弱比较隐式类型转换字符串连接数学运算布尔判断相等运算符 字符串与数值的弱比较字符串转化为数值的具体规则字符串与数值的弱比较一般情况科学计数法前缀 0E 与 0e PHP8 在字符串与数值的弱比较方面做出的改动数值字符串优化 参考 项目描述搜索引擎Bing…

栈的应用(C++,进制转化、括号匹配)

十进制转化八进制&#xff0c;利用栈 #include<iostream>//十进制转八进制&#xff0c;利用栈 using namespace std; typedef struct stack {int data;stack* next; }stack, * linkstack; void Initstack(linkstack& s) {s NULL; } int Emptystack(linkstack s) {i…

华为云云耀云服务器L实例评测|基于开源库 Stable Diffusion web UI部署AI绘画应用

前言 随着云计算时代的进一步深入&#xff0c;越来越多的中小企业企业与开发者需要一款简单易用、高能高效的云计算基础设施产品来支撑自身业务运营和创新开发。基于这种需求&#xff0c;华为云焕新推出华为云云服务器实例新品。 华为云云服务器具有智能不卡顿、价优随心用、…