PyTorch 训练自定义功能齐全的神经网络模型的详细教程

news2024/9/21 14:47:37

在前面的文章中,老牛同学介绍了不少大语言模型的部署、推理和微调,也通过大模型演示了我们的日常的工作需求场景。我们通过大语言模型,实实在在的感受到了它强大的功能,同时也从中受益颇多。

今天,老牛同学想和大家一起来训练一个自定义的、但是功能齐全的简单的神经网络模型。这个模型虽然在参数规模、训练数据集、应用场景等方面均无法与大语言模型相媲美,但是我们旨在通过这个模型的训练过程,一窥神经网络模型的训练全貌。正所谓“麻雀虽小,五脏俱全”,同时老牛同学也希望能通过本文,与大家一起学习加深对训练神经网络的理解,逐步做到“肚里有货,从容不迫”!

由于模型训练过程的代码可能会反复调试和修改,老牛同学强烈建议大家使用Jupyter Lab来编写和调试代码。如果还没有配置好Jupyter Lab环境,请先移步老牛同学之前的文章,首先完成大模型研发的基础环境配置:大模型应用研发基础环境配置(Miniconda、Python、Jupyter Lab、Ollama 等)

定义神经网络模型

本文重在演示训练过程,因此为了方便我们训练,我们模型定义如下:

  1. 它是一个简单的线性计算模型
  2. 它只有3 个权重参数
  3. 它输出一个数值结果

根据以上定义,我们的模型的线性运算公式定义为:y = W1*x1 + W2*x2 + W3*x3 + b

  • y 为模型输出,在训练时,则代表模型的目标训练数据集
  • x 为模型输入,在训练时,则代表模型的输入训练数据集
  • W 为模型权重,是模型训练的最终结果
  • b 为调整线性运算结果的偏置向量

我们将根据 yx 训练数据集,逐步训练得出模型权重 Wb 值。

本文的源码地址,老牛同学放到评论区。如果大家不想一步一步地跟着老牛同学进行模型训练,也可以直接看源代码,一步到位看完整代码(源代码中还有 1 个权重的样例)。

准备训练数据集

首先,打开 Jupyter Lab 编辑器:

conda activate PY3.12
jupyter-lab .

为了后面创建数据集、创建模型、模型训练等操作,我们直接引入所有的依赖包:

import torch
import torch.nn as nn
import random
import torch.optim as optim
import numpy as np

为了方便构建数据集,我们先假设W的内容(偏置向量b值初始化为0):

# 随便写几个数字
weights = [1.3, 2.9, 3.7]
w_count = len(weights)

大家可能会有疑问,既然我们都已经知道模型权重了,那我们还训练个啥呢?

别着急,老牛同学提前定义它,有 2 个目的,后面在实际训练时不会使用它:

  1. 方便构造我们的训练数据集:因为我们已经确定了模型的线性运算公式,那么我们只需要随机一些x,就可以容易得到训练数据集y
  2. 模型训练结束之后,方便后面做个比对,看下我们训练结果我们预期值是否符合我们预期

我们先构建我们输入数据集,即线性运算公式x的内容,我们通过随机函数构建了 100 个随机数字,并转换模型训练时 PyTorch 张量类型:

x_list = []
for _ in range(100):
    x_list.append([random.randint(1, 50) for _ in range(w_count)])

inputs = torch.tensor(x_list, dtype=torch.float32, requires_grad=True)

输入训练数据集

然后,我们根据的输入数据集 x,构建目标训练数据集,同样转换模型训练时 PyTorch 张量类型:

y_list = []
for x in x_list:
    y_list.append(np.dot(x, weights))

targets = torch.tensor(y_list, dtype=torch.float32)

结果 y 就是输入 x 和 模型权重的点积运算:

输出训练数据集

至此,我们的训练数据集已经构建完成(我们可以忘掉权重了)。接下来,我们来构建神经网络模型。

构建神经网络模型

我们把即将构建的神经网络模型定义为LNTXModel(即:老牛同学线性模型):

