【深度学习实验】卷积神经网络(六):卷积神经网络模型(VGG)训练、评价

news2024/12/25 16:36:56

目录

一、实验介绍

二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. 构建数据集(CIFAR10Dataset)

a. read_csv_labels()

b. CIFAR10Dataset

2. 构建模型(FeedForward)

3.整合训练、评估、预测过程(Runner)

4. __main__

代码整合


一、实验介绍

        本实验实现了一个简化版VGG网络,并基于此完成图像分类任务。(包括模型训练、评价)
       

        VGG网络是深度卷积神经网络中的经典模型之一,由牛津大学计算机视觉组(Visual Geometry Group)提出。它在2014年的ImageNet图像分类挑战中取得了优异的成绩(分类任务第二,定位任务第一),被广泛应用于图像分类、目标检测和图像生成等任务。

        VGG网络的主要特点是使用了非常小的卷积核尺寸(通常为3x3)和更深的网络结构。该网络通过多个卷积层和池化层堆叠在一起,逐渐增加网络的深度,从而提取图像的多层次特征表示。VGG网络的基本构建块是由连续的卷积层组成,每个卷积层后面跟着一个ReLU激活函数。在每个卷积块的末尾,都会添加一个最大池化层来减小特征图的尺寸。VGG网络的这种简单而有效的结构使得它易于理解和实现,并且在不同的任务上具有很好的泛化性能。

        VGG网络有几个不同的变体,如VGG11、VGG13、VGG16和VGG19,它们的数字代表网络的层数。这些变体在网络深度和参数数量上有所区别,较深的网络通常具有更强大的表示能力,但也更加复杂。

二、实验环境

    本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        卷积神经网络(Convolutional Neural Network,简称CNN)是一种深度学习模型,广泛应用于图像识别、计算机视觉和模式识别等领域。它的设计灵感来自于生物学中视觉皮层的工作原理。

        卷积神经网络通过多个卷积层、池化层全连接层组成。

  • 卷积层主要用于提取图像的局部特征,通过卷积操作和激活函数的处理,可以学习到图像的特征表示。
  • 池化层则用于降低特征图的维度,减少参数数量,同时保留主要的特征信息。
  • 全连接层则用于将提取到的特征映射到不同类别的概率上,进行分类或回归任务。

        卷积神经网络在图像处理方面具有很强的优势,它能够自动学习到具有层次结构的特征表示,并且对平移、缩放和旋转等图像变换具有一定的不变性。这些特点使得卷积神经网络成为图像分类、目标检测、语义分割等任务的首选模型。除了图像处理,卷积神经网络也可以应用于其他领域,如自然语言处理和时间序列分析。通过将文本或时间序列数据转换成二维形式,可以利用卷积神经网络进行相关任务的处理。

0. 导入必要的工具包

import torch
from torch import nn
import torch.nn.functional as F

1. 构建数据集(CIFAR10Dataset)

a. read_csv_labels()

        从CSV文件中读取标签信息并返回一个标签字典。

def read_csv_labels(fname):
    """读取fname来给标签字典返回一个文件名"""
    with open(fname, 'r') as f:
        # 跳过文件头行(列名)
        lines = f.readlines()[1:]
    tokens = [l.rstrip().split(',') for l in lines]
    return dict(((name, label) for name, label in tokens))
  •  使用open函数打开指定文件名的CSV文件,并将文件对象赋值给变量f。这里使用'r'参数以只读模式打开文件。

  • 使用文件对象的readlines()方法读取文件的所有行,并将结果存储在名为lines的列表中。通过切片操作[1:],跳过了文件的第一行(列名),将剩余的行存储在lines列表中。

  • 列表推导式(list comprehension):对lines列表中的每一行进行处理。对于每一行,使用rstrip()方法去除行末尾的换行符,并使用split(',')方法将行按逗号分割为多个标记。最终,将所有行的标记组成的子列表存储在tokens列表中。

  • 使用字典推导式(dictionary comprehension)将tokens列表中的子列表转换为字典。对于tokens中的每个子列表,将子列表的第一个元素作为键(name),第二个元素作为值(label),最终返回一个包含这些键值对的字典。

b. CIFAR10Dataset

