libtorch---day04[MNIST数据集]

news2024/11/14 19:27:06

参考pytorch。

数据集读取

MNIST数据集是一个广泛使用的手写数字识别数据集,包含 60,000张训练图像和10,000张测试图像。每张图像是一个 28 × 28 28\times 28 28×28像素的灰度图像,标签是一个 0到9之间的数字,表示图像中的手写数字。

MNIST 数据集的文件格式

  • 图像文件格式
    图像文件包含一系列28x28像素的灰度图像,文件的前16个字节是头信息,后面是图像数据。头信息的格式如下:

魔数(4 字节):用于标识文件类型。对于图像文件,魔数是 2051(0x00000803)。
图像数量(4 字节):表示文件中包含的图像数量。
图像高度(4 字节):表示每个图像的高度,通常为 28。
图像宽度(4 字节):表示每个图像的宽度,通常为 28。
图像数据紧随其后,每个像素占用 1 字节,按行优先顺序存储。

  • 标签文件格式
    标签文件包含一系列标签,文件的前8个字节是头信息,后面是标签数据。头信息的格式如下:

魔数(4 字节):用于标识文件类型。对于标签文件,魔数是 2049(0x00000801)。
标签数量(4 字节):表示文件中包含的标签数量。
标签数据紧随其后,每个标签占用 1 字节,表示图像的类别(0 到 9)。

std::string dataset_path = "../dataSet/MNIST";
void loadMNIST(std::vector<cv::Mat>& images, std::vector<torch::Tensor>& images_tensor, std::vector<uint8_t>& labels, torch::Tensor& train_labels, bool train_data) {
    int32_t magic_number;
    int32_t num;
    int32_t HEIGHT;
    int32_t WIDTH;

    std::string image_path = train_data ? dataset_path + "/train-images.idx3-ubyte" : dataset_path + "/t10k-images.idx3-ubyte";
    std::string label_path = train_data ? dataset_path + "/train-labels.idx1-ubyte" : dataset_path + "/t10k-labels.idx1-ubyte";

    images.clear();
    labels.clear();
    images_tensor.clear();

    std::ifstream fs;
    fs.open(image_path.c_str(), std::ios::binary);
    if (fs.is_open()) {
        fs.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
        magic_number = _byteswap_ulong(magic_number);
        fs.read(reinterpret_cast<char*>(&num), sizeof(num));
        num = _byteswap_ulong(num);
        fs.read(reinterpret_cast<char*>(&HEIGHT), sizeof(HEIGHT));
        HEIGHT = _byteswap_ulong(HEIGHT);
        fs.read(reinterpret_cast<char*>(&WIDTH), sizeof(WIDTH));
        WIDTH = _byteswap_ulong(WIDTH);
        printf("magic number: %d, image number: %d, image height: %d, image width: %d\n", magic_number, num, HEIGHT, WIDTH);
        for (int i = 0; i < num; i++) {
            std::vector<unsigned char> image_data;
            image_data.resize(HEIGHT * WIDTH);
            fs.read(reinterpret_cast<char*>(image_data.data()), HEIGHT * WIDTH);

            cv::Mat image_cv(HEIGHT, WIDTH, CV_8UC1, image_data.data());
            torch::Tensor image_torch = torch::from_blob(image_data.data(), { static_cast<long long>(image_data.size()) }, torch::kUInt8).clone();
            image_torch = image_torch.to(torch::kF32) / 255.;
            images_tensor.push_back(image_torch);
            images.push_back(image_cv.clone()); // 使用 clone() 确保数据独立
        }
        printf("image vector size: %d\n", int(images.size()));
        fs.close();
    }
    else {
        printf("can not open file %s\n", image_path.c_str());
        return;
    }

    fs.open(label_path.c_str(), std::ios::binary);
    if (fs.is_open()) {
        fs.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
        magic_number = _byteswap_ulong(magic_number);
        fs.read(reinterpret_cast<char*>(&num), sizeof(num));
        num = _byteswap_ulong(num);
        printf("magic number: %d, label number: %d\n", magic_number, num);
        labels.resize(num);
        fs.read(reinterpret_cast<char*>(labels.data()), num);
        train_labels = torch::from_blob(labels.data(), { num }, torch::kUInt8).clone();
        fs.close();
    }
    else {
        printf("can not open file %s\n", label_path.c_str());
        return;
    }
}
  • 要点
    • 在读取 MNIST 数据集时,文件中的数据是以大端序存储的,而大多数现代计算机(如 x86 架构)使用小端序。因此,在读取文件中的数据时,需要进行字节序转换,以确保数据的正确性。
    • 读取数据的时候,用for循环进行读取,如果用while(!fs.eof()),如果不额外处理,会导致多一个数据出来。
    • void数据构造Tensor*的时候,要注意调用clone方法。