# 线性模型
class LNTXModel(nn.Module):
    def __init__(self):
        super(LNTXModel, self).__init__()
        self.linear = nn.Linear(in_features=w_count, out_features=w_count)

    def forward(self, x):
        return self.linear(x)

# 实例化模型
model = LNTXModel()
model

在本模型定义中,我们直接使用了nn.Linear线性层,它有 2 个参数:

  • in_features: 输入特征的数量(即输入向量的维度)
  • out_features: 输出特征的数量(即输出向量的维度)
  • bias:偏置向量参数默认为True

定义和初始化模型

然后我们初始化了模型(大语言模型一般为加载模型)。接下来,我们就可以开始使用训练数据集来训练这个模型了。

训练神经网络模型

神经网络模型的训练过程,通常包括以下几步:

  • 首先,进行前向传播以预测结果(即forward函数)
  • 然后,将预测结果与目标结果进行比较,即计算损失值
  • 接着,利用反向传播算法计算损失值的梯度
  • 最后,根据梯度更新模型的参数

首先,我们定义损失函数和优化器:损失函数用于根据模型的预测结果和目标结果计算损失值,而优化器则用于根据计算出的梯度更新模型的权重,以最小化损失。

# 定义损失函数
loss_fn = nn.MSELoss()

# 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.0005)

nn.MSELoss() 损失函数常用于回归任务中,用来衡量模型预测值与实际目标值之间的差距。函数返回一个标量张量,代表了所有输入的均方误差。

torch.optim.SGD 是随机梯度下降优化算法,一般用于最小化损失函数。与标准的梯度下降算法相比,随机梯度下降算法每一步更新只基于一个或一小批样本的梯度估计。这种方法能够更快地收敛,并且有助于跳出局部极小点。

其中,lr 学习率(Learning Rate)参数是优化算法中的一个重要超参数,它决定了模型参数在每次更新时的变化幅度。较高的学习率可以加快收敛的速度,但可能会导致优化过程震荡或者无法稳定在最小值附近;而较低的学习率有助于更精确地找到最小值,但可能会陷入局部最小值或者导致训练过程非常缓慢。确定最佳学习率通常需要基于模型、数据集和问题的特性进行反复试验。一般情况下,我们可以从一个较小的学习率开始(比如老牛同学本次设置为0.0005),然后根据模型的收敛情况逐渐增加或减少学习率。

现在,所有准备工作都已经完成,我们可以开始训练我们的模型了。

# 训练循环,迭代1000次
num_epochs = 1000
for epoch in range(num_epochs):
    for i, x in enumerate(inputs):
        # 前向传播
        predictions = model(x)

        # 计算损失
        loss = loss_fn(predictions, targets[i])

        # 清空梯度
        optimizer.zero_grad()

        # 反向传播
        loss.backward()

        # 更新参数
        optimizer.step()
    if (epoch+1) % 100 == 0:
        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

# 训练完成
print('Train done.')

可以看到,经过900 轮的训练,预测损失接近为0

模型训练结果

我们可以打印出模型的训练结果,与我们预期结果进行比较:

print(f'Final weights:{model.linear.weight.data}')
print(f'Final bias:{model.linear.bias.data}')

模型权重和偏置量

可以看出,模型权重与我们预期结果基本吻合,模型巡检结果基本符合预期!

使用神经网络模型

模型训练完成,我们就可以使用我们的模型了:model(x)

model(torch.tensor([float(1), float(1), float(1)]))

总结:扩展模型大小

至此,我们整个训练过程已经完成了。在上面演示案例中,我们只是用了 3 个权重参数的简单模型,我们可以根据需求,进一步扩大模型参数。但是不论模型权重参数扩大到多少,他们的训练流程基本是一样的:

  1. 初始化模型(大语言模型成为加载模型)
  2. 根据x输入预测输出y
  3. 通过损失函数计算损失梯度值
  4. 最后根据梯度更新模型参数值
  5. 直到训练结束,模型权重符合预期

最后的最后,8 月开始了,大家S1 绩效基本都沟通确定了吧?公众号回复都是匿名的,最终绩效结果如何,大家若感觉兴趣,欢迎在评论区留言分享~

