Resnet代码实现

news2024/11/6 3:59:58

 

图2 18-layer、34-layer的残差结构

图3 50-layer、101-layer、102-layer的残差结构

import torch
import torch.nn as nn


#这个18或者34层网络的残差模块,根据ResNet的具体实现可以自动匹配
class BasicBlock(nn.Module):
    '''
    conv1 stride=1对应的实线残差,因为不会改变高宽
          stride=2对应的虚线残差,因为会改变高宽(减半)

    conv2 stride都为1
    '''
    expansion = 1#便于控制通道数
    def __init__(self,in_channels,out_channels,stride=1,downsample=None):
        super(BasicBlock,self).__init__()
        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=stride,padding=1,bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1,bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self,x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out


#50,101,152
class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self,in_channels,out_channels,stride=1,downsample=None):
        super(Bottleneck,self).__init__()
        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=1,bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=stride,padding=1,bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels,out_channels*self.expansion,kernel_size=1,stride=1,bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self,x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out = self.relu(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self,block,layer_nums,num_classes=1000,include_top=True):
        '''

        :param block:
        :param layer_nums:模块数目34layers[3,4,6,4]
        :param include_top:
        '''
        super(ResNet,self).__init__()
        self.include_top = include_top
        self.in_channel = 64#经过cnov1 后,进入残差块的通道数变成64

        #这里指定你的数据通道数是3
        self.conv1 = nn.Conv2d(1,self.in_channel,kernel_size=7,stride=2,padding=3,bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)#长宽减半

        self.layer1 = self._make_layer(block,64,layer_nums[0])#conv2 图中看出,不同deep的layer只有通道不同,不用调整步长
        self.layer2 = self._make_layer(block,128,layer_nums[1],stride=2)#conv3
        self.layer3 = self._make_layer(block,256,layer_nums[2],stride=2)#conv4
        self.layer4 = self._make_layer(block,512,layer_nums[3],stride=2)#conv5
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1,1))
            self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self,block,out_channels,block_num,stride=1):
        '''

        :param block: 选择不同的残差模块
        :param out_channels: 输出通道
        :param block_num: 残差模块的个数 3就是这个block连续用3次数
        :param stride: 默认步长1
        '''
        downsample = None
        if stride != 1 or self.in_channel != out_channels*block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel,out_channels*block.expansion,kernel_size=1,stride=stride,bias=False),
                nn.BatchNorm2d(out_channels*block.expansion)
            )
        layers = []
        layers.append(block(self.in_channel,out_channels,stride,downsample))
        self.in_channel = out_channels*block.expansion

        for _ in range(1,block_num):
            layers.append(block(self.in_channel,out_channels))

        return nn.Sequential(*layers)

    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x,1)#用于将输入张量展平成一维张量torch.Size([10, 3072])
            x = self.fc(x)
        return x


def resnet18(num_classes=1000, include_top=True):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)
def resnet34(num_classes=1000,include_top=True):
    return ResNet(BasicBlock,[3,4,6,3],num_classes=num_classes, include_top=include_top)

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    input_tensor = torch.rand(1, 3, 224, 224).to(device)  # Ensure input tensor is on the correct device
    model = resnet18(num_classes=1000, include_top=True).to(device)  # Explicitly pass num_classes and include_top
    # print(model)
    print(model(input_tensor).shape)

 参考源:

deep-learning-for-image-processing/pytorch_classification/Test5_resnet/model.py at master · WZMIAOMIAO/deep-learning-for-image-processing

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2230555.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

为什么大家都在学数字孪生呢?

随着物联网,大数据、人工智能等技术的发展,新一代信息技术与制造业正在深度融合,人们与物理世界的交互方式正在发生转折性的变化。数字化转型正在成为企业的重要战略,而数字孪生则成为全新的焦点。 当下,在数字技术和…

IDEA使用Maven Helper查看整个项目的jar冲突

在插件市场安装Maven Helper,安装好后,重启IDEA;双击打开可能存在jar冲突的pom文件;在右侧面板查看冲突,text是引入的依赖明细,点击Dependecy Analyzer选项卡即可查看冲突的jar。

「Pytorch」如何理解深度学习中的算子(operator)

在深度学习中,“算子”(operator)通常指的是在神经网络中进行的各种数学运算或函数。这些算子可以是基本的数学操作,如加法、乘法、卷积,也可以是更复杂的变换,如激活函数和池化操作。 主要类型的算子 线性…

Hbuilder html5+沉浸式状态栏

manifest.json源码视图添加 {"statusbar": {"immersed": true }如图: 2、plusready准备,将状态栏字体变黑,不然背景白色、状态栏白色看不到 //2.1 如果你用了mui, mui.plusReady(function(){plus.navigat…

windows/linux注册服务与阿里镜像仓库使用

这里写目录标题 启动Windows将jar注册服务Linux将jar设置开机启动 外网环境编译打包 启动 Windows将jar注册服务 将jar包导入到服务器上,将WinSW工具也放到服务器上。 winSw下载地址:https://github.com/winsw/winsw/releases 依据下图修改xml内容即可…

建筑行业知识库搭建:好处、方法与注意事项

在建筑行业,知识管理对于提升项目效率、降低成本、增强创新能力以及构建竞争优势具有至关重要的作用。搭建一个高效、系统的建筑行业知识库,不仅有助于实现知识的有效沉淀与便捷共享,还能促进知识在项目实践中的灵活应用,从而加速…

Oracle与SQL Server的语法区别

1)日期和日期转换函数。 SQL: SELECT A.*, CASE WHEN NVL(PAA009,) OR PAA009 >Convert(Varchar(10), SYSDATE,120) THEN Y ELSE N END AS ActiveUser FROM POWPAA A WHERE PAA001admin or PAA002admin Oracle: SELECT A.*, CASE WHEN NVL(PAA009,) or PAA009&…

