基于pytorch_lightning测试resnet18不同激活方式在CIFAR10数据集上的精度

news2024/9/9 1:07:39

基于pytorch_lightning测试resnet18不同激活方式在CIFAR10数据集上的精度

  • 一.曲线
    • 1.train_acc
    • 2.val_acc
    • 3.train_loss
    • 4.lr
  • 二.代码

本文介绍了如何基于pytorch_lightning测试resnet18不同激活方式在CIFAR10数据集上的精度
特别说明:
1.NoActive:没有任何激活函数
2.SparseActivation:只保留topk的激活,其余清零,topk通过训练得到[初衷是想让激活变得稀疏]
3.SelectiveActive:通过训练得到使用的激活函数
可参考的代码片段
1.pytorch_lightning 如何使用
2.pytorch如何替换激活函数
3.如何对自定义权值做衰减

一.曲线

1.train_acc

在这里插入图片描述

2.val_acc

在这里插入图片描述

3.train_loss

在这里插入图片描述

4.lr

在这里插入图片描述

二.代码

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import os
import numpy as np
from pytorch_lightning.loggers import TensorBoardLogger

#torch.set_float32_matmul_precision('medium')

class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(outchannel)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(outchannel)
            )
        self.act=nn.ReLU()

    def forward(self, x):
        out = self.left(x)
        out += self.shortcut(x)
        out = self.act(out)
        return out

class ResNet(nn.Module):
    def __init__(self, ResidualBlock, num_classes=10):
        super(ResNet, self).__init__()
        self.inchannel = 64
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.layer1 = self.make_layer(ResidualBlock, 64,  2, stride=1)
        self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
        self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
        self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
        self.fc = nn.Linear(512, num_classes)
        self.dropout=nn.Dropout(0.5)

    def make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.inchannel, channels, stride))
            self.inchannel = channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.dropout(out)
        out = self.fc(out)
        return out
  
class SparseActivation(nn.Module):
    act_array=[x.cuda() for x in [nn.ReLU(),
                nn.ReLU6(),
                nn.Sigmoid(),
                nn.Hardsigmoid(),
                nn.GELU(),
                nn.SiLU(),
                nn.Mish(),
                nn.LeakyReLU(),
                nn.Hardswish(),
                nn.PReLU(),
                nn.SELU(),
                nn.Softplus(),
                nn.Softsign()]]
                    
    def __init__(self,args):
        super(SparseActivation, self).__init__()
        self.input_weights = nn.Parameter(torch.randn(1)).cuda()
        self.act=SparseActivation.act_array
        self.act_weights = nn.Parameter(torch.randn(len(self.act))).cuda()
        self.args=args
        
    def forward(self, x):        
        
        index=self.args.act
        if index>=0:
            index=index-1
            if index==-1:
                prob=F.softmax(self.act_weights,dim=0)
                _, index = torch.topk(prob, 1, dim=0)
            x=self.act[index](x)
        
        if self.args.sparse==0:
            return x
            
        input=x.flatten(1)
        input_weights = torch.sigmoid(self.input_weights)        
        topk = input.size(1)*input_weights
        topk=topk.int()
        topk_vals, topk_indices = torch.topk(input, topk, dim=1)
        mask = torch.zeros_like(input).scatter(1, topk_indices, topk_vals)
        return mask.reshape(x.shape)
            
