[CVPR 2023]PyramidFlow-训练并推理-附bug调试

news2024/11/16 13:39:01

CVPR2023-PyramidFlow-zero shot异常检测网络 代码调试记录

  • 一.论文以及开源代码
  • 二.前期代码准备
  • 三.环境配置
  • 四.bug调试
    • num_samples should be a positive integer value, but got num_samples=0
    • AttributeError: Can't pickle local object 'fix_randseed.<locals>.seed_worker'
  • 五.数据集准备
  • 六.训练
  • 七.推理

一.论文以及开源代码

PyramidFlow一篇2023年发表于CVPR的关于无监督异常检测算法的论文,由浙江大学出品,下面附上论文和代码链接:
论文链接:PyramidFlow论文
代码链接:PyramidFlow源代码

二.前期代码准备

首先,我们需要把我在一中提到的代码先git clone到我们的项目路径中,这是我们接下去的训练代码,当然其中也包括了验证和测试(推理过程也包含在内部了,需要自己写一小部分)。然后我们还需要去作者的官网git clone一份名为autoFlow的项目代码,这里面包含了训练代码中将会调用的一些函数,十分重要:
进入训练代码的链接后,点击作者头像,如图所示
进入训练代码的链接后,点击作者头像,如图所示。然后我们便进入了作者的github主页,点击主页下方的这个链接:
在这里插入图片描述
就可以跳转到这个页面:
在这里插入图片描述
红框中包含了两个链接,其中一个是我们在第一步就已经clone好的训练代码,不用管他了,现在我们点击蓝色框中的链接:
在这里插入图片描述
点击code然后复制链接,然后打开git工具使用git clone命令行即可:
在这里插入图片描述
此时,两个项目都已经拷贝下来了,我这里选择将autoflow这个文件夹直接复制到了PyramidFlow里面,这样方便PyramidFlow中代码的调用:
在这里插入图片描述

三.环境配置

PyramidFlow的环境,作者已经在Readme中给出,按照里面的版本pip install即可,如果下载速度过慢,可以设置默认源为清华源,可以大大方便我们配置环境。

四.bug调试

这里默认大家的环境已经按照要求配置好了。

num_samples should be a positive integer value, but got num_samples=0

这时候我们直接运行PyramidFlow中的train.py时一般会报这个错误:
在这里插入图片描述
这个问题的原因主要是torch库中的DataLoader函数加载数据时,如果已经设置了batch_size,就不需要设置shuffle=True来打乱数据了,此时需要把shuffle设置为False,DataLoader的具体参数可以参考如下:

DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=1, persistent_workers=True, pin_memory=True, drop_last=True, **loader_dict)

参数解释:
dataset:包含所有数据的数据集,加载的数据集(Dataset对象)

batch_size :每个batch包含的数据数量

Shuffle : 是否打乱数据位置。

sampler : 自定义从数据集中采样的策略,如果制定了采样策略,shuffle则必须为False.

Batch_sampler:和sampler一样,但是每次返回一组的索引,和batch_size, shuffle, sampler, drop_last 互斥。

num_workers : 使用线程的数量,当为0时数据直接加载到主程序,默认为0。

collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可

pin_memory:s 是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些

drop_last: dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

AttributeError: Can’t pickle local object ‘fix_randseed..seed_worker’

打开util.py:
在这里插入图片描述
我将这段代码中的seed_worker直接独立出来:
在这里插入图片描述
然后在train.py代码中,我们需要创建一个变量接收fix_randseed的返回值,然后将这个返回值作为seed_worker的参数传入:
在这里插入图片描述

五.数据集准备

用到的是Mvtec数据集,放在项目文件夹的同一级路径下,改名为如下所示:
在这里插入图片描述

六.训练

作者在源代码中是在训练代码的最后一npz的形式保存了模型的权重,由于我对这个npz了解甚少,并且我平时推理常用的都是pt,onnx或者tensorRT的engine等,因此,我在训练代码的最后加了一句torch.save()来将模型以pt的方式保存,见110行代码:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import numpy as np
import time, argparse
from sklearn.metrics import roc_auc_score

