1、mmdet中损失函数模块简介
1.1. Loss的注册器
先来看段代码:mmseg/models/builder.py
# mmseg/registry/registry.py
# mangage all kinds of modules inheriting `nn.Module`
# MODELS = Registry('model', parent=MMENGINE_MODELS, locations=['mmseg.models'])
from mmseg.registry import MODELS
BACKBONES = MODELS
NECKS = MODELS
HEADS = MODELS
LOSSES = MODELS # 损失
SEGMENTORS = MODELS
这里MODELS注册器同时赋予给了其他模块。
再看看mmseg\models_init_.py
from .assigners import * # noqa: F401,F403
from .backbones import * # noqa: F401,F403
from .data_preprocessor import SegDataPreProcessor
from .decode_heads import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .segmentors import * # noqa: F401,F403
from .text_encoder import * # noqa: F401,F403
from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
build_head, build_loss, build_segmentor)
__all__ = [
'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
'build_head', 'build_loss', 'build_segmentor', 'SegDataPreProcessor'
]
# build_mtl_SHUAI
1.2. 注册FocalLoss()
models\losses\focal_loss.py
上述初始化参数比较简单,就两个参数:init():部分主要关注gamma和alpha两个参数,forward()部分主要关注pred和target两个参数。
举个实际例子算一下:
import torch
from mmseg.models import build_loss
# 配置dict
loss_bbox = dict(type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.5,
reduction='mean',
class_weight=None,
loss_weight=1.0,
loss_name='loss_focal')
# 从注册器中构建
focal_loss = build_loss(loss_bbox)
# 使用focal loss
pred = torch.Tensor([[0, 2, 3, 0], [0,2,3,0]]) # [2,4]
target = torch.Tensor([[1, 1, 1, 0], [1,1,1,1]]) # [2,4]
loss = focal_loss(pred, target)
print("loss:",loss)
1.3. 总结
基本上mmseg所有损失的计算流程就上述过程,在使用Focal Loss时,不必关心那么多超参,直接build loss然后传入pred和target即可,其余参数基本默认即可。