第P7周:咖啡豆识别(VGG-16复现)

news2024/12/23 13:15:05

>- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/rbOOmire8OocQ90QM78DRA) 中的学习记录博客**
>- **🍖 原作者:[K同学啊 | 接辅导、项目定制](https://mtyjkh.blog.csdn.net/)**

一、前期工作

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings

warnings.filterwarnings("ignore")             #忽略警告信息

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device



2. 导入数据

import os,PIL,random,pathlib

data_dir = './7-data/'
data_dir = pathlib.Path(data_dir)

data_paths  = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[1] for path in data_paths]
classeNames

# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    # transforms.RandomHorizontalFlip(), # 随机水平翻转
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

test_transform = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

total_data = datasets.ImageFolder("./7-data/",transform=train_transforms)
total_data

3. 划分数据集





二、手动搭建VGG-16模型

VGG-16结构说明:

●13个卷积层(Convolutional Layer),分别用blockX_convX表示
●3个全连接层(Fully connected Layer),分别用fcX与predictions表示
●5个池化层(Pool layer),分别用blockX_pool表示

VGG-16包含了16个隐藏层(13个卷积层和3个全连接层),故称为VGG-16

这里,我制作了一个视频来展示VGG-16的传播过程

play

0:00/0:22

倍速

volumeUp

fullscreen

fullscreen

VGG-16网络动画展示

FC7

FE8

FC6

7*7*512

1*1*512

1*1*1000

1*1*512

CONV5

CONV4

14*14*512

CONV3

28*28*512

CONV2

56*56*256

112*112*128

CONVOLUTION+RELU

K同学啊制作

MAX POOLING

CONVI

百度/谷歌/微信搜索:K同学啊

224*224*64

FULLY CONNECTED+RELU

image.png


1. 搭建模型



2. 查看模型详情


三、 训练模型
1. 编写训练函数


2. 编写测试函数

测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器


3. 正式训练

model.train()、model.eval()训练营往期文章中有详细的介绍。

📌如果将优化器换成 SGD 会发生什么呢?请自行探索接下来发生的诡异事件的原因。



四、 结果可视化
1. Loss与Accuracy图

 

TRAINING AND VALIDATION ACCURACY

TRAINING AND

VALIDATION LOSS

1.4

1.0

TRAINING LOSS

1.2

TEST LOSS

1.0

0.8

0.8

0.6

0.6

0.4

0.4

0.2

TRAINING ACCURACY

TEST ACCURACY

0.0

75

35

15

35

25

20

40

30

40

30

output_29_0.png


2. 指定图片进行预测
 

Python复制代码

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

from PIL import Image

classes = list(total_data.class_to_idx)

def predict_one_image(image_path, model, transform, classes):

test_img = Image.open(image_path).convert('RGB')

plt.imshow(test_img) # 展示预测的图片

test_img = transform(test_img)

img = test_img.to(device).unsqueeze(0)

model.eval()

output = model(img)

_,pred = torch.max(output,1)

pred_class = classes[pred]

print(f'预测结果是:{pred_class}')

Python复制代码

1

2

3

4

5

# 预测训练集中的某张照片

predict_one_image(image_path='./7-data/Dark/dark (1).png',

model=model,

transform=train_transforms,

classes=classes)

Plain Text复制代码

1

预测结果是:Dark

25

50

75

100

125

150

175

200

50

100

150

200

output_32_1.png


3. 模型评估
 

Python复制代码

1

2

best_model.eval()

epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)

Python复制代码

1

epoch_test_acc, epoch_test_loss

Plain Text复制代码

1

(0.9916666666666667, 0.035762640996836126)

Python复制代码

1

2

# 查看是否与我们记录的最高准确率一致

epoch_test_acc

Plain Text复制代码

1

0.9916666666666667

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1313431.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux 常用命令----mktemp 命令

文章目录 基本用法实例演示高级用法注意事项 mktemp 命令用于创建一个临时文件或目录,这在需要处理临时数据或进行安全性测试时非常有用。使用 mktemp 可以保证文件名的唯一性,避免因文件名冲突而导致的问题。 基本用法 创建临时文件: 命令 mktemp 默认…

【IC验证】perl脚本——分析前/后仿用例回归情况

目录 1 脚本名称 2 脚本使用说明 3 nocare_list文件示例 4 脚本执行方法 5 postsim_result.log文件示例 6 脚本代码 1 脚本名称 post_analysis 2 脚本使用说明 help:打印脚本说明信息 命令:post_analysis help 前/后仿结束后,首先填…

餐饮品牌小红书探店怎么做?建议收藏

对于餐饮行业来说,消费者的口味和选择越来越多样化,如何在众多竞争者中脱颖而出,成为消费者的心头好,探店活动便应运而生。小红书作为国内知名的社交电商平台,拥有庞大的用户群体,特别是年轻人和美食爱好者…

多域名https证书购买选择

多域名https证书是一种特殊的SSL证书,它允许一个证书同时保护多个域名,并且不限制域名的类型,可以保护多个域名和子域名,确保网站传输信息时不被窃取、篡改。那么我们该怎么选择符合需求的多域名https证书呢?今天就随S…

Python神器:快速删除文本文件中指定行的方法

