前馈神经网络dropout实例

news2024/7/4 5:18:27

直接看代码。

(一)手动实现


import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt


#下载MNIST手写数据集  
mnist_train = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True, transform=transforms.ToTensor())  
mnist_test = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.ToTensor())  

#读取数据  
batch_size = 256 
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  


#初始化参数  
num_inputs,num_hiddens,num_outputs =784, 256,10

num_epochs=30

lr = 0.001

def init_param():
    W1 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens,num_inputs)), dtype=torch.float32)  
    b1 = torch.zeros(1, dtype=torch.float32)  
    W2 = torch.tensor(np.random.normal(0, 0.01, (num_outputs,num_hiddens)), dtype=torch.float32)  
    b2 = torch.zeros(1, dtype=torch.float32)  
    params =[W1,b1,W2,b2]
    for param in params:  
        param.requires_grad_(requires_grad=True)  
    return W1,b1,W2,b2

def dropout(X, drop_prob):
    X = X.float()
    assert 0 <= drop_prob <= 1
    keep_prob = 1 - drop_prob
    if keep_prob == 0:
        return torch.zeros_like(X)
    mask = (torch.rand(X.shape) < keep_prob).float()
    print(mask)
    return mask * X / keep_prob

def net(X, is_training=True):
    X = X.view(-1, num_inputs)
    H1 = (torch.matmul(X, W1.t()) + b1).relu()
    if is_training:
        H1 = dropout(H1, drop_prob)
    return (torch.matmul(H1,W2.t()) + b2).relu()


def train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr=None,optimizer=None):
    train_ls, test_ls = [], []
    for epoch in range(num_epochs):
        ls, count = 0, 0
        for X,y in train_iter:
            l=loss(net(X),y)
            
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            
            ls += l.item()
            count += y.shape[0]
            
        train_ls.append(ls)
        
        ls, count = 0, 0
        
        for X,y in test_iter:
            
            l=loss(net(X,is_training=False),y)
            
            ls += l.item()
            count += y.shape[0]
            
        test_ls.append(ls)
        
        if(epoch+1)%10==0:
            print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))
            
    return train_ls,test_ls


drop_probs = np.arange(0,1.1,0.1)

Train_ls, Test_ls = [], []

for drop_prob in drop_probs:

    W1,b1,W2,b2 = init_param()
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD([W1,b1,W2,b2],lr = 0.001)
    train_ls, test_ls =  train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr,optimizer)   
    Train_ls.append(train_ls)
    Test_ls.append(test_ls)
    
    
x = np.linspace(0,len(train_ls),len(train_ls))

plt.figure(figsize=(10,8))

for i in range(0,len(drop_probs)):
    plt.plot(x,Train_ls[i],label= 'drop_prob=%.1f'%(drop_probs[i]),linewidth=1.5)
    plt.xlabel('epoch')
    plt.ylabel('loss')
    
# plt.legend()
plt.legend(loc=2, bbox_to_anchor=(1.05,1.0),borderaxespad = 0.)
plt.title('train loss with dropout')
plt.show()

运行结果:

在这里插入图片描述

(二)torch.nn实现

import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

mnist_train = torchvision.datasets.MNIST(root='./MNIST', train=True, download=True, transform=transforms.ToTensor())  
mnist_test = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.ToTensor())  
batch_size = 256 
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  


class LinearNet(nn.Module):
    def __init__(self,num_inputs, num_outputs, num_hiddens1, num_hiddens2, drop_prob1,drop_prob2):
        super(LinearNet,self).__init__()
        self.linear1 = nn.Linear(num_inputs,num_hiddens1)
        self.relu = nn.ReLU()
        self.drop1 = nn.Dropout(drop_prob1)
        self.linear2 = nn.Linear(num_hiddens1,num_hiddens2)
        self.drop2 = nn.Dropout(drop_prob2)
        self.linear3 = nn.Linear(num_hiddens2,num_outputs)
        self.flatten  = nn.Flatten()
    
    def forward(self,x):
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.drop1(x)
        x = self.linear2(x)
        x = self.relu(x)
        x = self.drop2(x)
        x = self.linear3(x)
        y = self.relu(x)
        return y
    
    
    
    
def train(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,optimizer=None):
    train_ls, test_ls = [], []
    for epoch in range(num_epochs):
        ls, count = 0, 0
        for X,y in train_iter:
            l=loss(net(X),y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            ls += l.item()
            count += y.shape[0]
        train_ls.append(ls)
        ls, count = 0, 0
        for X,y in test_iter:
            l=loss(net(X),y)
            ls += l.item()
            count += y.shape[0]
        test_ls.append(ls)
        if(epoch+1)%5==0:
            print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))
    return train_ls,test_ls    
    
    
    
    
    
num_inputs,num_hiddens1,num_hiddens2,num_outputs =784, 256,256,10
num_epochs=20
lr = 0.001
drop_probs = np.arange(0,1.1,0.1)
Train_ls, Test_ls = [], []