from model import PyramidFlow
from util import MVTecAD, BatchDiffLoss
from util import fix_randseed, compute_pro_score_fast, getLogger, seed_worker
import cv2


def train(logger, save_name, cls_name, datapath, resnetX, num_layer, vn_dims, \
          ksize, channel, num_stack, device, batch_size, save_memory, ):
    # save config
    save_dict = {'cls_name': cls_name, 'resnetX': resnetX, 'num_layer': num_layer, 'vn_dims': vn_dims,\
                 'ksize': ksize, 'channel': channel, 'num_stack': num_stack, 'batch_size': batch_size}

    #我的改动
    seed_ = fix_randseed(seed=0)
    loader_dict = seed_worker(seed_)

    # model 
    flow = PyramidFlow(resnetX, channel, num_layer, numStack, ksize, vn_dims, saveMem).to(device)
    x_size = 256 if resnetX==0 else 1024
    optimizer = torch.optim.Adam(flow.parameters(), lr=2e-4, eps=1e-04, weight_decay=1e-5, betas=(0.5, 0.9)) # using cs-flow optimizer
    Loss = BatchDiffLoss(batch_size, p=2)

    # dataset
    train_dataset = MVTecAD(cls_name, mode='train', x_size=x_size, y_size=256, datapath=datapath)
    val_dataset = MVTecAD(cls_name, mode='val', x_size=x_size, y_size=256, datapath=datapath)
    test_dataset = MVTecAD(cls_name, mode='test', x_size=x_size, y_size=256, datapath=datapath)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1, persistent_workers=True, pin_memory=True, drop_last=True, **loader_dict)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1, persistent_workers=True, pin_memory=True, drop_last=False, **loader_dict)
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=1, persistent_workers=True, pin_memory=True, **loader_dict)

    # training & evaluation
    pixel_auroc_lst = [0]
    pixel_pro_lst = [0]
    image_auroc_lst = [0]
    losses_lst = [0]
    t0 = time.time()
    for epoch in range(15):
        # train
        flow.train()
        losses = []
        for train_dict in train_loader:
            image, labels = train_dict['images'].to(device), train_dict['labels'].to(device)
            optimizer.zero_grad()
            pyramid2= flow(image)
            diffes = Loss(pyramid2)
            diff_pixel = flow.pyramid.compose_pyramid(diffes).mean(1)  
            loss = torch.fft.fft2(diff_pixel).abs().mean() # Fourier loss
            loss.backward()
            nn.utils.clip_grad_norm_(flow.parameters(), max_norm=1e0) # Avoiding numerical explosions
            optimizer.step()
            losses.append(loss.item())
        mean_loss = np.mean(losses)
        logger.info(f'Epoch: {epoch}, mean_loss: {mean_loss:.4f}, time: {time.time()-t0:.1f}s')
        losses_lst.append(mean_loss)
        
        # val for template 
        flow.eval()
        feat_sum, cnt = [0 for _ in range(num_layer)], 0
        for val_dict in val_loader:
            image = val_dict['images'].to(device)
            with torch.no_grad():
                pyramid2= flow(image) 
                cnt += 1
            feat_sum = [p0+p for p0, p in zip(feat_sum, pyramid2)]
        feat_mean = [p/cnt for p in feat_sum]

        # test
        flow.eval()
        diff_list, labels_list = [], []
        for test_dict in test_loader:
            image, labels = test_dict['images'].to(device), test_dict['labels']
            with torch.no_grad():
                pyramid2 = flow(image) 
                pyramid_diff = [(feat2 - template).abs() for feat2, template in zip(pyramid2, feat_mean)]
                diff = flow.pyramid.compose_pyramid(pyramid_diff).mean(1, keepdim=True)# b,1,h,w
                diff_list.append(diff.cpu())
                labels_list.append(labels.cpu()==1)# b,1,h,w

        labels_all = torch.concat(labels_list, dim=0)# b1hw 
        amaps = torch.concat(diff_list, dim=0)# b1hw 
        amaps, labels_all = amaps[:, 0], labels_all[:, 0] # both b,h,w
        pixel_auroc = roc_auc_score(labels_all.flatten(), amaps.flatten()) # pixel score
        image_auroc = roc_auc_score(labels_all.amax((-1,-2)), amaps.amax((-1,-2))) # image score
        pixel_pro = compute_pro_score_fast(amaps, labels_all) # pro score
        logger.info(f'   TEST Pixel-AUROC: {pixel_auroc}, time: {time.time()-t0:.1f}s')
        logger.info(f'   TEST Image-AUROC: {image_auroc}, time: {time.time()-t0:.1f}s')
        logger.info(f'   TEST Pixel-PRO: {pixel_pro}, time: {time.time()-t0:.1f}s')

        if pixel_auroc > np.max(pixel_auroc_lst):
            save_dict['state_dict_pixel'] = {k: v.cpu() for k, v in flow.state_dict().items()} # save ckpt
        if pixel_pro > np.max(pixel_pro_lst):
            save_dict['state_dict_pro'] = {k: v.cpu() for k, v in flow.state_dict().items()} # save ckpt
        pixel_auroc_lst.append(pixel_auroc)
        pixel_pro_lst.append(pixel_pro)
        image_auroc_lst.append(image_auroc)

        del amaps, labels_all, diff_list, labels_list
    save_dict['pixel_auroc_lst'] = pixel_auroc_lst
    save_dict['image_auroc_lst'] = image_auroc_lst
    save_dict['pixel_pro_lst']   = pixel_pro_lst
    save_dict['losses_lst'] = losses_lst
    torch.save(flow, "best.pt")

    np.savez(f'saveDir/{save_name}.npz', **save_dict) # save all



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Training on MVTecAD')
    parser.add_argument('--cls', type=str, default='bottle', choices=\
                        ['tile', 'leather', 'hazelnut', 'toothbrush', 'wood', 'bottle', 'cable', \
                         'capsule', 'pill', 'transistor', 'carpet', 'zipper', 'grid', 'screw', 'metal_nut'])
    parser.add_argument('--datapath', type=str, default='../mvtec_anomaly_detection')
    # hyper-parameters of architecture
    parser.add_argument('--encoder', type=str, default='resnet18', choices=['none', 'resnet18', 'resnet34'])
    parser.add_argument('--numLayer', type=str, default='auto', choices=['auto', '2', '4', '8'])
    parser.add_argument('--volumeNorm', type=str, default='auto', choices=['auto', 'CVN', 'SVN'])
    # non-key parameters of architecture
    parser.add_argument('--kernelSize', type=int, default=7, choices=[3, 5, 7, 9, 11])
    parser.add_argument('--numChannel', type=int, default=16)
    parser.add_argument('--numStack', type=int, default=4)
    # other parameters
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--batchSize', type=int, default=2)
    parser.add_argument('--saveMemory', type=bool, default=True) 
    
    args = parser.parse_args()
    cls_name = args.cls
    resnetX = 0 if args.encoder=='none' else int(args.encoder[6:])
    if args.volumeNorm == 'auto':
        vn_dims = (0, 2, 3) if cls_name in ['carpet', 'grid', 'bottle', 'transistor'] else (0, 1)
    elif args.volumeNorm == 'CVN':
        vn_dims = (0, 1)
    elif args.volumeNorm == 'SVN':
        vn_dims = (0, 2, 3)
    if args.numLayer == 'auto':
        num_layer = 4
        if cls_name in ['metal_nut', 'carpet', 'transistor']:
            num_layer = 8
        elif cls_name in ['screw',]:
            num_layer = 2
    else:
        num_layer = int(args.numLayer)
    ksize = args.kernelSize
    numChannel = args.numChannel
    numStack = args.numStack
    gpu_id = args.gpu
    batchSize = args.batchSize
    saveMem = args.saveMemory
    datapath = args.datapath

    logger, save_name = getLogger(f'./saveDir')
    logger.info(f'========== Config ==========')
    logger.info(f'> Class: {cls_name}')
    logger.info(f'> MVTecAD dataset root: {datapath}')
    logger.info(f'> Encoder: {args.encoder}')
    logger.info(f"> Volume Normalization: {'CVN' if len(vn_dims)==2 else 'SVN'}")
    logger.info(f'> Num of Pyramid Layer: {num_layer}')
    logger.info(f'> Conv Kernel Size in NF: {ksize}')
    logger.info(f'> Num of Channels in NF: {numChannel}')
    logger.info(f'> Num of Stack Block: {numStack}')
    logger.info(f'> Batch Size: {batchSize}')
    logger.info(f'> GPU device: cuda:{gpu_id}')
    logger.info(f'> Save Training Memory: {saveMem}')
    logger.info(f'============================')

    train(logger, save_name, cls_name, datapath,\
          resnetX, num_layer, vn_dims, \
          ksize=ksize, channel=numChannel, num_stack=numStack, \
          device=f'cuda:{gpu_id}', batch_size=batchSize, save_memory=saveMem)

