Interesting bug caused by getattr

news2024/11/14 5:35:21

题意:由 getattr 引起的有趣的 bug

问题背景:

I try to train 8 CNN models with the same structures simultaneously. After training a model on a batch, I need to synchronize the weights of the feature extraction layers in other 7 models.

This is the model:

class GNet(nn.Module):
    def __init__(self, dim_output, dropout=0.5):
        super(GNet, self).__init__()
        self.out_dim = dim_output
        # Load the pretrained AlexNet model
        alexnet = models.alexnet(pretrained=True)

        self.pre_filtering = nn.Sequential(
            alexnet.features[:4]
        )

        # Set requires_grad to False for all parameters in the pre_filtering network
        for param in self.pre_filtering.parameters():
            param.requires_grad = False

        # construct the feature extractor
        # every intermediate feature will be fed to the feature extractor

        # res: 25 x 25
        self.feat_ex1 = nn.Conv2d(192, 128, kernel_size=3, stride=1)

        # res: 25 x 25
        self.feat_ex2 = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Dropout(p=dropout),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        )

        # res: 25 x 25
        self.feat_ex3 = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Dropout(p=dropout),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        )

        # res: 13 x 13
        self.feat_ex4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.Dropout(p=dropout),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        )

        # res: 13 x 13
        self.feat_ex5 = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Dropout(p=dropout),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        )

        # res: 13 x 13
        self.feat_ex6 = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Dropout(p=dropout),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        )

        # res: 13 x 13
        self.feat_ex7 = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Dropout(p=dropout),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
        )

        # define the flexible pooling field of each layer
        # use a full convolution layer here to perform flexible pooling
        self.fpf13 = nn.Conv2d(in_channels=448, out_channels=448, kernel_size=13, groups=448)
        self.fpf25 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=25, groups=384)
        self.linears = {}
        for i in range(self.out_dim):
            self.linears[f'linear_{i+1}'] = nn.Linear(832, 1)

        self.LogTanh = LogTanh()
        self.flatten = nn.Flatten()

And this is the function to synchronize the weights:

def sync_weights(models, current_sub, sync_seqs):
    for sub in range(1, 9):
        if sub != current_sub:
            # Synchronize the specified layers
            with torch.no_grad():
                for seq_name in sync_seqs:
                    reference_layer = getattr(models[current_sub], seq_name)[2]
                    layer = getattr(models[sub], seq_name)[2]
                    layer.weight.data = reference_layer.weight.data.clone()
                    if layer.bias is not None:
                        layer.bias.data = reference_layer.bias.data.clone()

then an error is raised:

'Conv2d' object is not iterable

which means the getattr() returns a Conv2D object. But if I remove [2]:

def sync_weights(models, current_sub, sync_seqs):
    for sub in range(1, 9):
        if sub != current_sub:
            # Synchronize the specified layers
            with torch.no_grad():
                for seq_name in sync_seqs:
                    reference_layer = getattr(models[current_sub], seq_name)
                    layer = getattr(models[sub], seq_name)
                    layer.weight.data = reference_layer.weight.data.clone()
                    if layer.bias is not None:
                        layer.bias.data = reference_layer.bias.data.clone()

I get another error:

'Sequential' object has no attribute 'weight'

which means the getattr() returns a Sequential. But previously it returns a Conv2D object. Does anyone know anything about this? For your information, the sync_seqs parameter passed in sync_weights is:

sync_seqs = [
    'feat_ex1',
    'feat_ex2',
    'feat_ex3',
    'feat_ex4',
    'feat_ex5',
    'feat_ex6',
    'feat_ex7'
]

问题解决:

In both instances, getattr is returning a Sequential, which in turn contains a bunch of objects. In the second case, you're directly assigning that Sequential to a variable, so reference_layer ends up containing a Sequential.

In the first case, however, you're not doing that direct assignemnt. You're taking the Sequential object and then indexing it with [2]. That means reference_layer contains the third item in the Sequential, which is a Conv2d object.

Take a more simple example. Suppose I had a ListContainer class that did nothing except hold a list. I could then recreate your example as follows, with test1 corresponding to your first test case and vice versa:

class ListContainer:
    def __init__(self, list_items):
        self.list_items = list_items

