BusterNet网络Python模型实现学习笔记之二

news2024/12/25 9:04:57

文章目录

    • 一、squeeze函数的用法
    • 二、nn.CrossEntropyLoss函数
    • 三、isinstance函数
    • 四、定义冻结层 freeze_layers
    • 五、SummaryWriter 基础用法
    • 六、Python 基础语法
        • 1.变量嵌入到字符串
        • 2. enumerate() 函数
        • 3. 进度条库tqdm
        • 4. 字典(dict)展开为关键字参数(keyword arguments)
        • 5. assert 断言操作
        • 6. \_\_class__.__name__获取对象类名
        • 7. all() 方法判断字符是不是都非零
    • 附录

一、squeeze函数的用法

import torch

# 创建一个具有形状(2, 1, 3)的张量
x = torch.tensor([[[1, 2, 3]], [[4, 5, 6]]])
print("Original tensor shape:", x.shape)
# 输出: Original tensor shape: torch.Size([2, 1, 3])

# 使用squeeze()移除所有大小为1的维度
x_squeezed = x.squeeze()
print("Squeezed tensor shape:", x_squeezed.shape)
print(x_squeezed)
# 输出: Squeezed tensor shape: torch.Size([2, 3])

# 使用squeeze(dim)仅移除特定维度
x_squeezed_dim_1 = x.squeeze(1)
print("Squeezed tensor (dim 1) shape:", x_squeezed_dim_1.shape)
# 输出: Squeezed tensor (dim 1) shape: torch.Size([2, 3])

在这个例子中,我们创建了一个形状为(2, 1, 3)的3维张量。我们可以看到,第二个维度(索引为1)的大小为1。使用squeeze()函数移除所有大小为1的维度后,张量的形状变为(2, 3)。同时,使用squeeze(dim)函数仅移除特定维度也可以达到相同效果。

我们注意到给出的张量在第二个维度上不为零,不经让人产生疑问,我用squeeze(2)会报错吗?不妨动手一试

x_squeezed_dim_2 = x.squeeze(2)
print(x_squeezed_dim_2)
print(x_squeezed_dim_2.shape)

tensor([[[1, 2, 3]],
            [[4, 5, 6]]])
torch.Size([2, 1, 3])

结果是代码正常编译了,并没有产生问题。和原本张量一致,张量不会发生压缩。





二、nn.CrossEntropyLoss函数

nn.CrossEntropyLoss()是PyTorch中一个非常常用的损失函数,用于多分类任务。这个损失函数同时执行了nn.LogSoftmax()nn.NLLLoss()(负对数似然损失)。

请注意,这个损失函数需要两个输入:预测值(logits,未经softmax层处理的输出)和真实标签。对于预测值,输入张量的形状应该是(batch_size, num_classes, ...),其中...表示任意其他尺寸。对于真实标签,输入张量的形状应该是(batch_size, ...),标签值应该是0num_classes-1之间的整数。

下例中批量大小为3,类别数量为4

import torch
import torch.nn as nn

# 创建一个批量大小为3,类别数量为4的预测值张量(logits)
logits = torch.tensor([
    [2.5, 1.0, 0.5, 1.5],
    [0.3, 3.2, 2.1, 1.0],
    [1.2, 2.3, 3.1, 0.7]
])

# 创建一个对应的真实标签张量
labels = torch.tensor([0, 1, 2]) # 第一个样本的真实类别是0,第二个是1,第三个是2

# 初始化损失函数
criterion = nn.CrossEntropyLoss()

# 计算损失
loss = criterion(logits, labels)

print("Cross entropy loss:", loss.item())

经过运行,我们得到了如下的结果

Cross entropy loss: 0.4313742220401764

Process finished with exit code 0

如果是更高维度的,预测值张量会是什么形式的呢?

我们以语义分割任务为例,假设我们有一个批量大小为2,类别数量为3,图像高度和宽度分别为 4 × 4 4\times4 4×4 的预测值张量。在这种情况下,输入张量的形状应该是例如形状为 (batch_size, num_classes, height, width),其中 heightwidth 分别表示图像的高度和宽度。这意味着我们需要一个(batch_size, height, width)的标签向量。这是一个具体的示例,输入的形状为(2, 3, 4, 4)

[
  [
    [
      [0.5, 1.0, 1.2, 0.3],
      [0.2, 0.9, 1.1, 1.4],
      [1.5, 0.7, 0.6, 0.8],
      [0.1, 0.4, 0.2, 0.9]
    ],
    [
      [1.0, 0.5, 0.7, 1.2],
      [1.5, 0.3, 0.8, 0.6],
      [0.2, 1.0, 1.2, 0.5],
      [1.1, 0.9, 0.7, 0.6]
    ],
    [
      [0.7, 1.1, 0.6, 0.9],
      [0.6, 1.2, 0.5, 0.3],
      [1.0, 0.8, 1.1, 1.4],
      [1.2, 0.5, 0.9, 0.8]
    ]
  ],
  [
    [
      [1.1, 0.6, 0.8, 0.5],
      [0.9, 1.0, 1.2, 1.1],
      [0.7, 1.5, 0.3, 0.6],
      [1.3, 0.2, 0.4, 0.9]
    ],
    [
      [0.4, 1.2, 0.9, 1.5],
      [1.6, 0.1, 0.3, 0.7],
      [0.9, 0.6, 1.4, 1.0],
      [0.8, 1.1, 0.5, 0.3]
    ],
    [
      [0.6, 1.3, 1.0, 0.2],
      [0.5, 1.7, 0.8, 0.9],
      [1.2, 0.3, 1.1, 1.5],
      [0.7, 0.9, 1.0, 1.2]
    ]
  ]
]