到这里,训练完之后,我们就得到了模型的权重pt文件,为我们后面的推理做准备。

七.推理

未完待续

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/921985.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++信息学奥赛1136:密码翻译

#include <iostream> #include <string> using namespace std;int main() {string arr;getline(cin, arr); // 输入字符串&#xff0c;包括空格for (int i 0; i < arr.length(); i) {char a arr[i] 1; // 字符加1if (arr[i] z) {a a; // 如果当前字符是…

springboot 基于JAVA的动漫周边商城的设计与实现64n21

动漫周边商城分为二个模块&#xff0c;分别是管理员功能模块和用户功能模块。管理员功能模块包括&#xff1a;文章资讯、文章类型、动漫活动、动漫商品功能&#xff0c;用户功能模块包括&#xff1a;文章资讯、动漫活动、动漫商品、购物车&#xff0c;传统的管理方式对时间、地…

PyTorch深度学习实战(13)——可视化神经网络中间层输出

PyTorch深度学习实战&#xff08;13&#xff09;——可视化神经网络中间层输出 0. 前言1. 可视化特征学习的结果2. 可视化第一个卷积层的输出3. 可视化不同网络层的特征图小结系列链接 0. 前言 随着深度学习的快速发展&#xff0c;神经网络已成为解决各种复杂任务的重要工具。…

