[Pytorch案例实践008]基于卷积神经网络和通道注意力机制的图像分类实战

news2024/11/24 22:57:21

一、项目介绍

        这是一个蜜蜂、蚂蚁图像分类项目,旨在使用卷积神经网络(CNN)结合SE(Squeeze-and-Excitation)模块进行二分类任务。以下是项目的详细介绍:

项目背景

图像分类是计算机视觉中的一个基本任务,广泛应用于医疗诊断、自动驾驶、安防监控等领域。本项目的目标是通过设计并训练一个神经网络模型,将输入的图像正确地分类为两个类别。

模型架构

模型架构主要包括卷积层、SE模块、池化层、全连接层和Dropout层。具体细节如下:

  1. 卷积层(Convolutional Layers)

    • 两个卷积层(Conv2d),用于提取图像的局部特征。
    • 每个卷积层后跟随一个SE模块,用于提升特征表达的能力。
  2. SE模块(Squeeze-and-Excitation Block)

    • SE模块通过“压缩”和“激励”机制,重新校准特征通道的权重,增强有用特征,抑制无用特征。
    • 具体实现上,通过全局平均池化获取通道的全局信息,然后通过全连接层和激活函数(ReLU和Sigmoid)来重新计算通道权重。
  3. 池化层(Pooling Layers)

    • 最大池化层(MaxPool2d),用于降低特征图的尺寸,减少计算量和过拟合风险。
  4. 全连接层(Fully Connected Layers)

    • 增加了多个全连接层(Linear),分别具有1024、512、256和2个神经元。
    • 通过增加网络的深度和复杂度,提高模型的分类能力。
  5. Dropout层

    • 在每个全连接层后增加Dropout层,防止过拟合,增强模型的泛化能力。

数据预处理

为了提高模型的泛化能力,对训练数据进行了数据增强和预处理:

  • 训练集

    • 随机裁剪(RandomResizedCrop)
    • 随机水平翻转(RandomHorizontalFlip)
    • 随机旋转(RandomRotation)
    • 颜色抖动(ColorJitter)
    • 归一化(Normalize)
  • 验证集

    • 调整大小(Resize)
    • 中心裁剪(CenterCrop)
    • 归一化(Normalize)

训练过程

训练过程中采用交叉熵损失函数(CrossEntropyLoss)和Adam优化器,并引入了以下技术:

  • L2正则化(weight_decay):防止过拟合。
  • 学习率调度器(ReduceLROnPlateau):根据验证损失动态调整学习率。
  • 早停法:如果验证损失在若干个epoch内没有下降,则提前停止训练。

测试脚本

测试脚本用于加载训练好的模型,并对给定的输入图片或文件夹进行推理,输出分类结果和置信度,并将结果保存在图片上和指定文件夹中。主要步骤如下:

  1. 加载保存的模型。
  2. 对输入图像进行预处理。
  3. 进行模型推理。
  4. 处理输出,绘制分类结果和置信度,并保存结果。

二、通道注意力机制

通道注意力机制(Channel Attention Mechanism)是深度学习中一种增强模型特征表示能力的方法。它通过动态地调整不同特征通道的权重,来突出重要特征并抑制无关特征,从而提高模型的性能。下面是关于通道注意力机制的详细介绍:

概述

通道注意力机制旨在通过对每个通道分配一个权重,来调整不同特征通道的重要性。这些权重是动态计算的,基于输入特征自适应地调整。常见的通道注意力机制有Squeeze-and-Excitation(SE)块和其他一些变体。

Squeeze-and-Excitation(SE)块

SE块是通道注意力机制的一种经典实现方法,由Jie Hu等人在论文《Squeeze-and-Excitation Networks》中提出。SE块的主要思想是通过“压缩”(Squeeze)和“激励”(Excitation)两个步骤来重新校准通道特征。

1. 压缩(Squeeze)

在“压缩”步骤中,通过全局平均池化(Global Average Pooling)操作,将每个通道的空间维度(宽和高)压缩为一个标量,从而获取每个通道的全局信息。这一步骤可以看作是对每个通道进行全局信息的聚合。

2. 激励(Excitation)

