人工智能基础部分16-神经网络与GPU加速训练

news2025/1/12 16:14:26

大家好,我是微学AI,今天给大家介绍一下人工智能基础部分15-神经网络与GPU加速训练,在深度学习领域,神经网络已经成为了一种流行的、表现优秀的技术。然而,随着神经网络的规模越来越大,训练神经网络所需的时间和计算资源也在快速增长。为加速训练过程,研究者们开始利用图形处理器(GPU)来进行并行计算。在本文中,我们将研究神经网络与GPU的关系,以及如何使用GPU加速神经网络训练。

一、神经网络与GPU的关系

神经网络是一种模拟人脑神经元连接的计算系统,具有非常强大的表达和学习能力。与传统的计算机程序不同,神经网络是一个大规模并行的计算系统,因此天然适合于使用GPU进行并行计算。
GPU,全称是Graphical Processing Unit,中文为图形处理器。最初设计用于图形渲染任务。随着计算能力的提升,现在的GPU也被广泛应用于通用计算任务,如深度学习。相较于通用的CPU(中央处理器),GPU具有更多的内核,更大的内存带宽,从而能够更好地完成这些并行计算任务。
深度学习框架如TensorFlow和PyTorch等已经实现了利用GPU加速神经网络训练的功能。在接下来的部分,我们将详细介绍一个简单的神经网络实现,并展示如何使用PyTorch库和GPU加速训练。

二、 GPU的原理

GPU的基本原理是将大规模的并行计算任务拆分成更小的任务,并将其分发给GPU上的多个内核同时进行处理。这种基于数据并行的计算方式非常适合神经网络的训练。在神经网络的训练过程中,包括前向传播、反向传播和权重更新等步骤可以很容易地拆分成并行任务,因此利用GPU进行加速非常有效。

9541a6dddf194a24aad395c672ac4962.png

三、代码示例

为了演示如何使用PyTorch和GPU进行神经网络训练,我们以一个简单的多层感知机(MLP)为例。这是一个简单的线性二分类问题,我们使用随机生成的数据集进行训练。利用代码生成了一个包含1000个样本的数据集,每个样本具有20个特征。类别标签为0或1。

import numpy as np

np.random.seed(0)
NUM_SAMPLES = 1000
NUM_FEATURES = 20

X = np.random.randn(NUM_SAMPLES, NUM_FEATURES)
y = np.random.randint(0, 2, (NUM_SAMPLES,))

print("X shape:", X.shape)
print("y shape:", y.shape)


### 构建简单的神经网络

#接下来,我们使用PyTorch框架来构建一个简单的多层感知机。这个感知机包括一个输入层、一个隐藏层和一个输出层。

import torch
import torch.nn as nn

class SimpleMLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

INPUT_SIZE = 20
HIDDEN_SIZE = 10
OUTPUT_SIZE = 1

model = SimpleMLP(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE)

print(model)


### 训练神经网络

#现在,我们将训练这个简单神经网络。为了使用GPU进行训练,请确保已经安装了适当的PyTorch GPU版本。

# 判断是否有GPU可用,如果有,则将模型和数据移动到GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


### 训练循环

#下面的代码将执行训练循环,并在每个循环后输出训练损失。

# 超参数设置
learning_rate = 0.001
num_epochs = 500
batch_size = 40
num_batches = NUM_SAMPLES // batch_size

# 创建优化器和损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCEWithLogitsLoss()

# 转换数据为PyTorch张量
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32).unsqueeze(1)

# 训练循环
for epoch in range(num_epochs):
    for i in range(num_batches):
        start = i * batch_size
        end = start + batch_size
        inputs = X[start:end].to(device)  # 将数据移动到GPU
        targets = y[start:end].to(device)  # 将数据移动到GPU

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}")

运行结果:

