基于知识蒸馏的两阶段去雨去雪去雾模型学习记录(一)

news2025/3/1 8:50:31

前面完成了基于知识蒸馏的去雨去雪去雾模型大的部署与训练,下面则进行代码的学习。
使用debug的方式进行代码的学习。
首先是网络结构展示:轻易不要打开,这个模型太复杂了。说到底倒不是多复杂,就是层数太多了

Net(
  (conv_input): ConvLayer(
    (reflection_pad): ReflectionPad2d((5, 5, 5, 5))
    (conv2d): Conv2d(3, 16, kernel_size=(11, 11), stride=(1, 1))
  )
  (dense0): Sequential(
    (0): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (1): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (2): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
  )
  (conv2x): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2))
  )
  (conv1): RDB(
    (dense_layers): Sequential(
      (0): make_dense(
        (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): make_dense(
        (conv): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (2): make_dense(
        (conv): Conv2d(48, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (3): make_dense(
        (conv): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (conv_1x1): Conv2d(80, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (fusion1): Encoder_MDCBlock1(
    (up_convs): ModuleList(
      (0): DeconvBlock(
        (deconv): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
    (down_convs): ModuleList(
      (0): ConvBlock(
        (conv): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
  )
  (dense1): Sequential(
    (0): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (1): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (2): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
  )
  (conv4x): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
  )
  (conv2): RDB(
    (dense_layers): Sequential(
      (0): make_dense(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): make_dense(
        (conv): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (2): make_dense(
        (conv): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (3): make_dense(
        (conv): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (conv_1x1): Conv2d(160, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (fusion2): Encoder_MDCBlock1(
    (up_convs): ModuleList(
      (0): DeconvBlock(
        (deconv): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): DeconvBlock(
        (deconv): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
    (down_convs): ModuleList(
      (0): ConvBlock(
        (conv): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): ConvBlock(
        (conv): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
  )
  (dense2): Sequential(
    (0): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (1): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (2): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
  )
  (conv8x): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
  )
  (conv3): RDB(
    (dense_layers): Sequential(
      (0): make_dense(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): make_dense(
        (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (2): make_dense(
        (conv): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (3): make_dense(
        (conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (conv_1x1): Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (fusion3): Encoder_MDCBlock1(
    (up_convs): ModuleList(
      (0): DeconvBlock(
        (deconv): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): DeconvBlock(
        (deconv): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (2): DeconvBlock(
        (deconv): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
    (down_convs): ModuleList(
      (0): ConvBlock(
        (conv): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): ConvBlock(
        (conv): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (2): ConvBlock(
        (conv): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
  )
  (dense3): Sequential(
    (0): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (1): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (2): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
  )
  (conv16x): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2))
  )
  (conv4): RDB(
    (dense_layers): Sequential(
      (0): make_dense(
        (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): make_dense(
        (conv): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (2): make_dense(
        (conv): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (3): make_dense(
        (conv): Conv2d(512, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (conv_1x1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (fusion4): Encoder_MDCBlock1(
    (up_convs): ModuleList(
      (0): DeconvBlock(
        (deconv): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): DeconvBlock(
        (deconv): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (2): DeconvBlock(
        (deconv): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (3): DeconvBlock(
        (deconv): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
    (down_convs): ModuleList(
      (0): ConvBlock(
        (conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): ConvBlock(
        (conv): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (2): ConvBlock(
        (conv): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (3): ConvBlock(
        (conv): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
  )
  (dehaze): Sequential(
    (res0): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res1): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res2): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res3): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res4): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res5): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res6): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res7): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res8): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res9): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res10): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res11): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res12): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res13): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res14): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res15): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res16): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (res17): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
  )
  (convd16x): UpsampleConvLayer(
    (conv2d): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2))
  )
  (dense_4): Sequential(
    (0): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (1): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (2): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
  )
  (conv_4): RDB(
    (dense_layers): Sequential(
      (0): make_dense(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): make_dense(
        (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (2): make_dense(
        (conv): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (3): make_dense(
        (conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (conv_1x1): Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (fusion_4): Decoder_MDCBlock1(
    (down_convs): ModuleList(
      (0): ConvBlock(
        (conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
    (up_convs): ModuleList(
      (0): DeconvBlock(
        (deconv): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
  )
  (convd8x): UpsampleConvLayer(
    (conv2d): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2))
  )
  (dense_3): Sequential(
    (0): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (1): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (2): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
  )
  (conv_3): RDB(
    (dense_layers): Sequential(
      (0): make_dense(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): make_dense(
        (conv): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (2): make_dense(
        (conv): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (3): make_dense(
        (conv): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (conv_1x1): Conv2d(160, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (fusion_3): Decoder_MDCBlock1(
    (down_convs): ModuleList(
      (0): ConvBlock(
        (conv): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): ConvBlock(
        (conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
    (up_convs): ModuleList(
      (0): DeconvBlock(
        (deconv): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): DeconvBlock(
        (deconv): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
  )
  (convd4x): UpsampleConvLayer(
    (conv2d): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2))
  )
  (dense_2): Sequential(
    (0): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (1): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (2): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
  )
  (conv_2): RDB(
    (dense_layers): Sequential(
      (0): make_dense(
        (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): make_dense(
        (conv): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (2): make_dense(
        (conv): Conv2d(48, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (3): make_dense(
        (conv): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (conv_1x1): Conv2d(80, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (fusion_2): Decoder_MDCBlock1(
    (down_convs): ModuleList(
      (0): ConvBlock(
        (conv): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): ConvBlock(
        (conv): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (2): ConvBlock(
        (conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
    (up_convs): ModuleList(
      (0): DeconvBlock(
        (deconv): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): DeconvBlock(
        (deconv): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (2): DeconvBlock(
        (deconv): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
  )
  (convd2x): UpsampleConvLayer(
    (conv2d): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2))
  )
  (dense_1): Sequential(
    (0): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (1): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
    (2): ResidualBlock(
      (conv1): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (conv2): ConvLayer(
        (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
        (conv2d): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
      )
      (relu): PReLU(num_parameters=1)
    )
  )
  (conv_1): RDB(
    (dense_layers): Sequential(
      (0): make_dense(
        (conv): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): make_dense(
        (conv): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (2): make_dense(
        (conv): Conv2d(24, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (3): make_dense(
        (conv): Conv2d(32, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (conv_1x1): Conv2d(40, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
  (fusion_1): Decoder_MDCBlock1(
    (down_convs): ModuleList(
      (0): ConvBlock(
        (conv): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): ConvBlock(
        (conv): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (2): ConvBlock(
        (conv): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (3): ConvBlock(
        (conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
    (up_convs): ModuleList(
      (0): DeconvBlock(
        (deconv): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (1): DeconvBlock(
        (deconv): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (2): DeconvBlock(
        (deconv): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
      (3): DeconvBlock(
        (deconv): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (act): PReLU(num_parameters=1)
      )
    )
  )
  (conv_output): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(16, 3, kernel_size=(3, 3), stride=(1, 1))
  )
)

首先进入训练模式,又称知识收集训练阶段:

def train_kc_stage(model, teacher_networks, ckt_modules, train_loader, optimizer, scheduler, epoch, criterions):
	print(Fore.CYAN + "==> Training Stage 1")
	print("==> Epoch {}/{}".format(epoch, args.max_epoch))
	print("==> Learning Rate = {:.6f}".format(optimizer.param_groups[0]['lr']))
	meters = get_meter(num_meters=5)	
	criterion_l1, criterion_scr, _ = criterions
	model.train()
	ckt_modules.train()
	for teacher_network in teacher_networks:
		teacher_network.eval()

声明所需要的损失函数,ckt_models(协作知识迁移模型)的训练模式
ckt_models 的详细结构如下:

ModuleList(
  (0): CKTModule(
    (teacher_projectors): TeacherProjectors(
      (PFPs): ModuleList(
        (0): Sequential(
          (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): Sequential(
          (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (IPFPs): ModuleList(
        (0): Sequential(
          (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): Sequential(
          (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
    )
    (student_projector): StudentProjector(
      (PFP): Sequential(
        (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): ReLU(inplace=True)
        (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
  )
  (1): CKTModule(
    (teacher_projectors): TeacherProjectors(
      (PFPs): ModuleList(
        (0): Sequential(
          (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): Sequential(
          (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (IPFPs): ModuleList(
        (0): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
    )
    (student_projector): StudentProjector(
      (PFP): Sequential(
        (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
  )
  (2): CKTModule(
    (teacher_projectors): TeacherProjectors(
      (PFPs): ModuleList(
        (0): Sequential(
          (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): Sequential(
          (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (IPFPs): ModuleList(
        (0): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
    )
    (student_projector): StudentProjector(
      (PFP): Sequential(
        (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): ReLU(inplace=True)
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
  )
  (3): CKTModule(
    (teacher_projectors): TeacherProjectors(
      (PFPs): ModuleList(
        (0): Sequential(
          (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): Sequential(
          (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
      (IPFPs): ModuleList(
        (0): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (1): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (2): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ReLU(inplace=True)
          (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
      )
    )
    (student_projector): StudentProjector(
      (PFP): Sequential(
        (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): ReLU(inplace=True)
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
  )
)

criterions的结构,其定义的是损失函数,分别是L1损失,SCR损失以及HCR损失

ModuleList(
  (0): L1Loss()
  (1): SCRLoss(
    (vgg): Vgg19(
      (slice1): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (slice2): Sequential(
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU(inplace=True)
      )
      (slice3): Sequential(
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): ReLU(inplace=True)
        (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU(inplace=True)
      )
      (slice4): Sequential(
        (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (13): ReLU(inplace=True)
        (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (15): ReLU(inplace=True)
        (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (17): ReLU(inplace=True)
        (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (20): ReLU(inplace=True)
      )
      (slice5): Sequential(
        (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (22): ReLU(inplace=True)
        (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (24): ReLU(inplace=True)
        (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (26): ReLU(inplace=True)
        (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (29): ReLU(inplace=True)
      )
    )
    (l1): L1Loss()
  )
  (2): HCRLoss(
    (vgg): Vgg19(
      (slice1): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (slice2): Sequential(
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU(inplace=True)
      )
      (slice3): Sequential(
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): ReLU(inplace=True)
        (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU(inplace=True)
      )
      (slice4): Sequential(
        (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (13): ReLU(inplace=True)
        (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (15): ReLU(inplace=True)
        (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (17): ReLU(inplace=True)
        (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (20): ReLU(inplace=True)
      )
      (slice5): Sequential(
        (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (22): ReLU(inplace=True)
        (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (24): ReLU(inplace=True)
        (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (26): ReLU(inplace=True)
        (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (29): ReLU(inplace=True)
      )
    )
    (l1): L1Loss()
  )
)

可以看到教师网络就是将先前的Net网络复制了3份,只是加载不同权重而已。即三个model。

在这里插入图片描述

在这里插入图片描述

继续训练

start = time.time()
pBar = tqdm(train_loader, desc='Training')
for target_images, input_images in pBar:
	
	# Check whether the batch contains all types of degraded data
	if target_images is None: continue

	# move to GPU
	target_images = target_images.cuda()
	input_images = [images.cuda() for images in input_images]

	# Fix all teachers and collect reconstruction results and features from cooresponding teacher
	preds_from_teachers = []
	features_from_each_teachers = []
	with torch.no_grad():
		for i in range(len(teacher_networks)):
			preds, features = teacher_networks[i](input_images[i], return_feat=True)
			preds_from_teachers.append(preds)
			features_from_each_teachers.append(features)	
			
	preds_from_teachers = torch.cat(preds_from_teachers)
	features_from_teachers = []
	for layer in range(len(features_from_each_teachers[0])):
		features_from_teachers.append([features_from_each_teachers[i][layer] for i in range(len(teacher_networks))])

	preds_from_student, features_from_student = model(torch.cat(input_images), return_feat=True)   

	
	# Project the features to common feature space and calculate the loss
	PFE_loss, PFV_loss = 0., 0.
	for i, (s_features, t_features) in enumerate(zip(features_from_student, features_from_teachers)):
		t_proj_features, t_recons_features, s_proj_features = ckt_modules[i](t_features, s_features)
		PFE_loss += criterion_l1(s_proj_features, torch.cat(t_proj_features))
		PFV_loss += 0.05 * criterion_l1(torch.cat(t_recons_features), torch.cat(t_features))

	T_loss = criterion_l1(preds_from_student, preds_from_teachers)
	SCR_loss = 0.1 * criterion_scr(preds_from_student, target_images, torch.cat(input_images))
	total_loss = T_loss + PFE_loss + PFV_loss + SCR_loss

	optimizer.zero_grad()
	total_loss.backward()
	optimizer.step()

进入评估模块:加载模型,验证集,最终输出psnr与ssim

if epoch % args.val_freq == 0:
			psnr, ssim = evaluate(model, val_loader, epoch)
			# Check whether the model is top-k model
			top_k_state = save_top_k(model, optimizer, scheduler, top_k_state, args.top_k, epoch, args.save_dir, psnr=psnr, ssim=ssim)

evaluate(model, val_loader, epoch) 函数详细代码:

在这里插入图片描述
随后进行结果输出:

pred = model(image)

即跳入Net的forward中进行特征提取

输入值:

输入x: 图像维度为640x480,此时初始维度:torch.Size([1, 3, 480, 640])
随后经过一系列的卷积降维,生成了如下特征图:这个过程就不赘述了。

在这里插入图片描述

输出值:

输出x与feature:最终的x的维度依旧为torch.Size([1, 3, 480, 640])

在这里插入图片描述

feature的维度,共有4个特征图,分别如下:

在这里插入图片描述

这里设置只输出x,所以pred的值即为x的值:

在这里插入图片描述
得到输出值后,即可进行损失的计算了:

psnr_list.append(torchPSNR(pred, target).item())
ssim_list.append(pytorch_ssim.ssim(pred, target).item())

具体实现:

@torch.no_grad()
def torchPSNR(prd_img, tar_img):
	if not isinstance(prd_img, torch.Tensor):
		prd_img = torch.from_numpy(prd_img)
		tar_img = torch.from_numpy(tar_img)

	imdff = torch.clamp(prd_img, 0, 1) - torch.clamp(tar_img, 0, 1)
	rmse = (imdff**2).mean().sqrt()
	ps = 20 * torch.log10(1/rmse)
	return ps

最终将19张图片全部评估完毕:

在这里插入图片描述
得到psnr_list值:
在这里插入图片描述
需要19张全部评估完,这里只进行了两张。

最终返回平均值:

return np.mean(psnr_list), np.mean(ssim_list)

该方法最终的值变为:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1030320.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

玩玩“小藤”开发者套件 Atlas 200I DK A2 之VSCode远程连接

玩玩“小藤”开发者套件 Atlas 200I DK A2 之VSCode远程连接 0. 背景1. VSCode 安装 Remote - SSH 插件2. 安装 OpenSSH 组件3. VSCode SSH 连接 Atlas 200I DK A24. 打开远程文件夹 0. 背景 总所周知,英伟达的GPU供不应求,还各种限制。华为推出了升腾A…

ChatGLM2-6B 部署与微调

文章目录 一、ChatGLM-6B二、ChatGLM2-6B三、本地部署ChatGLM2-6B3.1 命令行模式3.2 网页版部署3.3 本地加载模型权重3.4 模型量化3.5 CPU部署3.6 多卡部署 四、P-tuning v2微调教程4.1 P-tuning v2 原理4.2 P-tuning v2微调实现4.2.1 安装依赖,下载数据集4.2.2 开始…

好题记录 Leetcode 394.字符串解码 中等难度

方法一&#xff1a;递归 思路很简单&#xff0c;比较好理解&#xff0c;注意细节处理&#xff01;&#xff01;&#xff01; class Solution { public:string decodeString(string s) {string ans;for(int i0;s[i]!0;i){if(s[i]>a&&s[i]<z)anss[i];if(s[i]>…

CentOS 7系统安装与配置、常用100条操作命令

CentOS 7 是一个广泛使用的开源 Linux 操作系统&#xff0c;它是 Red Hat Enterprise Linux (RHEL) 的一个免费重建版本&#xff0c;以稳定性和安全性而著称。在 CentOS 7 上安装虚拟机通常使用虚拟化技术&#xff0c;如 VirtualBox 或 VMware 等。以下是 CentOS 7 的简要介绍以…

外国固定资产管理系统功能有哪些

很多公司都在寻找提高自己资产管理效益的方法。为了满足这一要求&#xff0c;国外的固定资产管理系统已经发展成多种形式。以下是国外一些常见的固定资产管理系统的特点:自动化和智能化:许多现代固定资产管理系统采用自动化和数字化技术&#xff0c;以简化流程&#xff0c;减少…

使用vue-cli搭建SPA项目及使用和路由及路由嵌套的使用

目录 一、介绍 ( 1 ) 概述 ( 2 ) 作用 二、项目搭建 SPA介绍 讲述 特点 优点 ( 1 ) 检查 ( 2 ) 安装 ( 3 ) 构建 ( 4 ) 启动 ( 5 ) 导入 三、路由及嵌套使用 ( 1 ) 路由 ( 2 ) 嵌套 给我们的收获 一、介绍 ( 1 ) 概述 vue-cli是一个基于Vue.js的脚…

Next.js项目初始化(附gitHub地址)

Next.js项目初始化 1.脚手架搭建 npx create-next-applatest 生成目录&#xff1a; 我生成的package.json: {"name": "nest-initial-demo","version": "0.1.0","private": true,"scripts": {"dev": …

基于微信小程序的医院挂号预约系统设计与实现(源码+lw+部署文档+讲解等)

前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作者&#xff0c;博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌&#x1f497; &#x1f447;&#x1f3fb;…

智能配电系统:保障电力运行安全、可控与高效

智能配电系统是一种先进的电力分配技术&#xff0c;它通过智能化、数字化和网络化等方式&#xff0c;有效地保障了电力运行的安全、可控和高效。 力安科技智能配电系统是在配电室&#xff08;含高压柜、变压器、低压柜&#xff09;、箱式变电站、配电箱及动力柜&#xff08…

jvm垃圾收集算法

简介 由于《分代收集理论》和不同垃圾收集算法&#xff0c;Java堆会被划分为不同区域&#xff0c;一般至少会把Java堆划分为新生代&#xff08;Young Generation&#xff09;和老年代&#xff08;Old Generation&#xff09;两个区域。 垃圾收集器可以只回收其中某一个或者某…

爬虫 — App 爬虫(二)

目录 一、Appium介绍二、node.js 安装三、Java 的 SDK 安装以及配置1、安装步骤2、配置环境变量 四、安卓环境的配置1、配置环境变量 五、Appium 安装1、安装2、打开 APP3、使用 六、Appium 使用1、定位数据&#xff08;方法一&#xff0c;不常用&#xff09;2、定位数据&#…

Linux系统上使用SQLite

1. 安装SQLite 在Linux上安装SQLite非常简单。可以使用包管理器&#xff08;如apt、yum&#xff09;直接从官方软件源安装SQLite。例如&#xff0c;在Ubuntu上使用以下命令安装SQLite&#xff1a; sudo apt-get install sqlite32. 打开或创建数据库 要打开或创建一个SQLite数…

C语言每日一题(8):有序序列合并

文章主题&#xff1a;有序序列合并&#x1f525;&#x1f525;&#x1f525;所属专栏&#xff1a;C语言每日一题&#x1f4d7;作者简介&#xff1a;每天不定时更新C语言的小白一枚&#xff0c;记录分享自己每天的所思所想&#x1f604;&#x1f3b6;个人主页&#xff1a;[₽]的…

启动盘制作软件 Rufus

下载链接&#xff1a;Rufus - 轻松创建 USB 启动盘 我根据自己的系统选择了X86 找一个路径即可下载

基于SSM框架的《超市订单管理系统》Web项目开发(第四天)用户管理,增删改查(日期插件的使用)

基于SSM框架的《超市订单管理系统》Web项目开发&#xff08;第四天&#xff09;用户管理&#xff0c;增删改查&#xff08;日期插件的使用&#xff09; 昨天我们实现了多表关联查询&#xff0c;还有分页显示数据的功能。那么今天我们要继续完善用户管理这一模块。 今天要完成的…

基于微信小程序的网上商城设计与实现(源码+lw+部署文档+讲解等)

前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作者&#xff0c;博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌&#x1f497; &#x1f447;&#x1f3fb;…

Linux小程序-进度条

进度条我将实现三个版本&#xff1a; 1 简单原理版本 2 实际工程实践版本 3 c语言扩展-设计颜色 首先我们需要有一些前置知识&#xff1a;关于行缓冲区和回车换行 行缓冲区&#xff1a;c/c语言会针对标准输出给我们提供默认的缓冲区&#xff0c;这次的角色是输出缓冲区 输…

java框架-Springboot-快速入门

文章目录 组件注册条件注解属性绑定自动装配原理自定义组件yaml属性配置日志日志级别日志分组文件输出文件归档与文件切割自定义配置切换日志组合 组件注册 Configuration、SpringBootConfigurationBean、ScopeController、Service、Repository、ComponentImportComponentScan…

selenium不定位元素直接使用键盘操作(如弹框操作)

今天在使用selenium进行定位时&#xff0c;发现直接定位不了chrome的弹框&#xff0c;如这种弹框&#xff1a; 使用的是下面这行代码 driver.switch_to.alert.accept() 运行报错&#xff0c;说是没有 alert windown。。。。 啊&#xff1f;难道chrome的弹框不是用alert写的&…

Go 多版本管理工具

Go 多版本管理工具 文章目录 Go 多版本管理工具一、go get 命令1.1 使用方法&#xff1a; 二、Goenv三、GVM (Go Version Manager)四、voidint/g4.1 安装4.2 冲突4.3 使用 在平时开发中&#xff0c;本地新旧项目并行开发的过程中&#xff0c;你大概率会遇到一个令人头疼的问题&…