LibTorch实战二:MNIST的libtorch代码

news2024/11/28 6:35:09

目录

一、前言

二、另一种下载数据集方式

三、MNIST的Pytorch源码

四、MNIST的Libtorch源码

一、前言

        前面介绍过了MNIST的python的训练代码、和基于torchscript的模型序列化(导出模型)。今天看看,如何使用libtorch C++来实现手写数字训练。     

二、另一种下载数据集方式

        同时,我已经说过了,对你MNIST数据集该如何下载。有关数据集的下载,这种不重要的问题卡了很久,简直浪费时间,差评。这里再介绍一种下载方式,在官方仓库中,有个脚本可以直接下载https://github.com/pytorch/examples/blob/main/cpp/tools/download_mnist.py,直接在命令行窗口执行就可以下载,如下,可能网络会很卡,不过下载好了。

        这里直接把download_mnist.py源码贴出来吧:

from __future__ import division
from __future__ import print_function

import argparse
import gzip
import os
import sys
import urllib

try:
    from urllib.error import URLError
    from urllib.request import urlretrieve
except ImportError:
    from urllib2 import URLError
    from urllib import urlretrieve

RESOURCES = [
    'train-images-idx3-ubyte.gz',
    'train-labels-idx1-ubyte.gz',
    't10k-images-idx3-ubyte.gz',
    't10k-labels-idx1-ubyte.gz',
]


def report_download_progress(chunk_number, chunk_size, file_size):
    if file_size != -1:
        percent = min(1, (chunk_number * chunk_size) / file_size)
        bar = '#' * int(64 * percent)
        sys.stdout.write('\r0% |{:<64}| {}%'.format(bar, int(percent * 100)))


def download(destination_path, url, quiet):
    if os.path.exists(destination_path):
        if not quiet:
            print('{} already exists, skipping ...'.format(destination_path))
    else:
        print('Downloading {} ...'.format(url))
        try:
            hook = None if quiet else report_download_progress
            urlretrieve(url, destination_path, reporthook=hook)
        except URLError:
            raise RuntimeError('Error downloading resource!')
        finally:
            if not quiet:
                # Just a newline.
                print()


def unzip(zipped_path, quiet):
    unzipped_path = os.path.splitext(zipped_path)[0]
    if os.path.exists(unzipped_path):
        if not quiet:
            print('{} already exists, skipping ... '.format(unzipped_path))
        return
    with gzip.open(zipped_path, 'rb') as zipped_file:
        with open(unzipped_path, 'wb') as unzipped_file:
            unzipped_file.write(zipped_file.read())
            if not quiet:
                print('Unzipped {} ...'.format(zipped_path))


def main():
    parser = argparse.ArgumentParser(
        description='Download the MNIST dataset from the internet')
    parser.add_argument(
        '-d', '--destination', default='.', help='Destination directory')
    parser.add_argument(
        '-q',
        '--quiet',
        action='store_true',
        help="Don't report about progress")
    options = parser.parse_args()

    if not os.path.exists(options.destination):
        os.makedirs(options.destination)

    try:
        for resource in RESOURCES:
            path = os.path.join(options.destination, resource)
            url = 'http://yann.lecun.com/exdb/mnist/{}'.format(resource)
            download(path, url, options.quiet)
            unzip(path, options.quiet)
    except KeyboardInterrupt:
        print('Interrupted')


if __name__ == '__main__':
    main()

 执行下载过程中,可能会很卡,下载信息如下:

(base) C:\Users\Administrator\Desktop\examples-master_2\examples-master\cpp\tools>python download_mnist.py              
.\train-images-idx3-ubyte.gz already exists, skipping ...                                                               
.\train-images-idx3-ubyte already exists, skipping ...                                                                  
.\train-labels-idx1-ubyte.gz already exists, skipping ...                                                               
.\train-labels-idx1-ubyte already exists, skipping ...                                                                  
.\t10k-images-idx3-ubyte.gz already exists, skipping ...                                                                
.\t10k-images-idx3-ubyte already exists, skipping ...                                                                   
.\t10k-labels-idx1-ubyte.gz already exists, skipping ...                                                                
.\t10k-labels-idx1-ubyte already exists, skipping ... 

