华为昇腾系列——入门学习

news2024/9/28 21:27:10

概述

昇腾(Ascend)是华为推出的人工智能处理器品牌,其系列产品包括昇腾910和昇腾310芯片等。

生态情况

众所周知,华为昇腾存在的意义就是替代英伟达的GPU。从事AI开发的小伙伴,应该明白这个替代,不仅仅是 Ascend-910加速卡的算力 达到了Nvidia-A100的算力,而是需要整个AI开发生态的替代。下面简单列一下,昇腾生态与英伟达生态的一些对标项。

AscendNvidia
加速卡Ascend-910、Ascend-310Nvidia-A100、Nvidia-H100...
服务器Atlas 800 训练服务器NVIDIA DGX
计算架构CANNCUDA cuDNN NVCC
集合通信库HCCLNCCL

入门使用

假设原有基于GPU运行代码如下:

# 引入模块
import time
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision

# 初始化运行device
device = torch.device('cuda:0')   

# 定义模型网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.net = nn.Sequential(
            # 卷积层
            nn.Conv2d(in_channels=1, out_channels=16,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            # 池化层
            nn.MaxPool2d(kernel_size=2),
            # 卷积层
            nn.Conv2d(16, 32, 3, 1, 1),
            # 池化层
            nn.MaxPool2d(2),
            # 将多维输入一维化
            nn.Flatten(),
            nn.Linear(32*7*7, 16),
            # 激活函数
            nn.ReLU(),
            nn.Linear(16, 10)
        )
    def forward(self, x):
        return self.net(x)

# 下载数据集
train_data = torchvision.datasets.MNIST(
    root='mnist',
    download=True,
    train=True,
    transform=torchvision.transforms.ToTensor()
)

# 定义训练相关参数
batch_size = 64   
model = CNN().to(device)  # 定义模型
train_dataloader = DataLoader(train_data, batch_size=batch_size)    # 定义DataLoader
loss_func = nn.CrossEntropyLoss().to(device)    # 定义损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)    # 定义优化器
epochs = 10  # 设置循环次数

# 设置循环
for epoch in range(epochs):
    for imgs, labels in train_dataloader:
        start_time = time.time()    # 记录训练开始时间
        imgs = imgs.to(device)    # 把img数据放到指定NPU上
        labels = labels.to(device)    # 把label数据放到指定NPU上
        outputs = model(imgs)    # 前向计算
        loss = loss_func(outputs, labels)    # 损失函数计算
        optimizer.zero_grad()
        loss.backward()    # 损失函数反向计算
        optimizer.step()    # 更新优化器

# 定义保存模型
torch.save({
               'epoch': 10,
               'arch': CNN,
               'state_dict': model.state_dict(),
               'optimizer' : optimizer.state_dict(),
            },'checkpoint.pth.tar')

参考华为官方文档快速体验-PyTorch 网络模型迁移和训练-模型开发(PyTorch)-...-文档首页-昇腾社区 (hiascend.com)

改造后,可以得到以下 用于在昇腾NPU上运行的训练代码(故意加多了全连接层的参数,以便看NPU使用情况):

# 引入模块
import time
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision

import torch_npu
from torch_npu.npu import amp # 导入AMP模块
from torch_npu.contrib import transfer_to_npu    # 使能自动迁移

# 初始化运行device
device = torch.device('npu:0')   

# 定义模型网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.net = nn.Sequential(
            # 卷积层
            nn.Conv2d(in_channels=1, out_channels=16,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            # 池化层
            nn.MaxPool2d(kernel_size=2),
            # 卷积层
            nn.Conv2d(16, 32, 3, 1, 1),
            # 池化层
            nn.MaxPool2d(2),
            # 将多维输入一维化
            nn.Flatten(),
            nn.Linear(32*7*7, 4000), 
            # 激活函数
            nn.ReLU(),
            nn.Linear(4000, 10000), 
            nn.ReLU(),
            nn.Linear(10000, 10)
        )
    def forward(self, x):
        return self.net(x)

# 下载数据集
train_data = torchvision.datasets.MNIST(
    root='mnist',
    download=True,
    train=True,
    transform=torchvision.transforms.ToTensor()
)

# 定义训练相关参数
# batch_size = 64   
batch_size = 128
model = CNN().to(device)  # 定义模型
train_dataloader = DataLoader(train_data, batch_size=batch_size)    # 定义DataLoader
loss_func = nn.CrossEntropyLoss().to(device)    # 定义损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)    # 定义优化器

