【AI炼丹术】写深度学习代码的一些心得体会

news2024/9/21 20:51:57

写深度学习代码的一些心得体会

  • 体会1
  • 体会2
  • 体会3
  • 总结
  • 内容来源

一般情况下,拿到一批数据之后,首先会根据任务先用领域内经典的Model作为baseline跑通,然后再在这个框架内加入自己设计的Model,微调代码以及修改一些超参数即可。总体流程参考如下:

  1. 先写dataset部分,包括数据的读取、预处理、增广等操作,将数据集准备好。
  2. 然后model部分baseline无需修改,proposed是自行设计,定义模型的结构和参数,建立模型架构。
  3. 最后是train部分,这里调用所有的类实现训练:包括定义模型,模型包裹;获取dataloader;定义loss,优化器,学习率,定义early stoping策略;保存模型权重,保存日志。

当然,文无定法。这个顺序并不是固定不变的,也可以根据具体情况作出相应的调整。例如,当你的数据集已经准备好了,可以直接开始定义模型,然后再定义训练过程;或者在进行模型训练之前,先进行数据集的分析和可视化等操作。

体会1

源自:作者三四但不犹豫
对于图像任务:

  1. 顺序上,先写dataset部分,检查基本的transform,再搭model,构建head和loss,就可以把一个基础的、可以跑的网络就能跑起来了(这点很重要);
  2. 可视化很重要,如果是本地开发机,善用cv.imshow直观、便捷地可视化处理的结果;
  3. 一个基础的train/inference流程跑通后,分别构建1 张、10 张的数据用于debug,确保任意改动后,可以overfit;
  4. 调试代码阶段避免随机性、避免数据增强,一定用tensorboard之类的工具观察 loss 下降是否合理;
  5. 一般数据集最好处理成coco的格式,我的任务跟传统任务不太一样,但也尽量仿照coco来设计,写dataset的时候可以参考开源实现;
  6. 善用开源框架,比如Open-MMLab,Detectron2之类的,好处是方便实验,在框架里写不容易出现难以察觉的bug,坏处是开源框架为了适配各种网络,代码复杂程度会高一点,建议从第一版入手了解框架,然后基于最新的一边阅读一边开发。

体会2

源自:捡到一束光
先给结论:以写了两三年pytorch代码的经验而言,比较好的顺序是先写model,再写dataset,最后写train。在讨论码组件的具体顺序前,先分析每一个组件背后的目的和逻辑。

  • model构成了整个深度学习训练与推断系统骨架,也确定了整个AI模型的输入和输出格式
    • 对于视觉任务,模型架构多为卷积神经网络或是最新的ViT模型
    • 对于NLP任务,模型架构多为Transformer以及Bert
    • 对于时间序列预测,模型架构多为RNNLSTM

不同的model对应了不同的数据输入格式,如ResNet一般是输入多通道二维矩阵,而ViT则需要输入带有位置信息的图像patchs。确定了用什么样的model后,数据的输入格式也就确定下来。根据确定的输入格式,我们才能构建对应的dataset。

  • dataset构建了整个AI模型的输入与输出格式。

    • 在写作dataset组件时,我们需要考虑数据的存储位置与存储方式,如数据是否是分布式存储的,模型是否要在多机多卡的情况下运行,读写速度是否存在瓶颈,如果机械硬盘带来了读写瓶颈则需要将数据预加载进内存等。
    • 在写dataset组件时,我们也要反向微调model组件。例如,确定了分布式训练的数据读写后,需要用nn.DataParallel或者nn.DistributedDataParallel等模块包裹model,使模型能够在多机多卡上运行。
    • 此外,dataset组件的写作也会影响训练策略,这也为构建train组件做了铺垫。比如根据显存大小,我们需要确定相应的BatchSize,而BatchSize则直接影响学习率的大小。再比如根据数据的分布情况,我们需要选择不同的采样策略进行Feature Balance,而这也会体现在训练策略中。
  • train构建了模型的训练策略以及评估方法,它是最重要也是最复杂的组件。先构建model与dataset可以添加限制,减少train组件的复杂度。

    • 在train组件中,我们需要根据训练环境(单机多卡,多机多卡或是联邦学习)确定模型更新的策略,以及确定训练总时长epochs,优化器的类型,学习率的大小与衰减策略,参数的初始化方法,模型损失函数
    • 此外,为了对抗过拟合,提升泛化性,还需要引入合适的正则化方法,如Dropout,BatchNorm,L2-Regularization,Data Augmentation等。
    • 有些提升泛化性能的方法可以直接在train组件中实现(如添加L2-Reg,Mixup),有些则需要添加进model中(如Dropout与BatchNorm),还有些需要添加进dataset中(如Data Augmentation)。。

