【读代码】model.py

news2024/9/9 0:47:58
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from torchvision.models import ResNet50_Weights
# 没有自定义库
# Resnet50FPN  和 CountRegressor 两个类
# weights_normal_init  和 weights_xavier_init 两个函数
# 类中的函数叫方法

# 定义 Resnet50FPN 类,继承自 PyTorch 的 nn.Module
# 类 函数 方法 类里面的函数叫方法

# 这个网络是真简单,就是把 ResNet-50 的前四层作为 conv1,然后分别提取第五、第六、第七层作为 conv2、conv3、conv4
# 只是返回了 conv3 和 conv4 的特征图,没有做其他操作

# 返回特征图 然后回归
class Resnet50FPN(nn.Module):
    # 类的初始化方法
    def __init__(self):
        # 调用父类 nn.Module 的初始化方法
        super(Resnet50FPN, self).__init__()

        # 以下三行代码被注释掉,它们分别使用不同的方式加载预训练的 ResNet-50 模型
        # 使用 torchvision 库加载预训练的 ResNet-50 模型,参数 pretrained=True 表示下载并使用 ImageNet 上预训练的权重
        # self.resnet = torchvision.models.resnet50(pretrained=True)
        # 使用自定义的 ResNet50_Weights.DEFAULT 作为权重,这可能是一个特定的权重配置
        # self.resnet = torchvision.models.resnet50(weights=ResNet50_Weights.DEFAULT)
        # 选择使用 ImageNet1K_V1 预训练权重的 ResNet-50 模型
        self.resnet = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)        
        
        # 获取 resnet 模型的所有子模块
        children = list(self.resnet.children())
        # 分别提取 ResNet-50 中的各个层

        # 提取前四层作为 conv1
        self.conv1 = nn.Sequential(*children[:4])
        self.conv2 = children[4] # 提取第五层作为 conv2
        self.conv3 = children[5] # 提取第六层作为 conv3
        self.conv4 = children[6] # 提取第七层作为 conv4
    # 定义前向传播方法
    def forward(self, im_data):        
        feat = OrderedDict()  # 创建一个有序字典 feat 用于存储特征图
        feat_map = self.conv1(im_data)  # 通过 conv1 层获取特征图
        feat_map = self.conv2(feat_map) #  将 conv1 的输出传递给 conv2       
        feat_map3 = self.conv3(feat_map) # 将 conv2 的输出传递给 conv3
        feat_map4 = self.conv4(feat_map3)  # 将 conv3 的输出传递给 conv4
        feat['map3'] = feat_map3 # 将不同层的特征图存储到 feat 字典中
        feat['map4'] = feat_map4
        return feat

# 定义 CountRegressor 类,继承自 PyTorch 的 nn.Module
# 目的是将输入的特征图转换为一个预测的密度图
# 图像的物体计数问题
# 将输入特征图转换为最终的预测密度图
# 所以是在预测图???怎么实现的物体计数,输出也不是很懂。
class CountRegressor(nn.Module):
    # 类的初始化方法
    def __init__(self, input_channels,pool='mean'):

        # 调用父类 nn.Module 的初始化方法
        super(CountRegressor, self).__init__()

        # 属性 pool 用于指定池化操作的类型,可以是 'mean' 或 'max'
        self.pool = pool

        # 定义回归器网络结构,使用 nn.Sequential 来顺序添加层
        # 卷积 激活 采样 卷积 激活 采样... 最后一层卷积激活
        self.regressor = nn.Sequential(
            nn.Conv2d(input_channels, 196, 7, padding=3),   # 第一个卷积层
            nn.ReLU(),  # 激活函数
            nn.UpsamplingBilinear2d(scale_factor=2),   # 上采样

            nn.Conv2d(196, 128, 5, padding=2),
            nn.ReLU(),
            nn.UpsamplingBilinear2d(scale_factor=2),

            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.UpsamplingBilinear2d(scale_factor=2),

            nn.Conv2d(64, 32, 1),
            nn.ReLU(),

            nn.Conv2d(32, 1, 1), # 最后一层卷积层, 输出通道数为 1
            # in_channels=32,out_channels=1,kernel_size=1
            nn.ReLU(),
        )

    # 前向传播方法
    def forward(self, im):
        # 获取输入样本的数量
        num_sample =  im.shape[0]

        # 如果输入只有一个样本
        if num_sample == 1:
            # 移除批次维度,并通过网络获取输出
            output = self.regressor(im.squeeze(0))

            # 根据pool属性选择池化操作
            if self.pool == 'mean':
                output = torch.mean(output, dim=(0),keepdim=True)  
                return output
            elif self.pool == 'max':
                output, _ = torch.max(output, 0,keepdim=True)
                return output
        else:
            for i in range(0,num_sample):
                # 对每个样本执行网络前向传播
                output = self.regressor(im[i])

                # 根据pool属性选择池化操作
                if self.pool == 'mean':
                    output = torch.mean(output, dim=(0),keepdim=True)
                elif self.pool == 'max':
                    output, _ = torch.max(output, 0,keepdim=True)
                
                # 如果是第一个样本,直接赋值Output
                if i == 0:
                    Output = output
                # 否则,将输出添加到Output列表中
                else:
                    Output = torch.cat((Output,output),dim=0)
            
            # 返回所有样本的输出
            return Output

