[PyTorch][chapter 50][自定义网络 ResNet18]

news2024/9/28 12:16:27

前言:

        这里结合一个ResNet-18 网络,讲解一下自己定义一个深度学习网络的完整流程。

经过20轮的训练,测试集上面的精度85%

一   残差块定义

针对图像处理有两种结构,下面代码左右实现的是左边的结构.

# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 12:00:57 2023

@author: chengxf2
"""

import torch 
from torch import nn
from torch.nn import functional as F

class ResBlk(nn.Module):
  
    """
    resnet block
    """
    def __init__(self, in_ch, out_ch, step):
       
        super(ResBlk, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = in_ch,
                               out_channels = out_ch,
                               kernel_size =3,
                               stride =step,
                               padding=1)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(in_channels = out_ch,
                               out_channels = out_ch,
                               kernel_size =3,
                               stride =1,
                               padding=1)
        
        self.bn2 = nn.BatchNorm2d(out_ch)
        
        self.extra = nn.Sequential()
        
        #残差块部分
        if in_ch != out_ch:
            
            self.extra = nn.Sequential(
                #[b,in_ch, h,w]=>[b, out_ch, h,w]
                nn.Conv2d(in_ch, out_ch, kernel_size=1, stride = step),
                nn.BatchNorm2d(out_ch)
                )
        
    def forward(self,x):
        
        """
        param x: [b ,ch, h,w]
        return 
        """
        
        print(x.shape)
        
        conv = self.conv1(x)
        bn1 = self.bn1(conv)
        out = F.relu(bn1)

        
        conv = self.conv2(out)
        bn2 = self.bn2(conv)
        out = F.relu(bn2)

        out = self.extra(x)+out
        out = F.relu(out)
        
        return out
    


    
    
        
        

      


二 定义网络

# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 14:22:34 2023

@author: chengxf2
"""

import torch 
from torch import nn
from torch.nn import functional as F
from ResBlock import ResBlk


class ResNet18(nn.Module):
    
    def __init__(self, num_class):
        
        super(ResNet18, self).__init__()
        
        conv = nn.Conv2d(in_channels = 3,
                               out_channels = 16,
                               kernel_size =3,
                               stride =2,
                               padding=0)
        bn = nn.BatchNorm2d(16)
        
        self.conv1 = nn.Sequential(conv, bn)
        
        #followed 4 blocks
        
        #[b,16,h,w]=>[b,32,h,w]
        self.blk1 = ResBlk(16, 32, 3)
        
        #[b,16,h,w]=>[b,32,h,w]
        self.blk2 = ResBlk(32, 64, 3)
        
        #[b,16,h,w]=>[b,32,h,w]
        self.blk3 = ResBlk(64, 128, 3)
        
        #[b,16,h,w]=>[b,32,h,w]
        self.blk4 = ResBlk(128, 256, 3)
        
     
        
        self.fc = nn.Linear(256*2*2, num_class) 
        
    def forward(self, x):
        
        
        a = self.conv1(x)
        a = F.relu(a)
        print("\n a ",a.shape)
        a = self.blk1(a)
        a = self.blk2(a)
        a = self.blk3(a)
        a = self.blk4(a)
        
        #print(x.shape)
        print("\n fc a: ",a.shape)
        a = a.view(a.size(0),-1) #Flatten
        y = self.fc(a)
        
        return y
    

def main():
    
    blk = ResBlk(64, 128,2)
    #tmp: [batch, channel, width, height]
    tmp = torch.randn(2,64,224,224)
    out = blk(tmp)
    print("\n resBlock: ",out.shape)
    
    
    model =ResNet18(5)
    
    tmp = torch.randn(2,3,224,224)
    
    out = model(tmp)
    
    print("resnet-18 ",out.shape)
    
    #numbel是指tensor占用内存的数量
 
    mp =map(lambda p:p.numel(),  model.parameters())
    sz = sum(mp)
    print("\n parameters size ",sz)
   

if __name__ == "__main__":
    
     main() 
        
        
    
        
        
        

三 Train& Test

   逻辑如下:

   先使用训练集数据训练

    使用验证集数据过拟合检查,保存模型参数

    加载模型参数,进行测试

# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 15:28:13 2023

@author: chengxf2
"""



for epoch in range(epochs):
    
    train(train_db)
    
    if epoch %10 ==0:
        
        val_acc = evaluate(val_db)
        
        if val_ass is the best:
            #报错模型参数,防止过拟合
            save_ckpt()
        
        if out_of_patience():
            
            break
#加载模型参数        
load_ckpt()

test_acc = evaluate(test_db)

四 训练,验证,测试部分完整代码

  

# -*- coding: utf-8 -*-
"""
Created on Tue Aug 15 15:38:18 2023

