基于深度学习的气象图像分类【mobilenet+VGG16+swin_transfomer+PyQt5界面】

news2024/12/23 18:42:55

在这里插入图片描述

深度学习天气图像分类

文章目录

  • 1 绪论
    • 1.1 研究背景
    • 1.2 国内外研究现状
      • 1.2.1 国内外研究现状
      • 1.2.2 国内外研究现状
  • 2 相关理论基础
    • 2.1 Tensorflow框架
    • 2.2 卷积神经网络
      • 2.2.1 神经元与权值共享
      • 2.2.2 结构组成
      • 2.2.3反向传播算法
    • 2.3 MobileNetV1网络
    • 2.4 VGG16网络
    • 2.5 Transformer网络
  • 3 系统设计与实现
    • 3.1 数据准备
      • 3.1.1 数据集介绍
      • 3.1.2 数据预处理
    • 3.2 算法实现
      • 3.2.1 MobileNet实现
      • 3.2.2 VGG16实现
      • 3.2.3 SwinTransformer实现
      • 3.2.4 模型训练
    • 3.3 界面设计
    • 3.4 功能设计
      • 3.4.1 选择图像
      • 3.4.2 图像显示
      • 3.4.3 AI分类
    • 3.5 效果展示
      • 3.5.1 mobilenet预测
      • 3.5.2 VGG16预测
      • 3.5.3 swin-transformer预测
  • 4 参考文献


1 绪论

1.1 研究背景

  在全球气候变化的背景下,极端天气事件频发,气象预测和灾害监测的重要性日益凸显。气象图像作为关键信息源,对于准确预测和及时监测至关重要。然而,传统气象图像分类方法受限于人工特征提取和专家经验的局限性,难以应对大数据挑战。
深度学习技术的兴起为气象图像分类提供了新路径。它能自动学习图像中的复杂特征,实现精准分类,极大提升了气象预测的准确性和效率。本研究基于深度学习,深入探索气象图像分类的有效方法,对于提升灾害监测效能、减少灾害损失具有重要意义[1]。
  通过构建和优化深度学习模型,成功实现了对多种气象现象的准确分类。实验结果表明,深度学习模型在特征提取和分类性能上显著优于传统方法,展现出强大的应用潜力。
  展望未来,实验将继续完善模型,探索多源数据融合和时空分析等先进技术,进一步提升气象图像分类的准确性和泛化能力。相信随着研究的深入,基于深度学习的气象图像分类将在气象预测和灾害监测中发挥越来越重要的作用,为应对全球气候变化贡献力量。

1.2 国内外研究现状

1.2.1 国内外研究现状

  在现有的研究中,基于深度学习的气象图像分类已成为研究热点。例如,Zhang等人在《Deep Learning for Weather Forecasting: A Review》(2020)中对深度学习在天气预报中的应用进行了全面的解析,其中就包括了利用深度学习技术对气象图像进行分类的部分。关于深度学习在气象图像分类中的具体应用,李华的博士学位论文《基于深度学习的气象图像分类技术研究》(2019)以卷积神经网络(CNN)为研究对象,从网络架构优化的视角切入,论述了如何提高气象图像分类的准确性,既有对深度学习技术的宏观总结,又有针对特定模型的个案分析。
  另外,国内外的研究者们也在不断尝试将深度学习与其他技术结合,以提升气象图像分类的性能。例如,Wang等人在《Meteorological Image Classification Based on Deep Learning and Support Vector Machine》(2021)中提出了一种结合深度学习和支持向量机(SVM)的气象图像分类方法,通过深度学习提取图像特征,再利用SVM进行分类,取得了良好的效果。

1.2.2 国内外研究现状

  从研究方法上看,上述著作和诸多论文的研究方法、技术路线都给予重要的启发。现有的研究多以深度学习模型的优化和应用为主,尤其是卷积神经网络(CNN)在气象图像分类中的应用得到了广泛的研究。然而,目前的研究方法还存在一定的局限性,主要依赖于单一的深度学习模型,对于模型的融合和多种技术的结合应用探索相对较少。
  主观上,深度学习模型的理解和应用需要具备一定的专业知识和技能,而这也在一定程度上限制了其广泛应用。因此,探索深度学习与其他技术的结合,如支持向量机(SVM)、随机森林等,无疑可以增强文献分析研究结论的可靠性和准确性,促进气象图像分类研究的丰富完善。
  总结前人研究的基础上,进一步探索深度学习与其他先进技术的结合,以提高气象图像分类的准确性和效率,为相关领域的研究和应用提供新的思路和方法。同时,也将尝试对深度学习模型进行更深入的理解和优化,以适应更复杂多变的气象图像分类任务。

2 相关理论基础

