【Week-P7】VGG16识别咖啡豆

news2024/11/28 6:43:48

Week-P7 VGG16识别咖啡豆

  • 一、环境配置
  • 二、准备数据
  • 三、搭建网络结构 VGG16
  • 四、开始训练
  • 五、查看训练结果
  • 六、改变优化器,VSCode运行

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

(1)本次学习分两次执行,第一次是优化器为Adam时,使用Jupyter Notebook运行;第二次修改优化器为SGD,使用VSCode运行,需要注意一点,VSCode遇到PIPEERROR时需要修改DataLoader中的num_workers=0,参考这里;

(2)本次学习在jupyter运行时,需要安装torchsummary,因为时第一次用到,之前也没有安装过,后续在VSCode里运行,不需要重复安装;
在这里插入图片描述

(3)Adam优化器训练后,Test_Acc=100%。而SGD优化器训练后,Test_Acc始终为22.1%
(4)用VSCode训练的速度比Jupyter NoteBook快超级多,Jupyter大概需要一整天,而且还会显得电脑卡,用VSCode大概4~5小时可以完成,VSCode的效率高太多。

一、环境配置

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings

warnings.filterwarnings("ignore")             #忽略警告信息

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

在这里插入图片描述

二、准备数据

2.1 打印classeNames列表,显示每个文件所属的类别名称
2.2 打印归一化后的类别名称,01
2.3 划分数据集,划分为训练集&测试集,torch.utils.data.DataLoader()参数详解
2.4 检查数据集的shape

import os,PIL,random,pathlib

data_dir = 'D:/jupyter notebook/DL-100-days/datasets/coffebeans-data/'
data_dir = pathlib.Path(data_dir)

data_paths  = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[5] for path in data_paths]
print("classeNames: ", classeNames)

在这里插入图片描述

# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    # transforms.RandomHorizontalFlip(), # 随机水平翻转
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

test_transform = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

total_data = datasets.ImageFolder("D:/jupyter notebook/DL-100-days/datasets/coffebeans-data/",transform=train_transforms)
print("total_data: ",total_data)

在这里插入图片描述

total_data.class_to_idx

在这里插入图片描述

train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
print("train_dataset: ", train_dataset)
print("test_dataset: ", test_dataset)

在这里插入图片描述

batch_size = 32

train_dl = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=1)

在这里插入图片描述

三、搭建网络结构 VGG16

import torch.nn.functional as F

class vgg16(nn.Module):
    def __init__(self):
        super(vgg16, self).__init__()
        # 卷积块1
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        # 卷积块2
        self.block2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        # 卷积块3
        self.block3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        # 卷积块4
        self.block4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        # 卷积块5
        self.block5 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        )
        
        # 全连接网络层,用于分类
        self.classifier = nn.Sequential(
            nn.Linear(in_features=512*7*7, out_features=4096),
            nn.ReLU(),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(),
            nn.Linear(in_features=4096, out_features=4)
        )

    def forward(self, x):

        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)

        return x

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
    
model = vgg16().to(device)
model

在这里插入图片描述

# 统计模型参数量以及其他指标
import torchsummary as summary
summary.summary(model, (3, 224, 224))

在这里插入图片描述

四、开始训练

4.1 设置超参数
4.2 编写训练函数
4.3 编写测试函数
4.4 开始正式训练,epochs==40

# 训练循环
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小
    num_batches = len(dataloader)   # 批次数目, (size/batch_size,向上取整)

    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率
    
    for X, y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)
        
        # 计算预测误差
        pred = model(X)          # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
        
        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()        # 反向传播
        optimizer.step()       # 每一步自动更新
        
        # 记录acc与loss
        train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()
            
    train_acc  /= size
    train_loss /= num_batches

    return train_acc, train_loss
def test (dataloader, model, loss_fn):
    size        = len(dataloader.dataset)  # 测试集的大小
    num_batches = len(dataloader)          # 批次数目, (size/batch_size,向上取整)
    test_loss, test_acc = 0, 0
    
    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)
            
            # 计算loss
            target_pred = model(imgs)
            loss        = loss_fn(target_pred, target)
            
            test_loss += loss.item()
            test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc  /= size
    test_loss /= num_batches

    return test_acc, test_loss
'''
如果将优化器换成 SGD 会发生什么呢?请自行探索接下来发生的诡异事件的原因。
--> 稍后会在VSCode里面实现,Jupyter 实在是太太太太太慢了。。。。。
'''
import copy

optimizer  = torch.optim.Adam(model.parameters(), lr= 1e-4)
loss_fn    = nn.CrossEntropyLoss() # 创建损失函数

epochs     = 40

train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []

best_acc = 0    # 设置一个最佳准确率,作为最佳模型的判别指标

for epoch in range(epochs):
    
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)
    
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
    
    # 保存最佳模型到 best_model
    if epoch_test_acc > best_acc:
        best_acc   = epoch_test_acc
        best_model = copy.deepcopy(model)
    
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    # 获取当前的学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, 
                          epoch_test_acc*100, epoch_test_loss, lr))
    
# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)

print('Done')

在这里插入图片描述
最终训练结果是:Test_acc=100%(记得之前有个老师说,准确率不可能完全到达100%,只会无限趋近100%,所以不知道这里的100%准确率会不会有问题。)

五、查看训练结果

5.1 绘制Accurac_Loss图
5.2 指定图片进行预测

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

from PIL import Image 

classes = list(total_data.class_to_idx)

def predict_one_image(image_path, model, transform, classes):
    
    test_img = Image.open(image_path).convert('RGB')
    plt.imshow(test_img)  # 展示预测的图片

    test_img = transform(test_img)
    img = test_img.to(device).unsqueeze(0)
    
    model.eval()
    output = model(img)

    _,pred = torch.max(output,1)
    pred_class = classes[pred]
    print(f'预测结果是:{pred_class}')
    
# 预测训练集中的某张照片
predict_one_image(image_path='D:/jupyter notebook/DL-100-days/datasets/coffebeans-data/Dark/dark (1).png', 
                  model=model, 
                  transform=train_transforms, 
                  classes=classes)

在这里插入图片描述

六、改变优化器,VSCode运行

使用VSCode运行,需要修改/增加以下几个地方:

  • (1)安装pytorch相关库:
    在这里插入图片描述
  • (2)遇到PIPEERROR,修改DataLoader()中的num_workers的值:
    在这里插入图片描述
    训练结果如下:
    在这里插入图片描述
    已经训练了14个epoch,Train_acc、Train_loss、Test_acc、Test_loss、Lr等的值始终没发生变化,始终是22.1%,所以我暂停训练了。

Q:为什么SGD优化器会导致这样的现象?
参考这里
仔细分析两种优化器的优化原理,Adam对设置的初始学习率不敏感,可以在很广泛的区间内优化到一个较好的参数,SGD在很小的区间内可以得到很好的参数,如果范围太大,可能会得到一个当前区间内的最优解,但是对于全局来说却不是最优解。(猜测可能是这个原因。)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1425545.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AI工具【OCR 01】Java可使用的OCR工具Tess4J使用举例(身份证信息识别核心代码及信息提取方法分享)

Java可使用的OCR工具Tess4J使用举例 1.简介1.1 简单介绍1.2 官方说明 2.使用举例2.1 依赖及语言数据包2.2 核心代码2.3 识别身份证信息2.3.1 核心代码2.3.2 截取指定字符2.3.3 去掉字符串里的非中文字符2.3.4 提取出生日期(待优化)2.3.5 实测 3.总结 1.简…

阿里云AI通义千问出bug,解决不了直接弃,开始对国产AI由支持变失望

AI怀疑人生 引言对比出大问题思考尝试解决代码结尾 引言 今天的第二篇 原本是想写这个爬取什么值得买 延续零基础爬什么值得买的榜单——爬虫练习题目一(答一) 但没想到 这个阿里云的AI 通义千问 删了我很多的对话 也就是说 我之前一直提问的AI角色没了…

VxTerm:C++ MFC,在工具栏中增加Edit/ComboBox等组件,打造一个地址栏/搜索栏功能

VxTerm软件可以在本站链接下载:唯一国产化SSH工具下载,单文件纯绿色不需要安装,替代SecureCRT 在软件的主界面中,增加了一个地址栏功能。 本人的文章内容都是经本人亲自实现并验证成功的干货,关注我,互相交…

代理模式详解(重点解析JDK动态代理)

- 定义 在解析动态代理模式之前,先简单看下整个代理模式。代理模式分为普通代理、强制模式、动态代理模式。其中动态代理模式主要实现方式为Java JDK提供的JDK动态代理,第三方类库提供的,例如CGLIB动态代理。 代理模式就是为其他对象提供一种…

【百度Apollo】自动驾驶规划技术:实现安全高效的智能驾驶

🎬 鸽芷咕:个人主页 🔥 个人专栏:《linux深造日志》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! ⛳️ 推荐 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下…

与数组相关经典面试题

𝙉𝙞𝙘𝙚!!👏🏻‧✧̣̥̇‧✦👏🏻‧✧̣̥̇‧✦ 👏🏻‧✧̣̥̇:Solitary-walk ⸝⋆ ━━━┓ - 个性标签 - :来于“云”的“羽球人”。…

安卓网格布局GridLayout

<?xml version"1.0" encoding"utf-8"?> <GridLayout xmlns:android"http://schemas.android.com/apk/res/android"xmlns:tools"http://schemas.android.com/tools"android:layout_width"match_parent"android:la…

C#用正则表达式Regex.Matches 方法检查字符串中重复出现的词

目录 一、Regex.Matches 方法 1.重载 二、Matches(String, String, RegexOptions, TimeSpan) 1.定义 2.示例 三、Matches(String, String, RegexOptions) 1.定义 2.示例 3.示例&#xff1a;用正则表达式检查字符串中重复出现的词 四、Matches(String, Int32) 1.定义…