@author: chengxf2
"""

import torch
from torch import optim,nn
import visdom
from torch.utils.data import DataLoader
from ResNet_18 import ResNet18
from PokeDataset import Pokemon

batchNum = 32
lr = 1e-3
epochs = 20
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1234)

root ='pokemon'
resize =224

csvfile ='data.csv'
train_db = Pokemon(root, resize, 'train',csvfile)
val_db = Pokemon(root, resize, 'val',csvfile)
test_db = Pokemon(root, resize, 'test',csvfile)

train_loader = DataLoader(train_db, batch_size =batchNum,shuffle= True,num_workers=4)
val_loader = DataLoader(val_db, batch_size =batchNum,shuffle= True,num_workers=2)
test_loader = DataLoader(test_db, batch_size =batchNum,shuffle= True,num_workers=2)
viz = visdom.Visdom()

def evalute(model, loader):
    
    total =len(loader.dataset)
    correct =0
    for x,y in loader:
        
        x = x.to(device)
        y = y.to(device)
        
        with torch.no_grad():
            
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += torch.eq(pred, y).sum().float().item()
    
    acc = correct/total
    
    return acc   
        
        

def main():
    
    model = ResNet18(5).to(device)
    optimizer = optim.Adam(model.parameters(),lr =lr) 
    criteon = nn.CrossEntropyLoss()
    
    best_epoch=0,
    best_acc=0
    viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))
    viz.line([0],[-1],win='val_loss',  opts =dict(title='val_acc'))
    global_step =0
    
    for epoch in range(epochs):
        print("\n --main---: ",epoch)
        for step, (x,y) in enumerate(train_loader):
            #x:[b,3,224,224] y:[b]

             x = x.to(device)
             y = y.to(device)
             #print("\n --x---: ",x.shape)
             
             logits =model(x)
             loss = criteon(logits, y)
             #print("\n --loss---: ",loss.shape)
             optimizer.zero_grad()
             loss.backward()
             optimizer.step()
             
             viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')
             global_step +=1
             
        if epoch %2 ==0:
            
             val_acc = evalute(model, val_loader)
             
             if val_acc>best_acc:
                 best_acc = val_acc
                 best_epoch =epoch
                 torch.save(model.state_dict(),'best.mdl')
             print("\n val_acc ",val_acc)
             viz.line([val_acc],[global_step],win='val_loss',update='append')
    print('\n best acc',best_acc, "best_epoch: ",best_epoch)
    
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt')
    
    test_acc = evalute(model, test_loader)
    print('\n test acc',test_acc)
                 

if __name__ == "__main__":
    
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/887686.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于YOLOv8模型和Caltech数据集的行人检测系统(PyTorch+Pyside6+YOLOv8模型)

摘要 基于YOLOv8模型和Caltech数据集的行人检测系统可用于日常生活中检测与定位行人,利用深度学习算法可实现图片、视频、摄像头等方式的行人目标检测,另外本系统还支持图片、视频等格式的结果可视化与结果导出。本系统采用YOLOv8目标检测算法训练数据集…

2022年06月 C/C++(二级)真题解析#中国电子学会#全国青少年软件编程等级考试

第1题&#xff1a;小白鼠再排队 N只小白鼠(1 < N < 100)&#xff0c;每只鼠头上戴着一顶有颜色的帽子。现在称出每只白鼠的重量&#xff0c;要求按照白鼠重量从小到大的顺序输出它们头上帽子的颜色。帽子的颜色用 “red”&#xff0c;“blue”等字符串来表示。不同的小白…

实现简单纯Canvas文本输入框,新手适用

文章目录 概要效果技术细节代码 概要 Canvas上面提供输入&#xff1a; 一、最简单可能是用dom渲染一个input,覆盖在图形上面进行文本编辑&#xff0c;编辑完再把内容更新到图形.这样简单&#xff0c;但是缺点也明显&#xff0c;就是它不是真正绘制在canvas上面&#xff0c;没…

爬虫逆向实战(三)--天某云登录

一、数据接口分析 主页地址&#xff1a;天某云 1、抓包 通过抓包可以发现登录接口是account/login 2、判断是否有加密参数 请求参数是否加密&#xff1f; 通过“载荷”模块可以发现password、comParam_signature、comParam_seqCode是加密的 请求头是否加密&#xff1f; 无…

嵌入式学习之字符串

通过今天的学习&#xff0c;我主要提高了对sizeof 和 strlen、puts()、gets()、strcmp 、strncmp、strstr、strtok的理解。重点对sizeof的使用有了更加深刻的理解

【会议征稿信息】第二届信息学,网络与计算技术国际学术会议(ICINC2023)

2023年第二届信息学&#xff0c;网络与计算技术国际学术会议(ICINC2023) 2023 2nd International Conference on Informatics,Networking and Computing (ICINC 2023) 2023年第二届信息学&#xff0c;网络与计算技术国际学术会议(ICINC2023)将于2023年10月27-29日于中国武汉召…

MongoDB:数据库初步应用

一.连接MongoDB 1.MongoDBCompass连接数据库 连接路径:mongodb://用户名:密码localhost:27017/ 2.创建数据库(集合) MongoDB中数据库被称为集合. MongoDBCompass连接后,点击红色框加号创建集合,点击蓝色框加号创建文档(数据表) 文档中的数据结构(相当于表中的列)设计不用管…

mqtt学习记录

目录 1 匿名登录2 ⽤户名密码登录&#xff0c;配置接收的主题mosquitto 配置文件修改添加⽤户信息添加topic和⽤户的关系登录演示 1 匿名登录 ⾸先打开三个终端&#xff0c; 启动代理服务&#xff1a;mosquitto -v -v 详细模式 打印调试信息 默认占⽤&#xff1a;1883端⼝订阅…

机器学习笔记:线性链条件随机场(CRF)

0 引入&#xff1a;以词性标注为例 比如我们要对如下句子进行标注&#xff1a; “小明一把把把把住了”那么我么可能有很多种词性标注的方法&#xff0c;中间四个“把”&#xff0c;可以是“名词名词动词名词”&#xff0c;可以是“名词动词动词名词”等多种形式。 那么&#…

安装chromedriver 115,对应chrome版本115(经检验,116也可以使用)

目录 1. 查看Chrome浏览器的版本2. 找到对应的chromedriver3. 安装ChromeDriver 1. 查看Chrome浏览器的版本 点进这个网站查看&#xff1a;chrome://settings/help &#xff08;真是的&#xff0c;上一秒还是115版本&#xff0c;更新后就是116版本了&#xff0c;好在chromedi…

Python程序设计——列表

一、引言 关键点&#xff1a;一个列表可以存储任意大小的数据集合。 程序一般都需要存储大量的数值。假设&#xff0c;举个例子&#xff0c;需要读取100个数字&#xff0c;计算出它们的平均值&#xff0c;然后找出多少个数字是高于这个平均值的。程序首先读取100个数字并计算它…

基于 matplotlib module 的物理示意图绘制

基于 matplotlib module 的物理示意图绘制 # 创建画布和子图 fig, ax plt.subplots()# 去除 x 轴和 y 轴的边框线 ax.spines[bottom].set_visible(False) ax.spines[top].set_visible(False) ax.spines[left].set_visible(False) ax.spines[right].set_visible(False)# 隐藏 …

Azure如何调整虚拟机的大小

参考 https://blog.csdn.net/m0_48468018/article/details/132267096 创建虚拟机进入资源&#xff0c;点击大小选项&#xff0c;并对大小进行调整 点击如下图的cloud shell,进入Azure CLI,使用az vm resize 进行大小调整 命令中的g对应资源组&#xff0c;n对应虚拟机名称&am…

巧妙使用js IntersectionObserver实现dom懒加载

效果 源码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><title>IntersectionObserver</title></head><body style"text-align: center"><div id"container">…

数据库高性能架构模式

互联网业务兴起之后&#xff0c;海量用户加上海量数据的特点&#xff0c;单个数据库服务器已经难以满足业务需要&#xff0c;必须考虑数据库集群的方式来提升性能。高性能数据库集群的第一种方式是“读写分离”&#xff0c;第二种方式是“数据库分片”。 1、读写分离架构 **读…

Vue父子组件数据双向绑定

今天写一个功能时&#xff0c;遇到一些问题&#xff1a; 为什么子组件的contentList改变&#xff0c;也会将form中的trContentVOList的值改变&#xff1f; 吓的我立马去补充知识&#xff08;小白一枚&#xff09;,也借鉴了别的大佬的一些文章&#xff0c;这里自己整理一下&…

时序预测 | MATLAB实现基于CNN-GRU卷积门控循环单元的时间序列预测-递归预测未来(多指标评价)

时序预测 | MATLAB实现基于CNN-GRU卷积门控循环单元的时间序列预测-递归预测未来(多指标评价) 目录 时序预测 | MATLAB实现基于CNN-GRU卷积门控循环单元的时间序列预测-递归预测未来(多指标评价)预测结果基本介绍程序设计参考资料 预测结果 基本介绍 MATLAB实现基于CNN-GRU卷积…

华为网络篇 RIP的默认路由-30

难度2复杂度2 目录 一、实验原理 二、实验拓扑 三、实验步骤 四、实验过程 总结 一、实验原理 使用RIP搭建内部网络后&#xff0c;我们还需要在边界路由器进行相应的配置&#xff0c;否则无法与Internet通信。默认情况&#xff0c;内部的RIP路由器是不知道Internet的路由条…

Linux驱动开发之点亮三盏小灯

头文件 #ifndef __HEAD_H__ #define __HEAD_H__//LED1和LED3的硬件地址 #define PHY_LED1_MODER 0x50006000 #define PHY_LED1_ODR 0x50006014 #define PHY_LED1_RCC 0x50000A28 //LED2的硬件地址 #define PHY_LED2_MODER 0x50007000 #define PHY_LED2_ODR 0x50007014 #define…

人工智能驱动的视频分析技术:实时洞察与关键信息提供者

引言&#xff1a;人工智能在视频分析领域的应用为监控视频提供了更加智能化和高效的处理方式。通过实时分析监控视频&#xff0c;人工智能可以自动识别特定的对象、运动模式、区域异常等&#xff0c;并生成相关的报告和统计数据&#xff0c;为用户提供关键信息和洞察。本文将详…