2.1 Tensorflow框架

  TensorFlow是Google Brain团队发布流行深度学习框架,以其丰富函数库和工具受到研究者和开发人员青睐。其核心在于计算图概念,它让模型构建和优化变得灵活。TensorFlow不仅适用于深度学习,也广泛应用于强化学习、自然语言处理和计算机视觉等领域。其预定义模型和算法库,CNN、RNN和GAN,满足了各种应用需求。作为强大而灵活框架,TensorFlow在学术界和工业界都有广泛应用。本设计利用TensorFlow构建和训练气象图像分类模型,推动气象预测和灾害监测的应用。

2.2 卷积神经网络

  卷积神经网络(CNN)是一种受生物自然视觉认知机制启发而来的深度学习架构,最初是为解决图像识别等问题设计的。CNN的最大特点在于卷积的权值共享结构,可以大幅减少神经网络的参数量,防止过拟合的同时又降低了神经网络模型的复杂度[2]。下面详细介绍卷积神经网络中的一些技术点。

2.2.1 神经元与权值共享

  卷积神经网络中的神经元设计模拟了生物神经元的连接和工作方式。每个神经元只与前一层的一个小区域内(即局部感受野)的神经元相连,这种局部连接方式减少了参数数量,提高了计算效率。同时权值共享结构使得同一个卷积核在图像的不同位置进行卷积操作时共享相同的权重,进一步减少了参数量。神经元框架如下图所示:
在这里插入图片描述

2.2.2 结构组成

(1)输入层:是整个神经网络的输入,它一般代表了一张图片的像素矩阵。在图像处理中,输入层将图像表示为像素的向量。
(2)卷积层:卷积层是CNN中最重要的部分之一,由若干个卷积单元组成,每个卷积单元的参数都是通过反向传播算法优化得到的。卷积层使用卷积核对输入图像进行特征提取和特征映射,通过卷积操作,提取出图像中的局部特征。
(3)池化层:池化层通常位于卷积层之后,用于进行下采样,对特征图进行稀疏处理,减少数据运算量。池化操作包括最大池化和平均池化。
(4)全连接层:全连接层通常在CNN的尾部,用于将卷积输出的二维特征图转化成一维向量,实现端到端的学习过程。全连接层的作用是将输入特征映射到输出结果,用于分类、回归等任务。
(5)输出层:用于输出结果,可以根据具体任务设计不同的输出形式。如下图结构所示。
在这里插入图片描述

2.2.3反向传播算法

  图像通过CNN的每一层进行前向传播,直到得到网络的最终输出分类概率;然后,根据网络的最终输出和真实标签计算损失函数得出误差,将误差从输入层向隐藏层反向传播,直至传播到输入层[3]。在训练过程中,CNN通过反向传播算法优化网络参数,通过多次迭代来不断优化网络参数,直到网络在验证数据上出现过拟合为止,使模型在训练数据上达到最佳性能。训练完成后,CNN用于对新的图像数据进行特征提取和分类任务。

2.3 MobileNetV1网络

  MobileNetV1是Google研究团队所提出的一种轻量级卷积神经网络模型,它的出现,极大地推动了深度学习在移动和嵌入式视觉应用中的发展。该模型的核心创新在于其深度可分离卷积技术,这一技术的引入使得模型在计算负担和参数数量上都有了显著减少,从而实现了高效与性能之间的平衡。
  传统的卷积神经网络虽然能够取得不错的分类效果,但往往伴随着庞大的计算量和存储空间需求,这在资源有限的环境中显得尤为突出。而MobileNetV1的出现,正好解决了这一难题。其轻量级的特性使得它能够在保证一定精度的同时,大大减少计算资源和存储空间的占用,非常适合进行实时气象图像分类。
  本课题将通过构建基于MobileNetV1的气象图像分类模型,对大量气象图像数据进行训练和测试,以验证其在实际应用中的性能表现。同时,本研究也将探讨如何对MobileNetV1进行改进和优化,以适应气象图像分类任务的特殊需求,进一步提高分类精度和效率如下图所示:
在这里插入图片描述

2.4 VGG16网络

  VGG16,作为牛津大学Visual Geometry Group(VGG)团队所研发的一种经典卷积神经网络模型,在深度学习领域具有举足轻重的地位。其设计理念独特且实用,通过构建深层的网络结构以及采用小尺寸的卷积核,使得VGG16在图像处理任务中展现出了强大的特征提取能力。
  VGG16的网络结构深度适中,既保证了模型的复杂度,又避免了过拟合现象的发生。它多次堆叠简单的卷积层和池化层,通过这种“简单而有效”的方式,逐步提取图像从低层次到高层次的特征。这种设计不仅提高了模型的泛化能力,还使得模型在处理复杂图像数据时更加高效。
  在本课题中,将研究VGG16在气象图像分类任务中的应用。通过与MobileNetV1模型进行比较,能够更全面地评估VGG16在气象图像分类领域的性能表现。VGG16结构如下图所示
在这里插入图片描述