1. 简介 文件操作是编程中的重要方面。Python作为强大的编程语言,提供了处理文件的能力。删除特定行是文件处理中常见的需求。 2. 打开文件和读取内容 当打开文件并读取其内容时,open()函数和with语句是Python中常用的工具。以下是展示如何使用它们的…

由@EnableWebMvc注解引发的Jackson解析异常

同事合了代码到开发分支,并没有涉及到改动的类却报错。错误信息如下: Servlet.service() for servlet [dispatcherServlet] in context with path [] threw exception [Request processing failed; nested exception is org.springframework.http.conv…

第1章:企业级研发测试流程

通过实际(自研互联网)企业的研发流程一览图。 我们发现分为9个阶段,当然每个公司细节并不一样。 所以我希望你能理解这句话: 一切的流程、行为、结果都是围绕“产品质量”这4个字开展活动。而作为测试,你该考虑的是如何…

数据结构与算法—查找算法(线性查找、二分查找、插值查找、斐波那契查找)

查找算法 文章目录 查找算法1. 线性查找算法2. 二分查找算法2.1 二分查找思路分析2.2 应用实例 3. 插值查找3.1 基本原理3.2 应用实例 4. 斐波那契4.1 基本原理4.2 应用实例 5. 查找总结 在java中,常用的查找有四种: 顺序(线性)查找二分查找/折半查找插值…

JS基本语法

JS基本语法 变量数据类型原始数据类型 函数定义第一种方式第二种方式 JS 对象ArrayStringJavaScript 自定义对象JSONDOMBOM JS 事件事件监听事件绑定常见事件 变量 数据类型 原始数据类型 函数定义 第一种方式 第二种方式 JS 对象 Array String JavaScript 自定义对象 JSON …

2023版本QT学习记录 -2- 标准文件对话框

头文件的使用 #include "QFileDialog"函数原型 getOpenFileName效果 参数 未完待续

电阻的运用

本文引注 https://baijiahao.baidu.com/s?id1749115196647029942&wfrspider&forpc 一、零欧电阻 在电子电路设计时经常用到的一种元件就是电阻,我们都知道电阻在电路中起到分压限流的作用。然而,实际使用时会用到一种特殊的电阻:零…

3ds max软件中的一些常用功能分享!

3ds max软件有很多小伙伴反馈说,明明有很多3ds max教程资料。却不知道如何入门3dmax。 掌握3dmax基本功能是开始使用3dmax的基础之一,所以,小编带大家盘点一下3dmax常用操作。 3dmax常用功能介绍如下,快快跟着小编一起看起来。 1…

[渗透测试学习] Codify - HackTheBox

首先nmap扫描端口 nmap -sV -sC -p- -v --min-rate 1000 10.10.11.239扫出来三个端口,22端口为ssh服务,80端口有http服务,3000端口为nodejs框架 尝试访问下80端口,发现页面重定向 将该域名添加到hosts里 sudo vim /etc/hosts 成…

优先考虑泛型

Java中的泛型(Generics)提供了一种参数化类型的机制,使得你可以编写更灵活、类型安全的代码。下面是一个例子,说明在Java中优先考虑泛型的好处: 考虑一个简单的容器类,它可以存储任意类型的元素&#xff0…

三种好用的在线色彩提取工具

#三种好用的在线色彩提取工具 1.ecjson网站 网址: https://www.ecjson.com/image_color#b1cfea 或点击链接: ecjson在线色彩提取 图1 ecjson网站色彩提取举例 2.微查网 网址:http://zxqsq.wiicha.com/ 或点击链接: 微查网在线色彩提取 …

JaveEE:手动实现定时器精讲

前言 在Java并发编程学习中,定时器是必不可少的环节。 我们知道线程的调度是随机的,但是有的时候我们就是需要它有序一些,此时的定时器就可以很好的解决这个问题。它可以按照一定的先后顺序,将我们的任务依次执行。 目录 一.Java官…

elementui + vue2实现表格行的上下移动

场景&#xff1a; 如上&#xff0c;要实现表格行的上下移动 实现&#xff1a; <el-dialogappend-to-bodytitle"条件编辑":visible.sync"dialogVisible"width"60%"><el-table :data"data1" border style"width: 100%&q…

mybatis高级扩展-插件和分页插件PageHelper

1、建库建表 create database mybatis-example; use mybatis-example; create table emp (empNo varchar(40),empName varchar(100),sal int,deptno varchar(10) ); insert into emp values(e001,张三,8000,d001); insert into emp values(e002,李四,9000,d001); insert into…

MATLAB代码:分布式电源接入对配电网影响分析

微♥关注“电击小子程高兴的MATLAB小屋”获取专属优惠 关键词&#xff1a;分布式电源 配电网 评估 仿真平台&#xff1a;MATLAB 主要内容&#xff1a;代码主要做的是分布式电源接入场景下对配电网运行影响的分析&#xff0c;其中&#xff0c;可以自己设置分布式电源接入…

【信息学奥赛】拼在起跑线上,想入道就别落下自己!

编程无难事&#xff0c;只怕有心人&#xff0c;学就是了&#xff01; 文章目录 1 信息学奥赛简介2 信息学竞赛的经验回顾3 优秀参考图书推荐《信息学奥赛一本通关》4 高质量技术圈开放 1 信息学奥赛简介 信息学奥赛&#xff0c;作为全国中学生学科奥林匹克“五大学科竞赛”之一…