python代码训练5个epoch结果。

Test set: Average loss: 0.0287, Accuracy: 9907/10000 (99%)

三、MNIST的Pytorch源码

MNIST 的python源码:

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR


class Net(nn.Module):
    def __init__(self): # self指的是类实例对象本身(注意:不是类本身)。
    # self不是关键词
        # super 用于继承,https://www.runoob.com/python/python-func-super.html
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        # input:28*28
        x = self.conv1(x) # -> (28 - 3 + 1 = 26),26*26*32
        x = F.relu(x)
        # input:26*26*32
        x = self.conv2(x) # -> (26 - 3 + 1 = 24),24*24*64
        # input:24*24*64
        x = F.relu(x)
        x = F.max_pool2d(x, 2)# -> 12*12*64 = 9216
        x = self.dropout1(x) #不改变维度
        x = torch.flatten(x, 1) # 9216*1
        # w = 128*9216
        x = self.fc1(x) # -> 128*1
        x = F.relu(x)
        x = self.dropout2(x)
        # w = 10*128
        x = self.fc2(x) # -> 10*1
        output = F.log_softmax(x, dim=1) # softmax归一化
        return output


def train(args, model, device, train_loader, optimizer, epoch):
    # 在使用pytorch构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),
    # 作用是启用batch normalization和drop out。
    # 测试过程中会使用model.eval(),这时神经网络会沿用batch normalization的值,并不使用drop out。
    model.train()
    # 可以查看下卷积核的参数尺寸
    #model.conv1.weight.shape torch.Size([32, 1, 3, 3]
    #model.conv2.weight.shape torch.Size([64, 32, 3, 3])

    for batch_idx, (data, target) in enumerate(train_loader):
        # train_loader.dataset.data.shape
        # Out[9]: torch.Size([60000, 28, 28])

        # batch_size:64
        # data:64个样本输入,torch.Size([64, 1, 28, 28])
        # target: 64个label,torch.Size([64])
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        # output:torch.Size([64, 10])
        output = model(data)
        # 类似于交叉熵
        # reference: https://blog.csdn.net/qq_22210253/article/details/85229988
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        # 我们打印一个卷积核参数看看
        # print(model.conv2._parameters)

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=5, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=True,
                        help='For Saving the current Model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True, # 锁页内存,可以加快内存到显存的速度
                       'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)
    # torchvision.transforms是pytorch中的图像预处理包。一般用Compose把多个步骤整合到一起
    #
    transform = transforms.Compose([
        transforms.ToTensor(), # (H x W x C)、[0, 255]  -> (C x H x W)、[0.0, 1.0]
        transforms.Normalize((0.1307,), (0.3081,)) # 数据的归一化
        ])

    dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
    dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
    # 固定步长衰减
    # reference: https://zhuanlan.zhihu.com/p/93624972
    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
        scheduler.step()

    if args.save_model:
        #torch.save(model.state_dict(), "pytorch_mnist.pt")
        torch.save(model, "pytorch_mnist.pth")


if __name__ == '__main__':
    main()

四、MNIST的Libtorch源码

以下是C++代码(官方的C++代码的网络结果似乎和python代码不能完全对应上,所以我作了修改,其实就是改了网络模型,请看struct Net : torch::nn::Module):可以对一下struct Net : torch::nn::Module和上述python代码中的 class Net(nn.Module):