基于 Qwen2 大模型微调技术详细教程(LoRA 参数高效微调和 SwanLab 可视化监控)

LivePortrait 数字人:开源的图生视频模型,本地部署和专业视频制作详细教程

基于 Qwen2/Lllama3 等大模型,部署团队私有化 RAG 知识库系统的详细教程(Docker+AnythingLLM)

使用 Llama3/Qwen2 等开源大模型,部署团队私有化 Code Copilot 和使用教程

本地部署 GLM-4-9B 清华智谱开源大模型方法和对话效果体验

玩转 AI,笔记本电脑安装属于自己的 Llama 3 8B 大模型和对话客户端

ChatTTS 开源文本转语音模型本地部署、API 使用和搭建 WebUI 界面

Ollama 完整教程:本地 LLM 管理、WebUI 对话、Python/Java 客户端 API 应用

微信公众号:老牛同学

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1973284.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Android Studiio】default activity 原生安卓和uniapp默认启动分析

文章目录 思路: 一、原生安卓二、uniapp 探究方向:找到Default Activity 思路: 在Android开发中,"default activity"这个概念通常指的是应用启动时默认会加载和显示的那个Activity。AndroidManifest.xml文件是Android…

基于Selenium实现操作网页及操作windows桌面应用

Selenium操作Web页面 Why? 通常情况下,网络安全相关领域,更多是偏重于协议和通信。但是,如果协议通信过程被加密或者无法了解其协议构成,是无法直接通过协议进行处理。此时,可以考虑模拟UI操作,进而实现相…

声音和数据之间的调制解调 —— 电报机和电传打字机如何影响计算机的演变

注:机翻,未校对。 The Squeal of Data The through line between the telegraph and the computer is more direct than you might realize. Its influence can be seen in common technologies, like the modem. 电报和计算机之间的直通线比你想象的要…

基于IOT架构的数据采集监控平台!

LP-SCADA数据采集监控平台是蓝鹏测控推出的一款聚焦于工业领域的自动化数据采集监控系统, 助力数字工厂建设的统一监控平台。 为企业提供从下到上的完整的生产信息采集与集成服务,从而为企业综合自动化、工厂数字化及完整的"管控一体化”的解决方案…

LockSupport详解

文章目录 理解可重入锁LockSupport线程等待唤醒机制(wait/notify) waitNotify限制awaitSignal限制LockSupport重点说明 理解可重入锁 可重入锁的种类: 隐式锁(即synchronized关键字使用的锁)默认是可重入锁。 同步代…

站在临床数据科学的角度,药物试验归根结底是这两大假设

在临床数据科学的领域中,药物试验的设计和实施是评估药物效果及其安全性的关键环节。药物试验的基础无外乎两大核心假设:有效性与安全性。这两个假设不仅是药物试验的起点,也是整个研究过程中的重要指导原则。 药物试验的核心主旨在于对待测试…

Python高性能计算:进程、线程、协程、并发、并行、同步、异步

这里写目录标题 进程、线程、协程并发、并行同步、异步I/O密集型任务、CPU密集型任务 进程、线程、协程 进程、线程和协程是计算机程序执行的三种不同方式,它们在资源管理、执行模型和调度机制上有显著的区别。以下是对它们的详细解释和比较: 进程&…

一款有趣的工具,锁定鼠标键盘,绿色免安装

这是一款完全免费的程序,可以实现在不锁定屏幕的情况下锁定鼠标键盘,让鼠标键盘无法操作。比较适合防止误碰鼠标键盘,以及离开电脑时不希望别人操作自己的电脑。 ★★★★★锁定鼠标键盘工具:https://pan.quark.cn/s/e5c518a2165…

路由配置修改(五)

一、默认约定式路由 1、umi 会根据 pages 目录自动生成路由配置。 * name umi 的路由配置* description 只支持 path,component,routes,redirect,wrappers,name,icon 的配置* param path path 只支持两种占位符配置,第一种是动态参数 :id 的形式,第二种…

win11 intel新显卡控制面板无自定义分辨率选项解决

