【深度学习】卷积网络代码实战ResNet

news2025/1/2 22:49:44

        ResNet (Residual Network) 是由微软研究院的何凯明等人在2015年提出的一种深度卷积神经网络结构。ResNet的设计目标是解决深层网络训练中的梯度消失和梯度爆炸问题,进一步提高网络的表现。下面是一个ResNet模型实现,使用PyTorch框架来展示如何实现基本的ResNet结构。这个例子包括了一个基本的残差块(Residual Block)以及ResNet-18的实现,代码结构分为model.py(模型文件)和train.py(训练文件)。

model.py 

      首先,我们导入所需要的包 

import torch
from torch import nn
from torch.nn import functional as F

        然后,定义Resnet Block(ResBlk)类。

class ResBlk(nn.Module):
    def __init__(self):
        super(ResBlk, self).__init__()
        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()
        if ch_out != ch_in
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1)
                nn.BatchNorm2d(ch_out)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(x)))
        out = self.extra(x) + out
        return out

        最后,根据ResNet18的结构对ResNet Block进行堆叠。

class Resnet18(nn.Module):
    def __init__(self):
        super(Resnet18, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
            nn.BatchNorm2d(64)
        )
        self.blk1 = ResBlk(64, 128)
        self.blk2 = ResBlk(128, 256)
        self.blk3 = ResBlk(256, 512)
        self.blk4 = ResBlk(512, 1024)
        self.outlayer = nn.Linear(512, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)
        
        # print('after conv1:', x.shape)
        x = F.adaptive_avg_pool2d(x, [1,1])
        x = x.view(x.size(0), -1)
        x = self.outlayer(x)
        return x


        其中,在网络结构搭建过程中,需要用到中间阶段的图片参数,用下述测试过程求得。

def main():
    tmp = torch.randn(2, 3, 32, 32)
    out = blk(tmp)
    print('block', out.shape)
    
    x = torch.randn(2, 3, 32, 32)
    model = ResNet18()
    out = model(x)
    print('resnet:', out.shape)

train.py

        首先,导入所需要的包

import torch
from torchvision import datasets
from torchvision import transforms
from torch import nn, optimizer

        然后,定义main()函数

def main():
    batchsz = 32
    cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
        ]), download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
    cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
        ]), download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
 
    x, label = iter(cifar_train).next()
    print('x:', x.shape, 'label:', label.shape)
    
    device = torch.device('cuda')
    model = ResNet18().to(device)
    criteon = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    print(model)
 
    for epoch in range(100):
        for batchidx, (x, label) in enumerate(cifar_train):
            x, label = x.to(device), label.to(device)
            logits = model(x)
            loss = criteon(logitsm label)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    print(loss.item())
 
 
    with torch.no_grad():
        total_correct = 0
        total_num = 0
        for x, label in cifar_test:
            x, label = x.to(device), label.to(device)
            logits = model(x)
            pred = logits.argmax(dim=1)
            total_correct += torch.eq(pred, label).floot().sum().item()
            total_num += x.size(0)
 
        acc = total_correct / total_num
        print(epoch, acc)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2267890.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

添砖java第四更@(+)@

今天的学习内容主要是围绕着实体类来进行的,就是说在java里面我们常常会把数据存放和方法分别存放在不同的类里面。 首先就是关于实体类是什么,实体类就是只提供了get方法,set方法,和默认构造器的类。 接着就是熟悉java与别的语言的不同之处就在于它是…

算法题(19):多数元素

审题: 数组不为空且一定存在众数。需要返回众数的数值 思路: 方法一:哈希映射 先用哈希映射去存储对应数据出现的次数,然后遍历找到众数并输出 当然也可以在第一次映射的过程中就维护一个出现次数最多的数据,这样子就可…

电子商务网站的三层架构的理解和实践

在电子商务领域,网站架构的设计对于系统的稳定性、可扩展性和用户体验至关重要。其中,三层架构作为一种经典的设计模式,被广泛应用于各类电子商务网站中。本文将从理论、理解和实践三个方面,详细探讨电子商务网站的三层架构。 一、…

LVS 负载均衡原理 | 配置示例

注:本文为 “ LVS 负载均衡原理 | 配置” 相关文章合辑。 部分内容已过时,可以看看原理实现。 使用 LVS 实现负载均衡原理及安装配置详解 posted on 2017-02-12 14:35 肖邦 linux 负载均衡集群是 load balance 集群的简写,翻译成中文就是负…

JavaScript甘特图 dhtmlx-gantt

背景 需求是在后台中,需要用甘特图去展示管理任务相关视图,并且不用依赖vue,兼容JavaScript原生开发。最终使用dhtmlx-gantt,一个半开源的库,基础功能免费,更多功能付费。 甘特图需求如图: 调…

领域驱动设计第一篇-DP主题

一:领域驱动设计概述 领域驱动设计。Domain-Driven Design 可以理解为基于领域的工程设计。 1:什么是领域? 初步理解领域:业务问题的范畴。 领域可大可小,对应着大小业务问题的边界。业务上要做的几个事&#xff0…

EMNLP'24 最佳论文解读 | 大语言模型的预训练数据检测:基于散度的校准方法