#include<torch/torch.h>
#include<cstddef>
#include<iostream>
#include<vector>
#include<string>
// 继承自Module模块
struct Net : torch::nn::Module
{
    // 构造函数
    Net() :
        conv1(torch::nn::Conv2dOptions(1, 32, 3)), // kernel_size = 5
        conv2(torch::nn::Conv2dOptions(32, 64, 3)),
        fc1(9216, 128),
        fc2(128, 10)
    {
        register_module("conv1", conv1);
        register_module("conv2", conv2);
        register_module("conv2_drop", conv2_drop);
        register_module("fc1", fc1);
        register_module("fc2", fc2);
    }
    // 成员函数:前向传播
    torch::Tensor forward(torch::Tensor x)
    {
        // input:1*28*28
        x = torch::relu(conv1->forward(x)); //conv1:(28 - 3 + 1 = 26), 26*26*32
        // input:26*26*32
        x = torch::max_pool2d(torch::relu(conv2->forward(x)), 2);//conv2:(26 - 3 + 1 = 24),24*24*64; max_poolded:12*12*64 = 9216
        x = torch::dropout(x, 0.25, is_training());
        x = x.view({ -1, 9216 });// 9216*1
        // w:128*9216
        x = torch::relu(fc1->forward(x)); //fc1:w = 128*9216,w * x ->128*1
        x = torch::dropout(x, 0.5, is_training());
        // w:10*128
        x = fc2->forward(x);//fc2:w = 10*128,w * x -> 10*1
        x = torch::log_softmax(x, 1);
        return x;

    }


    // 模块成员
    torch::nn::Conv2d conv1;
    torch::nn::Conv2d conv2;
    torch::nn::Dropout2d conv2_drop;
    torch::nn::Linear fc1;
    torch::nn::Linear fc2;
};

//train
template<typename DataLoader>
void train(size_t epoch, Net& model, torch::Device device, DataLoader& data_loader, torch::optim::Optimizer& optimizer, size_t dataset_size)
{
    //set "train" mode
    model.train();
    size_t batch_idx = 0;
    for (auto& batch: data_loader)
    {
        auto data = batch.data.to(device);
        auto targets = batch.target.to(device);
        optimizer.zero_grad();
        auto output = model.forward(data);
        auto loss = torch::nll_loss(output, targets);
        AT_ASSERT(!std::isnan(loss.template item<float>()));
        loss.backward();
        optimizer.step();

        // 每10个batch_size打印一次loss
        if (batch_idx++ % 10 == 0)
        {
            std::printf("\rTrain Epoch: %ld [%5ld/%5ld] Loss: %.4f",
                epoch,
                batch_idx * batch.data.size(0),
                dataset_size,
                loss.template item<float>());
        }
    }
}

template<typename DataLoader>
void test(Net& model, torch::Device device, DataLoader& data_loader, size_t dataset_size)
{
    torch::NoGradGuard no_grad;
    // set "test" mode
    model.eval();
    double test_loss = 0;
    int32_t correct = 0;
    for (const auto& batch: data_loader)
    {
        auto data = batch.data.to(device);
        auto targets = batch.target.to(device);
        auto output = model.forward(data);
        test_loss += torch::nll_loss(output, targets, /*weight=*/{}, torch::Reduction::Sum).template item<float>();
        auto pred = output.argmax(1);
        // eq = equal 判断prediction 是否等于label
        correct += pred.eq(targets).sum().template item<int64_t>();
    }
    test_loss /= dataset_size;
    std::printf(
        "\nTest set: Average loss: %.4f | Accuracy: %.3f\n",
        test_loss,
        static_cast<double>(correct) / dataset_size);
}

int main()
{
    torch::manual_seed(1);
    torch::DeviceType device_type;
    if (torch::cuda::is_available())
    {
        std::cout << "CUDA available! Training on GPU." << std::endl;
        device_type = torch::kCUDA;
    }
    else
    {
        std::cout << "Training on CPU." << std::endl;
        device_type = torch::kCPU;
    }

    torch::Device device(device_type);

    Net model;
    model.to(device);
    // load train data
    auto train_dataset = torch::data::datasets::MNIST("D://MNIST//")
        .map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
        .map(torch::data::transforms::Stack<>());

    const size_t train_dataset_size = train_dataset.size().value();
    std::cout << train_dataset_size << std::endl;
    auto train_loader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
        std::move(train_dataset), 64);
    // load test data
    auto test_dataset = torch::data::datasets::MNIST(
        "D://MNIST//", torch::data::datasets::MNIST::Mode::kTest)
        .map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
        .map(torch::data::transforms::Stack<>());
    const size_t test_dataset_size = test_dataset.size().value();
    auto test_loader =
        torch::data::make_data_loader(std::move(test_dataset), 1000);

    // optimizer
    torch::optim::SGD optimizer(model.parameters(), torch::optim::SGDOptions(0.01).momentum(0.5));

    //train
    for (size_t epoch = 0; epoch < 5; epoch++)
    {
        train(epoch, model, device, *train_loader, optimizer, train_dataset_size);
        test(model, device, *test_loader, test_dataset_size);
    }
    // save
    return 1;
}

