【深度学习实验】前馈神经网络(九):整合训练、评估、预测过程(Runner)

news2024/10/5 13:42:38

目录

一、实验介绍

 二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. __init__(初始化)

2. train(训练)

3. evaluate(评估)

4. predict(预测)

5. save_model

6. load_model

7. 代码整合


一、实验介绍

      

 二、实验环境

    本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        前馈神经网络(Feedforward Neural Network)是一种常见的人工神经网络模型,也被称为多层感知器(Multilayer Perceptron,MLP)。它是一种基于前向传播的模型,主要用于解决分类和回归问题。

        前馈神经网络由多个层组成,包括输入层、隐藏层和输出层。它的名称"前馈"源于信号在网络中只能向前流动,即从输入层经过隐藏层最终到达输出层,没有反馈连接。

以下是前馈神经网络的一般工作原理:

  1. 输入层:接收原始数据或特征向量作为网络的输入,每个输入被表示为网络的一个神经元。每个神经元将输入加权并通过激活函数进行转换,产生一个输出信号。

  2. 隐藏层:前馈神经网络可以包含一个或多个隐藏层,每个隐藏层由多个神经元组成。隐藏层的神经元接收来自上一层的输入,并将加权和经过激活函数转换后的信号传递给下一层。

  3. 输出层:最后一个隐藏层的输出被传递到输出层,输出层通常由一个或多个神经元组成。输出层的神经元根据要解决的问题类型(分类或回归)使用适当的激活函数(如Sigmoid、Softmax等)将最终结果输出。

  4. 前向传播:信号从输入层通过隐藏层传递到输出层的过程称为前向传播。在前向传播过程中,每个神经元将前一层的输出乘以相应的权重,并将结果传递给下一层。这样的计算通过网络中的每一层逐层进行,直到产生最终的输出。

  5. 损失函数和训练:前馈神经网络的训练过程通常涉及定义一个损失函数,用于衡量模型预测输出与真实标签之间的差异。常见的损失函数包括均方误差(Mean Squared Error)和交叉熵(Cross-Entropy)。通过使用反向传播算法(Backpropagation)和优化算法(如梯度下降),网络根据损失函数的梯度进行参数调整,以最小化损失函数的值。

        前馈神经网络的优点包括能够处理复杂的非线性关系,适用于各种问题类型,并且能够通过训练来自动学习特征表示。然而,它也存在一些挑战,如容易过拟合、对大规模数据和高维数据的处理较困难等。为了应对这些挑战,一些改进的网络结构和训练技术被提出,如卷积神经网络(Convolutional Neural Networks)和循环神经网络(Recurrent Neural Networks)等。

本系列为实验内容,对理论知识不进行详细阐释

(咳咳,其实是没时间整理,待有缘之时,回来填坑)

977468b5ae9843c6a88005e792817cb1.png

0. 导入必要的工具包

import torch
from torch import nn
import torch.nn.functional as F
# 绘画时使用的工具包
import matplotlib.pyplot as plt
# 导入鸢尾花数据集
from sklearn.datasets import load_iris
# 构建自己的数据集,继承自Dataset类
from torch.utils.data import Dataset, DataLoader

1. __init__(初始化)

    def __init__(self, model, optimizer, loss_fn, metric, **kwargs):
        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        # 用于计算评价指标
        self.metric = metric

        # 记录训练过程中的评价指标变化
        self.dev_scores = []
        # 记录训练过程中的损失变化
        self.train_epoch_losses = []
        self.dev_losses = []
        # 记录全局最优评价指标
        self.best_score = 0
  • 五个参数:
    • model(模型)
    • optimizer(优化器)
    • loss_fn(损失函数)
    • metric(评价指标)
    • 其他可选参数。
  • 该类还定义了一些用于记录训练过程中的指标变化和全局最优指标的属性:
    • self.dev_scores(记录验证集评价指标的变化)
    • self.train_epoch_losses(记录训练集损失的变化)
    • self.dev_losses(记录验证集损失的变化)
    • self.best_score(记录全局最优评价指标)