在“激励”步骤中,使用两个全连接层(Fully Connected Layers)对“压缩”后的特征进行处理,得到每个通道的权重。第一个全连接层将通道数降维,第二个全连接层将通道数升维。通过非线性激活函数ReLU和Sigmoid,得到每个通道的权重。

3. 重校准(Recalibration)

将得到的通道权重重新分配给原始特征图的每个通道,从而实现通道的重新校准。

通道注意力机制的优点

  1. 增强特征表示能力:通过动态调整通道的权重,突出重要特征,抑制无关特征。
  2. 提高模型性能:在许多图像分类任务中,加入通道注意力机制后,模型的准确率和泛化能力都有显著提升。
  3. 轻量级:SE模块增加的参数和计算量较少,适用于各种CNN架构。

通过引入通道注意力机制,模型能够更好地理解和处理输入数据中的关键信息,从而提高分类效果。

三、代码

训练代码:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

# 从自定义的文件中加载网络
from cnn_SE import CNNWithSE


# 权重初始化函数
def initialize_weights(model):
    """
    初始化模型的权重,包括卷积层和全连接层
    """
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.constant_(m.bias, 0)


# 设置超参数
batch_size = 4  # 每批处理的样本数
learning_rate = 0.0001  # 初始学习率
num_epochs = 30  # 训练的轮次
# lr_step_size = 10  # 学习率每隔多少个epoch降低一次
# lr_gamma = 0.1  # 学习率降低的倍率
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 使用GPU或CPU

# Data augmentation and preprocessing
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 加载数据集
train_dataset = ImageFolder(root=r'I:\code\pytorch\cnn_SE\datasets\train', transform=train_transform)
val_dataset = ImageFolder(root=r'I:\code\pytorch\cnn_SE\datasets\val', transform=val_transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

# 初始化模型、损失函数和优化器
model = CNNWithSE().to(device)  # 初始化模型并移动到GPU或CPU
initialize_weights(model)  # 初始化权重
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # 优化器

# 学习率调度器,每隔 lr_step_size 个 epoch 将学习率乘以 lr_gamma
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_step_size, gamma=lr_gamma)

# 用于保存训练过程中的损失和准确率数据
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
best_accuracy = 0.0

# 训练和验证模型
for epoch in range(num_epochs):
    model.train()  # 进入训练模式
    running_loss = 0.0
    correct_train = 0
    total_train = 0

    # 训练循环
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device)

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()  # 清空梯度
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数

        running_loss += loss.item()  # 累计损失
        _, predicted = torch.max(outputs.data, 1)  # 获取预测结果
        total_train += labels.size(0)  # 累计样本数
        correct_train += (predicted == labels).sum().item()  # 累计正确预测的样本数

    # 记录训练损失和准确率
    train_loss = running_loss / len(train_loader)
    train_losses.append(train_loss)
    train_accuracy = 100 * correct_train / total_train
    train_accuracies.append(train_accuracy)

    # 验证模型
    model.eval()  # 进入评估模式
    running_val_loss = 0.0
    correct_val = 0
    total_val = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()

    # 记录验证损失和准确率
    val_loss = running_val_loss / len(val_loader)
    val_losses.append(val_loss)
    val_accuracy = 100 * correct_val / total_val
    val_accuracies.append(val_accuracy)
    print(f'Validation Accuracy: {val_accuracy:.2f}%')

    # 保存准确率最高的模型
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        torch.save(model.state_dict(), 'best_model.pth')
        print(f"Model saved with accuracy: {best_accuracy:.2f}%")

    # 更新学习率
    # scheduler.step()

# 保存损失和准确率曲线
plt.figure(figsize=(12, 5))

# 绘制损失曲线
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid()

# 绘制准确率曲线
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid()

# 保存并展示图像
plt.savefig('training_curves.png')
plt.show()

测试代码:

import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont
import os
from cnn_SE import CNNWithSE

def load_model(model_path, num_classes, device):
    model = CNNWithSE()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

def predict(model, image, device, transform):
    image = transform(image).unsqueeze(0).to(device)
    outputs = model(image)
    _, preds = torch.max(outputs, 1)
    confidence = nn.functional.softmax(outputs, dim=1)[0][preds].item()
    return preds.item(), confidence