此外,train还需要记录训练过程的一些重要信息,并将这些信息可视化出来,比如在每个epoch上记录训练集的平均损失以及测试集精度,并将这些信息写入tensorboard,然后在网页端实时监控。在构建train组件中,我们需要随时根据模型表现进行参数微调,并根据结果改进model和dataset两个组件。
tensorboard

体会3

源自:芙兰朵露
作为data driven的学科,不同的AI model适合不同的数据类型,选择用哪个模型是基于你的数据长什么样来决定的。初学者知道用CNN处理图片,用RNN处理时间序列/语言,但这些都是最基础的工作,真正体现水平的是根据数据的性质来选择合适的细分模型。比如稀疏图像需要用Sparse CNN,语言Transformer效果比较好,但对某些特殊的时间序列RNN也有奇效。

接下来还有很多技术细节,比如需不需要数据增强?需不需要标签平滑?需不需要残差链接?需不需要多loss,如果需要如何平衡?需不需要解释模型?我甚至没有提到超参数,因为超参数是锦上添花而不是雪中送炭。只要没有明确的信息瓶颈,超参数对模型的影响是很小的。

上面提到的这些问题不需要全想明白,但心里要大致有个谱,至少也要知道这些问题是可能影响你的训练结果的,这其实需要相当的阅读和积累。这样之后出了问题才知道去哪里debug。

然后就可以开始写了。这些问题想明白之后,其实先写哪个part已经不重要了,因为你的心中已经有了一个picture,先把这个picture给sketch下来,然后开始跑,第一遍效果肯定不好,但你要根据输出的结果大致判断哪个部分出了问题,然后针对性地去改进。这一步真的没什么好办法,很多时候其实是直觉,做多了自然就知道了。训练模型-发现问题-修改模型-再训练,就像炼丹一样,经过无数遍的抟炼,才能得到最后的金丹。

其实洋洋洒洒说了这么多,本质不过是几个字:解决问题的能力making things to work几乎是机器学习中最重要的能力了,而这种能力就是在日常的积累和训练中反复磨练出来的,成功的路上没有捷径

总结

单纯就个人习惯而言,先写model,确保model的结果没有错误,调试正确。然后写dataset,并调试输出正确。之后写损失函数,并调试正确。最后写train训练代码,推理代码。

内容来源

  1. 写深度学习代码是先写model还是dataset还是train呢,有个一般化的顺序吗?
  2. A Recipe for Training Neural Networks

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/459186.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Matlab进阶绘图第18期—相关性气泡热图

相关性气泡热图是一种特殊的气泡热图。 与一般的气泡热图相比,其数值位于[-1 1]区间,其颜色用于表示正负,而其气泡的大小用于表示数值绝对值的大小,可以十分直观地对两个变量的相关性进行分析。 由于Matlab中未收录相关性气泡热…

In-Context Learning中的示例选择及效果

一. ICL的背景 大型语言模型(LLM)如GPT-3是在大规模的互联网文本数据上训练,以给定的前缀来预测生成下一个token(Next token prediction)。这样简单的训练目标,大规模数据集以及高参数量模型相结合&#x…

国内可直接使用的OpenAI DALL*E 图片AI体验站,可通过自然语言生成图片

体验站最终演示效果 国内可直接使用的图片AI体验站:https://zizhu888.cn/text2img/index.html ChatGPT3.5 Turbo国内体验站: https://zizhu888.cn/chatgpt/index.html OpenAI DALL*E可以通过自然语言生成图片,内容创作者的福音,大大降低了创…

基于飞桨 PaddleVideo 的骨骼行为识别模型 CTR-GCN

main.pysame_seedsparse_argsmain ensemble.pyconfigs 文件夹Joint(J)的配置文件ctrgcn_fsd_J_fold0.yamlctrgcn_fsd_J_fold1.yaml Joint Angle(JA)的配置文件ctrgcn_fsd_JA_fold0.yaml paddlevideo 文件夹utils 文件夹__init__.p…

【Python 协程详解】

0.前言 前面讲了线程和进程,其实python还有一个特殊的线程就是协程。 协程不是计算机提供的,计算机只提供:进程、线程。协程是人工创造的一种用户态切换的微进程,使用一个线程去来回切换多个进程。 为什么需要协程? …

中国制造再击败一家海外企业,彻底取得垄断地位

中国制造已在13个行业取得领先优势,凸显出中国制造的快速崛起,日前中国制造又在一个行业彻底击败海外同行,再次证明了中国制造的实力。 一、海外企业承认失败 提前LGD宣布它位于广州的8.5代液晶面板生产线停产,预计该项目将出售给…

crm day03 创建市场活动