这是一个随机输出和随机目标张量的示例:

import torch
import torch.nn as nn

# 假设 logits 是我们的模型预测的输出
logits = torch.randn(2, 3, 4, 4)  # 模拟输入张量

# 假设 targets 是我们的真实标签
targets = torch.randint(0, 3, (2, 4, 4))  # 随机生成一个目标张量

# 使用 nn.CrossEntropyLoss 计算损失
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, targets)
print(loss)

备注:在Pycharm中想要格式化代码,可以使用快捷键 Windows/Linux:Ctrl+Alt+L

但是,上面的模拟输入张量和随机目标张量都是随机的,为了更具说服力(其实是为了水文章长度),我们就用上面的 2 × 3 × 4 × 4 2\times3\times4\times4 2×3×4×4 维张量来试验一下。我们可以随时调整真实标签的值,来观察loss = criterion(logits, targets)的值是增大还是减小。

tensor(1.2113)
Process finished with exit code 0

现在我们修改一下模型的预测输出结果 logits。可以看到输出的 loss 值明显降低,说明预测值更加符合实际标签。

# 假设 logits 是我们的模型预测的输出
logits = torch.tensor([
    [
        [
            [0, 0, 0, 0],
            [0.2, 0.9, 1.1, 1.4],
            [1.5, 0.7, 0.6, 0.8],
            [0.1, 0.4, 0.2, 0.9]
        ],
        [
            [5, 5, 5, 5],
            [0, 0, 0, 0],
            [0.2, 1.0, 1.2, 0.5],
            [1.1, 0.9, 0.7, 0.6]
        ],
        [
            [0.7, 1.1, 0.6, 0.9],
            [5, 5, 5, 5],
            [1.0, 0.8, 1.1, 1.4],
            [1.2, 0.5, 0.9, 0.8]
        ]
    ],
    [
        [
            [1.1, 0.6, 0.8, 0.5],
            [0.9, 1.0, 1.2, 1.1],
            [0.7, 1.5, 0.3, 0.6],
            [1.3, 0.2, 0.4, 0.9]
        ],
        [
            [0.4, 1.2, 0.9, 1.5],
            [1.6, 0.1, 0.3, 0.7],
            [0.9, 0.6, 1.4, 1.0],
            [0.8, 1.1, 0.5, 0.3]
        ],
        [
            [0.6, 1.3, 1.0, 0.2],
            [0.5, 1.7, 0.8, 0.9],
            [1.2, 0.3, 1.1, 1.5],
            [0.7, 0.9, 1.0, 1.2]
        ]
    ]
])  # 模拟输入张量

# 假设 targets 是我们的真实标签
targets = torch.tensor(
    [
        [
            [1, 1, 1, 1],
            [2, 2, 2, 2],
            [0, 0, 0, 0],
            [0, 0, 0, 0]
        ],
        [
            [1, 1, 1, 1],
            [2, 2, 2, 2],
            [0, 0, 0, 0],
            [0, 0, 0, 0]
        ]
    ])

# 使用 nn.CrossEntropyLoss 计算损失
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, targets)
print(loss)

tensor(0.9149)
进程已结束,退出代码0






三、isinstance函数

isinstance() 函数是 Python 的内置函数,用于检查一个对象是否是指定类的实例。该函数具有两个参数:

  • 第一个参数是要检查的对象。
  • 第二个参数是类或类的元组。

函数的返回值是布尔值,如果对象是给定类的实例(或者是元组中任何类的实例),则返回 True,否则返回 False

下面是一些使用 isinstance() 函数的示例:

# 示例 1: 判断变量是否为整数
num = 5
print(isinstance(num, int))  # 输出: True

# 示例 2: 判断变量是否为字符串
text = "Hello, World!"
print(isinstance(text, str))  # 输出: True

# 示例 3: 判断变量是否为整数或浮点数
num2 = 3.14
print(isinstance(num2, (int, float)))  # 输出: True

# 示例 4: 使用自定义类
class MyClass:
    pass

class AnotherClass:
    pass

obj = MyClass()
print(isinstance(obj, MyClass))  # 输出: True
print(isinstance(obj, AnotherClass))  # 输出: False

上面提到第二个参数可以是类的元组,表示的关系,下面是一个示例:

class Animal:
    pass

class Dog(Animal):
    pass

class Cat(Animal):
    pass

class Car:
    pass

# 创建一个 Dog 对象
dog = Dog()

# 使用 isinstance() 函数检查 dog 是否是 Dog 或 Cat 类的实例
print(isinstance(dog, (Dog, Cat)))  # 输出: True