class CIFAR10Dataset(Dataset):
    def __init__(self, folder_path, fname):
        self.labels = read_csv_labels(os.path.join(folder_path, fname))
        self.folder_path = os.path.join(folder_path, 'train')

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img = read_image(self.folder_path + '/' + str(idx + 1) + '.png')
        label = self.labels[str(idx + 1)]

        return img, torch.tensor(int(label))
  • 构造函数:

    • 接受两个参数

      • folder_path表示数据集所在的文件夹路径

      • fname表示包含标签信息的文件名。

    • 调用read_csv_labels函数,传递folder_pathfname作为参数,以读取CSV文件中的标签信息,并将返回的标签字典存储在self.labels变量中。

    • 通过拼接folder_path和字符串'train'来构建数据集的文件夹路径,将结果存储在self.folder_path变量中。

  • def __len__(self)

    • 这是CIFAR10Dataset类的方法,用于返回数据集的长度,即样本的数量。

  • def __getitem__(self, idx): 这是CIFAR10Dataset类的方法,用于根据给定的索引idx获取数据集中的一个样本。它首先根据索引idx构建图像文件的路径,并调用read_image函数来读取图像数据,将结果存储在img变量中。然后,它通过将索引转换为字符串,并使用该字符串作为键来从self.labels字典中获取相应的标签,将结果存储在label变量中。最后,它返回一个元组,包含图像数据和经过torch.tensor转换的标签。

2. 构建模型(FeedForward)

        参考前文:

【深度学习实验】卷积神经网络(五):深度卷积神经网络经典模型——VGG网络(卷积层、池化层、全连接层)_QomolangmaH的博客-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/m0_63834988/article/details/133350927?spm=1001.2014.3001.5501

3.整合训练、评估、预测过程(Runner)

        参考前文:

【深度学习实验】前馈神经网络(九):整合训练、评估、预测过程(Runner)_QomolangmaH的博客-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/m0_63834988/article/details/133219448?spm=1001.2014.3001.5501

        (略有改动:)

class Runner(object):
    def __init__(self, model, optimizer, loss_fn, metric=None):
        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        # 用于计算评价指标
        self.metric = metric
        
        # 记录训练过程中的评价指标变化
        self.dev_scores = []
        # 记录训练过程中的损失变化
        self.train_epoch_losses = []
        self.dev_losses = []
        # 记录全局最优评价指标
        self.best_score = 0
   
 
# 模型训练阶段
    def train(self, train_loader, dev_loader=None, **kwargs):
        # 将模型设置为训练模式,此时模型的参数会被更新
        self.model.train()
        
        num_epochs = kwargs.get('num_epochs', 0)
        log_steps = kwargs.get('log_steps', 100)
        save_path = kwargs.get('save_path','best_model.pth')
        eval_steps = kwargs.get('eval_steps', 0)
        # 运行的step数,不等于epoch数
        global_step = 0
        
        if eval_steps:
            if dev_loader is None:
                raise RuntimeError('Error: dev_loader can not be None!')
            if self.metric is None:
                raise RuntimeError('Error: Metric can not be None')
                
        # 遍历训练的轮数
        for epoch in range(num_epochs):
            total_loss = 0
            # 遍历数据集
            for step, data in enumerate(train_loader):
                x, y = data
                logits = self.model(x.float())
                loss = self.loss_fn(logits, y.long())
                total_loss += loss
                if step%log_steps == 0:
                    print(f'loss:{loss.item():.5f}')
                    
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
            # 每隔一定轮次进行一次验证,由eval_steps参数控制,可以采用不同的验证判断条件
            if eval_steps != 0 :
                if (epoch+1) % eval_steps ==  0:

                    dev_score, dev_loss = self.evaluate(dev_loader, global_step=global_step)
                    print(f'[Evalute] dev score:{dev_score:.5f}, dev loss:{dev_loss:.5f}')
                
                    if dev_score > self.best_score:
                        self.save_model(f'model_{epoch+1}.pth')
                    
                        print(f'[Evaluate]best accuracy performance has been updated: {self.best_score:.5f}-->{dev_score:.5f}')
                        self.best_score = dev_score
                    
                # 验证过程结束后,请记住将模型调回训练模式   
                    self.model.train()
            
            global_step += 1
            # 保存当前轮次训练损失的累计值
            train_loss = (total_loss/len(train_loader)).item()
            self.train_epoch_losses.append((global_step,train_loss))
        self.save_model(f'{save_path}.pth')   
        print('[Train] Train done')
        
    # 模型评价阶段
    def evaluate(self, dev_loader, **kwargs):
        assert self.metric is not None
        # 将模型设置为验证模式,此模式下,模型的参数不会更新
        self.model.eval()
        global_step = kwargs.get('global_step',-1)
        total_loss = 0
        self.metric.reset()
        
        for batch_id, data in enumerate(dev_loader):
            x, y = data
            logits = self.model(x.float())
            loss = self.loss_fn(logits, y.long()).item()
            total_loss += loss 
            self.metric.update(logits, y)
            
        dev_loss = (total_loss/len(dev_loader))
        self.dev_losses.append((global_step, dev_loss))
        dev_score = self.metric.accumulate()
        self.dev_scores.append(dev_score)
        return dev_score, dev_loss
    
    # 模型预测阶段,
    def predict(self, x, **kwargs):
        self.model.eval()
        logits = self.model(x)
        return logits
    
    # 保存模型的参数
    def save_model(self, save_path):
        torch.save(self.model.state_dict(), save_path)
        
    # 读取模型的参数
    def load_model(self, model_path):
        self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