Docker容器化安装SonarQube9.9

文章目录 1.环境准备1.1 版本信息1.2 系统设置 2.Docker环境安装2.1 卸载旧版本2.2 设置源2.3 安装Docker2.4 设置阿里仓库2.5 启动Docker 3.Docker Compose4.登录4.1 首页4.2 安装插件 5.制作镜像离线安装 1.环境准备 1.1 版本信息 名称版本备注Docker25.0.1当前2024-01-01最…

【设计模式】六大原则详解,每个原则提供代码示例

设计模式六大原则 目录 一、单一职责原则——SRP 1、作用2、基本要点3、举例 二、开放封闭原则——OCP 1、作用2、基本要点3、举例 三、里氏替换原则——LSP 1、作用2、基本要点3、举例 四、依赖倒置原则——DLP 1、作用2、基本要点3、举例 五、迪米特法则——LoD 1、作用2、…

Arcgis10.3安装

所需软件地址 链接&#xff1a;https://pan.baidu.com/s/1aAykUDjkaXjdwFjDvAR83Q?pwdbs2i 提取码&#xff1a;bs2i 1、安装License Manager 点击License Manager.exe&#xff0c;默认下一步。 安装完&#xff0c;点击License Server Administrator&#xff0c;停止服务。…

Java基础 集合(二)List详解

目录 简介 数组与集合的区别如下&#xff1a; 介绍 AbstractList 和 AbstractSequentialList Vector 替代方案 Stack ArrayList LinkedList 前言-与正文无关 生活远不止眼前的苦劳与奔波&#xff0c;它还充满了无数值得我们去体验和珍惜的美好事物。在这个快节奏的世界…

IT行业中最重要的证书

在IT行业&#xff0c;拥有一些含金量较高的证书是职业发展的关键。这些证书不仅可以证明技能水平&#xff0c;还有助于提升在职场上的竞争力。本文将介绍几个IT行业中最重要的证书。 1. Cisco认证 CCNA&#xff08;Cisco Certified Network Associate&#xff09;是Cisco公司新…

ArcGIS学习(二)属性表的基本操作

ArcGIS学习(二)属性表的基本操作 1.查看属性表 ArcGIS是处理空间数据的平台。对于空间数据,大家可以理解成它是由两个部分构成:1.一个是空间形体,也就是点、线、面三种。线又可以分为直线、曲线,面又分为圆形、正方形、不规则形体等;2.另外一个部分是空间形体所附带的…

【深蓝学院】移动机器人运动规划--第3章 基于采样的路径规划--笔记

0. Preliminaries 做规划都是将WS转到C space下进行。 找到可行解和最优解&#xff08;这两个不同&#xff09; 通过增量或者批次地在C-space中采样来增量式地构建树或者图。 不显式地构造 如果把整个规划问题看成一个大的优化问题&#xff0c;那么大问题可以拆分成小问题进行…

09. 异常处理

目录 1、前言 2、常见的异常 3、异常处理try...except...finally 4、异常信息解读 5、raise 6、自定义异常 7、小结 1、前言 在编程中&#xff0c;异常&#xff08;Exception&#xff09;是程序在运行期间检测到的错误或异常状况。当程序执行过程中发生了一些无法继续执…

Monday.com替代工具大盘点:哪个更适合您的团队协作需求?

市场上有许多项目管理软件解决方案&#xff0c;每个都有自己的优点和缺点&#xff0c;根据您的具体需求和要求&#xff0c;市场上有8种可用的项目管理软件可以作为Monday.com的替代工具&#xff0c;分别是&#xff1a;Zoho Projects、Trello、Asana、Wrike、Basecamp、JIRA、Mi…

Hadoop-生产调优(更新中)

第1章 HDFS-核心参数 1.1 NameNode内存生产配置 1&#xff09;NameNode 内存计算 每个文件块大概占用 150 byte&#xff0c;一台服务器 128G 内存为例&#xff0c;能存储多少文件块呢&#xff1f; 128 * 1024 * 1024 * 1024 / 150byte ≈ 9.1 亿G MB KB Byte 2&#xff09…

【Redis】签到点赞和UV统计

Redis签到点赞和UV统计 点赞 点赞功能分析 需求&#xff1a; 同一个用户只能点赞一次&#xff0c;再次点击则取消点赞如果当前用户已经点赞&#xff0c;则点赞按钮高亮显示&#xff08;前端判断字段isLike属性&#xff09; 实现步骤&#xff1a; 利用Redis的set集合判断是…

安科瑞光伏、储能一体化监控及运维解决方案

上海安科瑞电气股份有限公司 胡冠楠 咨询家&#xff1a;“Acrelhgn”&#xff0c;了解更多产品资讯 前言 今年以来&#xff0c;在政策利好推动下光伏、风力发电、电化学储能及抽水蓄能等新能源行业发展迅速&#xff0c;装机容量均大幅度增长&#xff0c;新能源发电已经成为新…