day 38 | ● 518. 零钱兑换 II ● 377. 组合总和 Ⅳ

518. 零钱兑换 II 这道题就是完全背包问题&#xff0c;因为可以选择的数量是无限的。所以第二层的遍历顺序就是从前往后。 因为是次数问题&#xff0c;递推公式是 的&#xff0c;初值应该设定为dp【0】 1&#xff0c;否则无法进行累加。 func change(amount int, coins []i…

Python编程基础-基本语法II

循环语句 for()语句 可以遍历任何序列的项目&#xff0c;如一个列表、元组或者一个字符串 格式&#xff1a; for 循环索引值 in 序列 循环体 #for循环把字符串中字符遍历出来 for letter in Python:print ( 当前字母 :, letter )#通过索引循环 fruits [banana, apple, m…

百度地图:设置复杂的自定义覆盖物,添加自定义覆盖物ComplexCustomOverlay

// 设置复杂的自定义覆盖物 setComplexCustomOverlay({coordinate,icon 1,label,contentHTML, }) {var mp this.map;let _BMAP this.data.type 3 ? BMapGL : BMap;// 自定义覆盖物----------------------------------------function ComplexCustomOverlay({point,icon,lab…

【全站最全】被苹果、谷歌和Microsoft停产的产品(一)

目录 ​编辑 2025 Skype for Business 2023 Cortana Google Domains Google Optimize Google Universal Analytics YouTube Stories Grasshopper Google Currents (2019) Google Stadia 2022 YouTube Originals Google OnHub Atom Google Surveys Apple Watc…

【3dsmax】练习——制作碗椅

目录 目标 步骤 一、制作主体部分 二、制作靠垫部分 三、制作支架部分 目标 制作如下图所示的碗椅 步骤 一、制作主体部分 1. 首先创建一个球体 2. 转换为可编辑多边形&#xff0c;然后切换到边层级&#xff0c;选中球体上部的所有边&#xff0c;然后删除 3. 通过“壳…

Linux下的系统编程——gdb调试工具