【算法赌场】区间合并

区间问题 区间问题的引入 数学上,用两个数字可以确定数轴上的一个区间,较小的数字叫做区间的左端点,也叫区间起点,较大的数字叫做区间的右端点,也叫区间终点。 在算法竞赛中,很多题目是以区间为单位去进行…

GPT-Sovits-2-微调模型

1. 大致步骤 上一步整理完数据集后&#xff0c;此步输入数据, 微调2个模型VITS和GPT&#xff0c;位置在 <<1-GPT-SoVITS-tts>>下的<<1B-微调训练>> 页面的两个按钮分别执行两个文件: <./GPT_SoVITS/s2_train.py> 这一步微调VITS的预训练模型…

使用列表推导式处理列表中符合条件的元素将结果组成新的列表

【小白从小学Python、C、Java】 【计算机等级考试500强双证书】 【Python-数据分析】 使用列表推导式处理 列表中符合条件的元素 将结果组成新的列表 执行以下代码后&#xff0c;输出是什么&#xff1f; def process_numbers(numbers): return [1 / x for x in numbers if x ! …

SSM项目部署到服务器

将SSM&#xff08;Spring Spring MVC MyBatis&#xff09;项目部署到服务器上&#xff0c;通常需要以下步骤&#xff1a; 打包项目 生成一个WAR文件&#xff0c;通常位于target目录下 配置Tomcat&#xff1a; 将生成的WAR文件复制到Tomcat的webapps目录下。 配置conf/se…

TortoiseSVN 文件夹以及文件不显示差异感叹解决步骤

直接修改注册表&#xff0c;把TortoiseSVN图标悬浮注册项提前&#xff0c;靠后就不显示&#xff0c; 如下图 打开注册表&#xff0c;重命名TortoiseSVN 相关项&#xff0c;前面加上三四个空格&#xff0c;重启电脑即可。

架构师备考-软件测试

定义 软件测试是使用人工或自动的手段来运行或测定某个软件系统的过程&#xff0c;其目的在于检验它是否满足规定的需求或弄清预期结果与实际结果之间的差别。 软件测试的目的就是确保软件的质量、确认软件以正确的方式做了用户所期望的事情&#xff0c;所以软件测试工作主要是…

【实验九】前馈神经网络(5)--鸢尾花分类

实验内容 目录 1 .小批量梯度下降法 2 .数据处理 &#xff08;1&#xff09;将数据集封装为Dataset类 &#xff08;2&#xff09;用DataLoader进行封装 3 .模型构建 4 .完善Runner类 5 .模型训练 可视化观察训练集损失和训练集loss变化情况 6 .模型评价 7.模型预测 …

能提升幸福感的好物品牌有哪些?一定不能错过的五款品牌推荐!

最近&#xff0c;是不是有很多小伙伴们都在为不知道该买些什么而感到纠结呢&#xff1f;其实&#xff0c;对于那些还在犹豫不决&#xff0c;不知道该选择什么商品的朋友们&#xff0c;完全不必过于焦虑。我最近在购物时发现了一些能够显著提升生活幸福感的好物品牌&#xff0c;…

Cyber​​Panel upgrademysqlstatus 远程命令执行漏洞(QVD-2024-44346)

0x01 产品简介 CyberPanel是一个开源的Web控制面板,它提供了一个用户友好的界面,用于管理网站、电子邮件、数据库、FTP账户等。CyberPanel旨在简化网站管理任务,使非技术用户也能轻松管理自己的在线资源。 0x02 漏洞概述 该漏洞源于upgrademysqlstatus接口未做身份验证和…

Lua 从基础入门到精通(非常详细)

目录 什么是 Lua&#xff1f; Lua 环境安装 Lua基本语法 注释 数据类型 nil&#xff08;空&#xff09; Boolean number&#xff08;数字&#xff09; string&#xff08;字符串&#xff09; function&#xff08;函数&#xff09; userdata thread table&#xff…

Java:数据结构-MapSet

搜索树 二叉搜索树又称二叉排序树&#xff0c;它或者是一棵空树&#xff0c;或者是具有以下性质的二叉树 若它的左子树不为空&#xff0c;则左子树上所有节点的值都小于根节点的值若它的右子树不为空&#xff0c;则右子树上所有节点的值都大于根节点的值它的左右子树也分别为…

全新更新!Fastreport.NET 2025.1版本发布,提升报告开发体验

在.NET 2025.1版本中&#xff0c;我们带来了巨大的期待功能&#xff0c;进一步简化了报告模板的开发过程。新功能包括通过添加链接报告页面、异步报告准备、HTML段落旋转、代码文本编辑器中的文本搜索、WebReport图像导出等&#xff0c;大幅提升用户体验。 FastReport .NET 是…

楼梯区域分割系统:Web效果惊艳

楼梯区域分割系统源码&#xff06;数据集分享 [yolov8-seg-FocalModulation&#xff06;yolov8-seg-GFPN等50全套改进创新点发刊_一键训练教程_Web前端展示] 1.研究背景与意义 项目参考ILSVRC ImageNet Large Scale Visual Recognition Challenge 项目来源AAAI Global Al l…