4. __main__

 batch_size = 20
    # 构建训练集
    train_data = CIFAR10Dataset('cifar10_tiny', 'trainLabels.csv')
    train_iter = DataLoader(train_data, batch_size=batch_size)
    # 构建测试集
  



    num_classes = 10
    # 定义模型
    model = VGG_S(num_classes)
    # 定义损失函数
    loss_fn = F.cross_entropy
    # 定义优化器
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

    runner = Runner(model, optimizer, loss_fn, metric=None)
    runner.train(train_iter, num_epochs=10, save_path='chapter_5')

本文有待进一步完善……

代码整合

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1049125.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Chrome(谷歌浏览器)如何关闭搜索栏历史记录

目录 问题描述解决方法插件解决(亲测有效)自带设置解决步骤首先打开 地址 输入:chrome://flags关闭浏览器,重新打开Chrome 发现 已经正常 问题描述 Chrome是大家熟知的浏览器,但是搜索栏的历史记录如何自己一条条的删…

学校宿舍一键视频对讲

学校宿舍一键视频对讲 大学宿舍一键视频对讲是指在大学宿舍内安装一套视频对讲系统,通过一键操作,实现与宿舍内其他人进行视频通话的功能。 该系统通常包括以下组成部分: 1. 室内终端:每个宿舍内安装一个室内终端,室…

JavaScript求数组的交集和差集

1. 求交集(从2个数组中找到相同的元素, 组成新数组, 注意去重): 1) Setfilterincludes // 求交集: const arr1 [0, 1, 2] const arr2 [3, 2, 0] function intersectSet(arr1, arr2) {return [...new Set(arr1)].filter(item>arr2.includes(item)) } const values inter…

26593-2011 无损检测仪器 工业用X射线CT装置性能测试方法

声明 本文是学习GB-T 26593-2011 无损检测仪器 工业用X射线CT装置性能测试方法. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本标准规定了工业用X 射线CT 装置(以下简称CT 装置)性能测试的术语、定义、缩略语以及空间 分辨力、密度分辨率…

BChecks 自定义poc检测 - 把BurpSuite 打造成强大的漏洞扫描器

BChecks是什么? BChecks可以创建和导入的自定义扫描检查。Burp Scanner在执行其内置扫描例程的同时运行这些检查,帮助您定位扫描并使测试工作流尽可能高效。 每个BCheck都定义为一个以.bcheck文件扩展名结尾的纯文本文件。这些文件使用自定义语言来指定…

配置OSPF路由

OSPF路由 1.OSPF路由 1.1 OSPF简介 OSPF(Open Shortest Path First,开放式最短路径优先)路由协议是另一个比较常用的路由协议之一,它通过路由器之间通告网络接口的状态,使用最短路径算法建立路由表。在生成路由表时,…

Spring Cloud Netflix 教程和源码

本教程目标 想要系统地学习 Spring Cloud Netflix, 把自己的学习过程记录下来。 状态 持续更新中 微服务架构 微服务架构是一种将应用程序拆分为一组独立的、可独立部署的服务的架构模式。每个服务都运行在自己的进程中,可以独立地进行开发、测试和…