前言&#xff1a; 程序中除了一目了然的Bug之外都需要一定的调试手段来分析到底错在哪。到目前为止我们的调试手段只有一种∶根据程序执行时的出错现象假设错误原因﹐然后在代码中适当的位置插入printf﹐执行程序并分析打印结果﹐如果结果和预期的一样﹐就基本上证明了自己假设…

东风纳米首款车型纳米 01 亮相:纯电小车,固态电池 + 超级快充

根据近期发布的消息&#xff0c;东风纳米品牌推出的首款车型纳米 01 在全新发布会上正式亮相。这款车型采用东风量子架构 3 号平台&#xff0c;被寄予厚望将在国内市场迎来广泛的认可度。 作为一款小型纯电动车&#xff0c;纳米 01注重家庭出游、市区代步、个人通勤、接送孩子、…

“好声音”塌房、星空华文市值暴跌,两个交易日蒸发234亿港元

8月17日&#xff0c;因李玟生前录音事件&#xff0c;再次将《中国好声音》被舆论推至风口浪尖&#xff0c;引发社会对后者的质疑。 次日(8月18日)&#xff0c;《中国好声音》的IP运营商星空华文(06698.HK)股价大跌&#xff0c;其港股收盘股价跌幅达到23.4%&#xff0c;一天内市…

基于逻辑斯蒂回归的肿瘤预测案例

导入包 import pandas as pd import numpy as np from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from sklearn.linear_model import LogisticRegression from sklearn.metrics import classification_report,roc_a…

uniapp - 全平台兼容实现上传图片带进度条功能,用户上传图像到服务器时显示上传进度条效果功能(一键复制源码,开箱即用)

效果图 uniapp小程序/h5网页/app实现上传图片并监听上传进度,显示进度条完整功能示例代码 一键复制,改下样式即可。 全部代码 记得改下样式,或直接

JavaWeb_LeadNews_Day7-ElasticSearch, Mongodb

JavaWeb_LeadNews_Day7-ElasticSearch, Mongodb elasticsearch安装配置 app文章搜索创建索引库app文章搜索思路分析具体实现 新增文章创建索引思路分析具体实现 MongoDB安装配置SpringBoot集成MongoDB app文章搜索记录保存搜索记录思路分析具体实现 查询搜索历史删除搜索历史 搜…

关于java三元组的问题

在改代码的时候&#xff0c;发现一个奇怪的地方&#xff0c;举例如下 Testpublic void buildTest(){TT t new TT();Long time tnull?System.currentTimeMillis():t.getTime();System.out.println("done");}Datapublic static class TT{Long time;}这个地方运行就…

STL---vector

目录 1.vector的介绍及使用 2.vector接口说明及模拟实现 2.1vector定义 2.2vector迭代器的使用 2.3vector容量 2.4vector增删查改 3迭代器失效 4.使用memcpy拷贝 5.模拟实现 1.vector的介绍及使用 vector的文档介绍 1. vector是表示可变大小数组的序列容器。 2. 就像数…

电子病历系统EMR

电子病历系统EMR源码 一体化电子病历系统基于云端SaaS服务的方式&#xff0c;采用B/S&#xff08;Browser/Server&#xff09;架构提供&#xff0c;覆盖了医疗机构电子病历模板制作到管理使用的整个流程。除实现在线制作内容丰富、图文并茂、功能完善的电子病历模板外&#xff…

phpcms重置密码为123456

0b817b72c5e28b61b32ab813fd1ebd7f3vbCrK

实战项目ssm权限系统 3-总结篇,权限模块保护业务模块

一 工程模块介绍 1.1 工程模块关系 在业务微服务模块中引入安全认证模块&#xff0c;起到对业务模块的认证授权保护

ubuntu18.04复现yolo v8环境配置之CUDA与pytorch版本问题以及多CUDA版本安装及切换

最近在复现yolo v8的程序&#xff0c;特记录一下过程 环境&#xff1a;ubuntu18.04ros melodic 小知识&#xff1a;GPU并行计算能力高于CPU—B站UP主说的 Ubuntu可以安装多个版本的CUDA。如果某个程序的Pyorch需要不同版本的CUDA&#xff0c;不必删除之前的CUDA&#xff0c;…