YOLOv5图像分割--SegmentationModel类代码详解

news2025/1/12 6:14:34

目录

​编辑

SegmentationModel类

DetectionModel类

推理阶段

DetectionModel--forward()

BaseModel--forward() 

Segment类

Detect--forward 


 

SegmentationModel类

定义model将会调用models/yolo.py中的类SegmentationModel。该类是继承父类--DetectionModel类。

class SegmentationModel(DetectionModel):  # SegmentationModel这个类是继承了DetectionModel这个类
    # YOLOv5 segmentation model
    def __init__(self, cfg='yolov5s-seg.yaml', ch=3, nc=None, anchors=None):
        super().__init__(cfg, ch, nc, anchors)

DetectionModel类

因此直接去看下DetectionModel这个类代码,同时也能发现这个类又是继承BaseModel这个类。这里先看一下DetectionModel,后面再看BaseModel这个类。这个类的功能可以根据yaml文件定义网络【定义网络的函数为parse_model()】,在分割任务中,anchors为None。

class DetectionModel(BaseModel):  # 继承BaseModel这个类
    # YOLOv5 detection model
    def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None):  # model, input channels, number of classes
        super().__init__()
        if isinstance(cfg, dict):
            self.yaml = cfg  # model dict
        else:  # is *.yaml
            import yaml  # for torch hub
            self.yaml_file = Path(cfg).name
            with open(cfg, encoding='ascii', errors='ignore') as f:
                self.yaml = yaml.safe_load(f)  # model dict

        # Define model
        ch = self.yaml['ch'] = self.yaml.get('ch', ch)  # input channels
        if nc and nc != self.yaml['nc']:
            LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
            self.yaml['nc'] = nc  # override yaml value
        if anchors:
            LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
            self.yaml['anchors'] = round(anchors)  # override yaml value
        self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch])  # model, savelist
        

得到的model如下,这里需要注意的是此时的self指SegmentationModel类。