2. train(训练)

 def train(self, train_loader, dev_loader=None, **kwargs):
        # 将模型设置为训练模式,此时模型的参数会被更新
        self.model.train()
        num_epochs = kwargs.get('num_epochs', 0)
        log_steps = kwargs.get('log_steps', 100)
        save_path = kwargs.get('save_path', 'best_mode.pth')
        eval_steps = kwargs.get('eval_steps', 0)
        # 运行的step数,不等于epoch数
        global_step = 0

        if eval_steps:
            if dev_loader is None:
                raise RuntimeError('Error: dev_loader can not be None!')
            if self.metric is None:
                raise RuntimeError('Error: Metric can not be None')

        # 遍历训练的轮数
        for epoch in range(num_epochs):
            total_loss = 0
            # 遍历数据集
            for step, data in enumerate(train_loader):
                x, y = data
                logits = self.model(x.float())
                loss = self.loss_fn(logits, y.long())
                total_loss += loss
                if log_steps and global_step % log_steps == 0:
                    print(f'loss:{loss.item():.5f}')

                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
            # 每隔一定轮次进行一次验证,由eval_steps参数控制,可以采用不同的验证判断条件
            if (epoch + 1) % eval_steps == 0:

                dev_score, dev_loss = self.evaluate(dev_loader, global_step=global_step)
                print(f'[Evalute] dev score:{dev_score:.5f}, dev loss:{dev_loss:.5f}')

                if dev_score > self.best_score:
                    self.save_model(f'model_{epoch + 1}.pth')

                    print(
                        f'[Evaluate]best accuracy performance has been updated: {self.best_score:.5f}-->{dev_score:.5f}')
                    self.best_score = dev_score

                # 验证过程结束后,请记住将模型调回训练模式
                self.model.train()

            global_step += 1
            # 保存当前轮次训练损失的累计值
            train_loss = (total_loss / len(train_loader)).item()
            self.train_epoch_losses.append((global_step, train_loss))

        print('[Train] Train done')

3. evaluate(评估)

    def evaluate(self, dev_loader, **kwargs):
        assert self.metric is not None
        # 将模型设置为验证模式,此模式下,模型的参数不会更新
        self.model.eval()
        global_step = kwargs.get('global_step', -1)
        total_loss = 0
        self.metric.reset()

        for batch_id, data in enumerate(dev_loader):
            x, y = data
            logits = self.model(x.float())
            loss = self.loss_fn(logits, y.long()).item()
            total_loss += loss
            self.metric.update(logits, y)

        dev_loss = (total_loss / len(dev_loader))
        self.dev_losses.append((global_step, dev_loss))
        dev_score = self.metric.accumulate()
        self.dev_scores.append(dev_score)
        return dev_score, dev_loss

4. predict(预测)

    predict方法用于模型的阶段,输入数据x,返回模型对输入的预测结果。

 def predict(self, x, **kwargs):
        self.model.eval()
        logits = self.model(x)
        return logits

5. save_model

 def save_model(self, save_path):
        torch.save(self.model.state_dict(),save_path)

6. load_model

  def load_model(self, model_path):
        self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

7. 代码整合

class Runner(object):
    def __init__(self, model, optimizer, loss_fn, metric, **kwargs):
        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        # 用于计算评价指标
        self.metric = metric
        
        # 记录训练过程中的评价指标变化
        self.dev_scores = []
        # 记录训练过程中的损失变化
        self.train_epoch_losses = []
        self.dev_losses = []
        # 记录全局最优评价指标
        self.best_score = 0
   
 