2.5 Transformer网络

  Transformer模型是Google团队于2017年提出的一种革命性的NLP架构,由Ashish Vaswani等人在论文《Attention Is All You Need》中详细介绍。这一模型在机器翻译任务上显著超越了RNN和CNN的性能,仅通过encoder-decoder结构和注意力机制就取得了出色的效果,其最大的优势在于能够高效地进行并行化计算。
  Encoder部分由N个相同的层堆叠而成,每一层包含两个子层:一个多头自注意力层和一个简单的全连接前馈神经网络。每个子层都使用了残差连接和层归一化技术,[4]以增强模型的训练稳定性和表达能力。
  Decoder部分同样由N个相同的层组成,但结构略有不同。每个Decoder层包含三个子层:一个自注意力层、一个encoder-decoder注意力层和一个全连接前馈神经网络。其中,自注意力层和encoder-decoder注意力层都是基于多头注意力机制实现的。特别地,Decoder中的自注意力层使用了masking技术,以确保在预测某个位置的输出时不会使用到未来位置的信息,从而保证了训练的合理性。
  Transformer模型通过其独特的encoder-decoder结构和多头注意力机制,实现了对序列数据的高效处理,并在多项NLP任务中取得了卓越的性能。如下图所示:
在这里插入图片描述

3 系统设计与实现

3.1 数据准备

3.1.1 数据集介绍

  在本课题中,主要使用了一个气象图像数据集,该数据集包含了多种不同的天气状况的图像。具体而言,数据集涵盖了闪电、多云、雨天、雪天、薄雾以及雷雨等天气类别。每个类别都有200张图像以上,总计达到6872张图像,为模型的训练提供了丰富的数据支持。如下图所示
在这里插入图片描述
  数据集总共包含11个类别,分别是:露水、雾霾、霜、冰壳、冰雹、闪电、雨、彩虹、霜冻、沙尘暴与雪。为了评估模型的性能,将数据集分为训练集、验证集和测试集。在本课题中,采用了70%用作训练集,10%用作验证集,剩余20%用作测试集的比例进行划分。训练集用于模型的训练和优化,而测试集则用于评估模型在未见过的数据上的表现[5]。具体各类别图片数量可见下表所示:

类别图片数量
露水dew698
雾霾fogsmog861
霜frost475
冰壳glaze639
冰雹hail591
闪电lightning377
雨rain526
彩虹rainbow232
霜冻rime1160
沙尘暴sandstorm692
雪snow621

3.1.2 数据预处理

  数据预处理是确保模型性能的关键步骤。在本课题中,进行了以下预处理操作:缩放:由于原始图像的尺寸可能不统一,将其缩放到统一的640×640尺寸,以适应模型的输入要求。归一化,对图像的像素值进行归一化处理,将图像像素值统一缩放到[0,1]范围内,有助于模型的训练收敛。为了增加模型的泛化能力,采用了数据增强的方法,具体有随机旋转、裁剪和翻转,以扩充训练集。
  通过这些预处理步骤,确保了输入到模型中的数据具有统一且规范的格式,有助于提高模型的训练效率和性能。
  数据准备是气象图像分类任务中的关键步骤,通过选择合适的数据集、划分训练集和测试集,并进行必要的预处理操作,可以为模型的训练和评估提供有力的支持。

3.2 算法实现

3.2.1 MobileNet实现

def MobileNetV1(input_shape=None,
              alpha=1.0,
              depth_multiplier=1,
              dropout=1e-3,
              classes=1000):

    img_input = Input(shape=input_shape)

    # 224,224,3 -> 112,112,32  
    x = _conv_block(img_input, 32, alpha, strides=(2, 2))
    
    # 112,112,32 -> 112,112,64
    x = _depthwise_conv_block(x, 64, alpha, depth_multiplier, block_id=1)


    # 112,112,64 -> 56,56,128
    x = _depthwise_conv_block(x, 128, alpha, depth_multiplier,
                              strides=(2, 2), block_id=2)
    x = _depthwise_conv_block(x, 128, alpha, depth_multiplier, block_id=3)


    # 56,56,128 -> 28,28,256
    x = _depthwise_conv_block(x, 256, alpha, depth_multiplier,
                              strides=(2, 2), block_id=4)
    x = _depthwise_conv_block(x, 256, alpha, depth_multiplier, block_id=5)
    

    # 28,28,256 -> 14,14,512
    x = _depthwise_conv_block(x, 512, alpha, depth_multiplier,
                              strides=(2, 2), block_id=6)
    x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=7)
    x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=8)
    x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=9)
    x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=10)
    x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=11)

    # 14,14,512 -> 7,7,1024
    x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier,
                              strides=(2, 2), block_id=12)
    x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier, block_id=13)

    # 7,7,1024 -> 1,1,1024
    x = GlobalAveragePooling2D()(x)

    shape = (1, 1, int(1024 * alpha))

    x = Reshape(shape, name='reshape_1')(x)
    x = Dropout(dropout, name='dropout')(x)

    x = Conv2D(classes, (1, 1),padding='same', name='conv_preds')(x)
    x = Activation('softmax', name='act_softmax')(x)
    x = Reshape((classes,), name='reshape_2')(x)

    inputs = img_input

    model = Model(inputs, x, name='mobilenet_%0.2f' % (alpha))
    return model