点击蓝字 关注我们 AI TIME欢迎每一位AI爱好者的加入! 点击 阅读原文 观看作者讲解回放! 作者简介 张伟超,中国科学院计算所网络数据科学与技术重点实验室三年级直博生 内容简介 近年来,大语言模型(LLMs)的…

IntelliJ IDEA 远程调试

IntelliJ IDEA 远程调试 在平时开发 JAVA 程序时,在遇到比较棘手的 Bug 或者是线上线下结果不一致的情况下,我们会通过打 Log 或者 Debug 的方式去定位并解决问题,两种方式各有利弊,今天就简要介绍下如何通过远程 Debug 的情况下…

【Webug】攻防实战详情

世界上只有一种真正的英雄主义,那就是认清了生活的真相后,仍然热爱她 显错注入 首先整体浏览网站 注入点: control/sqlinject/manifest_error.php?id1 判断注入类型 输入: and 11 正常, 再输入: and 12 还正常, 排除数字型 输入单引号:…

SpringMVC核心、两种视图解析方法、过滤器拦截器 “ / “ 的意义

SpringMVC的执行流程 1. Spring MVC 的视图解析机制 Spring MVC 的核心职责之一是将数据绑定到视图并呈现给用户。它通过 视图解析器(View Resolver) 来将逻辑视图名称解析为具体的视图文件(如 HTML、JSP)。 核心流程 Controlle…

CyclicBarrier线程辅助类的简单使用

文章目录 简述内部机制构造函数使用案例异常处理 简述 CyclicBarrier 是另一个用于协调多个线程之间操作的同步辅助类,它允许一组线程互相等待彼此到达一个共同的屏障点(barrier)。与 CountDownLatch 不同的是,CyclicBarrier 可以…

B站推荐模型数据流的一致性架构

01 背景 推荐系统的模型,通过学习用户历史行为来达到个性化精准推荐的目的,因此模型训练依赖的样本数据,需要包括用户特征、服务端推荐的视频特征,以及用户在推荐视频上是否有一系列的消费行为。 推荐模型数据流,即为…

无需训练!多提示视频生成最新SOTA!港中文腾讯等发布DiTCtrl:基于MM-DiT架构

文章链接:https://arxiv.org/pdf/2412.18597 项目链接:https://github.com/TencentARC/DiTCtrl 亮点直击 DiTCtrl,这是一种基于MM-DiT架构的、首次无需调优的多提示视频生成方法。本文的方法结合了新颖的KV共享机制和隐混合策略,使…

尔湾市圣诞节文化交流会成功举办,展示多元文化魅力

洛杉矶——12月21日,圣诞节文化交流会在尔湾成功举办。圣诞节文化交流会旨在促进不同文化之间的交流与理解。通过举办舞蹈表演、演讲和互动游戏等,为参与者提供了一个展示和欣赏多元文化艺术的平台。这些活动不仅增加了社区成员之间的互动,也加深了他们对不同文化传统和艺术形式…

适用于项目经理的跨团队协作实践:Atlassian Jira与Confluence集成

适用于项目经理的跨团队协作实践:Atlassian Jira与Confluence集成 现代项目经理的核心职责是提供可视性、保持团队一致,并确保团队拥有交付出色工作所需的资源。在过去几年中,由于分布式团队的需求不断增加,项目经理这一角色已迅速…

Spring Cloud LoadBalancer (负载均衡)

目录 什么是负载均衡 服务端负载均衡 客户端负载均衡 Spring Cloud LoadBalancer快速上手 启动多个product-service实例 测试负载均衡 负载均衡策略 自定义负载均衡策略 什么是负载均衡 负载均衡(Load Balance,简称 LB) , 是高并发, 高可用系统必不可少的关…

探究步进电机与输入脉冲的关系

深入了解步进电机 前言一、 步进电机原理二、 细分三、脉冲数总结 前言 主要是探究以下内容: 1、步进电机的步进角。 2、什么是细分。 3、脉冲的计算。 最后再扩展以下STM32定时器的计算方法。 一、 步进电机原理 其实语言描述怎么样都不直观,我更建议…

HCIA-Access V2.5_7_1_XG(S)原理_系统概述

近年来,随着全球范围内接入市场的飞快发展以及全业务运营的快速开展,已有的PON技术标准在带宽需求,业务支撑能力以及接入节点设备和配套设备的性能提升等方面都面临新的升级需求,而GPON已经向10G GPON演示,本章将介绍1…

安装了python,环境变量也设置了,但是输入python不报错也没反应是为什么?window的锅!

目录 问题 结论总结 衍生问题 1 第1步:小白python安装,不要埋头一直点下一步!!! 2 第2步:可以选择删了之前的,重新安装python 3 第3步:如果你不想或不能删了重装python&#…

留学生交流互动系统|Java|SSM|VUE| 前后端分离

【技术栈】 1⃣️:架构: B/S、MVC 2⃣️:系统环境:Windowsh/Mac 3⃣️:开发环境:IDEA、JDK1.8、Maven、Mysql5.7 4⃣️:技术栈:Java、Mysql、SSM、Mybatis-Plus、VUE、jquery,html 5⃣️数据库可…