# 模型训练阶段
    def train(self, train_loader, dev_loader=None, **kwargs):
        # 将模型设置为训练模式,此时模型的参数会被更新
        self.model.train()
        
        num_epochs = kwargs.get('num_epochs', 0)
        log_steps = kwargs.get('log_steps', 100)
        save_path = kwargs.get('save_path','best_mode.pth')
        eval_steps = kwargs.get('eval_steps', 0)
        # 运行的step数,不等于epoch数
        global_step = 0
        
        if eval_steps:
            if dev_loader is None:
                raise RuntimeError('Error: dev_loader can not be None!')
            if self.metric is None:
                raise RuntimeError('Error: Metric can not be None')
                
        # 遍历训练的轮数
        for epoch in range(num_epochs):
            total_loss = 0
            # 遍历数据集
            for step, data in enumerate(train_loader):
                x, y = data
                logits = self.model(x.float())
                loss = self.loss_fn(logits, y.long())
                total_loss += loss
                if log_steps and global_step%log_steps == 0:
                    print(f'loss:{loss.item():.5f}')
                    
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
            # 每隔一定轮次进行一次验证,由eval_steps参数控制,可以采用不同的验证判断条件
            if (epoch+1)% eval_steps ==  0:

                dev_score, dev_loss = self.evaluate(dev_loader, global_step=global_step)
                print(f'[Evalute] dev score:{dev_score:.5f}, dev loss:{dev_loss:.5f}')
                
                if dev_score > self.best_score:
                    self.save_model(f'model_{epoch+1}.pth')
                    
                    print(f'[Evaluate]best accuracy performance has been updated: {self.best_score:.5f}-->{dev_score:.5f}')
                    self.best_score = dev_score
                    
                # 验证过程结束后,请记住将模型调回训练模式   
                self.model.train()
            
            global_step += 1
            # 保存当前轮次训练损失的累计值
            train_loss = (total_loss/len(train_loader)).item()
            self.train_epoch_losses.append((global_step,train_loss))
            
        print('[Train] Train done')
        
    # 模型评价阶段
    def evaluate(self, dev_loader, **kwargs):
        assert self.metric is not None
        # 将模型设置为验证模式,此模式下,模型的参数不会更新
        self.model.eval()
        global_step = kwargs.get('global_step',-1)
        total_loss = 0
        self.metric.reset()
        
        for batch_id, data in enumerate(dev_loader):
            x, y = data
            logits = self.model(x.float())
            loss = self.loss_fn(logits, y.long()).item()
            total_loss += loss 
            self.metric.update(logits, y)
            
        dev_loss = (total_loss/len(dev_loader))
        self.dev_losses.append((global_step, dev_loss))
        dev_score = self.metric.accumulate()
        self.dev_scores.append(dev_score)
        return dev_score, dev_loss
    
    # 模型预测阶段,
    def predict(self, x, **kwargs):
        self.model.eval()
        logits = self.model(x)
        return logits
    
    # 保存模型的参数
    def save_model(self, save_path):
        torch.save(self.model.state_dict(),save_path)
        
    # 读取模型的参数
    def load_model(self, model_path):
        self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1035017.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

latexocr安装过程中遇到的问题解决办法

环境要求:需要Python版本3.7,并安装相应依赖文件 具体的详细安装步骤可见我上次写的博文:Mathpix替代者|科研人必备公式识别插件|latexocr安装教程 ‘latexocr‘ 不是内部或外部命令,也不是可运行的程序或批处理文件的相关解决办…