3.2.2 VGG16实现


def VGG16(input_shape=None, classes=1000):
    img_input = Input(shape=input_shape)

    # Block 1
    # 224, 224, 3 -> 224, 224, 64
    x = Conv2D(64, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block1_conv1')(img_input)
    x = Conv2D(64, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block1_conv2')(x)
    # 224, 224, 64 -> 112, 112, 64
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)

    # Block 2
    # 112, 112, 64 -> 112, 112, 128
    x = Conv2D(128, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block2_conv1')(x)
    x = Conv2D(128, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block2_conv2')(x)
    # 112, 112, 128 -> 56, 56, 128
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)

    # Block 3
    # 56, 56, 128 -> 56, 56, 256
    x = Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block3_conv1')(x)
    x = Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block3_conv2')(x)
    x = Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block3_conv3')(x)
    # 56, 56, 256 -> 28, 28, 256
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)

    # Block 4
    # 28, 28, 256 -> 28, 28, 512
    x = Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block4_conv1')(x)
    x = Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block4_conv2')(x)
    x = Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block4_conv3')(x)
                      
    # 28, 28, 512 -> 14, 14, 512
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)

    # Block 5
    # 14, 14, 512 -> 14, 14, 512
    x = Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block5_conv1')(x)
    x = Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block5_conv2')(x)
    x = Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      name='block5_conv3')(x)
    # 14, 14, 512 -> 7, 7, 512
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)

    x = Flatten(name='flatten')(x)
    x = Dense(4096, activation='relu', name='fc1')(x)
    x = Dense(4096, activation='relu', name='fc2')(x)
    x = Dense(classes, activation='softmax', name='predictions')(x)

    inputs = img_input

    model = Model(inputs, x, name='vgg16')
    return model

3.2.3 SwinTransformer实现

class SwinTransformerBlock():
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4.,
                 qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path_prob=0., name=""):
        super().__init__()
        self.dim                = dim
        self.input_resolution   = input_resolution
        self.num_heads          = num_heads
        self.window_size        = window_size
        self.shift_size         = shift_size
        self.mlp_ratio          = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.dim                = dim
        self.input_resolution   = input_resolution
        self.num_heads          = num_heads
        self.window_size        = window_size
        self.shift_size         = shift_size
        self.mlp_ratio          = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1              = LayerNormalization(epsilon=1e-5, name=name + ".norm1")
        self.pre                = SwinTransformerBlock_pre(self.input_resolution, self.window_size, self.shift_size)
        self.attn               = WindowAttention(
            dim, 
            window_size = (self.window_size, self.window_size), 
            num_heads   = num_heads,
            qkv_bias    = qkv_bias, 
            qk_scale    = qk_scale, 
            attn_drop   = attn_drop, 
            proj_drop   = drop, 
            name        = name + ".attn"
        )
        self.post               = SwinTransformerBlock_post(self.dim, self.input_resolution, self.window_size, self.shift_size)
        self.drop_path          = DropPath(drop_path_prob if drop_path_prob > 0. else 0.)
        self.norm2              = LayerNormalization(epsilon=1e-5, name=name + ".norm2")
        mlp_hidden_dim          = int(dim * mlp_ratio)
        self.mlp                = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop, name=name + ".mlp")
        self.add                = Add()


    def call(self, x):
        H, W = self.input_resolution
        B, L, C = x.get_shape().as_list()
        assert L == H * W, "input feature has wrong size"
        # 56, 56, 96

        shortcut = x

        x = self.norm1(x)
        x = self.pre(x)
        # 64, 49, 97 -> 64, 49, 97
        x = self.attn.call(x, mask=self.pre.attn_mask)
        x = self.post(x)

        # FFN
        # 56 * 56, 96
        x = self.add([shortcut, self.drop_path(x)])
        x = self.add([x, self.drop_path(self.mlp.call(self.norm2(x)))])
        return x