for drop_prob in drop_probs:
    net = LinearNet(num_inputs, num_outputs, num_hiddens1, num_hiddens2, drop_prob,drop_prob)
    for param in net.parameters():
        nn.init.normal_(param,mean=0, std= 0.01)
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(),lr)
    train_ls, test_ls = train(net,train_iter,test_iter,loss,num_epochs,batch_size,net.parameters,lr,optimizer)
    Train_ls.append(train_ls)
    Test_ls.append(test_ls)
    
    
    
    
x = np.linspace(0,len(train_ls),len(train_ls))
plt.figure(figsize=(10,8))
for i in range(0,len(drop_probs)):
    plt.plot(x,Train_ls[i],label= 'drop_prob=%.1f'%(drop_probs[i]),linewidth=1.5)
    plt.xlabel('epoch')
    plt.ylabel('loss')
plt.legend(loc=2, bbox_to_anchor=(1.05,1.0),borderaxespad = 0.)
plt.title('train loss with dropout')
plt.show()


input = torch.randn(2, 5, 5)
m = nn.Sequential(
nn.Flatten()
)
output = m(input)
output.size()

运行结果:

在这里插入图片描述

关于dropout的原理,网上资料很多,一般都是用一个正态分布的矩阵,比较矩阵元素和(1-dropout),大于(1-dropout)的矩阵元素值的修正为1,小于(1-dropout)的改为1,将输入的值乘以修改后的矩阵,再除以(1-dropout)。

疑问:

  1. 数值经过正态分布矩阵的筛选后,还要除以 (1-dropout),这样做的原因是什么?
  2. Flatten层用来将输入“压平”,即把多维的输入一维化,常用在从卷积层到全连接层的过渡。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/894320.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

代码随想录-数组篇

2-二分查找 方法一&#xff1a; 左闭右闭&#xff0c;[left, right] class Solution { public:int search(vector<int>& nums, int target) {//[left, right]int left 0;int right nums.size() - 1 ;while(left < right){int middle left ((right - left)…

git 回滚相关问题

原本用as自带的git执行回滚任务&#xff0c; 但是提交之后发现并没有成功&#xff0c; 后面通过命令行的方式重新回滚并且提交上去&#xff0c;就可以了 说明as的git还是有点小瑕疵&#xff0c;还是命令行最稳妥 相关博文&#xff1a; git代码回滚操作_imkaifan的博客-CSDN博…

WebDAV之π-Disk派盘 + 那样记账

那样记账是一款个人记账应用,致力于提供简单和轻量的记账体验。以下是该应用的一些特点和功能: 1. 快速记账:那样记账提供多种直接记账方式,让您能够快速记录收入和支出。 2. 自定义:您可以自定义收支分类,以及记账的时间和金额。根据个人需求,随时修改和调整记账信息…

项目经理到底要不要懂技术?

大家好&#xff0c;我是老原。 “项目经理要不要懂技术&#xff1f;” 这个问题的争议性一直挺大&#xff0c;就好比先有的鸡还是先有的蛋&#xff0c;各方说各方有理。 我见过只负责流程不需要懂技术的pm&#xff0c;也见过pm连软硬件设计都有review的。 一般总结来说就是…

数据结构——链表详解

链表 文章目录 链表前言认识链表单链表结构图带头单循环链表结构图双向循环链表结构图带头双向循环链表结构图 链表特点 链表实现(带头双向循环链表实现)链表结构体(1) 新建头节点(2) 建立新节点(3)尾部插入节点(4)删除节点(5)头部插入节点(6) 头删节点(7) 寻找节点(8) pos位置…

GBU812-ASEMI新能源专用整流桥GBU812

编辑&#xff1a;ll GBU812-ASEMI新能源专用整流桥GBU812 型号&#xff1a;GBU812 品牌&#xff1a;ASEMI 封装&#xff1a;GBU-4 恢复时间&#xff1a;&#xff1e;50ns 正向电流&#xff1a;80A 反向耐压&#xff1a;1200V 芯片个数&#xff1a;4 引脚数量&#xff…

Windows基础安全知识

目录 常用DOS命令 ipconfig ping dir cd net user 常用DOS命令 内置账户访问控制 Windows访问控制 安全标识符 访问控制项 用户账户控制 UAC令牌 其他安全配置 本地安全策略 用户密码策略复杂性要求 强制密码历史&#xff1a; 禁止密码重复使用 密码最短使用期限…

【菜鸡读论文】MS-TCT: Multi-Scale Temporal ConvTransformer for Action Detection

【菜鸡读论文】MS-TCT: Multi-Scale Temporal ConvTransformer for Action Detection 大家好哇&#xff01;是谁美滋滋地准备开始放暑假了&#xff01;没错&#xff01;你没有听错&#xff01;放暑假&#xff01; 谁能想到都已经立秋了&#xff0c;竟然有人还在实验室&#xf…

java-IONIO

