深度学习-图像处理篇1.3pytorch神经网络例子

news2024/9/21 21:09:32

在这里插入图片描述
batch:一批·图像·数量
官方例子

#model
import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3,16,5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))    # input(3, 32, 32) output(16, 28, 28)
        x = self.pool1(x)            # output(16, 14, 14)
        x = F.relu(self.conv2(x))    # output(32, 10, 10)
        x = self.pool2(x)            # output(32, 5, 5)
        x = x.view(-1, 32*5*5)       # output(32*5*5)
        x = F.relu(self.fc1(x))      # output(120)
        x = F.relu(self.fc2(x))      # output(84)
        x = self.fc3(x)              # output(10)
        return x



import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms


def main():
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    # 50000张训练图片
    # 第一次使用时要将download设置为True才会自动去下载数据集
    train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,
                                               shuffle=True, num_workers=0)

    # 10000张验证图片
    # 第一次使用时要将download设置为True才会自动去下载数据集
    val_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,
                                             shuffle=False, num_workers=0)
    val_data_iter = iter(val_loader)
    val_image, val_label = next(val_data_iter)
    
    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    net = LeNet()
    loss_function = nn.CrossEntropyLoss()#损失函数
    optimizer = optim.Adam(net.parameters(), lr=0.001)#优化器

    for epoch in range(5):  # loop over the dataset multiple times

        running_loss = 0.0
        for step, data in enumerate(train_loader, start=0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()#将历史损失梯度清零
            # forward + backward + optimize
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()#反向传播
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if step % 500 == 499:    # print every 500 mini-batches
                with torch.no_grad():#接下来计算过程中不计算损失梯度,更好分配内存
                    outputs = net(val_image)  # [batch, 10]
                    predict_y = torch.max(outputs, dim=1)[1]
                    accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)

                    print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %
                          (epoch + 1, step + 1, running_loss / 500, accuracy))
                    running_loss = 0.0

    print('Finished Training')

    save_path = './Lenet.pth'
    torch.save(net.state_dict(), save_path)


if __name__ == '__main__':
    main()
import torch
import torchvision.transforms as transforms
from PIL import Image

from model import LeNet