C++代码训练结果如图:

可以看到C++版本的 MNIST代码能够正常训练模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1146144.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【算法|动态规划No.32 | 完全背包问题】完全背包模板题

个人主页&#xff1a;兜里有颗棉花糖 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 兜里有颗棉花糖 原创 收录于专栏【手撕算法系列专栏】【LeetCode】 &#x1f354;本专栏旨在提高自己算法能力的同时&#xff0c;记录一下自己的学习过程&#xff0c;希望…

2023年【加氢工艺】考试题库及加氢工艺免费试题

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 2023年加氢工艺考试题库为正在备考加氢工艺操作证的学员准备的理论考试专题&#xff0c;每个月更新的加氢工艺免费试题祝您顺利通过加氢工艺考试。 1、【单选题】《使用有毒物品作业场所劳动保护条例》规定,从事使用高…

Linux常用命令——chown命令

在线Linux命令查询工具 chown 用来变更文件或目录的拥有者或所属群组 补充说明 chown命令改变某个文件或目录的所有者和所属的组&#xff0c;该命令可以向某个用户授权&#xff0c;使该用户变成指定文件的所有者或者改变文件所属的组。用户可以是用户或者是用户D&#xff0…

杨辉三角形

要求输出10行杨辉三角形如下图&#xff1a; 杨辉三角的特点: 1,只需要处理下三角形; 2.第一列和主对角线的值为1; 3.其它位置的值等于上一行前 一列上一行同列的值。 int main() { #define ROW 10//行和列int arr[ROW][ROW];for (int i 0; i < ROW; i){for (int j 0; j &l…

第四章 文件管理 十一、虚拟文件系统

目录 一、虚拟文件系统图 二、虚拟文件系统的特点 三、存在的问题 四、文件系统挂载 一、虚拟文件系统图 二、虚拟文件系统的特点 1、向上层用户进程提供统一标准的系统调用接口&#xff0c;屏蔽底层具体文件系统的实现差异。 2、VFS要求下层的文件系统必须实现某些规定的…

SPI 串行外围设备接口

SPI&#xff08;Serial Peripheral interface&#xff09;&#xff0c;串行外围设备接口。是一种全双工形式的高速同步通信总线。 SPI 硬件接口由四根信号线组成&#xff0c;分别是&#xff1a; SDI&#xff1a;数据输入SDO&#xff1a;数据输出SCK&#xff1a;时钟CS/SS&…

BUUCTF 简单注册器 1

题目是简单注册器 分析 直接运行下 有个错误提示&#xff0c;使用jadx查找 &#xff08;ctrl shift f&#xff09; 直接复制下代码 int flag 1; String xx editview.getText().toString(); if (xx.length() ! 32 || xx.charAt(31) ! a || xx.charAt(1) ! b || (xx.cha…

数据库连接技术

一、许多编程语言 都可以 连接数据库。不是在C中加入SQL语句&#xff0c;而是使 用C编程语言 连接数据库&#xff0c;并执行SQL语句&#xff0c;以获得数据。 数据库连接&#xff0c;有一些通用的方式。C中连接数据库并执行SQL语句&#xff0c;主要有以下几种方式&#xff1a; …

PLC-200 smart 字节与字

这里写自定义目录标题 数据存储器的组合——字节与字组合字与双字组合 数据存储&#xff1a;右侧低位&#xff0c;左侧高位 1输出&#xff1b;0不输出 v&#xff1a;存储区标识符 例如&#xff1a;VB100.0&#xff0c;v存储区标识符&#xff1b;100 字节编号&#xff1b;“.0”…

C语言 定义一个函数,并调用,该函数中打印显示九九乘法表