页面切割 div切割,ifram显示 如何分割的呢,在主页面上打开iframe $(function(){ //页面加载时window.open("workbench/main/index.do","workareaFrame"); })注意所有在WEB-INF的页面都会收到保护,因此到达此目录下的页…

不得不的创建型模式-建造者模式

目录 建造者模式是什么 下面是一个简单的示例代码,演示了如何使用建造者模式来构建一个复杂对象: 面试中可能遇到的问题及回答: 建造者模式是什么 建造者模式是一种创建型模式,它的目的是将复杂对象的构造过程分离成多个简单的…

你知道项目进度控制和跟踪的目的是什么吗?

项目进度控制和跟踪的目的是: 增强项目进度的透明度,当项目进展与项目计划出现偏差时,可以及时采取适当的措施。 1、计划是项目监控的有效手段 项目控制的手段是根据计划对项目的各项活动进行监控,项目经理可以使用甘特图来制…

界面控件DevExtreme使用指南 - 折叠组件快速入门(二)

DevExtreme拥有高性能的HTML5 / JavaScript小部件集合,使您可以利用现代Web开发堆栈(包括React,Angular,ASP.NET Core,jQuery,Knockout等)构建交互式的Web应用程序,该套件附带功能齐…

微信小程序nodejs+python+php+springboot+vue 微型整容医美挂号预约app系统

(a) 管理员;管理员使用本系统涉到的功能主要有首页、个人中心、用户管理、体检预约管理、项目预约、系统管理等功能 (b) 用户;用户进入app可以实现首页、美容产品、我的等,在我的页面可以对在线预约、体检预约、项目预约等功能进行操作 本基于…

Unity之OpenXR+XR Interaction Toolkit实现 UI交互

一.前言 在VR中我们经常会和一些3D的UI进行交互,今天我们就来说一下如何实现OpenXRXRInteraction Toolkit和UI的交互。 二.准备工作 有了前两篇的配置介绍,我们就不在详细说明这些了,大家自行复习 Unity之OpenXRXR Interaction Toolkit接入Pico VR一体…

钉钉用一条斜杠,金山系用一张表格,做了华为一直想做的事

阿里的“新钉钉”又一次站在风口上 一场疫情导致数万企业停工的同时,却让阿里的钉钉、腾讯会议,还有字节跳动的飞书等在线协同办公产品火得一塌糊涂。 今天,OpenAI公司的一个chatGPT,让阿里、百度等各大互联网巨头扎堆发布大模型产品。 回顾…

如何在Web上实现激光点云数据在线浏览和展示?

无人机激光雷达测量是一项综合性较强的应用系统,具有数据精度高、层次细节丰富、全天候作业等优势,能够精确测量三维现实世界,为各个行业提供了丰富有效的数据信息。但无人机激光雷达测量产生的点云数据需要占用大量的存储空间,甚…

Gantt图和PERT图的相关知识

1、Gantt 图 Gantt图以时间为基准描述项目任务,可以清晰的描述每个任务从何时开始,到何时结束,以及每个任务的并行关系,但是不能反映项目各任务之间的依赖关系,也无法确定整个任务的关键所在。 2、PERT图 计划评审…

Canvas实现动态绘制圆周效果(沿圆周运动的圆的绘制)

步骤实现: 首先,创建一个 HTML 画布和一个 JavaScript 动画函数。 在画布上绘制一个圆。 定义一个变量来表示圆心的坐标和半径。 进行动画循环以更新圆心坐标,使其沿外圆周运动。 使用三角函数(如 sin 和 cos)来计…

前端代码版本管理规范

Git 是目前最流行的源代码管理工具。为规范开发,保持代码提交记录以及 git分支结构清晰,方便后续维护,总结了如下规范。 分支约定 ├── master # 生产分支 ├── release # 测试分支├── develop # 开发分支…

学系统集成项目管理工程师(中项)系列11b_沟通管理(下)

1. 沟通过程的有效性 1.1. 效果 1.1.1. 在适当的时间、适当的方式、信息被准确的发送给适当的沟通参与方(信息的接收方),并且能够被正确的理解,最终参与方能够正确的采取行动 1.2. 效率 1.2.1. 强调的是及时提供所需的信息 2…

两数之和hash

给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target 的那 两个 整数,并返回它们的数组下标。 你可以假设每种输入只会对应一个答案。但是,数组中同一个元素在答案里不能重复出现。 你可以按任意顺序返回…

基于opencv-python的深度学习模块案例

目录 图像分类 目标检测 人脸检测 姿态估计 车辆检测 一、图像分类 图像分类是基于深度学习的计算机视觉任务中最简单、也是最基础的一类,它其中用到的CNN特征提取技术也是目标检测、目标分割等视觉任务的基础。 具体到图像分类任务而言,其具体流…