# 使用 isinstance() 函数检查 dog 是否是 Animal 或 Car 类的实例
print(isinstance(dog, (Animal, Car)))  # 输出: True, 因为 Dog 类是 Animal 类的子类
                # Early stopping
                if epoch - best_epoch > opt.es_patience > 0:
                    print('[Info] Stop training at epoch {}. The lowest loss achieved is {}'.format(epoch, best_loss))
                    break

es_patienceEarly Stopping 的一种实现方式,其中’es’是early的缩写,'patience’指的是在停止训练之前允许的性能停滞时间。具体来说,es_patience 是一种在训练过程中使用的技术,它基于可以允许的性能停滞时间,在模型的训练过程中始终监测验证集的性能,以便及早停止训练并避免过拟合。

images = images.cuda()  # 将图片数据从 CPU 发送到 GPU 上进行处理
labels = labels.cuda()  # 将标签数据从 CPU 发送到 GPU 上进行处理  
if loss == 0 or not torch.isfinite(loss):
    continue

这行代码通常用于在训练神经网络时,处理梯度下降过程中产生的非数值(NaN)和无穷大(Inf)的情况。

loss 是一个 tensor 类型(张量),记录了当前模型输出与真实标签之间的损失值。在 PyTorch 中,如果 loss 的值为 0 0 0 或者不是有限数(即 NaN 或 Inf),则会出现异常,并且程序会中断。

# 创建一个包含 NaN 和 Inf 的张量
data = torch.tensor([float('nan'), float('inf')])

# 判断张量的元素是否为有限数
if torch.isfinite(data).all():
    # 如果所有元素都是有限数,则进行其他操作
    print("All elements are finite.")
else:
    # 如果存在非有限数元素,则跳过此操作
    print("There are infinite or NaN elements.")





四、定义冻结层 freeze_layers

    if opt.freeze_layers is not None:
        assert isinstance(opt.freeze_layers, list), "Required List string"
        def freeze_layers(m):
            classname = m.__class__.__name__ 
            for ntl in opt.freeze_layers:
                if ntl in classname:#可以理解为 "need to freeze layer"
                    for param in m.parameters():
                        param.require_grad = False 
        
        model.apply(freeze_layers)#将该函数作用于模型上,以实现对特定层的参数进行冻结
        print('[Info] freeze layers in ', opt.freeze_layers)

以上代码实现了对模型特定层的权重冻结,具体过程如下:

  1. 首先进行一个条件判断,如果 opt.freeze_layers 不为 None,则进入到定义函数 freeze_layers 的块中。

  2. freeze_layers 函数中,通过 m.__class__.__name__ 获取当前遍历的模块 m 的类名,并将其与 opt.freeze_layers 中的每个字符串进行比较。若 classname 包含 ntl,则说明该模块需要被冻结。

  3. 如果发现有层需要被冻结,则会遍历该层的参数列表,并将各参数的 require_grad 属性设置为 False,防止其在后续训练中被更新。

  4. 在对所有层都完成操作之后,通过 model.apply(freeze_layers)freeze_layers 这个函数作用于模型中的所有层次上,从而实现对特定层的参数进行冻结

  5. 最后,程序输出一条信息提示,显示哪些层被冻结了。






五、SummaryWriter 基础用法

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('logs')
writer.close()

我新建一个文件,尝试运行上述代码结果发生了下面的报错:

TypeError: Descriptors cannot not be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.

我在命令行中以管理员身份运行下述代码,修改了 protobuf 的版本,成功解决了这个报错。(注意安装新版本之前一定要卸载干净旧的版本,否则会有未知的错误)

pip install protobuf==3.19.0

但是出现了新的报错如下:

AttributeError: module ‘tensorflow’ has no attribute ‘io’

在这里插入图片描述
根据提示,我们打开 event_file_writer.py 文件,修改代码为

from tensorboard.compat import tensorflow_stub as tf

回到最开始的文件,此时编译就可以正常通过了,结果如下

Connected to pydev debugger (build 222.3345.131)
Process finished with exit code 0

此时 logs 文件夹已经生成,但是如果我们想要看到 tensorflow 可视化工具还会出现一些问题。我们在 PowerShell 界面尝试输入下面的代码,但是会报错。出现的错误原因是

tensorboard --logdir=logs

tensorboard ValueError: Duplicate plugins for name projector

在这里插入图片描述
这是因为我们曾经安装过 tensorflow,但是因为 Python 的版本控制问题,他会有一些安装包没有卸载干净。我们必须要删除掉一些遗留的文件夹才能解决掉这个问题。后来我安装了 anaconda 解决了 environment 这一困难,当然这是后话了。当然还有一个小 bug 就是可视化界面必须要在 chrome 内核的浏览器中才能打开。当然,这也是后话了。

六、Python 基础语法

1.变量嵌入到字符串

当需要将变量嵌入到字符串中时,可以使用字符串格式化方法。在Python中,有多种实现这种方法的方式,下面是一些例子:

  1. 使用百分号:
name = "Alice"
age = 25
message = "My name is %s, and I'm %d years old." % (name, age)
print(message)  # 输出:"My name is Alice, and I'm 25 years old."
  1. 使用format()函数:
name = "Bob"
weight = 68.5
height = 1.75
message = "Hello, my name is {} and my weight is {:.1f} kg. My height is {:.2f} m.".format(name, weight, height)`在这里插入代码片`
print(message)  # 输出:"Hello, my name is Bob and my weight is 68.5 kg. My height is 1.75 m."
  1. 使用f-string:
x = 3
y = 4
result = f'{x} + {y} = {x+y}'
print(result)  # 输出: "3 + 4 = 7"


2. enumerate() 函数

  1. 打印列表中的元素及其对应下标
fruits = ['banana', 'apple',  'mango']

for index, fruit in enumerate(fruits):
    print(index, fruit)

0 banana
1 apple
2 mango
Process finished with exit code 0

  1. 将列表转化为字典,其中字典的 key 是列表元素的下标,value 是列表元素本身
fruits = ['banana', 'apple',  'mango']

d = {index: fruit for index, fruit in enumerate(fruits)}
print(d)

{0: ‘banana’, 1: ‘apple’, 2: ‘mango’}
Process finished with exit code 0

  1. 枚举字符串中的字符
word = 'hello'

for i, char in enumerate(word):
   print(i, char)

0 h
1 e
2 l
3 l
4 o



3. 进度条库tqdm

tqdm 是 Python 中的一个进度条库,它可以让我们在循环体内添加一个进度条,以便在程序运行时实时显示循环进度,并可随时停止、暂停、恢复进度条等操作。

from tqdm import tqdm
import time

# 定义一个包含 10000 个元素的列表
l = list(range(10000))

# 使用 tqdm 显示循环进度
for i in tqdm(l):
    # 模拟耗时操作
    time.sleep(0.001)

在这里插入图片描述

tqdm 源自阿拉伯语 taqaddum (تقدّم) ,意思是进程 (“progress”)



4. 字典(dict)展开为关键字参数(keyword arguments)

在 Python 中,使用两个星号 ** 可以将一个字典(dict)展开为关键字参数(keyword arguments)。这意味着,如果我们有一个包含若干个关键字参数的字典 params,我们可以通过在函数调用时使用双星号来将这些参数传递到函数中。例如:

def some_function(a, b, c):
    print(f"a={a}, b={b}, c={c}")

params = {"a": 1, "b": 2, "c": 3}

some_function(**params)  # 等价于 some_function(a=1, b=2, c=3)

a=1, b=2, c=3

将一个字典展开为关键字参数时,字典中的键(key)必须和定义函数时的关键词参数名一致。只有这样,Python 才能正确地将字典中的值(value)分配给相应的关键词参数。

值得注意的是,如果在字典中缺少任何一个关键词参数,或者字典中存在多余的关键词参数,则会引发 TypeError 异常。我们将代码进行如下修改:

def some_function(a, b, c):
    print(f"a={a}, b={b}, c={c}")

params = {"a": 1, "b": 2, "c": 3,"d":5}

TypeError: some_function() got an unexpected keyword argument ‘d’



5. assert 断言操作

def add_numbers(x, y):
    assert isinstance(x, int) and isinstance(y, int), "x and y must be integers."
    return x + y

print(add_numbers(2, 3))  # Output: 5
print(add_numbers('Hello', 3))  # AssertionError: x and y must be integers.

AssertionError: x and y must be integers.
5

在以上示例中,第一行计算了 2 和 3 的和,输出结果 5,符合预期。而第二行在调用 add_numbers 函数时,将一个字符串 "Hello" 和整数 3 作为函数参数传入,因此此时 assert 语句判断失败,抛出异常并打印出错误信息 "x and y must be integers."



6. __class__.__name__获取对象类名

m.__class__.__name__ 是 Python 中一种获取对象类型的方式。在示例代码中,它是用来获取当前遍历到的模块 m 的类名。

具体来说,在 Python 中,任何一个对象都有一个类(或类型),可以使用 type() 或者对象的 __class__属性来获取它们的类型/类。例如,以下代码创建了两个对象并打印它们的类型:

a = 1
b = "hello"
print(type(a)) # <class 'int'>
print(b.__class__) # <class 'str'>

<class ‘int’>
<class ‘str’>

因为调用 type() 得到结果的标准格式不便于直接作为字符串进行处理,所以常常使用 __class__.__name__ 来获取对象类型的名称。__name__是指该类型名称,而 __class__则表示该类本身。例如,在上面示例代码中,使用 __class__.__name__ 可以将结果转化为字符串类型的对象名称:

a = 1
b = "hello"
print(a.__class__.__name__) # 'int'
print(b.__class__.__name__) # 'str'

int
str

类似地,对于 PyTorch 中的 nn 模块,也可以使用 __class__.__name__ 获取模块的类名。比如下面的代码:

import torch.nn as nn

linear_layer = nn.Linear(10, 5)  # 创建一层线性变换
conv_layer = nn.Conv2d(3, 16, (3,3), padding=1)  # 创建一层卷积变换

print(linear_layer.__class__.__name__)  # 输出:Linear
print(conv_layer.__class__.__name__)  # 输出:Conv2d