数据库管理-第108期 因Exadata存储节点操作系统空间异常的紧急处理(20230928)

数据库管理-第108期 因Exadata存储节点操作系统空间异常的紧急处理(20230928) 众所周知,明天放假了,本着对客户数据库软硬件负责任的态度,进行了一次深入彻底的软硬件巡检(就是检查包括计算节点、存储节点…

vue3中状态适配

写一个函数,在函数中定义一个对象 用于存放键值对,最后返回指定状态所对应的的值,即对象[指定状态] 的 对象的值。 在模板中把状态传入 // vue3 setup语法糖中 const formatXXXState (xxxState)>{const stateMap {键1: 值1,键2: 值2,.…

Linux-正则三剑客

目录 一、正则简介 1.正则表达式分两类: 2.正则表达式的意义 二、Linux三剑客简介 1.文本处理工具,均支持正则表达式引擎 2.正则表达式分类 3.基本正则表达式BRE集合 4.扩展正则表达式ere集合 三、grep 1.简介 2.实践 3.贪婪匹配 四、sed …

VS+Qt+opencascade三维绘图stp/step/igs/stl格式图形读取显示

程序示例精选 VSQtopencascade三维绘图stp/step/igs/stl格式图形读取显示 如需安装运行环境或远程调试,见文章底部个人QQ名片,由专业技术人员远程协助! 前言 这篇博客针对《VSQtopencascade三维绘图stp/step/igs/stl格式图形读取显示》编写…

postman安装使用教程

本文只是基于 Chrome 浏览器的扩展插件来进行的安装,并非单独应用程序。 首先,你要台电脑,其次,安装有 Chrome 浏览器,那你接着往下看吧。 1. 官网安装(别看) 打开官网,https://ww…

【计算机网络】P2P文件分发介绍

文章目录 P2P体系结构的自扩展性BitTorrent协议参考资料 考虑一个场景:从单一服务器向大量主机(称为对等方)分发一个大文件。 两种处理方式 客户-服务器文件分发:服务器需要向每个对等方发送该文件的一个副本 P2P文件分发&#xf…

使用代理后pip install 出现ssl错误

window直接设置代理 httphttp://127.0.0.1:7890;httpshttp://127.0.0.1

Java 并发编程面试题——BlockingQueue

目录 1.什么是阻塞队列 (BlockingQueue)?2.BlockingQueue 有哪些核心方法?3.BlockingQueue 有哪些常用的实现类?3.1.ArrayBlockingQueue3.2.DelayQueue3.3.LinkedBlockingQueue3.4.PriorityBlockingQueue3.5.SynchronousQueue 4.✨BlockingQu…

【C++】构造函数和析构函数第二部分(拷贝构造函数)--- 2023.9.28

目录 什么是拷贝构造函数?编译器默认的拷贝构造函数构造函数的分类及调用结束语 什么是拷贝构造函数? 用一句话来描述为拷贝构造即 “用一个已知的对象去初始化另一个对象” 具体怎么使用我们直接看代码,代码如下: class Maker…

什么是DOM和DOM操作

什么是DOM? DOM(文档对象模型):HTML文档的结构化表示。允许JavaScript访问HTML元素和样式来操作它们。(更改文本,HTML属性甚至CSS样式) 树结构由HTML加载后自动生成 DOM树结构 这个是一个很简单的HTML代…

Redis与分布式-主从复制

接上文 常用中间件-OAuth2 1.主从复制 启动两个redis服务器。 修改第一个服务器地址 修改第二个redis 然后分别启动 redis-server.exe redis.windows.conf) 查看当前服务器的主从状态,打开客户端:输入info replication命令来查看当前的主从状态&am…

数据结构基础9:排序全家桶

排序全家桶: 一:插入排序:1.简单插入排序:2.希尔排序: 二:选择排序:1.简单选择排序:2.堆排序(空间复杂度为O(1)): 三:快速排序;方法一…

共同见证丨酷雷曼武汉运营中心成立2周年

酷雷曼武汉运营中心2周年 全国合作商齐贺武汉公司2周年庆 2021年 作为酷雷曼辐射全国版图的又一重要据点 酷雷曼武汉运营中心 在“中国光谷”正式成立 沉浸式参观酷雷曼武汉公司 2年时间 尽管历经诸多客观因素的挑战 但后浪扬帆,依然交出了不斐的成绩 解决…