def draw_label(image, label, confidence):
    draw = ImageDraw.Draw(image)

    # 使用更大的字体
    try:
        font = ImageFont.truetype("arial", 30)  # 使用 Arial 字体,字号为 36
    except IOError:
        font = ImageFont.load_default()  # 如果 Arial 字体不可用,使用默认字体

    text = f"{label}: {confidence:.2f}"

    # 使用 textbbox 获取文本边界框
    text_bbox = draw.textbbox((20, 20), text, font=font)
    text_width = text_bbox[2] - text_bbox[0]
    text_height = text_bbox[3] - text_bbox[1]

    # 矩形背景框
    position = (20, 20)
    draw.rectangle([position, (position[0] + text_width, position[1] + text_height)], fill="black")
    draw.text(position, text, fill="white", font=font)
    return image


def process_image(model, image_path, output_dir, device, transform, class_names):
    image = Image.open(image_path).convert("RGB")
    pred, confidence = predict(model, image, device, transform)
    label = class_names[pred]
    image_with_label = draw_label(image, label, confidence)
    output_path = os.path.join(output_dir, os.path.basename(image_path))
    image_with_label.save(output_path)
    image_with_label.show()  # 显示处理后的图像

def process_folder(model, folder_path, output_dir, device, transform, class_names):
    os.makedirs(output_dir, exist_ok=True)
    for filename in os.listdir(folder_path):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
            image_path = os.path.join(folder_path, filename)
            process_image(model, image_path, output_dir, device, transform, class_names)

def main():
    # 硬编码参数
    model_path = r'I:\code\pytorch\cnn_SE\best_model.pth'  # 模型权重文件路径
    input_path = r'I:\code\pytorch\cnn_SE\datasets\val\ants\445356866_6cb3289067.jpg'  # 输入图片或文件夹路径
    output_dir = r'I:\code\pytorch\cnn_SE\result'  # 输出保存路径
    num_classes = 2  # 分类任务的类别数
    class_names = ['ants', 'bees']  # 类别名称

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = load_model(model_path, num_classes, device)

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    if os.path.isfile(input_path):
        process_image(model, input_path, output_dir, device, transform, class_names)
    elif os.path.isdir(input_path):
        process_folder(model, input_path, output_dir, device, transform, class_names)
    else:
        print("Invalid input path. Must be a file or directory.")

if __name__ == "__main__":
    main()

四、总结

主要工作
  1. 模型设计

    • 构建了一个包含SE模块的CNN,通过“压缩”和“激励”机制动态调整特征通道的权重,提升特征表示能力。
    • 增加了多个全连接层和Dropout层,增强模型的复杂度和防止过拟合。
  2. 数据预处理和增强

    • 对训练数据进行了数据增强(随机裁剪、水平翻转、旋转、颜色抖动等)和预处理(归一化),提高模型的泛化能力。
    • 对验证数据进行了统一的尺寸调整和归一化处理。
  3. 训练过程

    • 采用交叉熵损失函数和Adam优化器,结合L2正则化防止过拟合。
    • 使用学习率调度器根据验证损失动态调整学习率。
    • 实施了早停法,防止过拟合并减少训练时间。
  4. 模型验证和测试

    • 在验证集上评估模型性能,选择最优模型保存。
    • 编写测试脚本,加载保存的模型,对输入图像或文件夹进行推理,并保存分类结果和置信度。
主要技术点
  • 通道注意力机制(SE模块):通过全局平均池化和全连接层,动态调整通道权重,提升特征表示能力。
  • 数据增强:通过随机裁剪、翻转、旋转、颜色抖动等方法,增加数据多样性,防止过拟合。
  • Dropout层:在全连接层后应用Dropout,防止过拟合,增强模型的泛化能力。
  • 学习率调度:使用ReduceLROnPlateau调度器,根据验证损失动态调整学习率,提高模型收敛速度和性能。
项目成果
  • 成功设计并训练了一个带有通道注意力机制的CNN,在验证集上取得了良好的性能。
  • 通过数据增强和优化训练过程,显著提高了模型的泛化能力和稳定性。
  • 编写了完整的测试脚本,实现了对输入图像的推理和结果保存,便于后续应用。