def main():
    transform = transforms.Compose(
        [transforms.Resize((32, 32)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    net = LeNet()
    net.load_state_dict(torch.load('Lenet.pth'))

    # im = Image.open('cat.jpg')

    im = Image.open('airplane.png').convert('RGB')  # 转换为RGB图像
    im = transform(im)  # [C, H, W]
    im = torch.unsqueeze(im, dim=0)  # [N, C, H, W]

    with torch.no_grad():
        outputs = net(im)
        predict = torch.max(outputs, dim=1)[1].numpy()
        predict = predict.item()  # 从数组中提取标量值
    print(classes[int(predict)])


if __name__ == '__main__':
    main()

测试结果
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2153311.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python 找到给定点集的简单闭合路径(Find Simple Closed Path for a given set of points)

给定一组点,将这些点连接起来而不相交 例子: 输入:points[] {(0, 3), (1, 1), (2, 2), (4, 4), (0, 0), (1, 2), (3, 1}, {3, 3}}; 输出:按以下顺序连接点将 不造成任何交叉 {(0, 0), (3, …

将sqlite3移植到开发板上

1、下载c源码 sqlite官网下载C源码:SQLite Download Page 点击第二个链接下载 2、解压 1、将下载好的c源码,放在linux下, 2、解压压缩包:tar -zxvf sqlite-autoconf-3460100 新建一个用存放 编译出来的文件: mkd…

Web开发:使用C#创建、安装、调试和卸载服务以及要注意的写法

目录 一、创建服务 1.创建项目(.NET Framework) 2.重命名 3.编写逻辑代码 二、安装服务 1.方案一:利用VS2022安装文件的配置 选择添加安装程序 安装文件的介绍及配置 ​编辑​ 重新编译 工具安装 2.方案二:编写bat脚本安…

Excel match 函数使用方法,和 index 函数是绝配

大家好,这里是效率办公指南! 🔎 在处理Excel数据时,我们经常需要找出特定数据在列表或数组中的位置。MATCH函数正是为此设计的,它可以返回一个值在指定数组中的相对位置。今天,我们将详细介绍MATCH函数的使…

MySQL基础(13)- MySQL数据类型

目录 一、数据类型概述 1.MySQL中的数据类型 二、整型 1.数据类型可选属性 2.使用建议 三、浮点数、定点数、位类型 1.类型介绍 2.浮点类型 3.定点数类型 4.位类型 四、日期时间类型 1.YEAR 2.DATE 3.TIME 4.DATETIME 5.TIMESTAMP 6.TIMESTAMP和DATETIME的区别…

Android Studio开发发布教程

本文讲解Android Studio如何发布APP。 在Android Studiobuild菜单栏下点击Generate Singed Bundle/APK…打开对话框。 选择APK点击Next 点击Create New...进行创建

Flask找上下文源码

1. app Flask(__name__) app.__call__ 1.1 按住 command 键 点击 进到这个函数里 1.2 接着找 return 看看返回什么 点进去 1.3 找到定义函数 1.4 点进去先看这个里边有啥 1.5 找到定义类 1.6 找到RequestContext 类 1.7 找到 RequestContext ---> 的push 方法 1.8 点击 _c…

Redis面试真题总结(二)

文章收录在网站:http://hardyfish.top/ 文章收录在网站:http://hardyfish.top/ 文章收录在网站:http://hardyfish.top/ 文章收录在网站:http://hardyfish.top/ Redis常见的数据类型有哪些? string 字符串 字符串类型…

Mobile net V系列详解 理论+实战(3)

Mobilenet 系列 论文精讲部分0.摘要1. 引文2. 引文3. 基础概念的讨论3.1 深度可分离卷积3.2 线性瓶颈3.3 个人理解 4. 模型架构细节5. 实验细节6. 实验讨论7. 总结 论文精讲部分 鉴于上一小节中采用的代码是V2的模型,因此本章节现对V2模型论文讲解,便于…

【C++二叉树】二叉树的前序遍历、中序遍历、后序遍历递归与非递归实现

1.二叉树的前序遍历 144. 二叉树的前序遍历 - 力扣(LeetCode) 前序遍历方式:根-左子树-右子树。 递归实现: 要传一个子函数来实先递归,原因是原函数返回值为vector,在原函数迭代,返回值就难…

基于python上门维修预约服务数据分析系统

目录 技术栈和环境说明解决的思路具体实现截图python语言框架介绍技术路线性能/安全/负载方面可行性分析论证python-flask核心代码部分展示python-django核心代码部分展示操作可行性详细视频演示源码获取 技术栈和环境说明 结合用户的使用需求,本系统采用运用较为广…

一、机器学习算法与实践_03概率论与贝叶斯算法笔记

1、概率论基础知识介绍 人工智能项目本质上是一个统计学项目,是通过对 样本 的分析,来评估/估计 总体 的情况,与数学知识相关联 高等数学 ——> 模型优化 概率论与数理统计 ——> 建模思想 线性代数 ——> 高性能计算 在机器学…

Qt窗口——对话框

文章目录 对话框自定义对话框对话框分类消息对话框QMessageBox使用示例自定义按钮快速构造对话框 颜色对话框QColorDialog文件对话框QFileDialog字体对话框QFontDialog输入对话框QInputDialog 对话框 对话框可以理解成一个弹窗,用于短期任务或者简洁的用户交互 Qt…

AIoT智能工控板

在当今竞争激烈的商业环境中,企业需要强大的科技力量来助力腾飞,AIoT智能工控板就是这样的力量源泉。 其领先的芯片架构设计,使得主板的性能得到了极大的提升。无论是数据的处理速度、图形的渲染能力,还是多任务的并行处理能力&a…

【Linux笔记】虚拟机内Linux内容复制到宿主机的Window文件夹(文件)中

一、共享文件夹 I、Windows宿主机上创建一个文件夹 目录:D:\Centos_iso\shared_files II、在VMware中设置共享文件夹 1、打开VMware Workstation 2、选择需要设置的Linux虚拟机,点击“编辑虚拟机设置”。 3、在“选项”标签页中,选择“共…

【BetterBench博士】2024年中国研究生数学建模竞赛 E题:高速公路应急车道紧急启用模型 问题分析

2024年中国研究生数学建模竞赛 E题:高速公路应急车道紧急启用模型 问题分析 更新进展 【BetterBench博士】2024 “华为杯”第二十一届中国研究生数学建模竞赛 选题分析 【BetterBench博士】2024年中国研究生数学建模竞赛 E题:高速公路应急车道紧急启用…

Remix在SPA模式下,出现ErrorBoundary错误页加载Ant Design组件报错,不能加载样式的问题

Remix是一个既能做服务端渲染,又能做单页应用的框架,如果想做单页应用,又想学服务端渲染,使用Remix可以降低学习成本。最近,在学习Remix的过程中,遇到了在SPA模式下与Ant Design整合的问题。 我用Remix官网…

简单多状态dp第三弹 leetcode -买卖股票的最佳时机问题

309. 买卖股票的最佳时机含冷冻期 买卖股票的最佳时机含冷冻期 分析: 使用动态规划解决 状态表示: 由于有「买入」「可交易」「冷冻期」三个状态,因此我们可以选择用三个数组,其中: ▪ dp[i][0] 表示:第 i 天结束后&#xff0c…

探索AI编程新时代:GitHub Copilot如何重塑开发者工作效率

在当今技术瞬息万变的时代,软件开发者们每天都在努力寻找更高效的编程方法。面对繁忙的工作日程和不断增加的项目压力,如何在编码过程中大幅提升效率成为了一个备受关注的话题。在众多工具中,GitHub Copilot以其强大的AI驱动能力脱颖而出&…

菜鸟也能轻松上手的Java环境配置方法

初学者学习Java这么编程语言,第一个难题往往是Java环境的配置,今天与大家详细地聊一聊,以便大家能独立完成配置方法和过程。 首先,找到“JDK”,点击“archive”: 向下滑,在“previous java rel…