【霹雳吧啦】手把手带你入门语义分割の番外13:U2-Net 源码讲解(PyTorch)—— 损失的计算

news2024/11/26 0:36:43

目录

前言

Preparation

一、U2-Net 网络结构图

二、U2-Net 网络源代码

1、损失计算

2、model.py

3、train_and_eval.py

附:train_and_eval.py 源代码


前言

文章性质:学习笔记 📖

视频教程:U2-Net 源码讲解(PyTorch)- 3 损失计算

主要内容:根据 视频教程 中提供的 U2-Net 源代码(PyTorch),对 train_and_val.py 文件中的 criterion 函数进行具体讲解。

Preparation

源代码:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_segmentation/u2net

├── src: 搭建网络相关代码
├── train_utils: 训练以及验证相关代码
├── my_dataset.py: 自定义数据集读取相关代码
├── predict.py: 简易的预测代码
├── train.py: 单GPU或CPU训练代码
├── train_multi_GPU.py: 多GPU并行训练代码
├── validation.py: 单独验证模型相关代码
├── transforms.py: 数据预处理相关代码
└── requirements.txt: 项目依赖

一、U2-Net 网络结构图

原论文提供的 U2-Net 网络结构图如下所示: 

​​

【说明】在 Encoder 阶段,每通过一个 block 后都经 Maxpool 下采样 2 倍,在 Decoder 阶段,每通过一个 block 后都经 Bilinear 上采样 2 倍。U2-Net 网络的核心 block 是 ReSidual U-block,分为具备上下采样的 block 和不具备上下采样的 block:

  • 具备了上下采样的 block:Encoder1~Encoder4、Decoder1~Decoder4
  • 不具备上下采样的 block:Encoder5、Encoder6、Decoder5

二、U2-Net 网络源代码

1、损失计算

原论文给出了 U2-Net 的损失计算公式:

L=\sum_{m=1}^{M}{w_{side}^{(m)}}\, {l_{side}^{(m)}}+w_{fuse}\,l_{fuse}

式中:l 代表 二值交叉熵损失 ,w 代表每个损失的权重,M=6 表示有 Decoder1~Decoder5 和 Encoder6 等 6 个输出。

这个损失函数可以看成两部分, + 前半部分 是来自于不同尺度上的一个输出,令其通过对应的 3x3 卷积层和双线性插值,将其还原回原图尺度,再将得到的 Sup1~Sup6 特征图与手工标注的 Ground Truth 去计算损失,进行加权求和; + 后半部分 是融合后得到的最终的预测概率图与 GT 之间的损失。在源码中,权重 w 全部等于 1 。

2、model.py

【说明】在训练模式下,这里的 x 代表网络最终融合的一个输出,而 side_outputs 则是列表形式,收集了图中所示的 Sup1~Sup6 特征图,注意在训练模式下没有经过 sigmoid 函数,这样做是为了在使用混合精度训练时更加稳定。

3、train_and_eval.py

【说明】通过 for 循环去遍历 inputs 列表中的每一项,inputs 列表中存储的就是最终的一个融合预测特征图以及 Sup1~Sup6 特征图,将其与对应的 Ground Truth ,也就是 target ,进行损失计算,采用 F.binary_cross_entropy_with_logits 计算二值交叉熵损失。

附:train_and_eval.py 源代码

import math
import torch
from torch.nn import functional as F
import train_utils.distributed_utils as utils


def criterion(inputs, target):
    losses = [F.binary_cross_entropy_with_logits(inputs[i], target) for i in range(len(inputs))]
    total_loss = sum(losses)

    return total_loss


def evaluate(model, data_loader, device):
    model.eval()
    mae_metric = utils.MeanAbsoluteError()
    f1_metric = utils.F1Score()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'
    with torch.no_grad():
        for images, targets in metric_logger.log_every(data_loader, 100, header):
            images, targets = images.to(device), targets.to(device)
            output = model(images)

            # post norm
            # ma = torch.max(output)
            # mi = torch.min(output)
            # output = (output - mi) / (ma - mi)

            mae_metric.update(output, targets)
            f1_metric.update(output, targets)

        mae_metric.gather_from_all_processes()
        f1_metric.reduce_from_all_processes()

    return mae_metric, f1_metric