Linear
Conv2d

在代码中,我们使用 nn.Linear 和 nn.Conv2d 分别创建了两种不同的神经网络层。linear_layer 对象被初始化为 nn.Linear(10, 5),因此,linear_layer.__class__ 是 nn.Linear 类型,使用 __class__.__name__ 获取其类名为 ‘Linear’。

对于 conv_layer,也是类似的过程。因此,conv_layer.__class__.__name__ 会返回 ‘Conv2d’ 字符串表示它是一个卷积层。



7. all() 方法判断字符是不是都非零

all 方法在Python中用来判断一个数组是不是都是非零的。下面是例子:

# 定义一个包含零和正数元素的张量
x = torch.tensor([1, 2, 0, 4, 5])

# 判断张量中的所有元素是否都非零
if x.all():
    print("All elements are nonzero.")
else:
    print("There are zero elements.")

There are zero elements.

附录

import argparse
import datetime
import os
import traceback

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from tensorboardX import SummaryWriter
from tqdm import tqdm

from dataset import USCISIDataset
from net import BusterNet
from utils import CustomDataParallel

def get_args():
    parser = argparse.ArgumentParser('Buster Net')
    parser.add_argument('-n', '--num_workers', type=int, default=16, help='num_workers of dataloader')
    parser.add_argument('-b', '--batch_size', type=int, default=4,  help='The number of images per batch among all devices')
    parser.add_argument('--num_gpus', type=int, default=1,  help='The number of gpus') # Multi gpus not spport yet.
    parser.add_argument('--freeze_layers', nargs='*', default=None,
                        help='freeze layers with strategy')
    parser.add_argument('--lr', type=float, default=1e-2)
    parser.add_argument('--optim', type=str, default='adamw', help='select optimizer for training, '
                                                                   'suggest using \'adamw\' or \'adam\' until the'
                                                                   ' very final stage then switch to \'sgd\'')
    parser.add_argument('--num_epochs', type=int, default=500)
    parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases')
    parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving')
    parser.add_argument('--es_min_delta', type=float, default=0.0,
                        help='Early stopping\'s parameter: minimum change loss to qualify as an improvement')
    parser.add_argument('--es_patience', type=int, default=0,
                        help='Early stopping\'s parameter: number of epochs with no improvement after which training will be stopped. Set to 0 to disable this technique.')
    parser.add_argument('--lmdb_dir', type=str, default='./datasets/USCISI-CMFD', help='the root folder of dataset')
    parser.add_argument('--log_path', type=str, default='./logs/')
    parser.add_argument('-w', '--load_weights', type=str, default=None,
                        help='whether to load weights from a checkpoint, set None to initialize, set \'last\' to load last checkpoint')
    parser.add_argument('--saved_path', type=str, default='logs/')

    args = parser.parse_args()
    return args


class ModelWithLoss(nn.Module):
    def __init__(self, model, train_simi=True, train_mani=True, train_fusion=True, debug=False):
        super().__init__()
        self.ce_criterion = nn.CrossEntropyLoss()
        self.bce_criterion = nn.BCELoss()
        self.model = model
        self.train_simi = train_simi
        self.train_mani = train_mani 
        self.train_fusion = train_fusion
        self.debug = debug

    def forward(self, imgs, gts):
        fusion_preds, mani_preds, simi_preds = self.model(imgs)
        simi_gts = (1 - gts[:, 2, :, :]).type(torch.float)
        mani_gts = gts[:, 0, :, :].type(torch.float)
        _, fusion_gts = gts.max(dim=1)

        loss = torch.zeros(3)
        if self.train_fusion:
            fusion_loss = self.ce_criterion(fusion_preds, fusion_gts)
            loss[0] = fusion_loss
        if self.train_mani:
            mani_preds = mani_preds.squeeze(1)
            mani_loss = self.bce_criterion(mani_preds, mani_gts)
            loss[1] = mani_loss
        if self.train_simi:
            simi_preds = simi_preds.squeeze(1)
            simi_loss = self.bce_criterion(simi_preds, simi_gts)#ground truth segmentation 真值分割
            loss[2] = simi_loss

        return loss