全部代码

这里就一个线性层,结合交叉熵函数;

#include <torch/torch.h>
#include <fstream>
#include <opencv2/opencv.hpp>

std::string dataset_path = "../dataSet/MNIST";
void loadMNIST(std::vector<cv::Mat>& images, std::vector<torch::Tensor>& images_tensor, std::vector<uint8_t>& labels, torch::Tensor& train_labels, bool train_data) {
    int32_t magic_number;
    int32_t num;
    int32_t HEIGHT;
    int32_t WIDTH;

    std::string image_path = train_data ? dataset_path + "/train-images.idx3-ubyte" : dataset_path + "/t10k-images.idx3-ubyte";
    std::string label_path = train_data ? dataset_path + "/train-labels.idx1-ubyte" : dataset_path + "/t10k-labels.idx1-ubyte";

    images.clear();
    labels.clear();
    images_tensor.clear();

    std::ifstream fs;
    fs.open(image_path.c_str(), std::ios::binary);
    if (fs.is_open()) {
        fs.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
        magic_number = _byteswap_ulong(magic_number);
        fs.read(reinterpret_cast<char*>(&num), sizeof(num));
        num = _byteswap_ulong(num);
        fs.read(reinterpret_cast<char*>(&HEIGHT), sizeof(HEIGHT));
        HEIGHT = _byteswap_ulong(HEIGHT);
        fs.read(reinterpret_cast<char*>(&WIDTH), sizeof(WIDTH));
        WIDTH = _byteswap_ulong(WIDTH);
        printf("magic number: %d, image number: %d, image height: %d, image width: %d\n", magic_number, num, HEIGHT, WIDTH);
        for (int i = 0; i < num; i++) {
            std::vector<unsigned char> image_data;
            image_data.resize(HEIGHT * WIDTH);
            fs.read(reinterpret_cast<char*>(image_data.data()), HEIGHT * WIDTH);

            cv::Mat image_cv(HEIGHT, WIDTH, CV_8UC1, image_data.data());
            torch::Tensor image_torch = torch::from_blob(image_data.data(), { static_cast<long long>(image_data.size()) }, torch::kUInt8).clone();
            image_torch = image_torch.to(torch::kF32) / 255.;
            images_tensor.push_back(image_torch);
            images.push_back(image_cv.clone()); // 使用 clone() 确保数据独立
        }
        printf("image vector size: %d\n", int(images.size()));
        fs.close();
    }
    else {
        printf("can not open file %s\n", image_path.c_str());
        return;
    }

    fs.open(label_path.c_str(), std::ios::binary);
    if (fs.is_open()) {
        fs.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
        magic_number = _byteswap_ulong(magic_number);
        fs.read(reinterpret_cast<char*>(&num), sizeof(num));
        num = _byteswap_ulong(num);
        printf("magic number: %d, label number: %d\n", magic_number, num);
        labels.resize(num);
        fs.read(reinterpret_cast<char*>(labels.data()), num);
        train_labels = torch::from_blob(labels.data(), { num }, torch::kUInt8).clone();
        fs.close();
    }
    else {
        printf("can not open file %s\n", label_path.c_str());
        return;
    }
}
using namespace torch;

int main()
{
	std::vector<cv::Mat> images_show;
	std::vector<uint8_t> labels_train;
	std::vector<torch::Tensor> image_train;
	torch::Tensor label_train;
	 
	loadMNIST(images_show, image_train, labels_train, label_train, 0);
	torch::Tensor train_data = torch::stack(image_train);
	torch::Tensor train_data_label = label_train;
	torch::Tensor weights = torch::randn({ image_train[0].sizes()[0], 10 }).set_requires_grad(true);
	torch::Tensor bias = torch::randn({ 10 }).set_requires_grad(true);
	double lr = 1e-1;
	int iteration = 10000;
	torch::nn::CrossEntropyLoss criterion;
    torch::optim::SGD optim({weights, bias}, lr);
	for (int i = 0; i < iteration; i++)
	{
		auto predict = torch::matmul(train_data, weights) + bias;
		// std::cout << "predict data size: " << predict.sizes() << ", train_data_label data size: " << train_data_label.sizes() << std::endl;
		auto loss = criterion(predict, train_data_label);

		loss.backward();
        optim.step();
        optim.zero_grad();
        if((i+1) % 500 == 0)
		    printf("[%d /%d, loss: %lf]\n", i + 1, iteration, loss.item<double>());
	}

	loadMNIST(images_show, image_train, labels_train, label_train, 1);

	cv::Mat im_show;
	std::vector<cv::Mat> im_shows;
	for (int i = 0; i < 2; i++)
	{
		cv::Mat im_show_;
		std::vector<cv::Mat> im_shows_;
		for (int j = 0; j < 5; j++)
		{
            int index = torch::randint(0, images_show.size() - 1, {}).item<int>();
			cv::Mat im = images_show[index];
			torch::Tensor im_torch = image_train[index];
			uchar label = labels_train[index];
			auto predict = torch::matmul(im_torch, weights) + bias;
			auto label_predict = torch::argmax(predict.view({ -1 })).item<int>();
			cv::Mat im_resized;
			cv::resize(im, im_resized, cv::Size(16 * im.rows, 16 * im.cols));
			cv::cvtColor(im_resized, im, cv::COLOR_GRAY2RGB);
			cv::putText(im, "groud true: " + std::to_string(static_cast<int>(label)), cv::Point2f(40, 40), cv::FONT_HERSHEY_PLAIN, 3, cv::Scalar(0, 0, 255), 2);
			cv::putText(im, "predict true: " + std::to_string(label_predict), cv::Point2f(40, 90), cv::FONT_HERSHEY_PLAIN, 3, cv::Scalar(0, 255, 0), 2);
			im_shows_.push_back(im);
		}
		cv::hconcat(im_shows_, im_show_);
		im_shows.push_back(im_show_);
	}
	cv::vconcat(im_shows, im_show);
	try {
		/*cv::imshow("result", im_show);
		cv::waitKey(0);*/
        cv::imwrite("validation.png", im_show);
	}
	catch (cv::Exception& e)
	{
		printf("%s\n", e.what());
	}
	return 0;
}