def train_one_epoch(model, optimizer, data_loader, device, epoch, lr_scheduler, print_freq=10, scaler=None):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)

    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        image, target = image.to(device), target.to(device)
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            output = model(image)
            loss = criterion(output, target)

        optimizer.zero_grad()
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        lr_scheduler.step()

        lr = optimizer.param_groups[0]["lr"]
        metric_logger.update(loss=loss.item(), lr=lr)

    return metric_logger.meters["loss"].global_avg, lr


def create_lr_scheduler(optimizer,
                        num_step: int,
                        epochs: int,
                        warmup=True,
                        warmup_epochs=1,
                        warmup_factor=1e-3,
                        end_factor=1e-6):
    assert num_step > 0 and epochs > 0
    if warmup is False:
        warmup_epochs = 0

    def f(x):
        """
        根据step数返回一个学习率倍率因子,
        注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法
        """
        if warmup is True and x <= (warmup_epochs * num_step):
            alpha = float(x) / (warmup_epochs * num_step)
            # warmup过程中lr倍率因子从warmup_factor -> 1
            return warmup_factor * (1 - alpha) + alpha
        else:
            current_step = (x - warmup_epochs * num_step)
            cosine_steps = (epochs - warmup_epochs) * num_step
            # warmup后lr倍率因子从1 -> end_factor
            return ((1 + math.cos(current_step * math.pi / cosine_steps)) / 2) * (1 - end_factor) + end_factor

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)


def get_params_groups(model: torch.nn.Module, weight_decay: float = 1e-4):
    params_group = [{"params": [], "weight_decay": 0.},  # no decay
                    {"params": [], "weight_decay": weight_decay}]  # with decay

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights

        if len(param.shape) == 1 or name.endswith(".bias"):
            # bn:(weight,bias)  conv2d:(bias)  linear:(bias)
            params_group[0]["params"].append(param)  # no decay
        else:
            params_group[1]["params"].append(param)  # with decay

    return params_group

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1366321.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HTML JavaScript 康威生命游戏

