pytorch-激活函数与GPU加速

news2024/10/7 4:27:00

目录

  • 1. sigmod和tanh
  • 2. relu
  • 3. Leaky Relu
  • 4. selu
  • 5. softplus
  • 6. GPU加速
  • 7. 使用GPU加速手写数据训练

1. sigmod和tanh

sigmod梯度区间是0~1,当梯度趋近0或者1时会出现梯度弥散的问题。
tanh区间时-1~1,是sigmod经过平移和缩放而得到的,也存在梯度弥散的问题。
在这里插入图片描述

2. relu

relu函数当梯度<0时,梯度是0,梯度>0时梯度是1,不会出现梯度弥散和梯度爆炸,虽然relu函数使用广泛也不易出现梯度弥散和梯度爆炸,但是不代表它不会出现。
在这里插入图片描述

3. Leaky Relu

在梯度<0的时候,不在是等于0而是变成了a*x, a是一个比较小的系数,确保梯度小于0时不再是0
在这里插入图片描述

4. selu

由两部分组成一部分时Relu,另一部分是一个指数函数,从而使得selu在0点变成了连续的。
在这里插入图片描述

5. softplus

时relu的一个连续光滑的版本,在0处变得光滑而连续
在这里插入图片描述
总结:目前用的最大的sigmod、tanh、relu、leakyrelu,其他两种用的较少

6. GPU加速

torch.device(‘cuda:0’)中的cuda:0代表第几块显卡,如果使用CPU那么就是torch.device(‘cpu’)
使用.to(device)就把模块或者数据搬到了GPU上,然而模块和数据是有一些区别的,模块执行.to(device)返回一个reference和不使用初始化是完全一样的属于一个inplace操作,但是data就不一样了,比如:data2=data.to(device),data2和data是完全不一样的,data2是gpu数据,data是cpu数据。
注意:.cuda()方法已经不推荐使用了
在这里插入图片描述

7. 使用GPU加速手写数据训练

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transforms


batch_size=200
learning_rate=0.01
epochs=10

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=batch_size, shuffle=True)






class MLP(nn.Module):

    def __init__(self):
        super(MLP, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(784, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 10),
            nn.LeakyReLU(inplace=True),
        )

    def forward(self, x):
        x = self.model(x)

        return x

device = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss().to(device)

for epoch in range(epochs):

    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)
        data, target = data.to(device), target.cuda()

        logits = net(data)
        loss = criteon(logits, target)

        optimizer.zero_grad()
        loss.backward()
        # print(w1.grad.norm(), w2.grad.norm())
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data = data.view(-1, 28 * 28)
        data, target = data.to(device), target.cuda()
        logits = net(data)
        test_loss += criteon(logits, target).item()

        pred = logits.data.max(1)[1]
        correct += pred.eq(target.data).sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

从代码中可以看到网络、loss函数和数据都搬到了GPU上,激活函数改成了LeakyRelu

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1626055.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Golang | Leetcode Golang题解之第50题Pow(x,n)

题目&#xff1a; 题解&#xff1a; func myPow(x float64, n int) float64 {if n > 0 {return quickMul(x, n)}return 1.0 / quickMul(x, -n) }func quickMul(x float64, n int) float64 {if n 0 {return 1}y : quickMul(x, n/2)if n%2 0 {return y * y}return y * y * …

微服务组件-反向代理(Nginx)

微服务组件-反向代理(Nginx) Nginx 基本概念 1、nginx是什么&#xff1f; ①、Nginx (engine x) 是一个高性能的HTTP和反向代理web服务器同时也提供了IMAP/POP3/SMTP服务。它是一款轻量级的Web服务器/反向代理服务器及电子邮件&#xff08;IMAP/POP3&#xff09;代理服务器&a…

Linux--内核移植(一)Kernel编译启动

Linux内核编译 编译内核之前需要先在ubuntu上安装lzop库&#xff0c;另外&#xff0c;图形化配置工具还需要ncurses库支持&#xff0c;安装命令为&#xff1a; sudo apt-get install lzop sudo apt-get install build-essential sudo apt-get install libncurses5-dev 在U…

大数据时代的引擎:大数据架构随记

大数据架构通常可以分为以下几层&#xff1a; 一、数据采集层 负责从各种数据源采集、清洗、转换、丰富以及格式化数据&#xff0c;可能包括结构化、半结构化和非结构化的数据。 1.1、常用的技术 在大数据领域&#xff0c;数据采集是一个关键的环节&#xff0c;常用的数据采集…

如何安装sbt(sbt在ubuntu上的安装与配置)(有详细安装网站和图解)

sbt下载官网 选择对应的版本和安装程序 Download | sbt (scala-sbt.org) 安装 解压 将sbt-1.9.0.tgz上传到xshell&#xff0c;并解压 解压&#xff1a; tar -zxvf sbt-1.9.0.tgz 配置 1、在/home/hadoop/sbt中创建sbt脚本 /home/hadoop/sbt 注意要改成自己的地址 cd …

【Linux】详解信号产生的方式

一、kill命令 在命令行中通过kill -数字 pid指令可以给指定进程发送指定信号。这里说明一下几个常见的信号&#xff1a; SIGINT&#xff08;2号信号&#xff09;&#xff1a;中断信号&#xff0c;通常由用户按下CtrlC产生&#xff0c;用于通知进程终止。SIGQUIT&#xff08;3号…

小型内衣裤洗衣机哪个牌子好?六大选购锦囊私藏分享