def train(opt):
    train_file = 'train.keys'
    val_file = 'valid.keys'
    # Train similarity network or manipulation network independently or the whole network.
    train_simi=True
    train_mani=True
    train_fusion=True

    # According to the papers, set input_size default to 256.  
    input_size = 256

    train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((input_size, input_size)),
        # transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
    ])
    val_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
    ])
    target_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
    ])
    train_set = USCISIDataset(opt.lmdb_dir, train_file, train_transform, target_transform)
    val_set = USCISIDataset(opt.lmdb_dir, val_file, val_transform, target_transform)
    
    training_params = {'batch_size': opt.batch_size,
                       'shuffle': True,
                       'drop_last': True,
                    #    'collate_fn': collater,
                       'num_workers': opt.num_workers}

    val_params = {'batch_size': opt.batch_size,
                  'shuffle': False,
                  'drop_last': True,
                #   'collate_fn': collater,
                  'num_workers': opt.num_workers}

    training_generator = DataLoader(train_set, **training_params)
    val_generator = DataLoader(val_set, **val_params)

    model = BusterNet(image_size=input_size)

    if opt.load_weights is not None:
        try:
            # Load pretrain VGG16 in https://download.pytorch.org/models/vgg16-397923af.pth or continuing training
            if 'vgg16_bn' in opt.load_weights:
                vgg_backbone = torch.load(opt.load_weights)
                model.manipulation_net.load_state_dict(vgg_backbone, strict=False)
                model.similarity_net.load_state_dict(vgg_backbone, strict=False)
            else:
                model.load_state_dict(torch.load(opt.load_weights), strict=False)
        except RuntimeError as e:
            print(f'[Warning] Ignoring {e}')
        print(
            f'[Info] loaded weights: {os.path.basename(opt.load_weights)}')
    else:
        print('[Info] initializing weights...')
    #     init_weights(model)

    if opt.freeze_layers is not None:
        assert isinstance(opt.freeze_layers, list), "Required List string"
        def freeze_layers(m):
            classname = m.__class__.__name__ 
            for ntl in opt.freeze_layers:
                if ntl in classname:
                    for param in m.parameters():
                        param.require_grad = False 
        
        model.apply(freeze_layers)
        print('[Info] freeze layers in ', opt.freeze_layers)
    
    # warp the model with loss function, to reduce the memory usage on gpu0 and speedup
    model = ModelWithLoss(model, train_simi=train_simi, train_mani=train_mani, train_fusion=train_fusion)

    if opt.num_gpus > 1 and opt.batch_size // opt.num_gpus < 4:
        model.apply(replace_w_sync_bn)
        use_sync_bn = True
    else:
        use_sync_bn = False

    os.makedirs(opt.saved_path, exist_ok=True)
    writer = SummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

    if opt.num_gpus > 0:
        model = model.cuda()
        if opt.num_gpus > 1:
            model = CustomDataParallel(model, opt.num_gpus)
            if use_sync_bn:
                patch_replication_callback(model)
    
    if opt.optim == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), opt.lr)
    elif opt.optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), opt.lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(), opt.lr, momentum=0.9, nesterov=True)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)

    last_step = 0
    epoch = 0
    best_loss = 1e5
    best_epoch = 0
    step = max(0, last_step)
    model.train()
    
    num_iter_per_epoch = len(training_generator)
    
    try:
        for epoch in range(opt.num_epochs):

            epoch_loss = []
            progress_bar = tqdm(training_generator)
            for iter, data in enumerate(progress_bar):
                last_epoch = step // num_iter_per_epoch
                if iter < step - last_epoch * num_iter_per_epoch:
                    progress_bar.update()
                    continue
                try:
                    imgs, gts, _ = data

                    if opt.num_gpus == 1:
                        # if only one gpu, just send it to cuda:0
                        # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
                        imgs = imgs.cuda()
                        gts = gts.cuda()

                    optimizer.zero_grad()

                    fusion_loss, mani_loss, simi_loss = model(imgs, gts)
                    fusion_loss = fusion_loss.mean()
                    simi_loss = simi_loss.mean()
                    mani_loss = mani_loss.mean()

                    loss = fusion_loss + mani_loss + simi_loss
                    if loss == 0 or not torch.isfinite(loss):
                        continue

                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                    optimizer.step()

                    epoch_loss.append(float(loss))

                    progress_bar.set_description(
                        'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Fusion loss: {:.5f}. Mani loss: {:.5f}. Mini loss: {:.5f} Total loss: {:.5f}'.format(
                            step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch, fusion_loss.item(),
                            mani_loss.item(), simi_loss.item(), loss.item()))
                    writer.add_scalar('Loss', loss, step)
                    writer.add_scalar('fusion_loss', fusion_loss, step)
                    writer.add_scalar('simi_loss', simi_loss, step)
                    writer.add_scalar('mani_loss', mani_loss, step)

                    # log learning_rate
                    current_lr = optimizer.param_groups[0]['lr']
                    writer.add_scalar('learning_rate', current_lr, step)

                    step += 1

                    if step % opt.save_interval == 0 and step > 0:
                        save_checkpoint(model, f'model_{epoch}_{step}.pth')
                        print('checkpoint...')

                except Exception as e:
                    print('[Error]', traceback.format_exc())
                    print(e)
                    continue
            scheduler.step(np.mean(epoch_loss))

            if epoch % opt.val_interval == 0:
                model.eval()
                loss_fusion_ls = []
                loss_simi_ls = []
                loss_mani_ls = []
                for iter, data in enumerate(val_generator):
                    with torch.no_grad():
                        imgs, gts, _ = data

                        if opt.num_gpus == 1:
                            imgs = imgs.cuda()
                            gts = gts.cuda()

                        fusion_loss, mani_loss, simi_loss = model(imgs, gts)
                        fusion_loss = fusion_loss.mean()
                        simi_loss = simi_loss.mean()
                        mani_loss = mani_loss.mean()

                        loss = fusion_loss + mani_loss + simi_loss
                        if loss == 0 or not torch.isfinite(loss):
                            continue

                        loss_fusion_ls.append(fusion_loss.item())
                        loss_simi_ls.append(simi_loss.item())
                        loss_mani_ls.append(mani_loss.item())

                fusion_loss = np.mean(loss_fusion_ls)
                simi_loss = np.mean(loss_simi_ls)
                mani_loss = np.mean(loss_mani_ls)
                loss = fusion_loss + simi_loss + mani_loss

                print(
                    'Val. Epoch: {}/{}. Fusion loss: {:1.5f}. Simi loss: {:1.5f}. Mani loss: {:1.5f}. Total loss: {:1.5f}'.format(
                        epoch, opt.num_epochs, fusion_loss, simi_loss, mani_loss, loss))
                writer.add_scalar('Val_Loss', loss, step)
                writer.add_scalar('Val_Fusion_loss', fusion_loss, step)
                writer.add_scalar('Val_Simi_loss', simi_loss, step)
                writer.add_scalar('Val_Mani_loss', mani_loss, step)

                if loss + opt.es_min_delta < best_loss:
                    best_loss = loss
                    best_epoch = epoch

                    save_checkpoint(model, f'model_{epoch}_{step}.pth')

                model.train()

                # Early stopping
                if epoch - best_epoch > opt.es_patience > 0:
                    print('[Info] Stop training at epoch {}. The lowest loss achieved is {}'.format(epoch, best_loss))
                    break
    except KeyboardInterrupt:
        save_checkpoint(model, f'model_{epoch}_{step}.pth')
        writer.close()
    writer.close()