# 正态分布初始化模型参数
def weights_normal_init(model, dev=0.01):

    # 如果传入的 model 是一个列表,则遍历列表中的每个模型
    if isinstance(model, list):
        for m in model:
            weights_normal_init(m, dev)  #【递归调用】 对列表中的每个模型递归调用初始化函数
    else:
        # 如果传入的 model 不是一个列表,则对单个模型进行操作
        # 遍历模型的所有子模块
        for m in model.modules():

            # 如果当前模块是 nn.Conv2d 类型(二维卷积层)
            if isinstance(m, nn.Conv2d):   

                # 对权重进行正态分布初始化,均值为 0.0,标准差为 dev             
                m.weight.data.normal_(0.0, dev)

                # 如果存在偏置项,将其初始化为 0.0
                if m.bias is not None:
                    m.bias.data.fill_(0.0)
            
            # 如果当前模块是 nn.Linear 类型(全连接层)
            elif isinstance(m, nn.Linear):

                # 对全连接层的权重也进行正态分布初始化,均值为 0.0,标准差为 dev
                m.weight.data.normal_(0.0, dev)

                # 全连接层可能有偏置项,这里没有显式地初始化偏置,因为 nn.Linear 默认偏置为 None
                # 代码中没有对 nn.BatchNorm2d、nn.ReLU 或其他类型的层进行特殊处理,因为这些层可能不需要权重初始化,或者初始化方式可能不同。

# Xavier初始化
def weights_xavier_init(m):
    #  检查传入的模块 m 是否是二维卷积层 nn.Conv2d
    if isinstance(m, nn.Conv2d):

        # 使用 Xavier 正态分布初始化权重
        # xavier_normal_ 函数是 PyTorch 中用于实现 Xavier 初始化的函数
        # gain 参数用于调整权重的缩放因子,
        # calculate_gain('relu') 所使用的激活函数(在这个例子中是 ReLU)来计算缩放因子
        torch.nn.init.xavier_normal_(m.weight, 
                                     gain=nn.init.calculate_gain('relu'))
        
        # 如果卷积层有偏置项
        if m.bias is not None:

            # 使用 zeros_ 函数将偏置初始化为 0
            torch.nn.init.zeros_(m.bias)
            
# 应该是试了两个初始化方法 这个py文件没有返回值,都是定义的类和函数,类里的函数是方法

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1958643.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

读零信任网络:在不可信网络中构建安全系统04最小特权

1. 公钥基础设施 1.1. PKI 1.2. 数字证书本身并不能解决身份认证问题 1.2.1. 需要一种方法来验证获得的公钥的确属于某人 1.2.2. 公钥基础设施(PKI)解决了这个问题 1.3. PKI定义了一组角色及其职责,能够在不可信的网络中安全地分发和验证…

【Websim.ai】一句话让AI帮你生成一个网页

【Websim.ai】一句话让AI帮你生成一个网页 网站链接 websim.ai 简介 websim.ai接入了Claude Sonnet 3.5,GPT-4o等常用的LLM,只需要在websim.ai的官网指令栏中编写相关指令,有点类似大模型的Prompt,指令的好坏决定了网页生成的…

Lc63---1859将句子排序(排序)---Java版(未写完)