class LitNet(pl.LightningModule):
    def __init__(self, args):
        super(LitNet, self).__init__()
        self.save_hyperparameters()
        self.args = args
        self.resnet18 = ResNet(ResidualBlock)
        self.criterion = nn.CrossEntropyLoss()
        self.ws=[]
        self.replace_activation(self.resnet18,nn.ReLU, SparseActivation,self.ws)    
        
    def replace_activation(self,module, old_activation, new_activation,ws):
        for name, child in module.named_children():
            if isinstance(child, old_activation):
                op=new_activation(self.args)
                ws.append(op.input_weights)
                setattr(module, name,op)
            else:
                self.replace_activation(child, old_activation, new_activation,ws)        
        
    def forward(self, x):
        return self.resnet18(x)

    def on_train_epoch_start(self):
        self.train_total_loss=[]
        self.train_total_acc=[]

    def on_train_epoch_end(self):
        self.log('epoch_train_loss', np.mean(self.train_total_loss))
        self.log('epoch_train_acc', np.mean(self.train_total_acc)) 
        self.log("lr",self.optimizer.state_dict()['param_groups'][0]['lr'])
        
    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = self.criterion(output, target)
        
        l2_reg = torch.tensor(0.).cuda()
        l2_lambda=0.001
        for param in self.ws:
            l2_reg += torch.norm(param+4)                    
        loss += l2_lambda * l2_reg        
        self.log('iter_train_loss', loss)

        _, predicted = torch.max(output.data, 1)
        correct = (predicted == target).sum()
        acc = 100. * correct / target.size(0)      
        self.train_total_loss.append(loss.item())
        self.train_total_acc.append(acc.item())
        
        return loss       

    def on_validation_epoch_start(self):
        self.val_total_loss=[]
        self.val_total_acc=[]

    def on_validation_epoch_end(self):
        self.log('epoch_val_loss', np.mean(self.val_total_loss))
        self.log('epoch_val_acc', np.mean(self.val_total_acc))

    def validation_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        _, predicted = torch.max(output.data, 1)
        correct = (predicted == target).sum()
        acc = 100. * correct / target.size(0)
        loss = self.criterion(output, target)        
        self.val_total_loss.append(loss.item())
        self.val_total_acc.append(acc.item())

    def test_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = self.criterion(output, target)
        self.log('test_loss', loss)
        return loss
        
    def configure_optimizers(self):
        self.optimizer = optim.SGD(self.parameters(), lr=self.args.lr, momentum=0.9,weight_decay=5e-4)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer,step_size=10,gamma = 0.8)            
        return [self.optimizer],[self.scheduler]

class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def setup(self, stage=None):
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4), 
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        self.train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
        self.test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size,shuffle=True,num_workers=2,persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size,shuffle=False,num_workers=2,persistent_workers=True)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)