未来工作
  • 模型优化:继续优化模型架构,尝试其他类型的注意力机制,如空间注意力机制,进一步提升模型性能。
  • 数据集扩展:增加更多种类和数量的训练数据,提高模型的鲁棒性和泛化能力。

        通过本项目的研究和实践,展示了通道注意力机制在图像分类任务中的有效性,并为进一步的研究和应用提供了基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2033338.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一图看懂数据仓库、数据平台、数据中台、数据湖的内涵和区别!

当大数据平台出现的时候,有人是说这不就是大号的数据仓库吗?当数据中台出现的时候,有人说这不就是数据仓库的进一步包装吗?数据湖的出现更是让很多人陷入困惑。 事实上,数据仓库、数据平台、数据中台、数据湖还是有区别的,不仅…

算法 三

堆 满二叉树:节点满的。 完全二叉树定义:最下层从左往右满,不跳。 下标性质 大根堆:某个节点为根节点,其下的所有结点都小于根节点。 小根堆 重要的变量 heapSize:当前堆的有效节点个数 重要的两个过程…

RCE-无字母数字绕过正则表达式

目录 一、源码展示 二、分析源码 2.1异或运算 2.2或运算 2.3取反运算 一、源码展示 <?php error_reporting(0); highlight_file(__FILE__); $code$_GET[code]; if(preg_match(/[a-z0-9]/i,$code)){die(hacker); } eval($code); 二、分析源码 根据源码&#xff0c;我…

数据治理:国家标准 GB/T 43697-2024《数据安全技术 数据分类分级规则》

按照国家数据分类分级保护有关要求,参照本文件制定本行业本领域的数据分类分级标准规范,重点可明确以下内容: 明确行业数据分类细则,确定数据分类所依据的业务属性,给出按照业务属性划分的数据类别:分析行业领域数据的领域、群体、区域、精度、规模、深度、重要性等分级要素…

设计模式-单一职责模式

DecoratorBridge Decorator 动机 在某些情况下我们可能会 “过度地使用继承来扩展对象的功能”&#xff0c;由于继承为类型引入的静态特质&#xff0c;使得这种扩展方式缺乏灵活性&#xff1b;并且随着子类的增多&#xff08;扩展功能的增多&#xff09;&#xff0c;各种子类的…

基于RK3568+FPGA医用心电监护仪解决方案

医用心电监护仪解决方案 随着我国老龄化速度加快、规模扩大&#xff0c;越来越多民生领域的热点引起民众的关注。庞大的老龄化群体将是一个严峻的问题&#xff0c;各种社会保障政策的实施和各级医疗资源的扩展与升级正在有效化解这一难题。 在这种背景下&#xff0c;医用心电监…

如何构建一个帮助你高效学习编程的完美笔记系统?

在编程学习的过程中&#xff0c;笔记记录是一项至关重要的技能。尤其是在学习Python这样一门功能强大、广泛应用的编程语言时&#xff0c;建立一个高效的笔记系统不仅能帮助你更好地掌握知识&#xff0c;还能提高你的编程效率。那么&#xff0c;如何构建一个帮助你高效学习Pyth…

Java面试八股之消息队列有哪些协议?各种协议有哪些具体实现

消息队列有哪些协议&#xff1f;各种协议有哪些具体实现 消息队列协议是指在消息队列系统中&#xff0c;用于消息的发送、接收和管理的一套通信规则。不同的协议有着不同的特性和应用场景&#xff0c;以下是一些常见的消息队列协议及其具体实现&#xff1a; AMQP (Advanced M…

【leetcode】杨辉三角 、移除元素(Java语言描述)

杨辉三角 给定一个非负整数 numRows&#xff0c;生成「杨辉三角」的前 numRows 行。 在「杨辉三角」中&#xff0c;每个数是它左上方和右上方的数的和。 示例 1: 输入: numRows 5 输出: [[1],[1,1],[1,2,1],[1,3,3,1],[1,4,6,4,1]]示例 2: 输入: numRows 1 输出: [[1]] …

SecureCoding in C and C++(二)