结果

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2101434.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用Aqua进行WebUI测试(Pytest)——介绍篇(附汉化教程)

一、在创建时选择Selenium with Pytest 如果选择的是Selenium&#xff0c;则只能选择Java类语言 选择selenium with Pytest&#xff0c;则可以选择Python类语言 Environment 其中的【Environment】可选New 和 Existing New &#xff1a;选择这个选项意味着你希望工具为你创…

常用企业技术架构开发速查工具列表

对于Java开发者来说,不光要关注业务代码也要注重架构的修炼。日常用到的工具组件都是我们架构中重要的元素,服务于应用系统。我们应该选择适合应用体量的架构避免过度设计,最简单的方式就是矩阵方式去分析每个组件的适用场景优缺点,从而综合评估做好决策。 程序员大多数时间…

一次性说清楚,微软Mos国际认证

简介&#xff1a; Microsoft Office Specialist&#xff08;MOS&#xff09;中文称之为“微软办公软件国际认证”&#xff0c;是微软为全球所认可的Office软件国际性专业认证&#xff0c;全球有168个国家地区认可&#xff0c;每年有近百万人次参加考试&#xff0c;它能有效证明…

Elasticsearch集群架构

Elasticsearch是一种分布式搜索引擎&#xff0c;基于Apache Lucene构建&#xff0c;支持全文搜索、结构化搜索、分析和实时数据处理。 节点&#xff08;Node&#xff09; 节点是集群中的一台服务器。根据节点的角色&#xff0c;可以分为以下几种类型&#xff1a; 主节点&#…

uniapp中slot插槽用法

1.slot的用法 1.1 简单概念 元素作为组件模板之中的内容分发插槽&#xff0c;<slot> 元素自身将被替换 是不是这段话听着有点迷? 那么直接开始上代码 此时创建一个简单的页面&#xff0c;在中间写上一个<slot></slot>标签&#xff0c;标签内并没有数据 …

MySQL——隔离级别及解决方案

CRUD不加控制&#xff0c;会有什么问题&#xff1f; 比如上图场景&#xff0c;当我们的客户端A发现还有一张票的时候&#xff0c;将票卖掉&#xff0c;嗨还没有执行更新数据库的时候&#xff0c;客户端B又检查票数&#xff0c;发现票数大于0&#xff0c;又卖掉了一张票。然后客…

基于FPGA实现SD NAND FLASH的SPI协议读写

基于FPGA实现SD NAND FLASH的SPI协议读写 在此介绍的是使用FPGA实现SD NAND FLASH的读写操作&#xff0c;以雷龙发展提供的CS创世SD NAND FLASH样品为例&#xff0c;分别讲解电路连接、读写时序与仿真和实验结果。 目录 1 FLASH背景介绍 2 样品申请 3 电路结构与接口协议 …

微信管理工具真的那么好用么?

01 多号一个界面聚合聊天 可以同时登录多个微信号&#xff0c;不再需要频繁切换账号或使用多台设备在一个界面聚合聊天。 02 多号朋友圈同步发朋友圈 多个微信号可以即时发布或定时发布朋友圈&#xff0c;省去了逐个发送的繁琐。 03 机器人自动回复 不仅可以自动通过好友…

Android Camera系列(三):GLSurfaceView+Camera