def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=128, metavar='N',help='input batch size for training (default: 64)')
    parser.add_argument('--epochs', type=int, default=100, metavar='N',help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',help='learning rate (default: 1.0)')
    parser.add_argument('--act', type=int, default=-1,help='learning rate (default: 1.0)')
    parser.add_argument('--sparse', type=int, default=0,help='learning rate (default: 1.0)')
    args = parser.parse_args()

    cifar10_data = CIFAR10DataModule(batch_size=args.batch_size)
    log_dir = "lightning_logs"
    
    
    args.sparse=0   #不开启稀疏
    args.act=0      #自适应激活
    model = LitNet(args)
    
    logger = TensorBoardLogger(save_dir=log_dir, name="SelectiveActive")    
    trainer = pl.Trainer(logger=logger,devices=1,max_epochs=args.epochs,val_check_interval=1.0,gradient_clip_val=0.9, gradient_clip_algorithm="value")
    trainer.fit(model, cifar10_data)    
    
    args.sparse=0     #不开启稀疏
    args.act=-1       #不用激活
    model = LitNet(args)    
    cifar10_data = CIFAR10DataModule(batch_size=args.batch_size)
    
    logger = TensorBoardLogger(save_dir=log_dir, name="NoActive")    
    trainer = pl.Trainer(logger=logger,devices=1,max_epochs=args.epochs,val_check_interval=1.0,gradient_clip_val=0.9, gradient_clip_algorithm="value")
    trainer.fit(model, cifar10_data)  
   
    args.sparse=1
    args.act=-1       #不用激活,开启稀疏
    model = LitNet(args)       
    
    logger = TensorBoardLogger(save_dir=log_dir, name="SparseActivation")    
    trainer = pl.Trainer(logger=logger,devices=1,max_epochs=args.epochs,val_check_interval=1.0,gradient_clip_val=0.9, gradient_clip_algorithm="value")
    trainer.fit(model, cifar10_data)  

    for idx,act_name in enumerate(SparseActivation.act_array):
        name=act_name.__class__.__name__
        print(name)
        
        args.act=idx+1
        args.sparse=0
        model = LitNet(args)     
        
        logger = TensorBoardLogger(save_dir=log_dir, name=name)    
        trainer = pl.Trainer(logger=logger,devices=1,max_epochs=args.epochs,val_check_interval=1.0,gradient_clip_val=0.9, gradient_clip_algorithm="value")
        trainer.fit(model, cifar10_data)

if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1807755.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【LeetCode 第 401 场周赛】K秒后第 N 个元素的值

文章目录 1. K秒后第 N 个元素的值🆗 1. K秒后第 N 个元素的值🆗 题目链接🔗 🐧解题思路: 前缀和 小规律🍎 🍎 从上图观察可知,规律一目了然,arr[i] arr[i] 对上一…

超越出身与学府:揭秘成功者共有的七大特质

在当今多元化的世界里,个人成功的故事如同繁星点点,照亮了无数追梦者的前行之路。新东方创始人俞敏洪先生曾深刻地指出,真正的成功并不取决于家庭背景的显赫与否,也不在于就读大学的名气大小,而是深深植根于个人内在的…

知识图谱的应用---智慧农业

文章目录 智慧农业典型应用 智慧农业 智慧农业通过生产领域的智能化、经营领域的差异性以及服务领域的全方位信息服务,推动农业产业链改造升级;实现农业精细化、高效化与绿色化,保障农产品安全、农业竞争力提升和农业可持续发展。目前,我国的…

战略引领下的成功产品开发之路

在当今竞争激烈的市场环境中,成功的产品开发不仅仅依赖于创意和技术的卓越,更需要战略性的规划和执行。本文将探讨战略在成功产品开发中的重要性,并结合实际案例,分析如何在战略的指引下,将创意转化为商业化的产品或服…

nginx mirror流量镜像详细介绍以及实战示例

nginx mirror流量镜像详细介绍以及实战示例 1.nginx mirror作用2.nginx安装3.修改配置3.1.nginx.conf3.2.conf.d目录下添加default.conf配置文件3.3.nginx配置注意事项3.3.nginx重启 4.测试 1.nginx mirror作用 为了便于排查问题,可能希望线上的请求能够同步到测试…

TMS320F280049学习3:烧录

TMS320F280049学习3:烧录 文章目录 TMS320F280049学习3:烧录前言一、烧录RAM二、烧录FLASH总结 前言 DSP的烧录分为两种,一种是将程序烧录到RAM中,一种是烧录到FLASH中,烧录ARM中的程序,只要未掉电&#x…

Linux驱动应用编程(四)IIC(获取BMP180温度/气压数据)

本文目录 一、基础1. 查看开发板手册,获取可用IIC总线2. 挂载从机,查看从机地址。3. 查看BMP180手册,使用命令读/写某寄存器值。4. 查看BMP180手册通信流程。 二、IIC常用API1. iic数据包/报2. ioctl函数 三、数据包如何被处理四、代码编写流…

配网终端通讯管理板,稳控装置通讯管理卡,铁路信号通讯管理卡

配网终端通讯管理板 ● 配网终端通讯管理板 ● ARM Cortex™-A5 ,533MHz ● 256MB RAM,512MB FLASH 配网终端通讯管理板 ARM Cortex™-A5 ,533MHz 256MB RAM,512MB FLASH 2x10/100/1000Mbps LAN(RJ45) 6x…

FastAPI系列 4 -路由管理APIRouter

FastAPI系列 -路由管理APIRouter 文章目录 FastAPI系列 -路由管理APIRouter一、前言二、APIRouter使用示例1、功能拆分2、users、books模块开发3、FastAPI主体 三、运行结果 一、前言 未来的py开发者请上座,在使用python做为后端开发一个应用程序或 Web API&#x…

MySQL数据库---LIMIT、EXPLAIN详解

分页查询 语法 select _column,_column from _table [where Clause] [limit N][offset M]select * : 返回所有记录limit N : 返回 N 条记录offset M : 跳过 M 条记录, 默认 M0, 单独使用似乎不起作用 limit N,M : 相当于 limit M offset N , 从第 N 条记录开始, 返回 M 条记录…

贪心算法学习三

例题一 解法(贪⼼): 贪⼼策略: ⽤尽可能多的字符去构造回⽂串: a. 如果字符出现偶数个,那么全部都可以⽤来构造回⽂串; b. 如果字符出现奇数个,减去⼀个之后,剩下的…

对象存储OSS 客户端签名直传的安全风险和解决方法

1. 前言 阿里云对象存储OSS(Object Storage Service)是一款海量、安全、低成本、高可靠的云存储服务,可提供99.9999999999%(12个9)的数据持久性,99.995%的数据可用性。多种存储类型供选择,全面…

AXI Quad SPI IP核中命令的使用

1 双通道SPI和混合内存模式下支持的常用命令 对于配置中Mode设置为Dual且Slave Device设置为Mixed的情况,IP核支持表3-1中列出的命令。这些命令在Winbond、Micron和Spansion内存设备上具有相同的命令、地址和数据行为。 某些命令,如fast read、dual I/…

产品创新:驱动企业增长的核心动力

在当今快速变化的市场环境中,产品创新已成为企业生存和发展的关键。产品创新不仅涉及全新产品或服务的开发,也包括对现有产品或服务的持续改进和优化。本文将深入探讨产品创新的定义、重要性以及如何通过创新驱动企业增长,并结合实际案例进行…

每位比特币人都终将成为一个国际主义者

原创 | 刘教链 周末BTC(比特币)趁势向着30日均线回归,现于69k一线悬停。7万刀以下加仓的机会窗口,和那蹉跎一生的岁月一样,过一天少一天,在每个纠结和拧巴的日子里,在软弱和彷徨的等待中&#x…

Python 算法交易实验71 QTV200数据流设计

说明 结构作为工程的基础,应该在最初的时候进行合理设计。这一次版本迭代,我希望最终实现的效果,除了在财务方法可以达到预期,在工程方面应该可以支持长期的维护、演进。 内容 1 财务表现期待 假设初始为60万资金作为主动资金…

Java学习 - MyBatis - 初识MyBatis

前言 什么是持久化 持久化是将程序数据在持久状态和瞬时状态间转换的机制,将数据保存到可永久保存的存储设备中。最常见的就是将内存中的对象存储在数据库中,或者存在磁盘文件、XML 数据文件中等等。其中,文件 IO 属于持久化机制&#xff0…

Web后端开发(请求-数组集合、日期、JSON参数)(三)

数组参数:请求参数名与形参数组名称相同且请求参数为多个,定义数组类型形参即可接收参数 RequestMapping("/arrayParam") public String arrayParam(String[] hobby){System.out.println(Arrays.toString(hobby));return "OK"; } …

Spring Event如何优雅实现系统业务解耦

Spring Event如何优雅实现系统业务解耦 一、介绍 Spring事件(Spring Event)是Spring框架的一项功能,它允许不同组件之间通过发布-订阅机制进行解耦的通信。在Spring中,事件是表示应用程序中特定事件的对象,例如用户注…

AI数据分析:根据Excel表格数据绘制柱形图

工作任务:将Excel文件中2013年至2019年间线上图书的销售额,以条形图的形式呈现,每个条形的高度代表相应年份的销售额,同时在每个条形上方标注具体的销售额数值 在deepseek中输入提示词: 你是一个Python编程专家&#…