class PatchMerging(keras.layers.Layer):
    def __init__(self, input_resolution):
        super().__init__()
        self.input_resolution   = input_resolution

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1] // 4, input_shape[2] * 4)
        
    def call(self, x):
        H, W = self.input_resolution
        B, L, C = x.get_shape().as_list()
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        # 56, 56, 96
        x = tf.reshape(x, shape=[-1, H, W, C])

        # 28, 28, 96
        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        # 28, 28, 384
        x = tf.concat([x0, x1, x2, x3], axis=-1)
        # 784, 384
        x = tf.reshape(x, shape=[-1, (H // 2) * (W // 2), 4 * C])

        return x

def BasicLayer(
    x, dim, input_resolution, depth, num_heads, window_size,
    mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path_prob=0., name=""
):
    for i in range(depth):
        x = SwinTransformerBlock(
                    dim                 = dim, 
                    input_resolution    = input_resolution,
                    num_heads           = num_heads, 
                    window_size         = window_size,
                    shift_size          = 0 if (i % 2 == 0) else window_size // 2,
                    mlp_ratio           = mlp_ratio,
                    qkv_bias            = qkv_bias, 
                    qk_scale            = qk_scale,
                    drop                = drop, 
                    attn_drop           = attn_drop,
                    drop_path_prob      = drop_path_prob[i] if isinstance(drop_path_prob, list) else drop_path_prob,
                    name                = name + ".blocks." + str(i),
                ).call(x)
    return x

def build_model(input_shape = [224, 224], patch_size=(4, 4), classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1):
    #-----------------------------------------------#
    #   224, 224, 3
    #-----------------------------------------------#
    inputs = Input(shape = (input_shape[0], input_shape[1], 3))
    
    #-----------------------------------------------#
    #   224, 224, 3 -> 56, 56, 768
    #-----------------------------------------------#
    x = Conv2D(embed_dim, patch_size, strides = patch_size, padding = "valid", name = "patch_embed.proj")(inputs)
    #-----------------------------------------------#
    #   56, 56, 768 -> 3136, 768
    #-----------------------------------------------#
    x = Reshape(((input_shape[0] // patch_size[0]) * (input_shape[1] // patch_size[0]), embed_dim))(x)
    x = LayerNormalization(epsilon=1e-5, name = "patch_embed.norm")(x)
    x = Dropout(drop_rate)(x)

    num_layers          = len(depths)
    patches_resolution  = [input_shape[0] // patch_size[0], input_shape[1] // patch_size[1]]
    dpr                 = [x for x in np.linspace(0., drop_path_rate, sum(depths))]
    #-----------------------------------------------#
    #   3136, 768 -> 3136, 49 
    #-----------------------------------------------#
    for i_layer in range(num_layers):
        dim                 = int(embed_dim * 2 ** i_layer)
        input_resolution    = (patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer))
        x = BasicLayer(
            x,
            dim                 = dim,
            input_resolution    = input_resolution,
            depth               = depths[i_layer],
            num_heads           = num_heads[i_layer],
            window_size         = window_size,
            mlp_ratio           = mlp_ratio,
            qkv_bias            = qkv_bias, qk_scale=qk_scale,
            drop                = drop_rate, attn_drop=attn_drop_rate,
            drop_path_prob      = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
            name                = "layers." + str(i_layer)
        )
        if (i_layer < num_layers - 1):
            x   = PatchMerging(input_resolution)(x)
            x   = LayerNormalization(epsilon=1e-5, name = "layers." + str(i_layer) + ".downsample.norm")(x)
            x   = Dense(2 * dim, use_bias=False, name = "layers." + str(i_layer) + ".downsample.reduction")(x)

    x = LayerNormalization(epsilon=1e-5, name="norm")(x)
    x = GlobalAveragePooling1D()(x)
    x = Dense(classes, name="head")(x)
    x = Softmax()(x)
    return keras.models.Model(inputs, x)

3.2.4 模型训练


batch_size  = Unfreeze_batch_size
start_epoch = Freeze_Epoch if start_epoch < Freeze_Epoch else start_epoch
end_epoch   = UnFreeze_Epoch
    
#-------------------------------------------------------------------#
#   判断当前batch_size,自适应调整学习率
#-------------------------------------------------------------------#
nbs             = 64
lr_limit_max    = 1e-3 if optimizer_type == 'adam' else 1e-1
lr_limit_min    = 1e-4 if optimizer_type == 'adam' else 5e-4
if backbone == 'vit':
    nbs             = 256
    lr_limit_max    = 1e-3 if optimizer_type == 'adam' else 1e-1
    lr_limit_min    = 1e-5 if optimizer_type == 'adam' else 5e-4
Init_lr_fit     = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max)
Min_lr_fit      = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
#---------------------------------------#
#   获得学习率下降的公式
#---------------------------------------#
lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
lr_scheduler    = LearningRateScheduler(lr_scheduler_func, verbose = 1)
callbacks       = [logging, loss_history, checkpoint, checkpoint_last, checkpoint_best, lr_scheduler]

for i in range(len(model.layers)): 
    model.layers[i].trainable = True
if ngpus_per_node > 1:
    with strategy.scope():
        model.compile(loss = 'categorical_crossentropy', optimizer = optimizer, metrics = ['categorical_accuracy'])
else:
    model.compile(loss = 'categorical_crossentropy', optimizer = optimizer, metrics = ['categorical_accuracy'])

epoch_step      = num_train // batch_size
epoch_step_val  = num_val // batch_size

if epoch_step == 0 or epoch_step_val == 0:
    raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。")

train_dataloader.batch_size    = Unfreeze_batch_size
val_dataloader.batch_size      = Unfreeze_batch_size

print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size))
model.fit(
    x                   = train_dataloader,
    steps_per_epoch     = epoch_step,
    validation_data     = val_dataloader,
    validation_steps    = epoch_step_val,
    epochs              = end_epoch,
    initial_epoch       = start_epoch,
    use_multiprocessing = True if num_workers > 1 else False,
    workers             = num_workers,
    callbacks           = callbacks
)

3.3 界面设计

  QT Designer是一款用于创建Qt应用程序用户界面的可视化设计工具。它允许开发人员通过拖放界面元素来创建界面,从而避免了手动编写复杂的布局代码。以下是使用QT Designer进行UI设计的详细步骤:

第一步,打开QT Designer:
  在QT Creator中,双击项目中的.ui文件。
第二步,添加控件:
  从左侧的“Widget Box”中拖拽所需的控件到中间的“Form Editor”区域。例如,添加一个按钮(QPushButton)、一个标签(QLabel)和一个文本框(QLineEdit)。
第三步,调整布局:
  使用“Layout”面板来管理控件的布局。可以选择水平布局(QHBoxLayout)、垂直布局(QVBoxLayout)、网格布局(QGridLayout)等。
第四步,设置控件属性:
  使用“Property Editor”面板来配置控件的属性。例如,设置按钮的文本、标签的文本、文本框的初始文本等。
第五步,连接信号与槽:
  在“Form Editor”中,选中需要连接信号与槽的控件,然后点击工具栏中的“Edit Signals/Slots”按钮。在弹出的对话框中,选择信号(如clicked())和槽函数(如on_button_clicked()),然后点击“OK”按钮
第六步,保存UI文件:
  完成设计后,点击“File” > “Save”保存.ui文件。
第七步,生成Python代码:
  使用pyuic5命令将.ui文件转换为Python代码。
在命令行中输入:pyuic5 designer.ui -o designer.py

在这里插入图片描述

3.4 功能设计

3.4.1 选择图像

    def pushButton_select_file_slot(self):
        self.ui.label_category.clear()
        self.ui.label_score.clear()
        self.ui.label_image.clear()

        self.file_path, ret = QFileDialog.getOpenFileName(None, '选择图像', "./datasets", "图像文件 (*.jpg *.jpeg *.png)")

        if self.file_path =='':
            QMessageBox.critical(self, '提示', '未选择图像')
            return
        image = cv2.imread(self.file_path)

        self.show_image(image, self.ui.label_image)

3.4.2 图像显示


    def show_image(img_src, label):
        try:
            ih, iw, _ = img_src.shape
            w = label.geometry().width()
            h = label.geometry().height()
            # 保持纵横比
            # 找出长边
            if iw > ih:
                scal = w / iw
                nw = w
                nh = int(scal * ih)
                img_src_ = cv2.resize(img_src, (nw, nh))

            else:
                scal = h / ih
                nw = int(scal * iw)
                nh = h
                img_src_ = cv2.resize(img_src, (nw, nh))

            frame = cv2.cvtColor(img_src_, cv2.COLOR_BGR2RGB)
            img = QImage(frame.data, frame.shape[1], frame.shape[0], frame.shape[2] * frame.shape[1],
                         QImage.Format_RGB888)
            label.setPixmap(QPixmap.fromImage(img))

        except Exception as e:
            print(repr(e))

3.4.3 AI分类

    def pushButton_cnn_infer_slot(self):
        try:
            if self.ui.radioButton_model_mobilenet.isChecked():
                cls_res = self.classfication_mobilenet.detect_image(Image.open(self.file_path))
            elif self.ui.radioButton_model_vgg16.isChecked():
                cls_res = self.classfication_vgg16.detect_image(Image.open(self.file_path))
            else :
                cls_res = self.transformer_mixed_cnn.detect_image(Image.open(self.file_path))
            print(cls_res)
            class_label, score = cls_res

            self.ui.label_category.setText(class_label)
            self.ui.label_score.setText(str(round(score,7)))

        except Exception as e:
            print(e)

3.5 效果展示

3.5.1 mobilenet预测

在这里插入图片描述

3.5.2 VGG16预测

在这里插入图片描述

3.5.3 swin-transformer预测

在这里插入图片描述

4 参考文献

[1] 龙学军,高枫.基于深度学习的高速公路气象识别方法[J].中国交通信息化,2021,(05):134-136.DOI:10.13439.
[2] 辛聪,李菁.基于MobileNet的电力设备图像识别[J].工业控制计算机,2024,37(03):157-158+166.
[3] 李林红,杨杰,蒋严宣,等.基于改进MobileNet v2的服装图像分类算法[J/OL].现代纺织技术,1-11[2024-03-30].https://doi.org/10.19398.
[4] 童立靖,王清河,冯金芝.基于混合Transformer模型的三维视线估计[J].中南民族大学学报(自然科学版),2024,43(01):97-103.
[5] 封皓元.基于深度迁移学习的天气图像识别[D].沈阳市沈阳工业大学,2023.
[6] 于晓,庄光耀.基于轻量化VGG16和BCBAM的电力设备故障红外图像诊断识别[J].河南科技学院学报(自然科学版),2023,51(06):60-69.
[7] Mavaddati S .Voice-based age, gender, and language recognition based on ResNet deep model and transfer learning in spectro-temporal domain[J].Neurocomputing,2024,580127429-.
[8] Huang Y ,Pi Y ,Ma K , et al.Predicting the error magnitude in patient-specific QA during radiotherapy based on ResNet.[J].Journal of X-ray science and technology,2024,
[9] 韩鑫豪,何月顺,陈杰,等.基于Swin Transformer的岩石岩性智能识别研究[J].现代电子技术,2024,47(07):37-44.DOI:10.16652/j.issn.1004-373x.2024.07.006.
[10] 张卓然,张倩,宋智,等.基于残差Swin Transformer的天气图像识别技术研究[J].成都信息工程大学学报,2023,38(06):637-642.DOI:10.16836.
[11] 符锌砂,胡弘毅,莫宇蓉,等.基于Vision Transformer的高速公路监控场景天气识别[J].公路交通科技,2023,40(07):164-169.
[12] 侯晓明,邱亚峰.基于卷积神经网络与特征融合的天气识别方法[J].应用光学,2023,44(02):323-329.
[13] 沈晨鑫.基于数据增强的天气图像分类方法研究[D].南京市东南大学,2022.
[14] 黄小猛,林岩銮,熊巍,等.数值预报AI气象大模型国际发展动态研究[J].大气科学学报,2024,47(01):46-54.
[15] 赖尉文.基于自适应证据推理规则的气象识别方法研究[D].哈尔滨市哈尔滨师范大学,2023.
[16] 朱坤.气象传真图像识别技术研究[D].哈尔滨市哈尔滨工程大学,2020.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2112368.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

tb-nightly库安装报错

使用pip安装&#xff08;默认清华镜像&#xff09;tb-nightly库报如下错误&#xff1a; 网上查阅资料&#xff0c;尝试了以下方式&#xff1a; 使用conda安装失败&#xff01;使用pip install tb-nightly --index-url https://pypi.org/simple安装失败最后&#xff0c;换成阿…

[Linux]:进程(上)

✨✨ 欢迎大家来到贝蒂大讲堂✨✨ &#x1f388;&#x1f388;养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; 所属专栏&#xff1a;Linux学习 贝蒂的主页&#xff1a;Betty’s blog 1. 初识进程 1.1 进程的概念 在计算机世界中&#xff0c;进程是一个关键概念…

机器学习中的聚类艺术:探索数据的隐秘之美

一 什么是聚类 聚类是一种经典的无监督学习方法&#xff0c;无监督学习的目标是通过对无标记训练样本的学习&#xff0c;发掘和揭示数据集本身潜在的结构与规律&#xff0c;即不依赖于训练数据集的类标记信息。聚类则是试图将数据集的样本划分为若干个互不相交的类簇&#xff…

Confluence8.5.14安装

一、Centos8、安装jdk11(略) 二、mysql数据库 1、mysql安装包下载: MySQL :: Download MySQL Community Server 2、安装: https://downloads.mysql.com/archives/get/p/23/file/mysql-8.0.37-1.el8.x86_64.rpm-bundle.tar tar -xvf mysql-8.0.37-1.el8.x86_64.rpm-bund…

浏览器剪贴板 API Clipboard API

在 Web 开发领域&#xff0c;Clipboard API 就是一个备受关注的新利器&#xff0c;它为我们提供了在网页中访问和操作剪贴板的能力&#xff0c;极大地丰富了用户交互体验。本文将深入探讨 Clipboard API 的使用方法和潜在应用场景。 一. 什么是 Clipboard API&#xff1f; Cl…

集合及映射

1、集合类图 1&#xff09;ArrayList与LinkedList 区别 LinkedList 实现了双向队列的接口&#xff0c;对于数据的插入速度较快&#xff0c;只需要修改前后的指向即可&#xff1b;ArrayList对于特定位置插入数据&#xff0c;需要移动特定位置后面的数据&#xff0c;有额外开销 …

Windows 安装mysql 教程,mysql 多版本共存教程,傻瓜式安装教程

mysql 各版本官方下载地址&#xff1a;⬇ ⬇⬇⬇⬇⬇⬇⬇⬇⬇(点击下面链接前往)MySQL :: Download MySQL Community Server (Archived Versions)https://downloads.mysql.com/archives/community/ 首先我本地安装了 mysql8.0版本了&#xff0c;通过msi 进行安装的也就是傻瓜式…

SprinBoot+Vue高校网上缴费综合务系统的设计与实现

目录 1 项目介绍2 项目截图3 核心代码3.1 Controller3.2 Service3.3 Dao3.4 application.yml3.5 SpringbootApplication3.5 Vue 4 数据库表设计5 文档参考6 计算机毕设选题推荐7 源码获取 1 项目介绍 博主个人介绍&#xff1a;CSDN认证博客专家&#xff0c;CSDN平台Java领域优质…

文心快码前端工程师观点分享:人机协同新模式的探索之路(三)

本系列视频来自百度工程效能部的前端研发经理杨经纬&#xff0c;她在由开源中国主办的“AI编程革新研发效能”OSC源创会杭州站105期线下沙龙活动上&#xff0c;从一款文心快码&#xff08;Baidu Comate&#xff09;前端工程师的角度&#xff0c;分享了关于智能研发工具本身的研…

AIGC是如何颠覆文旅行业的?

AI技术正在以前所未有的速度和规模&#xff0c;颠覆着各行各业的发展。在文旅行业&#xff0c;这种颠覆尤为显著。今天&#xff0c;我们深入探讨AIGC是如何颠覆文旅行业的。 传统的文旅内容创作方式&#xff0c;往往需要大量的人力、物力和财力投入。拍摄、录制、剪辑&#xf…

第二天旅游线路规划和预览

第二天&#xff1a;从克拉玛依市乌尔禾区到五彩滩&#xff0c;晚上住宿贾登峪&#xff1b; 规划结果见下图&#xff1a; 1、行程安排 根据上面的耗时情况&#xff0c;规划一天的行程安排如下&#xff1a; 1&#xff09;早上7&#xff1a;30起床&#xff0c;吃完早饭&#xff0c…

微信小程序页面制作——本地生活(含代码)

✅作者简介&#xff1a;2022年博客新星 第八。热爱国学的Java后端开发者&#xff0c;修心和技术同步精进。 &#x1f34e;个人主页&#xff1a;Java Fans的博客 &#x1f34a;个人信条&#xff1a;不迁怒&#xff0c;不贰过。小知识&#xff0c;大智慧。 &#x1f49e;当前专栏…

基于ASP+ACCESS的教师信息管理系统

摘要 随着我国社会主义市场经济的发展和改革开放的不断深入&#xff0c;计算机的应用已遍及国民经济的各个领域&#xff0c;计算机来到我们的工作和生活中&#xff0c;改变着我们和周围的一切。在以前&#xff0c;学校用手工处理教师档案以及工资发放等繁多的工作和数据时&…

谷粒商城の缓存篇

文章目录 前言一、本地缓存和分布式缓存1.本地缓存2.分布式缓存 二、项目实战1.配置Redis2.整合业务代码2.1 缓存击穿2.2 缓存雪崩2.3 缓存穿透2.4 业务代码1.0版2.5 分布式锁1.0版2.6 分布式锁2.0版2.7 Spring Cache及缓存一致性问题2.7.1 Spring Cache2.7.2 缓存一致性问题2.…

[003].第3节.在Windows环境中搭建Redis(单机版)环境

我的后端学习大纲 我的Redis学习大纲 1.Redis下载: 1.中文2.英文 2.Windows下搭建Redis环境&#xff1a; 2.1.单机

[论文笔记]Making Large Language Models A Better Foundation For Dense Retrieval

引言 今天带来北京智源研究院(BAAI)团队带来的一篇关于如何微调LLM变成密集检索器的论文笔记——Making Large Language Models A Better Foundation For Dense Retrieval。 为了简单&#xff0c;下文中以翻译的口吻记录&#xff0c;比如替换"作者"为"我们&quo…

深入理解C语言中的POSIX定时器

引言 在Unix和类Unix系统中&#xff0c;定时器是一种常见的机制&#xff0c;用于在特定时间间隔后执行某些操作。POSIX定时器因其灵活性和功能丰富而被广泛采用。本文将深入探讨POSIX定时器的工作原理、内部机制、使用方法及其在实际开发中的应用。 POSIX定时器基础 POSIX定…

【视频讲解】Python贝叶斯卷积神经网络分类胸部X光图像数据集实例

全文链接&#xff1a;https://tecdat.cn/?p37604 分析师&#xff1a;Yuanchun Niu 在人工智能的诸多领域中&#xff0c;分类技术扮演着核心角色&#xff0c;其应用广泛而深远。无论是在金融风险评估、医疗诊断、安全监控还是日常的交互式服务中&#xff0c;有效的分类算法都是…

数据仓库理论知识

1、数据仓库的概念 数据仓库&#xff08;英文&#xff1a;Date Warehouse&#xff0c;简称数仓、DW&#xff09;&#xff0c;是一个用于数据存储、分析、报告的数据系统。数据仓库的建设目的是面向分析的集成化数据环境&#xff0c;其数据来源于不同的外部系统&#…

Git 修改Push后的Commit Message

向远程仓库push代码之后&#xff0c;在IDEA中无法直接修改Commit Message&#xff0c;需要在终端或控制台中输入以下命令&#xff08;HEAD~1中的1表示只对最后一个提交进行修改&#xff0c;因此1可以自定义&#xff09; git rebase -i HEAD~1执行完rebase指令后&#xff0c;会…