经过上期的环境搭建过后&#xff0c;我们将正式的学习C系列&#xff0c;首先要学习的是C的一些常用的变量 从编译和连接学起似乎也是不错的选择。 个人总结的一句话&#xff1a;编译其实就是对预处理语句进行处理后&#xff0c;然后对语句进行处理。对预处理语句&#xff0c;例…

C++——list列表容器经典案例——手机按销量降序排列,若销量相同则按价格降序排列

需求&#xff1a;使用list列表对商品进行排序&#xff0c;先通过销量降序排&#xff0c;若销量相同则根据价格升序排列输出 涉及到的知识点&#xff1a;list列表容器、自定义数据类型、自定义排序规则 实现步骤&#xff1a; 1&#xff0c;自定义数据类型Product&#xff0c;…

Android 实现多进程通讯(如何实现多进程开发,Binder、AIDL)

目录 1&#xff09;为什么App需要多进程 2&#xff09;什么是多进程开发? 3&#xff09;如何实现多进程开发&#xff1f; 4&#xff09;跨进程间通讯(案例) 5&#xff09;多进程需要注意什么问题&#xff1f; 6&#xff09;多进程的底层原理是什么&#xff1f;【待写】 …

【Python机器学习】树回归——使用Python的tkinter库创建GUI

机器学习给我们提供了一些强大的工具&#xff0c;能从未知数据中抽取出有用的信息。因此&#xff0c;能否这些信息以易于人们理解的方式呈现十分重要。如果人们可以直接与算法和数据交互&#xff0c;将可以比较轻松的进行解释。其中一个能够同时支持数据呈现和用户交互的方式就…

手机IP地址:是根据网络还是设备决定的?

在日益数字化的今天&#xff0c;手机已经成为我们日常生活中不可或缺的一部分。它不仅是我们沟通的桥梁&#xff0c;更是我们获取信息、享受娱乐和完成工作的得力助手。然而&#xff0c;在使用手机上网的过程中&#xff0c;你是否曾经好奇过手机的IP地址是如何被分配的&#xf…

Java中class文件结构分析二

第17个常量池:01 00 15 28 4C 6A 61 76 61 2F 6C 61 6E 67 2F 53 74 72 69 6E 67 3B 29 56 01&#xff1a;tag位表示的是utf8类型的字面量常量 00 15 二个字节表示的是字面量常量的长度为21 接下来21个字节: 28 4C 6A 61 76 61 2F 6C 61 6E 67 2F 53 74 72 69 6E 67 3B 29 56…

经典大语言模型解读(1):BERT——基于双向Transformer的预训练语言模型

论文&#xff1a;BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 前言 BERT&#xff08;Bidirectional Encoder Representation from Transformer&#xff09;是Google于2019年提出的预训练语言模型。与寻常的Transformer架构不同&#…

eval和长度限制

目录 源码 解决方案 方法一 方法二 方法三 源码 <?php $param $_REQUEST[param]; if(strlen($param)<17 && stripos($param,eval) false && stripos($param,assert) false) {eval($param); } ?> 限制条件&#xff1a; 传入的参数长度不能…

Go语言+Vue3开发前后端后台管理系统实战 用户管理的前端界面和表结构分析

首页&#xff1a; 用户管理界面&#xff1a; 到这一步以后来看一下后端代码的表结构是如何设计的&#xff1a; 后端代码中&#xff0c;使用的操作MySQL的技术是gorm&#xff1a; gorm.io/gorm v1.25.5其中&#xff0c;用户表的定义位置如下&#xff1a; 此时的完整代码如…

C++虚函数习题

#include <iostream>using namespace std;class Animal { public:Animal() {}virtual void perform()0; };class Lion:public Animal { public:Lion() {}void perform(){cout << "狮子会吃小朋友&#xff01;&#xff01;&#xff01;快跑&#xff01;&#x…

设计模式(1)创建型模式和结构型模式

1、目标 本文的主要目标是学习创建型模式和结构型模式&#xff0c;并分别代码实现每种设计模式 2、创建型模式 2.1 单例模式&#xff08;singleton&#xff09; 单例模式是创建一个对象保证只有这个类的唯一实例&#xff0c;单例模式分为饿汉式和懒汉式&#xff0c;饿汉式是…