人类的悲欢并不相通—鲁迅 Android Camera系列&#xff08;一&#xff09;&#xff1a;SurfaceViewCamera Android Camera系列&#xff08;二&#xff09;&#xff1a;TextureViewCamera Android Camera系列&#xff08;三&#xff09;&#xff1a;GLSurfaceViewCamera 本系…

Telephony SMS

1、短信的协议架构 如下图,参考3GPP 23.040 4.9节 Protocols and protocol architecture 1、SM-AL : 应用层 2、SM-TL :传输层 3、SM-RL :中继层 4、SM-LL :链路层 由于我们只关注手机终端,因此只需要关注SM-TL这一层即可 2、SM-TL分类 短信的协议架构参考3GPP 23.04…

猛兽财经:在股价创下历史新高后,5个因素将使Netflix股价进一步上涨

来源&#xff1a;猛兽财经 作者&#xff1a;猛兽财经 股价创三年来新高后&#xff0c; Netflix股价还会继续上涨 作为流媒体领域无可争议的领导者&#xff0c;Netflix(NFLX)的股价在上周再次创下了新高&#xff08;每股超过了700美元&#xff0c;这一涨幅已经超过了2021年底创…

[Linux] 项目自动化构建工具-make/Makefile

标题&#xff1a;[Linux] 项目自动化构建工具-make/Makefile 水墨不写bug 目录 一、什么是make/makefile 二、make/makefile语法 补充&#xff08;多文件标识&#xff09;&#xff1a; 三、make/makefile原理 四、make/makefile根据时间对文件选择操作 正文开始&#xff…

基于SpringBoot的校园闲置物品租售系统

你好呀&#xff0c;我是计算机学姐码农小野&#xff01;如果有相关需求&#xff0c;可以私信联系我。 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;SpringBootMyBatis 工具&#xff1a;IDEA/Eclipse、Navicat、Maven 系统展示 首页 用户管理界面 …

华为云征文|Flexus云服务X实例应用,通过QT连接华为云MySQL,进行数据库的操作,数据表的增删改查

引出 4核12G-100G-3M规格的Flexus X实例使用测评第3弹&#xff1a;Flexus云服务X实例应用&#xff0c;通过QT连接华为云MySQL&#xff0c;进行数据库的操作&#xff0c;数据表的增删改查 什么是Flexus云服务器X实例 官方解释&#xff1a; Flexus云服务器X实例是新一代面向中…

【python】如何用python代码快速生成二维码

✨✨ 欢迎大家来到景天科技苑✨✨ &#x1f388;&#x1f388; 养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; &#x1f3c6; 作者简介&#xff1a;景天科技苑 &#x1f3c6;《头衔》&#xff1a;大厂架构师&#xff0c;华为云开发者社区专家博主&#xff0c;…

【算法思想·二叉树】思路篇

本文参考labuladong算法笔记[东哥带你刷二叉树&#xff08;思路篇&#xff09; | labuladong 的算法笔记] 本文承接 【算法思想二叉树】纲领篇&#xff0c;先复述一下前文总结的二叉树解题总纲&#xff1a; 二叉树解题的思维模式分两类&#xff1a; 1、是否可以通过遍历一遍二…

数据结构——单链表相关操作

zhuzhu1、结构框图&#xff1a; 2、增删改查&#xff1a; 定义链表节点和对象类型 /*************************************************************************> File Name: link.h> Author: yas> Mail: rage_yashotmail.com> Created Time: Tue 03 Sep 2024…

ServiceStage集成Sermant实现应用的优雅上下线

作者&#xff1a;聂子雄 华为云高级软件工程师 摘要 优雅上下线旨在确保服务在进行上下线操作时&#xff0c;能够平滑过渡&#xff0c;避免对业务造成影响&#xff0c;保证资源的高效利用。Sermant基于字节码增强的技术实现了应用优雅上下线能力&#xff0c;应用发布与运维平…

摩博会倒计时!OneOS操作系统抢先了解!

2024年第二十二届中国国际摩托车博览会&#xff08;摩博会&#xff09;临近&#xff0c;中移物联OneOS与智能硬件领域佼佼者恒石智能宣布强强合作&#xff0c;与9月13日至16日在重庆国家会展中心共同展现多款Model系列芯片&#xff08;Model3、Model4、Model3C、Model3A&#x…

I2C软件模拟时序的基本要素

目录 前言 一、关于I2C 二、正文 1.引脚的配置 2.I2C的起始和终止时序 3.发送一个字节 4.接收一个字节 5.应答信号 6.指定地址写和指定地址读 总结 前言 环境&#xff1a; 芯片&#xff1a;STM32F103C8T6 Keil&#xff1a;V5.24.2.0 本文主要参考江科大教程&#…