问题 下图是现在的intel显卡控制面板,不知道为啥变得很傻瓜式了,连所有显卡控制面板都有的分辨率自定义也被干掉了。 解决方式 其实解决很简单,因为自定义分辨率对显卡玩游戏来说还是很常用的,intel在beta版又加回来了&#x…

样式与特效(2)——新闻列表

1.盒子模型的边距概念 ) Margin-top 上面 Margin-bottom 底部 Margin-right 右边 Margin-left 左边 Margin : 10px (上下左右都是10px) Margin :10px,20px (上下边距10px 左右20px) CSS里面最重要的属性之一 将页面理解成…

C++ | Leetcode C++题解之第316题去除重复字母

题目&#xff1a; 题解&#xff1a; class Solution { public:string removeDuplicateLetters(string s) {vector<int> vis(26), num(26);for (char ch : s) {num[ch - a];}string stk;for (char ch : s) {if (!vis[ch - a]) {while (!stk.empty() && stk.back(…

C#值类型和引用类型,类和结构体

1、类class是引用类型&#xff0c;多个引用类型变量的值会互相影响。存储在堆&#xff08;heap&#xff09;上 2、结构体struct是值类型&#xff0c;多个值类型变量的值不会互相影响。存储在栈&#xff08;stack&#xff09;上 using System; using System.Collections.Generi…

PTA题目|象限的判断(python)

题目要求 输入一对坐标&#xff0c;输出它在直角坐标系中的象限。 输入格式: 输入坐标(x,y)&#xff0c;&#xff08;假设输入的x或y坐标值一定不会为0&#xff09;如&#xff1a;(3.5,-2)。 输出格式: 输出对应的象限&#xff0c;如&#xff1a;第四象限 输入样例: 在这…

Python | Leetcode Python题解之第315题计算右侧小于当前元素的个数

题目&#xff1a; 题解&#xff1a; import numpy as np from bisect import bisect_leftclass Solution:max_len 10000c []buckets []def countSmaller(self, nums: List[int]) -> List[int]:self.c [0 for _ in range(len(nums) 5)]counts [0 for _ in range(len(…

Sentinel-1 Level 1数据处理的详细算法定义(五)

《Sentinel-1 Level 1数据处理的详细算法定义》文档定义和描述了Sentinel-1实现的Level 1处理算法和方程,以便生成Level 1产品。这些算法适用于Sentinel-1的Stripmap、Interferometric Wide-swath (IW)、Extra-wide-swath (EW)和Wave模式。 今天介绍的内容如下: Sentinel-1 L…

前端怎么做一个验证码和JWT,使用mockjs模拟后端

流程图 创建一个发起请求 创建一个方法 getCaptchaImg() {this.$axios.get(/captcha).then(res > {console.log(res);this.loginForm.token res.data.data.tokenthis.captchaImg res.data.data.captchaImgconsole.log(this.captchaImg)})}, captchaImg: "", 创…

钡铼技术M12双通道防水分线盒稳定可靠

钡铼技术的DB系列M12双通道防水分线盒是一款专为工业自动化环境设计的高性能产品。其采用耐酸碱腐蚀材料制成的壳体&#xff0c;能够达到IP67防护等级&#xff0c;并通过灌胶工艺进一步提升到IP69K防护等级&#xff0c;确保在恶劣的工业条件下仍然能稳定可靠地运行。 技术特点…

小怡分享之Java的继承和多态

前言&#xff1a; &#x1f308;✨小怡给大家分享了Java的类和对象&#xff0c;今天小怡给大家分享的是继承和多态。 1.继承 1.1 为什么需要继承 Java中使用类对现实世界中实体来进行描述&#xff0c;类经过实例化之后的产物对象&#xff0c;则可以用来表示现实中的实体&…

无人机之环境监测篇

无人机在各个领域的应用越来越广泛&#xff0c;环境监测便是其中之一&#xff0c;它们能够提供高效、安全、经济的监测手段&#xff0c;帮助科学家和管理者更好的理解环境状况并采取相应措施。 一、污染监测 无人机可以搭载各种传感器&#xff0c;如气体检测器、红外热像仪等&…