def save_checkpoint(model, name):
    if isinstance(model, CustomDataParallel):
        torch.save(model.module.model.state_dict(), os.path.join(opt.saved_path, name))
    else:
        torch.save(model.model.state_dict(), os.path.join(opt.saved_path, name))

if __name__ == '__main__':
    opt = get_args()
    train(opt)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/479527.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

TAPFixer总结

相关工作 Menshen 检测属性用户写 et al检测属性就简单三个 未来工作&#xff1a; liveness; implicit; 数据集&#xff1b; 抽象方式合并&#xff1b;抽象规则配置&#xff1b;缓解谓词爆炸&#xff1b;concurrency的说明; 代码简化工作&#xff1b;给出能修复的漏洞种类 …

《基于光电容积法和机器学习的冠状动脉疾病患者出血风险预测》阅读笔记

目录 一、论文摘要 二、论文十问 三、论文亮点与不足之处 四、与其他研究的比较 五、实际应用与影响 六、个人思考与启示 参考文献 一、论文摘要 在冠状动脉疾病&#xff08;CAD&#xff09;患者的抗血栓治疗过程中&#xff0c;出血事件是关注的主要焦点。本研究旨在探讨…

浅谈一下布隆过滤器的设计之美

1 缓存穿透 2 原理解析 3 Guava实现 4 Redisson实现 5 实战要点 6 总结 布隆过滤器是一个非常有用的数据结构。它可以在大规模数据中高效地判断某个元素是否存在。布隆过滤器的应用非常广泛&#xff0c;不仅在搜索引擎、防垃圾邮件等领域中经常用到&#xff0c;而且在许多…

R语言单因素方差分析

R中的方差分析 介绍用于比较独立组的不同类型的方差分析&#xff0c;包括&#xff1a; 单因素方差分析&#xff1a;独立样本 t 检验的扩展&#xff0c;用于在存在两个以上组的情况下比较均值。这是方差分析检验的最简单情况&#xff0c;其中数据仅根据一个分组变量&#xff0…

【数据结构】七大排序总结

目录 &#x1f33e;前言 &#x1f33e; 内部排序 &#x1f308;1. 直接插入排序 &#x1f308;2. 希尔排序 &#x1f308;3. 直接选择排序 &#x1f308;4. 堆排序 &#x1f308;5. 归并排序 &#x1f308;6. 冒泡排序 &#x1f308;7. 快速排序 &#x1f33e;外部排序 &…

4 月份 火火火火 的开源项目

盘点 4 月份 GitHub 上 Star 攀升最多的开源项目&#xff0c;整个 4 月份最火项目 90% 都是 AI 项目&#xff08;准确的说&#xff0c;最近半年的热榜都是 AI 项目&#xff09; 本期推荐开源项目目录&#xff1a; 1. AI 生成逼真语音 2. 复旦大模型 MOSS&#xff01; 3. 让画中…

万万没想到在生产环境翻车了,之前以为很熟悉 CountDownLatch

前言 需求背景 具体实现 解决方案 总结 前言 之前我们分享了CountDownLatch的使用。这是一个用来控制并发流程的同步工具&#xff0c;主要作用是为了等待多个线程同时完成任务后&#xff0c;在进行主线程任务。然而&#xff0c;在生产环境中&#xff0c;我们万万没想到会…

【LeetCode】583. 两个字符串的删除操作

583. 两个字符串的删除操作&#xff08;中等&#xff09; 思路 这道题的状态定义和 1143. 最长公共子序列 相同&#xff0c;「定义一个 dp 数组&#xff0c;其中 dp[i]表示到位置 i 为止的子序列性质&#xff0c;并不是必须以 i 结尾」&#xff0c;此时 dp 数组的最后一位即为…