<!DOCTYPE html> <html> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>康威生命游戏</title><style>body {font-family: Arial, sa…

CompletableFuture超详解与实践

0.背景 一个接口可能需要调用 N 个其他服务的接口&#xff0c;这在项目开发中还是挺常见的。举个例子&#xff1a;用户请求获取订单信息&#xff0c;可能需要调用用户信息、商品详情、物流信息、商品推荐等接口&#xff0c;最后再汇总数据统一返回。 如果是串行&#xff08;按…

SpringBoot项目的三种创建方式

手动创建方式&#xff1a; ①&#xff1a;新建maven项目 ②&#xff1a;引入依赖 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>2.3.10.RELEASE</version>&l…

近似的同态比较:简单多项式的迭代计算

参考文献&#xff1a; [Gold64] Goldschmidt R E. Applications of division by convergence[D]. Massachusetts Institute of Technology, 1964.[CKKLL19] Cheon J H, Kim D, Kim D, et al. Numerical method for comparison on homomorphically encrypted numbers[C]//Inter…

普通人开视频号小店可以吗?

我是电商珠珠 我做电商也有五六年时间了&#xff0c;天猫、快手、抖店我都做过。 其中做抖店的时间最长&#xff0c;三年的时间&#xff0c;让我从个人化做到了团队化。 在22年的时候&#xff0c;发现了视频号小店这个项目&#xff0c;并开始投入进去。 一开始&#xff0c;…

Flask:模板渲染

本文章只作为个人笔记. 文章目录 前言 一、模板渲染 二、效果 前言 模板渲染 一、模板渲染 from flask import Flask, render_templateapp Flask(__name__)app.route(/) def hello_world():return render_template("index.html")app.route("/blog/<blog_…

java中常见的一些小知识(1)

1.数组转List 1.1. Arrays.asList public class Tesr {public static void main(String[] args) {String[] ary new String[]{ "1", "a"};List<String> list Arrays.asList((ary));list.add("ddsdsa");System.out.println(list);}}但是…

【EI会议征稿通知】第六届信息科学、电气与自动化工程国际学术会议(ISEAE 2024)

第六届信息科学、电气与自动化工程国际学术会议&#xff08;ISEAE 2024&#xff09; 2024 6th International Conference on Information Science, Electrical and Automation Engineering 第六届信息科学、电气与自动化工程国际学术会议&#xff08;ISEAE 2024&#xff09;定…

代码训练营Day.27 | 39. 组合总和、40. 组合总和II、131. 分割回文串

39. 组合总和 1. LeetCode链接 . - 力扣&#xff08;LeetCode&#xff09; 2. 题目描述 3. 解法 与其他组合总和题目不同的是&#xff0c;这一次数组中的数字可以重复使用。 回溯&#xff1a; 1. 参数和返回值。参数&#xff1a;数组、遍历起点、目标值。 2. 终止条件。…

Prometheus Blackbox_exporter笔记

一、安装Promtheus 在 Prometheus 官网 Download | Prometheus 获取适用于 Linux 的 Prometheus 安 装包&#xff0c;这里我选择最新的 2.46.0 版本&#xff0c;我是 Linux 系统&#xff0c;选择下载 prometheus-2.46.0.linux-amd64.tar.gz 下载安装包&#xff1a; wget htt…

Elasticsearch:Serarch tutorial - 使用 Python 进行搜索 (三)

这个是继上一篇文章 “Elasticsearch&#xff1a;Serarch tutorial - 使用 Python 进行搜索 &#xff08;二&#xff09;” 的续篇。在今天的文章中&#xff0c;本节将向你介绍一种不同的搜索方式&#xff0c;利用机器学习 (ML) 技术来解释含义和上下文。 向量搜索 嵌入 (embed…

【Axure高保真原型】日期天数加减计算器

今天和大家分享日期天数加减计算器的原型模板&#xff0c;我们通过这个模板选择指定日期&#xff0c;然后填写需要增加或者减少的天数&#xff0c;点击确认按钮后&#xff0c;就可以计算出对应的结束日期&#xff0c;本案例提供中继器版的日期选择器&#xff0c;以及JS版的日期…

C++常见的代码操作

1.输出C版本&#xff1a;cout << __cplusplus << endl; #include <iostream>int main() { cout << __cplusplus << endl;system("pause");return 0; } 老版的话会输出199711&#xff0c;支持c11的话会输出201103 注&#xff1a;vis…

java中实现对文件高效的复制

不多说我们直接上代码&#xff1a; 这个是使用NIO包下的FileChannel和ByteBuffer进行文件的操作的&#xff0c;会比较高效。

《人生没有太晚的开始》读书笔记

目录 一、作者简介 二、如何开始作画的&#xff1f; 三、经典语句摘录 一、作者简介 摩西奶奶&#xff08;安娜玛丽罗伯逊摩西&#xff09;1860- 1961年 78岁开始学习绘画&#xff0c;93岁登上《时代》杂志封面。 摩西奶奶的一生&#xff0c;是富有传奇色彩的一生&#xf…

企业内部知识库搭建真的很重要,优秀企业必备

在瞬息万变的商界&#xff0c;知识、信息和经验的获取和流通对于企业的生存和发展至关重要。每一个员工的专业知识、经验和教训&#xff0c;都不仅仅是他们自己的财富&#xff0c;更是企业的宝贵资产。然而&#xff0c;这些散布在公司各部门&#xff0c;甚至个别员工头脑中的知…

基于ssm的高校班级同学录网站设计与实现+jsp论文

摘 要 如今社会上各行各业&#xff0c;都喜欢用自己行业的专属软件工作&#xff0c;互联网发展到这个时候&#xff0c;人们已经发现离不开了互联网。新技术的产生&#xff0c;往往能解决一些老技术的弊端问题。因为传统高校班级同学录信息管理难度大&#xff0c;容错率低&…

数环通12月产品更新:新增数据表相关功能、优化编辑器,15+应用进行更新

为了满足用户不断增长的需求&#xff0c;我们持续努力提升产品的功能和性能&#xff0c;以更好地支持用户的工作。 数环通12月的最新产品更新已经正式发布&#xff0c;带来了一系列强大的功能&#xff0c;以提升您的工作效率和系统的可靠性。 更新快速预览 新增&优化功能&a…

【Win10安装Qt6.3】安装教程_保姆级

前言 Windows系统安装Qt4及Qt5.12之前版本和安装Qt.12之后及Qt6方法是不同的 &#xff1b;因为之前的版本提供的有安装包&#xff0c;直接一路点击Next就Ok了。但Qt5.12版本之后&#xff0c;Qt公司就不再提供安装包了&#xff0c;不论是社区版&#xff0c;专业版等&#xff0c…

你的手机可以检测听力啦

我的第一部手机是医院配发给我应对急诊的诺基亚手机&#xff0c;翻盖儿的&#xff0c;只能用来打电话。但现在的手机对于一个医生来讲具备了很多超现实的功能&#xff0c;比如听觉健康管理&#xff01;在你正常的情况下&#xff0c;你未必体会到听觉障碍给你带来的困惑。但是一…