参考pytorch。
数据集读取
MNIST数据集是一个广泛使用的手写数字识别数据集,包含 60,000张训练图像和10,000张测试图像。每张图像是一个 28 × 28 28\times 28 28×28像素的灰度图像,标签是一个 0到9之间的数字,表示图像中的手写数字。
MNIST 数据集的文件格式
- 图像文件格式
图像文件包含一系列28x28像素的灰度图像,文件的前16个字节是头信息,后面是图像数据。头信息的格式如下:
魔数(4 字节):用于标识文件类型。对于图像文件,魔数是 2051(0x00000803)。
图像数量(4 字节):表示文件中包含的图像数量。
图像高度(4 字节):表示每个图像的高度,通常为 28。
图像宽度(4 字节):表示每个图像的宽度,通常为 28。
图像数据紧随其后,每个像素占用 1 字节,按行优先顺序存储。
- 标签文件格式
标签文件包含一系列标签,文件的前8个字节是头信息,后面是标签数据。头信息的格式如下:
魔数(4 字节):用于标识文件类型。对于标签文件,魔数是 2049(0x00000801)。
标签数量(4 字节):表示文件中包含的标签数量。
标签数据紧随其后,每个标签占用 1 字节,表示图像的类别(0 到 9)。
std::string dataset_path = "../dataSet/MNIST";
void loadMNIST(std::vector<cv::Mat>& images, std::vector<torch::Tensor>& images_tensor, std::vector<uint8_t>& labels, torch::Tensor& train_labels, bool train_data) {
int32_t magic_number;
int32_t num;
int32_t HEIGHT;
int32_t WIDTH;
std::string image_path = train_data ? dataset_path + "/train-images.idx3-ubyte" : dataset_path + "/t10k-images.idx3-ubyte";
std::string label_path = train_data ? dataset_path + "/train-labels.idx1-ubyte" : dataset_path + "/t10k-labels.idx1-ubyte";
images.clear();
labels.clear();
images_tensor.clear();
std::ifstream fs;
fs.open(image_path.c_str(), std::ios::binary);
if (fs.is_open()) {
fs.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
magic_number = _byteswap_ulong(magic_number);
fs.read(reinterpret_cast<char*>(&num), sizeof(num));
num = _byteswap_ulong(num);
fs.read(reinterpret_cast<char*>(&HEIGHT), sizeof(HEIGHT));
HEIGHT = _byteswap_ulong(HEIGHT);
fs.read(reinterpret_cast<char*>(&WIDTH), sizeof(WIDTH));
WIDTH = _byteswap_ulong(WIDTH);
printf("magic number: %d, image number: %d, image height: %d, image width: %d\n", magic_number, num, HEIGHT, WIDTH);
for (int i = 0; i < num; i++) {
std::vector<unsigned char> image_data;
image_data.resize(HEIGHT * WIDTH);
fs.read(reinterpret_cast<char*>(image_data.data()), HEIGHT * WIDTH);
cv::Mat image_cv(HEIGHT, WIDTH, CV_8UC1, image_data.data());
torch::Tensor image_torch = torch::from_blob(image_data.data(), { static_cast<long long>(image_data.size()) }, torch::kUInt8).clone();
image_torch = image_torch.to(torch::kF32) / 255.;
images_tensor.push_back(image_torch);
images.push_back(image_cv.clone()); // 使用 clone() 确保数据独立
}
printf("image vector size: %d\n", int(images.size()));
fs.close();
}
else {
printf("can not open file %s\n", image_path.c_str());
return;
}
fs.open(label_path.c_str(), std::ios::binary);
if (fs.is_open()) {
fs.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
magic_number = _byteswap_ulong(magic_number);
fs.read(reinterpret_cast<char*>(&num), sizeof(num));
num = _byteswap_ulong(num);
printf("magic number: %d, label number: %d\n", magic_number, num);
labels.resize(num);
fs.read(reinterpret_cast<char*>(labels.data()), num);
train_labels = torch::from_blob(labels.data(), { num }, torch::kUInt8).clone();
fs.close();
}
else {
printf("can not open file %s\n", label_path.c_str());
return;
}
}
- 要点
- 在读取 MNIST 数据集时,文件中的数据是以大端序存储的,而大多数现代计算机(如 x86 架构)使用小端序。因此,在读取文件中的数据时,需要进行字节序转换,以确保数据的正确性。
- 读取数据的时候,用for循环进行读取,如果用while(!fs.eof()),如果不额外处理,会导致多一个数据出来。
- 用void数据构造Tensor*的时候,要注意调用clone方法。
全部代码
这里就一个线性层,结合交叉熵函数;
#include <torch/torch.h>
#include <fstream>
#include <opencv2/opencv.hpp>
std::string dataset_path = "../dataSet/MNIST";
void loadMNIST(std::vector<cv::Mat>& images, std::vector<torch::Tensor>& images_tensor, std::vector<uint8_t>& labels, torch::Tensor& train_labels, bool train_data) {
int32_t magic_number;
int32_t num;
int32_t HEIGHT;
int32_t WIDTH;
std::string image_path = train_data ? dataset_path + "/train-images.idx3-ubyte" : dataset_path + "/t10k-images.idx3-ubyte";
std::string label_path = train_data ? dataset_path + "/train-labels.idx1-ubyte" : dataset_path + "/t10k-labels.idx1-ubyte";
images.clear();
labels.clear();
images_tensor.clear();
std::ifstream fs;
fs.open(image_path.c_str(), std::ios::binary);
if (fs.is_open()) {
fs.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
magic_number = _byteswap_ulong(magic_number);
fs.read(reinterpret_cast<char*>(&num), sizeof(num));
num = _byteswap_ulong(num);
fs.read(reinterpret_cast<char*>(&HEIGHT), sizeof(HEIGHT));
HEIGHT = _byteswap_ulong(HEIGHT);
fs.read(reinterpret_cast<char*>(&WIDTH), sizeof(WIDTH));
WIDTH = _byteswap_ulong(WIDTH);
printf("magic number: %d, image number: %d, image height: %d, image width: %d\n", magic_number, num, HEIGHT, WIDTH);
for (int i = 0; i < num; i++) {
std::vector<unsigned char> image_data;
image_data.resize(HEIGHT * WIDTH);
fs.read(reinterpret_cast<char*>(image_data.data()), HEIGHT * WIDTH);
cv::Mat image_cv(HEIGHT, WIDTH, CV_8UC1, image_data.data());
torch::Tensor image_torch = torch::from_blob(image_data.data(), { static_cast<long long>(image_data.size()) }, torch::kUInt8).clone();
image_torch = image_torch.to(torch::kF32) / 255.;
images_tensor.push_back(image_torch);
images.push_back(image_cv.clone()); // 使用 clone() 确保数据独立
}
printf("image vector size: %d\n", int(images.size()));
fs.close();
}
else {
printf("can not open file %s\n", image_path.c_str());
return;
}
fs.open(label_path.c_str(), std::ios::binary);
if (fs.is_open()) {
fs.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
magic_number = _byteswap_ulong(magic_number);
fs.read(reinterpret_cast<char*>(&num), sizeof(num));
num = _byteswap_ulong(num);
printf("magic number: %d, label number: %d\n", magic_number, num);
labels.resize(num);
fs.read(reinterpret_cast<char*>(labels.data()), num);
train_labels = torch::from_blob(labels.data(), { num }, torch::kUInt8).clone();
fs.close();
}
else {
printf("can not open file %s\n", label_path.c_str());
return;
}
}
using namespace torch;
int main()
{
std::vector<cv::Mat> images_show;
std::vector<uint8_t> labels_train;
std::vector<torch::Tensor> image_train;
torch::Tensor label_train;
loadMNIST(images_show, image_train, labels_train, label_train, 0);
torch::Tensor train_data = torch::stack(image_train);
torch::Tensor train_data_label = label_train;
torch::Tensor weights = torch::randn({ image_train[0].sizes()[0], 10 }).set_requires_grad(true);
torch::Tensor bias = torch::randn({ 10 }).set_requires_grad(true);
double lr = 1e-1;
int iteration = 10000;
torch::nn::CrossEntropyLoss criterion;
torch::optim::SGD optim({weights, bias}, lr);
for (int i = 0; i < iteration; i++)
{
auto predict = torch::matmul(train_data, weights) + bias;
// std::cout << "predict data size: " << predict.sizes() << ", train_data_label data size: " << train_data_label.sizes() << std::endl;
auto loss = criterion(predict, train_data_label);
loss.backward();
optim.step();
optim.zero_grad();
if((i+1) % 500 == 0)
printf("[%d /%d, loss: %lf]\n", i + 1, iteration, loss.item<double>());
}
loadMNIST(images_show, image_train, labels_train, label_train, 1);
cv::Mat im_show;
std::vector<cv::Mat> im_shows;
for (int i = 0; i < 2; i++)
{
cv::Mat im_show_;
std::vector<cv::Mat> im_shows_;
for (int j = 0; j < 5; j++)
{
int index = torch::randint(0, images_show.size() - 1, {}).item<int>();
cv::Mat im = images_show[index];
torch::Tensor im_torch = image_train[index];
uchar label = labels_train[index];
auto predict = torch::matmul(im_torch, weights) + bias;
auto label_predict = torch::argmax(predict.view({ -1 })).item<int>();
cv::Mat im_resized;
cv::resize(im, im_resized, cv::Size(16 * im.rows, 16 * im.cols));
cv::cvtColor(im_resized, im, cv::COLOR_GRAY2RGB);
cv::putText(im, "groud true: " + std::to_string(static_cast<int>(label)), cv::Point2f(40, 40), cv::FONT_HERSHEY_PLAIN, 3, cv::Scalar(0, 0, 255), 2);
cv::putText(im, "predict true: " + std::to_string(label_predict), cv::Point2f(40, 90), cv::FONT_HERSHEY_PLAIN, 3, cv::Scalar(0, 255, 0), 2);
im_shows_.push_back(im);
}
cv::hconcat(im_shows_, im_show_);
im_shows.push_back(im_show_);
}
cv::vconcat(im_shows, im_show);
try {
/*cv::imshow("result", im_show);
cv::waitKey(0);*/
cv::imwrite("validation.png", im_show);
}
catch (cv::Exception& e)
{
printf("%s\n", e.what());
}
return 0;
}