一、JAVA IO 1.1. 阻塞 IO 模型 最传统的一种 IO 模型&#xff0c;即在读写数据过程中会发生阻塞现象。当用户线程发出 IO 请求之后&#xff0c;内 核会去查看数据是否就绪&#xff0c;如果没有就绪就会等待数据就绪&#xff0c;而用户线程就会处于阻塞状态&#xff0c;用户线…

Codeforces Round 893 (Div. 2)B题题解

文章目录 [The Walkway](https://codeforces.com/contest/1858/problem/B)问题建模问题分析1.分析所求2.如何快速计算每个商贩被去除后的饼干数量代码 The Walkway 问题建模 给定n个椅子&#xff0c;其中有m个位置存在商贩&#xff0c;在商贩处必须购买饼干吃&#xff0c;每隔…

u8g2 自制字体

显示器 SSD1306 单片机ARDUINO NANO 使用U8G2 将表情生成字库文件 使用DRAWGLYPH 显示表情字库 GIF转成40X80 GIF转PNG PNG 转1位 PNG生成BDP BDP生成 C U8G2源代码的TOOL\FONT中包含了PNG转BDP BDP转.C 文件 下载原代码 &#xff1a; GitHub - olikraus/u8g2: U8gl…

python3 0基础学习----数据结构(基础+练习)

python 0基础学习笔记之数据结构 &#x1f4da; 几种常见数据结构列表 &#xff08;List&#xff09;1. 定义2. 实例&#xff1a;3. 列表中常用方法.append(要添加内容) 向列表末尾添加数据.extend(列表) 将可迭代对象逐个添加到列表中.insert(索引&#xff0c;插入内容) 向指定…

redis查看执行的命令+配置文件命令

1.SLOWLOG LEN 获取 Slowlog 的长度&#xff0c;以确定 Slowlog 中有多少条记录 2.SLOWLOG GET 获取 Slowlog 中的具体记录。你可以使用 SLOWLOG GET 命令来获取第 n 条记录的详细信息&#xff0c;其中 n 是记录的索引&#xff08;从 0 开始&#xff09; 3.如果你想获取多条最…

RFID赋能新能源电池生产的智慧演进

随着全球对可再生能源的需求不断增长&#xff0c;新能源电池作为储能和供电的重要组成部分&#xff0c;正逐渐成为关注的焦点。然而&#xff0c;新能源电池的生产过程中存在着一系列挑战&#xff0c;如追踪和管理电池的生命周期、确保质量和安全等。在这方面&#xff0c;RFID正…

【WPF】 本地化的最佳做法

【WPF】 本地化的最佳做法 资源文件英文资源文件 en-US.xaml中文资源文件 zh-CN.xaml 资源使用App.xaml主界面布局cs代码 App.config辅助类语言切换操作类资源 binding 解析类 实现效果 应用程序本地化有很多种方式&#xff0c;选择合适的才是最好的。这里只讨论一种方式&#…

微信公众平台发布小程序流程

最近因为部署小程序&#xff0c;学习了下如何部署小程序 1. 取消不检验合法域名并上传小程序 建议在小程序上传之前&#xff0c;先取消不校验合法域名并真机调试下。 2. 登录微信公众平台 登录微信公众平台 3. 设置服务器域名 在开放->开发管理->开发设置找到服务器…

Minio知识点+linux下安装+面试总结

一 Minio简介 MinIO 是一个基于Apache License v2.0开源协议的对象存储服务。它兼容亚马逊S3云存储服务接口&#xff0c;非常适合于存储大容量非结构化的数据&#xff0c;例如图片、视频、日志文件、备份数据和容器/虚拟机镜像等&#xff0c;而一个对象文件可以是任意大小&…

Word设置只读后,为什么还能编辑?

Word文档设置了只读模式&#xff0c;是可以编辑的&#xff0c;但是当我们进行保存的时候就会发现&#xff0c;word提示需要重命名并选择新路径才能够保存。 这种操作&#xff0c;即使可以编辑文字&#xff0c;但是原文件是不会受到影响的&#xff0c;编辑之后的word文件会保存到…

【Unity小技巧】Unity探究自制对象池和官方内置对象池(ObjectPool)的使用

文章目录 前言不使用对象池使用官方内置对象池应用 自制对象池总结源码参考完结 前言 对象池&#xff08;Object Pool&#xff09;是一种软件设计模式&#xff0c;用于管理和重用已创建的对象。在对象池中&#xff0c;一组预先创建的对象被维护在一个池中&#xff0c;并在需要时…

七夕节送礼物清单,总有一款他/她会喜欢!

马上就要到一年一度的七夕节了&#xff0c;你想好送给对方什么礼物了吗&#xff1f;送礼不一定是贵的好&#xff0c;但一定要表达出自己心意&#xff0c;也有人说&#xff0c;七夕不适合单身狗&#xff0c;其实是错的&#xff0c;单身狗正好可以趁七夕这个浪漫的节日&#xff0…