letters = ["a", "b", "c"]
container = ListContainer(letters)

test1 = getattr(container, "list_items")[0]
test2 = getattr(container, "list_items")

print(type(test1)) # <class 'str'>
print(type(test2)) # <class 'list'>

In both tests, getattr itself is returning a list - but in the second, we're doing something with that list after we get it, so test2 ends up being a string instead.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1944132.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue3+Element Plus 实现table表格中input的验证

实现效果 html部分 <template><div class"table"><el-form ref"tableFormRef" :model"form"><el-table :data"form.detailList"><el-table-column type"selection" width"55" align&…

初识c++(string和模拟实现string)

一、标准库中的string类 string类的文档介绍&#xff1a;cplusplus.com/reference/string/string/?kwstring 1、auto和范围for auto&#xff1a; 在早期C/C中auto的含义是&#xff1a;使用auto修饰的变量&#xff0c;是具有自动存储器的局部变量&#xff0c;后来这个 不重…

【北航主办丨本届SPIE独立出版丨已确认ISSN号】第三届智能机械与人机交互技术学术会议(IHCIT 2024,7月27)

由北京航空航天大学指导&#xff0c;北京航空航天大学自动化科学与电气工程学院主办&#xff0c;AEIC学术交流中心承办的第三届智能机械与人机交互技术学术会议&#xff08;IHCIT 2024&#xff09;将定于2024年7月27日于中国杭州召开。 大会面向基础与前沿、学科与产业&#xf…

初识c++:string类 (1)

目录 # 初识c&#xff1a;string类 1.为什么学习string类 2.标准库中的string类 2.1 string类的了解 2.2 auto和范围for 2.3 string类的常用接口说明 2.3.1string类对象的常见构造 2.3.2string类对象的容量操作 2.3.3string类对象的访问及遍历操作 2.3.4string类对象…

DNS概述及DNS服务器的搭建(twelve day)

回顾 关闭防火墙 systemctl stop firewalld 永久停止防火墙 systemctl disable firewalld 关闭selinux setenforce 0 永久关闭selinux安全架构 vim /etc/selinux/config 修改静态IP地址 vim /etc/sysconfig/network-scripts/ifcfg-ens160 #修改uuid的目的是为了保证网络的唯一…

计算机的错误计算(四十)

摘要 计算机的错误计算&#xff08;三十九&#xff09;阐明有时计算机将0算成非0&#xff0c;非0算成0&#xff1b;并且前面介绍的这些错误计算相对来说均是由软件完成。本节讨论手持式计算器对这些算式的计算效果。 例1. 用手持式计算器计算 与 . 我们用同一个计算器计算…

机械学习—零基础学习日志(高数10——函数图形)

零基础为了学人工智能&#xff0c;真的开始复习高数 函数图像&#xff0c;开始新的学习&#xff01;本次就多做一做题目&#xff01; 第一题&#xff1a; 这个解法是有点不太懂的了。以后再多研究一下。再出一道题目。 张宇老师&#xff0c;比较多提示了大家&#xff0c;一定…

哪些工作可以年入几十万到2亿?

关注卢松松&#xff0c;会经常给你分享一些我的经验和观点。 从今年起&#xff0c; 每个月都会有年入几十万到2亿的新闻案例出来&#xff0c;而且很多都是官方媒体发的&#xff0c;你们看看&#xff1a; 7月19日35岁小伙扛楼一年多存了40万 7月4日老板娘一天卖出200斤知了日入…

Leetcode3217. 从链表中移除在数组中存在的节点

Every day a Leetcode 题目来源&#xff1a;3217. 从链表中移除在数组中存在的节点 解法1&#xff1a;集合 链表遍历 代码&#xff1a; /** lc appleetcode.cn id3217 langcpp** [3217] 从链表中移除在数组中存在的节点*/// lc codestart /*** Definition for singly-link…

docker--容器数据进行持久化存储的三种方式

文章目录 为什么Docker容器需要使用持久化存储1.什么是Docker容器&#xff1f;2.什么是持久化存储&#xff1f;3.为什么Docker容器需要持久化存储&#xff1f;4.Docker如何实现持久化存储&#xff1f;(1)、Docker卷(Volumes)简介适用环境:使用场景:使用案例: (2)、绑定挂载&…