Sequential(
  (0): Conv(
    (conv): Conv2d(3, 32, kernel_size=(6, 6), stride=(2, 2), padding=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): SiLU()
  )
  (1): Conv(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): SiLU()
  )
  (2): C3(
    (cv1): Conv(
      (conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (cv2): Conv(
      (conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (cv3): Conv(
      (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (m): Sequential(
      (0): Bottleneck(
        (cv1): Conv(
          (conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
        (cv2): Conv(
          (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
      )
    )
  )
  (3): Conv(
    (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): SiLU()
  )
  (4): C3(
    (cv1): Conv(
      (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (cv2): Conv(
      (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (cv3): Conv(
      (conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (m): Sequential(
      (0): Bottleneck(
        (cv1): Conv(
          (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
        (cv2): Conv(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
      )
      (1): Bottleneck(
        (cv1): Conv(
          (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
        (cv2): Conv(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
      )
    )
  )
  (5): Conv(
    (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): SiLU()
  )
  (6): C3(
    (cv1): Conv(
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (cv2): Conv(
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (cv3): Conv(
      (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (m): Sequential(
      (0): Bottleneck(
        (cv1): Conv(
          (conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
        (cv2): Conv(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
      )
      (1): Bottleneck(
        (cv1): Conv(
          (conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
        (cv2): Conv(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
      )
      (2): Bottleneck(
        (cv1): Conv(
          (conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
        (cv2): Conv(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
      )
    )
  )
  (7): Conv(
    (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): SiLU()
  )
  (8): C3(
    (cv1): Conv(
      (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (cv2): Conv(
      (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (cv3): Conv(
      (conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (m): Sequential(
      (0): Bottleneck(
        (cv1): Conv(
          (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
        (cv2): Conv(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
      )
    )
  )
  (9): SPPF(
    (cv1): Conv(
      (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (cv2): Conv(
      (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (m): MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1, ceil_mode=False)
  )
  (10): Conv(
    (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): SiLU()
  )
  (11): Upsample(scale_factor=2.0, mode=nearest)
  (12): Concat()
  (13): C3(
    (cv1): Conv(
      (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (cv2): Conv(
      (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (cv3): Conv(
      (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (m): Sequential(
      (0): Bottleneck(
        (cv1): Conv(
          (conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
        (cv2): Conv(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
      )
    )
  )
  (14): Conv(
    (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): SiLU()
  )
  (15): Upsample(scale_factor=2.0, mode=nearest)
  (16): Concat()
  (17): C3(
    (cv1): Conv(
      (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (cv2): Conv(
      (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (cv3): Conv(
      (conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (m): Sequential(
      (0): Bottleneck(
        (cv1): Conv(
          (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
        (cv2): Conv(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
      )
    )
  )
  (18): Conv(
    (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): SiLU()
  )
  (19): Concat()
  (20): C3(
    (cv1): Conv(
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (cv2): Conv(
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (cv3): Conv(
      (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (m): Sequential(
      (0): Bottleneck(
        (cv1): Conv(
          (conv): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
        (cv2): Conv(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
      )
    )
  )
  (21): Conv(
    (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): SiLU()
  )
  (22): Concat()
  (23): C3(
    (cv1): Conv(
      (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (cv2): Conv(
      (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (cv3): Conv(
      (conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (m): Sequential(
      (0): Bottleneck(
        (cv1): Conv(
          (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
        (cv2): Conv(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): SiLU()
        )
      )
    )
  )
  (24): Segment(
    (m): ModuleList(
      (0): Conv2d(128, 351, kernel_size=(1, 1), stride=(1, 1))
      (1): Conv2d(256, 351, kernel_size=(1, 1), stride=(1, 1))
      (2): Conv2d(512, 351, kernel_size=(1, 1), stride=(1, 1))
    )
    (proto): Proto(
      (cv1): Conv(
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (upsample): Upsample(scale_factor=2.0, mode=nearest)
      (cv2): Conv(
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
      (cv3): Conv(
        (conv): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU()
      )
    )
  )
)

然后继续看下面的代码,m=self.model[-1]是获取上面定义model的最后一个模块即Segment类【这个类又继承Detect类,这个】,所以此时的m类型为Segment类。然后看forward 的lambda表达式那行, 由于通过isinstance判断m为Segment为True,所以此时调用SegmentationModel类的forward函数,并且可以回看前面SegmentationModel这个类发现没有重新父类DetectionModel的forward函数,所以这里直接调用父类的forward即可

        # Build strides, anchors
        m = self.model[-1]  # Detect()
        if isinstance(m, (Detect, Segment)):
            s = 256  # 2x min stride
            m.inplace = self.inplace
            forward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)

下面这两行代码分别为anchors的映射与获得stride,前面的映射是指将anchors映射到对应feature map上。【看到这里可能有些懵,不是前面已经说anchors为None了么,怎么现在又有anchors了,前面的None指在SegmentationModel这个类,而现在的anchors是Segment类中,也就是上面代码中m这个变量,这个anchors是通过YAML文件获取的】 。

m.anchors /= m.stride.view(-1, 1, 1)  # anchors的缩放
self.stride = m.stride

推理阶段

DetectionModel--forward()

从面前我们已经知道了虽然我们可以通过SegmentationModel类的实例化来定义model,但在推理阶段是调用的DetectionModel这个类下的forward函数。

    def forward(self, x, augment=False, profile=False, visualize=False):
        if augment:
            return self._forward_augment(x)  # augmented inference, None
        return self._forward_once(x, profile, visualize)  # single-scale inference, train

BaseModel--forward() 

可以看到DetectionModel调用的为_forward_once(x,profile,visualize)这个函数,而这个函数是父类BaseModel下的函数。

class BaseModel(nn.Module):
    # YOLOv5 base model
    def forward(self, x, profile=False, visualize=False):
        return self._forward_once(x, profile, visualize)  # single-scale inference, train

    def _forward_once(self, x, profile=False, visualize=False):
        y, dt = [], []  # outputs
        for m in self.model:
            if m.f != -1:  # if not from previous layer
                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers 当为segment时xshape:[128,80,80]、[256,40,40],[512,20,20]
            if profile:
                self._profile_one_layer(m, x, dt)
            x = m(x)  # run 将x放入每个卷积层提取特征,得到的x是提取后的
            y.append(x if m.i in self.save else None)  # save output
            if visualize:
                feature_visualization(x, m.type, m.i, save_dir=visualize)
        return x

此时的x为输入的图像,shape为【1,3,640,640】。self为SegmentationModel,因此后面的self,model调用的前面定义好的分割网络model。 

for m in self.model是遍历网络的每一层,当遍历到head时【也就是遍历到segment类时】,得到的shape大小为[128,80,80],[256,40,40],[512,20,20],也就是会得到三个feature map,这三个层是通过m.f在y[j]中获得的。

下面这行代码是会将[4, 6, 10, 14, 17, 20, 23]这几层输出的output进行保存【这几层可以对照yaml文件看】。 

y.append(x if m.i in self.save else None)  # save output

下面是Segment【head】结构。

经过卷积以后得到的x为tuple类型,包含的内容为:

①【batch,25200,117】,

②【batch,32,160,160】,

③ list【[batch,3,80,80,117],【[batch,3,40,40,117]】,[batch,3,20,20,117]】

注:25200=3*80*80+40*40*3+20*20*3【可理解为将三个featrue map铺平后叠加在一起】;

这里的160是通过将80*80的feature上采样得到的 

这里的117指:5+80+32【这里的32是mask的数量】

最后得到的输出就是我们要的output。

Segment(
  (m): ModuleList(
    (0): Conv2d(128, 351, kernel_size=(1, 1), stride=(1, 1))
    (1): Conv2d(256, 351, kernel_size=(1, 1), stride=(1, 1))
    (2): Conv2d(512, 351, kernel_size=(1, 1), stride=(1, 1))
  )
  (proto): Proto(
    (cv1): Conv(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (upsample): Upsample(scale_factor=2.0, mode=nearest)
    (cv2): Conv(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (cv3): Conv(
      (conv): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
  )

Segment类

 前面我们说到了在BaseModel中对派生类SegmentationModel遍历时,在head部分会得到Segment获得最终的输出,那么我们来看一下这个类。

参数:

nc:分类数量。coco为80个类

anchors:通过yaml文件获得的anchors。

nm:mask数量

npr:protos数量

ch:3通道

Segment继承Detect这个类

在forward部分,x是前面获得的三个feature,分别从网络的17,20,23层获得。

proto的功能是针对x[0]进行卷积,将原来80*80大小的feature通过上采样变为160*160。然后调用Detect中的forward进行前向推理获得输出,然后返回[x[0],p,x[1]]也就是shape为【1,128,80,80】,【1,128,40,40】,【1,256,20,20】的tuple。

class Segment(Detect):
    # YOLOv5 Segment head for segmentation models
    def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), inplace=True):
        super().__init__(nc, anchors, ch, inplace)
        self.nm = nm  # number of masks
        self.npr = npr  # number of protos
        self.no = 5 + nc + self.nm  # number of outputs per anchor 5+80+32
        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # * output conv
        self.proto = Proto(ch[0], self.npr, self.nm)  # protos
        self.detect = Detect.forward

    def forward(self, x):
        """
        Args x is list,from 17,20,23
            x[0].shape=[batch_size,128,80,80],
            x[1].shape=[batch,256,40,40],
            x[2].shpe=[batch,512,20,20]

        proto:功能是将P3输出的80*80变160*160
        conv1(x[0])->upsample[x[0]=160*160]->conv2->conv3->output.shape=[batch,32,160,160],
        """
        p = self.proto(x[0])
        x = self.detect(self, x)  # x[0]:[batch,3,80,80,117],x[1]:[1,3,40,40,117],x[2]:[1,3,20,20,117]
        return (x, p) if self.training else (x[0], p) if self.export else (x[0], p, x[1])

Detect--forward 

在上面Segment中调用Detect的forward对x进行推理,下面就看看具体发生了什么变化。通过遍历三个head,在self指的Segment类,而self.m是Segment的三个卷积,如下:

(m): ModuleList(
    (0): Conv2d(128, 351, kernel_size=(1, 1), stride=(1, 1))
    (1): Conv2d(256, 351, kernel_size=(1, 1), stride=(1, 1))
    (2): Conv2d(512, 351, kernel_size=(1, 1), stride=(1, 1))
  )

因此用这三个卷积对x进行卷积,x为Segment类中的x,为tuple类型。

class Detect(nn.Module):
    # YOLOv5 Detect head for detection models
    stride = None  # strides computed during build
    dynamic = False  # force grid reconstruction
    export = False  # export mode

    # Detect layer init
    def __init__(self, nc=80, anchors=(), ch=(), inplace=True):  # detection layer
        super().__init__()
        self.nc = nc  # number of classes
        self.no = nc + 5  # number of outputs per anchor
        self.nl = len(anchors)  # number of detection layers
        self.na = len(anchors[0]) // 2  # number of anchors
        self.grid = [torch.empty(0) for _ in range(self.nl)]  # init grid
        self.anchor_grid = [torch.empty(0) for _ in range(self.nl)]  # init anchor grid
        self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2))  # shape(nl,na,2)
        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output conv
        self.inplace = inplace  # use inplace ops (e.g. slice assignment)
    # x是列表类型为P3 P4 P5的输出大小
    def forward(self, x):
        z = []  # inference output
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
            bs, _, ny, nx = x[i].shape
            # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

            if not self.training:  # inference
                if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
                    self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

                

由于self前面说了是Segment类型,因此可以将x[1,3,80,80,117=5+80+32]进行划分,得到boxes+mask的形式,形式为xy[中心点],wh[宽高],conf,mask ,并在对应head划分网格,最终将xy,wh,conf与mask进行拼接【在第四维度上,也就是最后一个维度】拼接为shape[batch,feature_w,feature_h,117]。

                if isinstance(self, Segment):  # (boxes + masks)
                    xy, wh, conf, mask = x[i].split((2, 2, self.nc + 1, self.no - self.nc - 5), 4)
                    xy = (xy.sigmoid() * 2 + self.grid[i]) * self.stride[i]  # xy
                    wh = (wh.sigmoid() * 2) ** 2 * self.anchor_grid[i]  # wh
                    y = torch.cat((xy, wh, conf.sigmoid(), mask), 4)

经过上面的操作,我们可以再返回Segment了,经过detect的forward我们得到的输出为:【(1,25200,117),list[(1,3,80,80,117),[1,3,40,40,117],[1,3,20,20,117]]】

再经过下面的操作,返回的形式为【x[0]=[1,25200,117],p=[1,32,160,160],x[1]=list[(1,3,80,80,117),[1,3,40,40,117],[1,3,20,20,117]]】

return (x, p) if self.training else (x[0], p) if self.export else (x[0], p, x[1])

 

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/61427.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数学基础从高一开始1、集合的概念

数学基础从高一开始1、集合的概念 目录 数学基础从高一开始1、集合的概念 一、课程引入 解析:方程​编辑2是否有解? 解析:所有到定点的距离等于定长的点组成何种图形? 结论: 二、课程讲解 问题1: 集…

1548_AURIX_TC275_锁步比较逻辑LCL

全部学习汇总: GreyZhang/g_TC275: happy hacking for TC275! (github.com) 这可能是这段时间看过的最简单的一个章节了,所有的章节内容都可以放进这一份笔记也不显得多。 1. 首先明确LCL的意思,其实是锁步核比较器逻辑的意思,还不…

知识点1--认识Docker

IT界2014年之前,对于服务器虚拟化的使用,有过一个流派,基于Windows server系统VMware组成服务器集群,但是后期由于这样的使用方式维护成本相当高,比如服务器的序列、服务器台账以及服务器与服务器之间的切换等等&#…

据说Linuxer都难忘的25个画面

虽然对 Linux 正式生日是哪天还有些争论,甚至 Linus Torvalds 认为在 1991 那一年有四个日子都可以算作 Linux 的生日。但是不管怎么说,Linux 已经 25 岁了,这里我们为您展示一下这 25 年来发生过的 25 件重大里程碑事件。 1991:L…

SpringMVC学习笔记二(获取Cookies、Session和Header、IDEA热部署)

目录 一、一些前置知识 二、SpringMVC获取cookies和session 🍑获取cookies和header 🍑获取session 三、SpringMVC热部署 📝添加框架支持 📝settings配置开启自动热部署 📝开启运行中热部署: &…

【Autopsy数字取证篇】Autopsy案例创建与镜像分析详细教程

【Autopsy数字取证篇】Autopsy案例创建与镜像分析详细教程 Autopsy是一款非常优秀且功能强大的免费开源数字取证分析工具。—【蘇小沐】 文章目录【Autopsy数字取证篇】Autopsy案例创建与镜像分析详细教程1.实验环境2.Autopsy下载安装(一)创建案例1.软件…

【简单易操作】图漾TM460-E2深度网络相机在ROS-melodic环境下的配置过程

目录一、配置的环境二、下载内容及链接三、ubuntu环境配置下载 Camport3 SDK安装依赖编译运行四、安装OpenNI2套件下载 Camport3 OpenNI2 SDK安装 Camport3 OpenNI2 SDK五、ROS平台安装下载 Camport3 ROS SDK编译配置环境变量运行一、配置的环境 相机型号:TM460-E2…

OpenRASP agent源码分析

目录 前言 准备 源码分析 1. manifest 2. agent分析 3. agent卸载逻辑 总结 前言 笔者在很早前写了(231条消息) OpenRASP Java应用自我保护使用_fenglllle的博客-CSDN博客 实际上很多商业版的rasp工具都是基于OpenRASP的灵感来的,主要就是对核心的Java类通过…

堆(二叉堆)-优先队列-数据结构和算法(Java)

文章目录1 概述1.1 定义1.2 二叉堆表示法2 API3 堆相关算法3.1 上浮(由下至上的堆有序化)3.2 下沉(由上至下的堆有序化)3.3 插入元素3.4 删除最大元素4 实现5 性能和分析5.1 调整数组的大小5.2 元素的不可变性6 简单测试6 后记1 概…

2006-2020年全国31省人口老龄化水平

2006-2020年全国31省人口老龄化 1、时间为2006-2020年 2、来源:人口与就业年鉴 3、数据缺失情况说明: 其中2010年存在缺失,采用线性插值法进行填补,内含原始数据、线性插值 4、计算说明:以城镇地区老年抚养比衡量…

uImage的制作过程详解

1、uImage镜像介绍 参考博客:《vmlinuz/vmlinux、Image、zImage与uImage的区别》; 2、uImage镜像的制作 2.1、mkimage工具介绍 参考博客:《uImage的制作工具mkimage详解(源码编译、使用方法、添加的头解析、uImage的制作)》; 2.2…

软路由搭建:工控机(3865U)安装esxi并在esxi上创建iStoreOS做主路由(网卡直通)

一、硬件介绍 1、工控机(3865U) CPU:3865U 内存:8G 硬盘:120G 网卡:六口网卡 2、无线路由器(荣耀路由器pro2) 3、主机 下载资料、制作启动盘、系统设置 4、U盘 至少8G以上 …

ConcurrentHashMap 1.7与1.8的区别

ConcurrentHashMap 与HashMap和Hashtable 最大的不同在于:put和 get 两次Hash到达指定的HashEntry,第一次hash到达Segment,第二次到达Segment里面的Entry,然后在遍历entry链表 从1.7到1.8版本,由于HashEntry从链表 变成了红黑树所以 concurr…

Python Gui之tkinter(下)

6.Radiobutton单按按钮 Radiobutton控件用于选择同一组单选按钮中的一个。Radiobutton可以显示文本,也可以显示图像。 7.Checkbutton复选按钮 Checkbutton控件用于选择多个按钮的情况。Checkbutton可以显示文本,也可以显示图像。 经典的Gui类的写法&a…

关于liunx 宝塔运行php项目

文章目录前言一、申请liunx服务器安装宝塔环境二、安装php看你自己安装需要的版本三.php文件创建四.数据库创建五.访问项目就可以了前言 自己研究学习,大佬勿喷 一、申请liunx服务器安装宝塔环境 我是线上安装的都一样看个人习惯爱好吧 等待安装完成提示地址和账…

Java基础—重新抛出异常

重新抛出异常 在catch块内处理完后,可以重新抛出异常,异常可以是原来的,也可以是新建的,如下所示: try{ //可能触发异常的代码 }catch(NumberFormatException e){ System.out.println("not valid numbe…

电子印章结构以及规范讲解

前言 为了确保电子印章的完整性、不可伪造性,以及合法用户才能使用,需要定义一个安全的电子印章数据格式,通过数字签名,将印章图像数据与签章者等印章属性进行安全绑定,形成安全电子印章 电子印章:一种由…

MVVM与Vue响应式的实现

Vue的响应式实现原理 MVVM M:模型 》data中的数据 V:视图 》模板 VM:视图模型 》Vue实例对象 ViewModel是一个中间的桥梁将视图View与模型Model连接起来,ViewModel内部通过数据绑定,实现数据变化,视图发…

链接装载(一)虚拟地址与物理地址

文章目录一、基本概念二、一个基本问题三、程序的执行四、从堆中分配的数据的逻辑地址一、基本概念 当我们写出一个程序,即便是最基本的 Hello World,都需要经过 预处理、编译、汇编、链接才能生成最终的可执行文件。 预处理: 预处理过程主…

spring ioc的循环依赖问题

spring ioc的循环依赖问题什么是循环依赖spring中循环依赖的场景通过构造函数注入时的循环依赖通过setter或Autowired注入时的循环依赖循环依赖的处理机制原型bean循环依赖单例bean通过构造函数注入循环依赖单例bean通过setter或者Autowired注入的循环依赖三级缓存对象的创建分…