#include<stdio.h> void chengfabiao() {for (int i 1; i < 10; i){for (int j 1; j < i; j){printf("%d * %d %d\t",j,i,i*j);} printf("\n");} } int main(int argc,const char *argv[]) {chengfabiao();return 0; }

饭局从入门到精通

文章目录 你会把你妈卖到妓院吗&#xff1f;声明一 为什么要请客吃饭1 环境变化&#xff0c;身份跟着变化2 酒杯识人3 吃人嘴软 二 饭局的准备1 明确自己设饭局的目的2 掌握客人的特点3 如何设计陪客的名单 三 如何正确选择饭店1 地段选择法2 环境选择法3 菜系选择法 四 如何邀…

【AD9361 数字接口CMOS LVDSSPI】B 并行数据之CMOS

##接上一篇&#xff1b; 本节介绍 AD9361 数字接口CMOS &LVDS&SPI最后一张表中四种工作模式的具体配置及时序波形图。 目录 1、单端口半双工模式 &#xff08;CMOS&#xff09; *代称 SHC*换句话说&#xff0c;最大值是12‘b0111_1111_1111&#xff0c;即0x7FF&#xf…

IOC课程整理-5 Spring IoC 依赖查找

1 依赖查找的今世前生 2 单一类型依赖查找 3 集合类型依赖查找 4 层次性依赖查找 5 延迟依赖查找 6 依赖查找安全性对比 7 内建可查找的依赖 • AbstractApplicationContext 内建可查找的依赖 注解驱动 Spring 应用上下文内建可查找的依赖&#xff08;部分&#xff09; 8 依…

两个手机屏幕的效果对比

其中一个刚买的二手&#xff0c;卖家说坏了&#xff0c;换成国产屏&#xff0c;没有指纹。其实拿到手时&#xff0c;吾就发现屏幕明显泛白&#xff0c;颜色与手头的相差太大。 对比1 对比2

轻量级 SSO 方略

单点登录 SSO&#xff08;Single Sign On&#xff09;是在多个应用系统中&#xff0c;用户只需要登录一次就可以访问所有相互信任的应用系统。打通所有系统的账户密码&#xff0c;只需要记住一个就行&#xff0c;而且登录一个系统后&#xff0c;打开其他系统不需要再登录。广义…

【Unity程序技巧】Input管理器

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;元宇宙-秩沅 &#x1f468;‍&#x1f4bb; hallo 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍&#x1f4bb; 本文由 秩沅 原创 &#x1f468;‍&#x1f4bb; 收录于专栏&#xff1a;Uni…

【Python 零基础入门】常用内置函数 初探

【Python 零基础入门】内容补充 1 常用内置函数 Python 简介为什么要学习内置函数数据类型和转换int(): 转为整数float(): 转为浮点数list(): 转为列表tuple(): 转换为元组set():转换为集合dict(): 创建字典: 数学运算abs(): 绝对值pow(): 幂运算round(): 四舍五入min(): 最小值…

类与面向对象

章节目录&#xff1a; 一、面向对象二、类2.1 类定义2.2 类对象2.3 self 代表类的实例&#xff0c;而非类 三、类的方法四、多继承五、方法重写六、私有属性及私有方法七、类的专有方法八、专有方法重载九、结束语 一、面向对象 Python 从设计之初就已经是一门面向对象的语言。…

【错误解决方案】ModuleNotFoundError: No module named ‘torch._six‘

1. 错误提示 在python程序中&#xff0c;试图导入一个名为torch._six的模块&#xff0c;但Python提示找不到这个模块。 错误提示&#xff1a;ModuleNotFoundError: No module named torch._six 2. 解决方案 出现这个错误可能是因为你使用的PyTorch版本和你的代码不兼容。在某…

MySQL实战1

文章目录 主要内容一.墨西哥和美国第三高峰1.准备工作代码如下&#xff08;示例&#xff09;: 2.目标3.实现代码如下&#xff08;示例&#xff09;: 4.相似例子代码如下&#xff08;示例&#xff09;: 二.用latest_event查找当前打开的页数1.准备工作代码如下&#xff08;示例&…