Python 实现PDF和TIFF图像之间的相互转换

PDF是数据文档管理领域常用格式之一&#xff0c;主要用于存储和共享包含文本、图像、表格、链接等的复杂文档。而TIFF&#xff08;Tagged Image File Format&#xff09;常见于图像处理领域&#xff0c;主要用于高质量的图像文件存储。 在实际应用中&#xff0c;我们可能有时需…

哪个邮箱最安全最好用啊

企业邮箱安全至关重要&#xff0c;需保护隐私、防财务损失、维护通信安全、避免纠纷&#xff0c;并维持业务连续性。哪个企业邮箱最安全好用呢&#xff1f;Zoho企业邮箱&#xff0c;采用加密技术、反垃圾邮件和病毒保护&#xff0c;支持多因素认证&#xff0c;确保数据安全合规…

php的文件上传

&#x1f3bc;个人主页&#xff1a;金灰 &#x1f60e;作者简介:一名简单的大一学生;易编橙终身成长社群的嘉宾.✨ 专注网络空间安全服务,期待与您的交流分享~ 感谢您的点赞、关注、评论、收藏、是对我最大的认可和支持&#xff01;❤️ &#x1f34a;易编橙终身成长社群&#…

【每日刷题Day85】

【每日刷题Day85】 &#x1f955;个人主页&#xff1a;开敲&#x1f349; &#x1f525;所属专栏&#xff1a;每日刷题&#x1f34d; &#x1f33c;文章目录&#x1f33c; 1. 125. 验证回文串 - 力扣&#xff08;LeetCode&#xff09; 2. 43. 字符串相乘 - 力扣&#xff08;L…

【es】elasticsearch 自定义排序-按关键字位置排序

一 背景 要求es查询的结果按关键字位置排序&#xff0c;位置越靠前优先级越高。 es版本7.14.0&#xff0c;项目是thrift&#xff0c;也可以平替springboot&#xff0c;使用easyes连接es。 二 easyes使用 配easyes按官方文档就差不多了 排序 | Easy-Es 主要的一个问题是easy…

FPGA深入浅出IP核学习(一)-- vivado中clk IP MMCM核的使用

FPGA深入浅出IP核学习系列文章主要是自己关于学习FGPA IP核的学习笔记&#xff0c;供大家参考学习指正。 目录 前言 一、MMCM-IP核简介 二、MMCM-IP核使用 1.IP核配置 2.模块程序设计 总结 前言 本文主要参考B站野火FGPA相关学习视频、正点原子达芬奇FPGA开发指南和赛灵思官方用…

一文入门SpringSecurity 5

目录 提示 Apache Shiro和Spring Security 认证和授权 RBAC Demo 环境 Controller 引入Spring Security 初探Security原理 认证授权图示​编辑 图中涉及的类和接口 流程总结 提示 Spring Security源码的接口名和方法名都很长&#xff0c;看源码的时候要见名知意&am…

docker笔记4-镜像理解

docker笔记4-镜像理解 一、镜像原理之联合文件系统二、镜像原理之分层理解三、commit镜像 一、镜像原理之联合文件系统 UnionFS&#xff08;联合文件系统&#xff09;: Union文件系统&#xff08;UnionFS&#xff09;是一种分层、轻量级并且高性能的文件系统&#xff0c;它支持…

深入了解路由器工作原理:从零开始的简单讲解

简介 在现代网络中&#xff0c;路由器扮演着至关重要的角色。它不仅连接了不同的设备&#xff0c;还确保数据能够准确地传输到目的地。本文将带你深入探讨路由器的工作原理&#xff0c;帮助网络基础小白们理解这一重要设备的基本功能。 路由器的构成 路由器是一种具有多个输入…

云计算实训12——配置web服务器、配置客户端服务器、配置DNS服务、实现DNS域名解析

一、配置web服务器 准备操作 首先在正式配置之前需要做以下操作 关闭防火墙 systemctl stop firewalld 永久关闭防火墙 systemctl disable firewalld 关闭selinux setenforce 0 永久关闭selinux vim /etc/selinux/config selinuxpermissive 还需要保证能够正常ping通www.bai…