1.题目描述 2.思路 (1)首先将句子按空格分割成若干单词。 (2)每个单词的最后一个字符是它的位置索引。我们可以通过这个索引将单词恢复到正确的位置。 (3)按照单词的索引顺序排序这些单词。 (4…

【已解决】嵌入式linux mobaxterm unable to open connection to comx 串口正常连接,但终端无法输入

1.点击Session重新选择串口,注意看看串口是不是连接到虚拟机,导致串口被占用。 2.选择PC机与开发板连接的串口,不知道的话可以打开设备管理器看看,选择正确的波特率,一般是115200。 3.关键一步:选择后别急…

性能测试:深入探索与实战指南

大家好,我是一名测试开发工程师,已经开源一套【自动化测试框架】和【测试管理平台】,欢迎大家联系我,一起【分享测试知识,交流测试技术】 在当今这个信息化、数字化的时代,软件系统的性能直接关乎到用户体验…

使用 Matlab 绘制带有纹理的柱状图

以下是效果 1. 在 Matlab 里安装两个额外的库: hatchfill2 和 legendflex。 (1)搜索并安装 hatchfill2,用来画纹理 (2) 搜索并安装 legendflex,用来画自定义的图例 2. 代码(说明见注释) data …

Centos 7系统(最小化安装)安装Git 、git-man帮助、补全git命令-详细文章

安装之前由于是最小化安装centos7安装一些开发环境和工具包 文章使用国内阿里源 cd /etc/yum.repos.d/ && mkdir myrepo && mv * myrepo&&lscurl -O https://mirrors.aliyun.com/repo/epel-7.repo;curl -O https://mirrors.aliyun.com/repo/Centos-7…

docker安装phpMyAdmin

直接安装phpMyAdmin需要有php环境,比较麻烦,总结了使用docker安装方法,并提供docker镜像。 1.docker镜像 见我上传的docker镜像:https://download.csdn.net/download/taotao_guiwang/89595177 2.安装 1).加载镜像 docker load …

(leetcode学习)24. 两两交换链表中的节点

给你一个链表,两两交换其中相邻的节点,并返回交换后链表的头节点。你必须在不修改节点内部的值的情况下完成本题(即,只能进行节点交换)。 示例 1: 输入:head [1,2,3,4] 输出:[2,1,4…

Sonatype Nexus Repository搭建与使用(详细教程3.70.1)

目录 一. 环境准备 二. 安装jdk 三. 搭建Nexus存储库 四. 使用介绍 一. 环境准备 主机名IP系统软件版本配置信息nexus192.168.226.26Rocky_linux9.4 Nexus Repository 3.70.1 MySQL8.0 jdk-11.0.23 2核2G,磁盘20G 进行时间同步,关闭防火墙和selinux…

秋招突击——7/29——操作系统——网络IO

文章目录 引言基础知识零拷贝传统文件读取传统文件传输零拷贝mmap writesendifle 网络通信IO模型阻塞IO非阻塞IO IO多路复用模型selectpollselect和poll的总结epoll边缘触发ET和水平触发LT 信号驱动IO模型异步IO 面试题库1、说一下Linux五种IO模型2、阻塞IO和非阻塞IO应用场景…

可视化目标检测算法推理部署(一)Gradio的UI设计

引言 在先前RT-DETR模型的学习过程中,博主自己使用Flask框架搭建了一个用于模型推理的小案例: FlaskRT-DETR模型推理 在这个过程中,博主需要学习Flask、HTML等相关内容,并且博主做出的页面还很丑,那么,是…

【北京迅为】《i.MX8MM嵌入式Linux开发指南》-第三篇 嵌入式Linux驱动开发篇-第六十六章 电容屏触摸驱动实验

i.MX8MM处理器采用了先进的14LPCFinFET工艺,提供更快的速度和更高的电源效率;四核Cortex-A53,单核Cortex-M4,多达五个内核 ,主频高达1.8GHz,2G DDR4内存、8G EMMC存储。千兆工业级以太网、MIPI-DSI、USB HOST、WIFI/BT…

1.5 1.6 操作系统引导 虚拟机

操作系统引导 操作系统引导的概念 操作系统引导是指计算机利用CPU运行特定程序,通过程序识别硬盘,识别硬盘分区,识别硬盘分区上的操作系统,最后通过程序启动操作系统,一环扣一环地完成上述过程 操作系统引导的过程 …

分布式锁 Redis+RedisSon

文章目录 1.什么是分布式锁2.分布式锁应该具备哪些条件3.分布式锁主流的实现方案4.未添加分布式锁存在的问题4.1测试未添加分布式锁的代码通过jmeter发送请求4.2 添加线程同步锁集群部署配置nginx修改jmeter端口号4.3 使用redis的setnx命令实现分布式锁解决办法4.4 使用try、fi…

【2025留学】德国留学真的很难毕业吗?为什么大家不来德国留学?

大家好!我是德国Viviane,一句话讲自己的背景:本科211,硕士在德国读的电子信息工程。 之前网上一句热梗:“德国留学三年将是你人生五年中最难忘的七年。”确实,德国大学的宽进严出机制,延毕、休…

【日常设计案例分享】通道对账

今天跟同事们讨论一个通道对账需求的技术设计。鉴于公司业务线有好几个,为避免不久的将来各业务线都重复竖烟囱,因此,我们打算将通道对账做成系统通用服务,以降低各业务线的开发成本。 以下文稿(草图)&…

正点原子imx6ull-mini-Linux设备树下的LED驱动实验(4)

1&#xff1a;修改设备树文件 在根节点“/”下创建一个名为“alphaled”的子节点&#xff0c;打开 imx6ull-alientek-emmc.dts 文件&#xff0c; 在根节点“/”最后面输入如下所示内容 alphaled {#address-cells <1>;#size-cells <1>;compatible "atkalp…

昇思25天学习打卡营第1天|快速入门实操教程

昇思25天学习打卡营第1天|快速入门实操教程 目录 昇思25天学习打卡营第1天|快速入门实操教程 一、MindSpore内容简介 主要特点&#xff1a; MindSpore的组成部分&#xff1a; 二、入门实操步骤 1. 安装必要的依赖包 2. 下载并处理数据集 3. 构建网络模型 4. 训练模型…

WIN下的文件病毒

文件病毒 一.windows下知识句柄禁用某些警告MAX_PATH_WIN32_FIND_DATAWFindFirstFileW注册到服务代码&#xff08;自启动&#xff09;隐藏窗口 二.客户端代码三.服务端代码 一.windows下知识 句柄 相当于指针&#xff0c;用来表示windows下的一些对象&#xff1b; 禁用某些警…