scaler = amp.GradScaler()    # 在模型、优化器定义之后,定义GradScaler

epochs = 20  # 设置循环次数

# 设置循环
for epoch in range(epochs):
    for imgs, labels in train_dataloader:
        start_time = time.time()    # 记录训练开始时间
        imgs = imgs.to(device)    # 把img数据放到指定NPU上
        labels = labels.to(device)    # 把label数据放到指定NPU上
        
        with amp.autocast(): 
            outputs = model(imgs)    # 前向计算
            loss = loss_func(outputs, labels)    # 损失函数计算
        optimizer.zero_grad()

        # 进行反向传播前后的loss缩放、参数更新
        scaler.scale(loss).backward()    # loss缩放并反向转播
        scaler.step(optimizer)    # 更新参数(自动unscaling)
        scaler.update()    # 基于动态Loss Scale更新loss_scaling系数
        

# 定义保存模型
torch.save({
               'epoch': 10,
               'arch': CNN,
               'state_dict': model.state_dict(),
               'optimizer' : optimizer.state_dict(),
            },'checkpoint.pth.tar')

 使用 "python train.py" 运行代码后,我们可以通过以下命令查看昇腾NPU的使用情况:

watch -n 1 npu-smi info

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1491789.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

力扣——盛最多水的容器

题目描述: 给定一个长度为 n 的整数数组 height 。有 n 条垂线,第 i 条线的两个端点是 (i, 0) 和 (i, height[i]) 。 找出其中的两条线,使得它们与 x 轴共同构成的容器可以容纳最多的水。 返回容器可以储存的最大水量。 说明:…

ZYNQ--PS_PL交互(AXI_HP)

AXI_HP接口 通过AXI_HP接口,可直接通过AXI_FULL协议向DDR中通过DMA传输数据。 BD设计 AXI_HP接口设置 AXI_Master代码 module axi_full_master #(parameter C_M_TARGET_SLAVE_BASE_ADDR = 32h40000000,parameter integer C_M_AXI_BURST_LEN = 16,parameter integer …

【送书活动1】基于React低代码平台开发:构建高效、灵活的应用新范式

【送书活动1】基于React低代码平台开发:构建高效、灵活的应用新范式 写在最前面一、React与低代码平台的结合优势二、基于React的低代码平台开发挑战三、基于React的低代码平台开发实践四、未来展望《低代码平台开发实践:基于React》编辑推荐内容简介作者…

ttkefu在线客服如何获取代码