富士康终于醒悟了,重新加码中国制造,印度制造信不过

4月25日富士康在郑州揭牌新事业总部&#xff0c;显示出在扰攘了数年之后&#xff0c;富士康再度加强郑州富士康的发展力度&#xff0c;这应该是富士康在印度努力数年之后终于清醒了&#xff0c;印度制造终究不如中国制造可靠。 一、苹果和富士康在印度发展的教训 这两年苹果和富…

智能算法系列之基于粒子群优化的模拟退火算法

文章目录 前言1. 算法结合思路2. 问题场景2.1 Sphere2.2 Himmelblau2.3 Ackley2.4 函数可视化 3. 算法实现代码仓库&#xff1a;IALib[GitHub] 前言 本篇是智能算法(Python复现)专栏的第四篇文章&#xff0c;主要介绍粒子群优化算法与模拟退火算法的结合&#xff0c;以弥补各自…

【unity项目实战】3DRPG游戏开发07——其他详细的设计

敌人动画设计 新增图层动画,把权重设为1 在新图层默认新建一个空状态Base State,实现怪物默认动画播放Base State,因为Base State是空动画,所以默认会找上一个层的动画,这样就实现了两个图层动画的切换,也可以选择修改权重的方式实现 敌人随机巡逻 显示敌人巡逻的范…

网络字节序和主机字节序详解(附代码)

一、网络字节序和主机字节序 网络字节序和主机字节序是计算机网络中常用的两种数据存储格式。 主机字节序&#xff1a; 指的是在计算机内部存储数据时采用的字节排序方式。对于一个长为4个字节的整数&#xff0c;若采用大端字节序&#xff0c;则该整数在内存中的存储顺序是&a…

AppScan-被动手动扫描

被动扫描是针对性的扫描&#xff0c;浏览器代理到AppScan&#xff0c;然后进行手工操作&#xff0c;探索产生出的流量给AppScan进行扫描。这样可以使得扫描足够精准&#xff0c;覆盖率更加高&#xff0c;还能减少不必要的干扰 &#xff08;一&#xff09;环境准备 1、火狐安装…

SAP UI5 之Controls (控件) 笔记三

文章目录 官网 Walkthrough学习-Controls控件1.0.1 在index.html中使用class id 属性控制页面展示的属性1.0.2 我们在index.js文件中引入 text文本控制1.0.3打开浏览器查看结果 官网 Walkthrough学习-Controls控件 Controls控件 在前面展示在浏览器中的Hello World 是在Html …

Presto 之Hash Join的Partition

一. 前言 在Presto中&#xff0c;当两表Join为Hash Join并且join_distribution_type为PARTITIONED的时候&#xff0c;Presto会将Build表分区&#xff08;Partition&#xff09;后再进行Join操作。在Presto中的Join操作中&#xff0c;对表的分区有两级&#xff0c;第一级是将Has…

超简单搭建一个自用的ChatGPT网站(支持给网站添加访问密码)

前言&#xff1a; 有小伙伴留言想在自己的服务器搭建上图所示的ChatGPT网站&#xff0c;那么今天就是教大家如何在自己的服务器搭建像上图所示的ChatGPT网站 准备条件&#xff1a; 1&#xff09;一台服务器(这里用centos7) 2&#xff09;ChatGPT的API-KEY 一、Docker环境部署…

存储资源调优技术——SmartThin智能精简配置技术

目录 基本概念 工作原理 SmartThin关键技术 SmartThin主要功能 应用场景 精简LUN&#xff0c;存储空间超分配 按需动态分配存储资源&#xff0c;提高存储资源利用率 Thick和Thin LUN的区别如下 基本概念 Thin Lun属于存储资源的虚拟化&#xff0c;因此需要基于RAID 2.0存…

当影像遇上Python:用MoviePy库轻松搞定视频编辑

I. 简介 当影像遇上Python&#xff1a;用MoviePy库轻松搞定视频编辑 I. 简介II. 安装III. 使用 &#x1f680;&#x1f3ac;1. 创建一个视频剪辑对象2. 剪辑视频3. 剪切视频片段4. 改变视频尺寸和速度5. 合并视频6. 合并多个视频7. 用混合模式合并视频8. 添加音频9. 添加背景音…

台北房价预测

目录 1.数据理解1.1分析数据集的基本结构&#xff0c;查询并输出数据的前 10 行和 后 10 行1.2识别并输出所有变量 2.数据清洗2.1输出所有变量折线图2.2缺失值处理2.3异常值处理 3.数据分析3.1寻找相关性3.2划分数据集 4.数据整理4.1数据标准化 5.回归预测分析5.1线性回归&…

C++之深入解析如何实现一个线程池

一、基础概念 当进行并行的任务作业操作时&#xff0c;线程的建立与销毁的开销是&#xff0c;阻碍性能进步的关键&#xff0c;因此线程池&#xff0c;由此产生。使用多个线程&#xff0c;无限制循环等待队列&#xff0c;进行计算和操作&#xff0c;帮助快速降低和减少性能损耗…