获取文件创建时间

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl Java源码 public void testGetFileTime() {try {String string "E://test.txt";File file new File(string);Path path file.toPath();BasicFileAttributes ba…

Python 运行代码

一、Python运行代码 可以使用三种方式运行Python,如下: 1、交互式 通过命令行窗口进入 Python 并开始在交互式解释器中开始编写 Python 代码 2、命令行脚本 可以把代码放到文件中,通过python 文件名.py命令执行代码,如下&#xff…

Grafana离线安装部署以及插件安装

Grafana是一个可视化面板(Dashboard),有着非常漂亮的图表和布局展示,功能齐全的度量仪表盘和图形编辑器,支持Graphite、zabbix、InfluxDB、Prometheus和OpenTSDB作为数据源。Grafana主要特性:灵活丰富的图形…

Idea远程调试Linux服务器(和docker容器)中的Java服务,实测可用

目录 idea远程调试linux上的java服务 idea远程调试linux上docker中的java服务 本地配置 打开设置 添加Remote 环境:java1.8 idea远程调试linux上的java服务 在远程Linux服务器上启动Java应用程序时,请使用以下命令行选项启用调试: java -agentlib:…

R3300L, Q7 ATV Android9固件

R3300L, Q7 ATV Android9固件 固件来源 https://www.znds.com/tv-1239603-1-1.html 之前在恩山上发布过1080p安卓6固件 https://www.right.com.cn/forum/thread-1761250-1-1.html, 这个固件的不足之处就是没有 Google Service Framework, 只能通过 Smart Youtube 之类的第三方…

Mac台式电脑内存清理方法教程

对于一些小白用户,如果觉得以上的清理方法比较复杂却又想要更好的优化Mac电脑内存,专业的系统清理软件是一个不错的选择。比起花几个小时时间浏览文件夹、删除临时文件、缓存和卸载残留。Cleanmymac X,只需单击几下即可完成所有内存清理工作&…

华为云:成就卓越企业,引领智能未来

编辑:阿冒 设计:沐由 “预感总是倏然来临,灵光一现,好像一种确凿无疑的信念在瞬间萌生却无从捕捉。”在其传世的《百年孤独》一书中,加西亚马尔克斯曾经这样写道。 如此笃信却无力的感觉,也让很多急于以数字…

Python_it_heima

P63 list P68 元组 注意:元组内部嵌套的list包含的内容可以修改,但list本身不能修改。 P69 字符串 P71 数据容器(序列)的切片 P73 集合 P75 字典 字典的常用操作 字典课后练习 P78 类数据容器的总结对比 P79 数据容器的通用操作 不…

ros2发布者节点

没有书看啊,就来看看这个代码吧: 首先:第一个函数:init,通过f12可以知道这个函数共有4个变量,前面2个就这么填,后面2个有默认值,不用管,所以一般这么写就好了。 spin函…

IDEA2023新UI回退老UI

idea2023年发布了新UI,如下所示 但是用起来真心不好用,各种位置也是错乱,用下面方法可以回退老UI

聊聊Spring的Aware接口

文章目录 0.前言1.什么是Aware接口2.Aware接口的设计目的3.详解3.1. ApplicationContextAware我们举个例子来说明 3.2. BeanFactoryAware3.3. BeanNameAware3.4. ServletContextAware3.5. MessageSourceAware3.6. ResourceLoaderAware 4.参考文档 0.前言 背景: 最近…

C语言数组和指针笔试题(三)(一定要看)

目录 字符数组四例题1例题2例题3例题4例题5例题6例题7 结果字符数组五例题1例题2例题3例题4例题5例题6例题7结果字符数组六例题1例题2例题3例题4例题5例题6例题7 结果 感谢各位大佬对我的支持,如果我的文章对你有用,欢迎点击以下链接 🐒🐒🐒个…

建构居住安全生态,鹿客科技2023秋季发布会圆满举办

9月20日,以「Lockin Opening」为主题的2023鹿客秋季发布会在上海隆重举办,面向居住安全领域鹿客带来了最新的高端旗舰智能锁新品、多眸OS1.0、Lockin Care服务以及全联接OPENING计划。此外,现场还邀请了国家机构、合作伙伴、技术专家等业界同…

ai智能生成文章-智能生成文章软件

您是否曾为创作内容而感到头疼不已?是否一直在寻找一种能够帮助您轻松生成高质量文章的解决方案?什么是AI智能生成文章,特别是147SEO智能原创文章生成。这是一种先进的技术,利用人工智能和自然语言处理,能够自动生成各…

Ant Design分页组件中实现禁止点击当前页按钮的方法

这里需要使用到Ant Design分页组件pagination的一个回调函数onChange onChange函数用来监听鼠标点击事件, 它有两个入参》1. 点击分页按钮时获取到的页码 2. 每页最大显示条数 所以,禁止点击当前分页按钮的核心逻辑是: if {当前页的页…

mybatis-plus异常:dynamic-datasource can not find primary datasource

现象 使用mybatis-plus多数据源配置时出现异常 com.baomidou.dynamic.datasource.exception.CannotFindDataSourceException: dynamic-datasource can not find primary datasource分析 异常原因是没有设置默认数据源,在类上没有使用DS指定数据源时,默…

批量、在线学习, 参数、非参数学习

批量学习(Batch Learning)和在线学习(Online Learning) 批量学习 批量学习的概念非常容易理解,我们之前介绍的许多机器学习算法,如果没有特殊说明,都可以采用批量学习的方式。批量学习的过程通…

网络安全——(黑客)自学

想自学网络安全(黑客技术)首先你得了解什么是网络安全!什么是黑客!!! 网络安全可以基于攻击和防御视角来分类,我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术,而“蓝队…

react多条件查询

1、声明一个filter常量 2.filter接受(condition,data)两个参数 3、调用data里面的filter进行筛选 4、任意一个item当筛选条件 5、使用object.key获取对象所有key 6、对每个key使用Array.prototype.every()方法判断是否满足条…