注册并登录ttkefu账号。可以在ttkefu的官方网站(https://www.ttkefu.com/)上进行注册和登录。下载并安装ttkefu的PC端软件。可以在官方网站上的下载页面(https://www.ttkefu.com/download.html)找到下载链接。在软件中获取代码。登…

day12_oop_抽象和接口

今日内容 零、 复习昨日 一、作业 二、抽象 三、接口 零、 复习昨日 final的作用 修饰类,类不能被继承修饰方法,方法不能重写[重点]修饰变量/属性,变成常量,不能更改 static修饰方法的特点 static修饰的方法,可以通过类名调用 static修饰的属性特点 在内存只有一份,被该类的所有…

AI应用开发-python字符串转字典

AI应用开发相关目录 本专栏包括AI应用开发相关内容分享,包括不限于AI算法部署实施细节、AI应用后端分析服务相关概念及开发技巧、AI应用后端应用服务相关概念及开发技巧、AI应用前端实现路径及开发技巧 适用于具备一定算法及Python使用基础的人群 AI应用开发流程概…

10 OpenCV 形态学的应用

文章目录 算子形态学提取直线示例 算子 adaptiveThreshold 二值化算子 adaptiveThreshold(src, dstNone,maxValue, adaptiveMethod, thresholdType, blockSize, C, ) /* *src:灰度化的图片 *dst:输出图像,可选 *maxValue:满足条件…

冒泡经典题

📑前言 本文主要是【】——简单使用的文章,如果有什么需要改进的地方还请大佬指出⛺️ 🎬作者简介:大家好,我是听风与他🥇 ☁️博客首页:CSDN主页听风与他 🌄每日一句:狠…

2023年09月CCF-GESP编程能力等级认证Scratch图形化编程三级真题解析

本文收录于专栏《Scratch等级认证CCF-GESP真题解析》,专栏总目录・点这里 一、单选题(共15题,共30分) 第1题 我国第一台大型通用电子计算机使用的逻辑部件是( )。 A:集成电路 B:大规模集成电路 C:晶体管 D:电子管 答案:D 第2题 下列流程图的输出结果是?( ) …

H5小游戏,象棋

H5小游戏源码、JS开发网页小游戏开源源码大合集。无需运行环境,解压后浏览器直接打开。有需要的,私信本人,发演示地址,可以后再订阅,发源码,含60+小游戏源码。如五子棋、象棋、植物大战僵尸、开心消消乐、扑鱼达人、飞机大战等等 <!DOCTYPE html PUBLIC "-//W3C/…

MySQL相关问题

MySQL相关问题 一、MySQL支持哪些存储引擎&#xff1f;二、MySQL是如何执行一条SQL的&#xff1f;三、MySQL数据库InnoDB存储引擎是如何工作的&#xff1f;四、如果要对数据库进行优化&#xff0c;该怎么优化&#xff1f;五、MySQL如何定位慢查询&#xff1f;六、如何分析MySQL…

Java生成 word报告

Java生成 word报告 一、方案比较二、Apache POI 生成三、FreeMarker 生成 在网上找了好多天将数据库信息导出到 word 中的解决方案&#xff0c;现在将这几天的总结分享一下。总的来说&#xff0c;Java 导出 word 大致有 5 种。 一、方案比较 1. Jacob Jacob 是 Java-COM Bri…

洛谷:P3068 [USACO13JAN] Party Invitations S(枚举、前缀和)

这题我们数据范围太大&#xff0c;用二维肯定是不行的&#xff0c;我们可以采用一维线性存储。 如题意&#xff0c;我们可以将每组奶牛编号都存在一维数组里面&#xff0c;只需记录每组的头尾指针就可以了。 如题中样例我们就可以存储成1 3 3 4 1 2 3 4 5 6 7 4 3 2 1 然后第…

C语言——结构体(位段)、联合体、枚举

hello&#xff0c;大家好&#xff01;我是柚子&#xff0c;今天给大家分享的内容是C语言中的自定义类型结构体、联合体以及枚举&#xff0c;有什么疑问或建议可以在评论区留言&#xff0c;会顺评论区回访哦~ 一、结构体 struct a.结构体声明 不同于数组的是&#xff0c;结构…

删除有序链表中重复的数字Ⅱ

题目 题目链接 删除有序链表中重复的元素-II_牛客题霸_牛客网 题目描述 代码实现 class Solution { public:/*** 代码中的类名、方法名、参数名已经指定&#xff0c;请勿修改&#xff0c;直接返回方法规定的值即可** * param head ListNode类 * return ListNode类*/ListNod…

旋转链表00

题目链接 旋转链表 题目描述 注意点 链表中节点的数目在范围 [0, 500] 内 解答思路 因为k可能比链表长度大&#xff0c;所以需要先找到链表的长度len&#xff0c;同时储存尾节点&#xff08;需要将尾节点与首节点相连&#xff09;&#xff0c;根据k % len计算出链表需要向…

O2OA(翱途)通过服务来调用接口实现单点登录案例

本文介绍O2OA服务管理中&#xff0c;接口的权限设定和调用方式。 创建接口 具有服务管理设计权限的用户&#xff08;具有ServiceManager角色或Manager角色&#xff09;打开“服务管理平台”&#xff0c;进入接口配置视图&#xff0c;点击左上角的新建按钮&#xff0c;可创建一…

langchain学习笔记(十一)

关于langchain中的memory&#xff0c;即对话历史&#xff08;message history&#xff09; 1、 Add message history (memory) | &#x1f99c;️&#x1f517; Langchain RunnableWithMessageHistory&#xff0c;可用于任何的chain中添加对话历史&#xff0c;将以下之一作为…

IntelliJ IDEA插件php golang python shell docker ignore UML plantuml等插件安装

IntelliJ IDEA插件php golang python shell docker ignore UML plantuml等插件安装 有的插件,需要代理才能搜索和下载 设置代理 不然插件搜索不到&#xff0c;也可能下载不了 Preferences -->Plugins --> Browse repositorise… --〉HTTP Proxy Settings… 选择 Manual…

FreeRTOS操作系统学习——内存管理

C库函数与FreeRTOS内存管理区别 在C语言的库函数中&#xff0c;有mallc、free等函数可以申请以及释放内存空间&#xff0c;那么这为什么不适用于FreeRTOS的内存管理呢&#xff1f; 不适合用在资源紧缺的嵌入式系统中这些函数的实现过于复杂、占据的代码空间太大并非线程安全的…