,,,
Epoch [489/500], Loss: 0.476614773273468
Epoch [490/500], Loss: 0.47668877243995667
Epoch [491/500], Loss: 0.47654348611831665
Epoch [492/500], Loss: 0.47667455673217773
Epoch [493/500], Loss: 0.4764176309108734
Epoch [494/500], Loss: 0.476365864276886
Epoch [495/500], Loss: 0.47667425870895386
Epoch [496/500], Loss: 0.47667378187179565
Epoch [497/500], Loss: 0.4764541685581207
Epoch [498/500], Loss: 0.47650662064552307
Epoch [499/500], Loss: 0.47656387090682983
Epoch [500/500], Loss: 0.4765079915523529

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/525161.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Kafka Connect JNDI注入漏洞复现(CVE-2023-25194)

漏洞原理 Apache Kafka Connect中存在JNDI注入漏洞,当攻击者可访问Kafka Connect Worker,且可以创建或修改连接器时,通过设置sasl.jaas.config属性为com.sun.security.auth.module.JndiLoginModule,进而可导致JNDI注入&#xff0c…

数字设计小思 - 谈谈非理想时钟的时钟偏差

写在前面 本系列整理数字系统设计的相关知识体系架构,为了方便后续自己查阅与求职准备。在FPGA和ASIC设计中,时钟信号的好坏很大程度上影响了整个系统的稳定性,本文主要介绍了数字设计中的非理想时钟的偏差来源与影响。 (本文长…

数据结构-排序-(直接插入、折半插入、希尔排序、冒泡、快速排序)

目录 一、直接插入排序 二、折半插入排序 三、希尔排序 四、冒泡排序 五、快速排序 *效率分析 一、直接插入排序 思想:每次将一个待排序的记录按其关键字大小插入到前面已经排好序中,直到全部记录插入完毕 保证稳定性 空间复杂度:O(1…

SpringBoot 基本介绍--依赖管理和自动配置--容器功能

目录 SpringBoot 基本介绍 官方文档 Spring Boot 是什么? SpringBoot 快速入门 需求/图解说明 完成步骤 创建MainApp.java SpringBoot 应用主程序 创建HelloController.java 控制器 运行MainApp.java 完成测试 快速入门小结 Spring SpringMVC SpringBoot 的关系 梳…

【论文阅读】RapSheet:端点检测和响应系统的战术来源分析(SP-2020)

Tactical Provenance Analysis for Endpoint Detection and Response Systems S&P-2022 伊利诺伊大学香槟分校 Hassan W U, Bates A, Marino D. Tactical provenance analysis for endpoint detection and response systems[C]//2020 IEEE Symposium on Security and Priva…

【YOLO系列】--YOLOv5网络结构超详细解读/总结

前言 官方源码仓库:GitHub - ultralytics/yolov5: YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite YOLOv5至今没有论文发表出来,YOLOv5项目的作者是Glenn Jocher并不是原作者Joseph Redmon。作者当时也有说准备在2021年的12月1号之…

linux pl320 mbox控制器驱动分析 - (1) pl320手册分析

linux pl320 mbox控制器驱动分析 1 pl320简介1.1 pl320用途1.2 pl320 IPCM 由以下部分组成:1.3 pl320 IPCM可配置的参数1.4 功能操作1.5 IPCM 操作流程1.6 Channel ID 2 Using mailboxes(使用邮箱中断)2.1 Defining source core2.2 Defining …

JNI中GetObjectArrayElement, GetStringUTFChars,ReleaseStringUTFChars函数讲解

文章目录 GetObjectArrayElement函数使用场景代码演示GetStringUTFChars 函数使用场景ReleaseStringUTFChars函数 GetObjectArrayElement函数 函数原型: jobject GetObjectArrayElement(JNIEnv *env, jobjectArray array, jsize index); Returns an element of a J…

STL容器之deque

文章目录 deque容器简介deque的操作 deque容器简介 deque是“double-ended queue”的缩写,和vector一样都是STL的容器 deque是双端数组,而vector是单端的deque在接口上和vector非常相似,在许多操作的地方都可以直接替换deque可以随机存取元…

C语言-【操作符二】

Hello,大家好,前面的文章里边介绍了算术、赋值以及移位操作符,这篇文章呢,就介绍一下C语言中的其他操作符吧~ 目录 位操作符 单目操作符 关系操作符 逻辑操作符 条件操作符 逗号表达式 下标引用,函…

C++11多线程:windows临界区和Linux互斥锁、递归锁的区别与使用。

文章目录 前言一、windows临界区1.1 基本概念1.2 函数使用 二、使用步骤1.代码示例1 总结 前言 多线程windows临界区和Linux互斥锁 提示:以下是本篇文章正文内容,下面案例可供参考 一、windows临界区 1.1 基本概念 Linux下有递归锁,递归锁…

着重讲解一下自动化测试框架的思想与构建策略,让你重新了解自动化测试框架

目录 序言: 一、简述自动化测试框架 二、自动化测试框架思想 三、构建自动化测试框架的策略 四、自动化测试框架的发展趋势 序言: 也许到现在大家对所谓的“自动化测试框架”仍然觉得是一种神秘的东西,仍然觉得其与各位很远;…

【JavaScript】ES6新特性(1)

1. let 声明变量 let 声明的变量只在 let 命令所在的代码块内有效 块级作用域 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge">&l…

08-04 中间件和平台运行期监控

缓存中间件的三大坑 缓存击穿 用户访问热点数据&#xff0c;并且缓存中没有热点数据&#xff0c;大量访问直接到DB&#xff0c;热点击穿采用Canal做数据异构方案&#xff0c;把数据库中的值全部放到缓存热点缓存策略&#xff1a;通过分析调用日志获取热点数据&#xff0c;放到…

PMP项目管理-[第十一章]风险管理

风险管理知识体系&#xff1a; 规划风险管理&#xff1a; 识别风险&#xff1a; 实施定性风险分析&#xff1a; 实施定量风险分析&#xff1a; 监督风险&#xff1a; 11.1 风险 定义&#xff1a;是一种不确定的事件或条件&#xff0c;一旦发生&#xff0c;就会对一个或多个项目…

Elasticsearch(二)

Clasticsearch&#xff08;二&#xff09; DSL查询语法 文档 文档&#xff1a;https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html 常见查询类型包括&#xff1a; 查询所有&#xff1a;查询出所有数据&#xff0c;一般测试用。如&#xff1a…

eNSP模拟器下VRRP+MSTP实验配置

①&#xff1a;底层配置&#xff1a; vlan trunk 略 ②&#xff1a;MSTP配置&#xff1a; 所有交换机&#xff1a; stp region-configuration region-name aa revision-level 1 instance 1 vlan 2 to 3 instance 2 vlan 4 to 5 active region-configuration 核心1&…

Java笔记_21(网络编程)

Java笔记_21 一、网路编程1.1、初始网络编程1.2、网络编程三要素1.3、IP1.4、端口号1.5、协议1.6、UDP协议 一、网路编程 1.1、初始网络编程 什么是网络编程 在网络通信协议下&#xff0c;不同计算机上运行的程序&#xff0c;进行的数据传输。 应用场景:即时通信、网游对战…

(一)SAS初识

1、SAS常用工作窗口 “结果”&#xff08;Result&#xff09;窗口——管理SAS程序的输出结果&#xff1b; “日志”&#xff08;Log&#xff09;窗口——记录程序的运行情况&#xff1b; “SAS资源管理器”&#xff08;Explore&#xff09;窗口&#xff1b; “输出”&#xff0…

洛谷P1217-回文质数 Prime Palindromes

洛谷P1217-回文质数 Prime Palindromes 这个题目我做出来了但是超时了&#xff0c;时间复杂度有点高&#xff0c;主要是因为我用了大量的循环&#xff0c; 所以我这个是比较暴力的解法&#xff0c;下面我分析我的暴力代码 首先是判断回文数的函数 第一步将标识传入参数是不是…