内衣洗衣机是现代家庭必不可少的小家电&#xff0c;它不仅方便快捷&#xff0c;还能够保持衣物清洁和卫生。然而&#xff0c;市场上洗衣机品牌众多&#xff0c;质量和性能参差不齐&#xff0c;使得消费者购买时难以做出选择。那么&#xff0c;小型内衣裤洗衣机哪个牌子好&#…

企业OA管理|基于SprinBoot+vue的企业OA管理系统(源码+数据库+文档)

企业OA管理目录 基于SprinBootvue的企业OA管理系统 一、前言 二、系统设计 三、系统功能设计 1 管理员模块的实现 1.1 用户信息管理 1.2 公告信息管理 1.3 客户关系管理 1.4 通讯录管理 2 用户模块的实现 2.1 客户关系添加 2.2 通讯录添加 2.3 日程安排添加 四、…

7-32 说反话-加强版

题目链接&#xff1a;7-32 说反话-加强版 一. 题目 1. 题目 2. 输入输出样例 3. 限制 二、代码 1. 代码实现 str1 input().split(\n)[0] // 按行获取输入 list_str str1.split()[::-1] // 按空格分割为字符串组&#xff0c;然后将字符串组逆序 str1 .join(list_str) //…

LCD液晶显示屏强光老化测试设备太阳光模拟器仪器

1. LCD液晶显示屏老化测试的意义 LCD液晶显示屏老化测试是评估显示屏寿命和性能的重要手段。随着科技的发展&#xff0c;LCD液晶显示屏已经成为我们日常生活中不可或缺的一部分。长期使用后&#xff0c;LCD液晶显示屏可能会出现亮度下降、颜色失真、响应速度变慢等问题。通过进…

已解决java.lang.IllegalThreadStateException: 非法线程状态异常的正确解决方法,亲测有效!!!

已解决java.lang.IllegalThreadStateException: 非法线程状态异常的正确解决方法&#xff0c;亲测有效&#xff01;&#xff01;&#xff01; 目录 问题分析 场景描述 报错原因 解决思路 解决方法 检查线程状态 正确管理线程生命周期 异常处理 总结 博主v&#xff1a…

STM32中断实现旋转编码器计数

系列文章目录 STM32单片机系列专栏 C语言理论和实践总结专栏 文章目录 1. 旋转编码器 2. 中断代码编写 2.1 Interrupt.c 2.2 Interrupt.h 2.3 完整工程文件 1. 旋转编码器 旋转编码器主要用于测量轴的旋转位置、速度或者是角度的变化&#xff0c;它能够将转动的角度或者…

LeetCode57. 插入区间

LeetCode57.插入区间 题目思路: 代码 /* 前置知识&#xff1a; vector<vector<int>> a,b; 二维vector数组是可以将二维中的一维vector数组给push_back的&#xff0c; 不是只有单个元素才可以&#xff0c;整个一维的vector数组也可以 b[0] {1,2,3},b[1] {4,5,6}…

积极应对半导体测试挑战 加速科技助力行业“芯”升级

在全球半导体产业高速发展的今天&#xff0c;中国“芯”正迎来前所未有的发展机遇。AI、5G、物联网、自动驾驶、元宇宙、智慧城市等终端应用方兴未艾&#xff0c;为测试行业带来新的市场规模突破点&#xff0c;成为测试设备未来重要的增量市场。新兴领域芯片产品性能不断提升、…

如何解决IntelliJ IDEA 2024打开项目时频繁闪退问题

&#x1f42f; 如何解决IntelliJ IDEA 2024打开项目时频繁闪退问题 &#x1f43e; 文章目录 &#x1f42f; 如何解决IntelliJ IDEA 2024打开项目时频繁闪退问题 &#x1f43e;摘要引言正文&#x1f4d8; 识别问题&#x1f4d9; 内存配置调整步骤1: 定位vmoptions文件步骤2: 修改…

企业年度规划:你的未来,我们帮你“画”出来!

亲爱的朋友们&#xff0c;您是不是常常觉得企业运营就像一场没有剧本的戏&#xff0c;时而高歌猛进&#xff0c;时而摸黑前行&#xff1f;别慌&#xff0c;今天我们就来科普一下&#xff0c;如何给企业来一场精心策划的“年度大戏”——年度规划&#xff01; 首先&#xff0c;…

qt实现方框调整

效果 在四周调整 代码 #ifndef MAINWINDOW_H #define MAINWINDOW_H#include <QWidget>class MainWindow : public QWidget {Q_OBJECT public:explicit MainWindow(QWidget *parent 0);~MainWindow();void paintEvent(QPaintEvent *event);void updateRect();void re…

每年首版次测试报告的要求有哪些?

每年首版次测试报告的要求可能因不同的地区、行业或产品而有所差异&#xff0c;但一般而言&#xff0c;它们通常遵循一些基本的标准和原则。以下是一些常见的首版次测试报告要求&#xff1a; 完整性&#xff1a;测试报告应包含所有必要的测试内容&#xff0c;包括但不限于测试…

git merge 和 git rebese的区别

git merge 和 git rebese的区别 拉取分支和合并代码会涉及两种选择&#xff0c;git merge 和 git rebase&#xff1a; rebase&#xff1a;变基&#xff0c;会有一个干净的分支&#xff0c;但是对于记录来源不够清楚merge&#xff1a;合并&#xff0c;git 分支看起来比较混乱&…

Linux 调度优先级

Linux中的每个任务都有其优先级。这个优先级的范围从-20到19。优先级越低&#xff08;-20&#xff09;&#xff0c;分配 给任务的CPU时间就越多。默认的优先级是0。 并非所有的任务都需要使用相同的优先级。交互式应用要求快速响应&#xff0c;通过 crontab 运行的后台…