精准检测花生豆:基于EfficientNet的深度学习分类项目
在现代农业生产中,作物的质量检测和分类是确保产品质量的重要环节。针对花生豆的检测与分类需求,我们开发了一套基于深度学习的解决方案,利用EfficientNetB0模型实现高效、准确的花生豆分类。本博客将详细介绍该项目的背景、数据处理、模型架构、训练过程、评估方法及预测应用。
目录
- 项目背景
- 项目概述
- 数据处理
- 数据集结构
- 数据增强与规范化
- 模型架构
- 训练过程
- 训练脚本 (
train.py
)
- 训练脚本 (
- 模型评估
- 评估脚本 (
evaluate.py
)
- 评估脚本 (
- 预测与应用
- 预测脚本 (
predict.py
)
- 预测脚本 (
- 项目成果
- 结论与未来工作
项目背景
花生豆作为一种重要的经济作物,其品质直接影响到市场价值和消费者满意度。传统的人工检测方法不仅耗时耗力,而且易受主观因素影响,难以实现大规模、精准的分类。因此,开发一种高效、准确的自动化检测系统显得尤为重要。
项目概述
本项目旨在利用深度学习技术,构建一个能够自动检测和分类花生豆的系统。通过收集和处理大量花生豆图像数据,训练一个高性能的卷积神经网络模型,实现对不同类别花生豆的精准分类。项目主要包括以下几个部分:
- 数据处理:图像数据的加载、预处理与增强。
- 模型架构:基于EfficientNetB0的分类模型设计。
- 训练过程:模型的训练与优化,包括断点续训与学习率调度。
- 模型评估:在测试集上的性能评估。
- 预测应用:对新图像进行花生豆分类与标注。
数据处理
数据集结构
项目使用的数据集分为训练集、验证集和测试集,具体结构如下:
./data/dataset/
├── train/
│ ├── baiban/
│ ├── bandian/
│ ├── famei/
│ ├── faya/
│ ├── hongpi/
│ ├── qipao/
│ ├── youwu/
│ └── zhengchang/
├── validation/
│ ├── baiban/
│ ├── bandian/
│ ├── famei/
│ ├── faya/
│ ├── hongpi/
│ ├── qipao/
│ ├── youwu/
│ └── zhengchang/
└── test/
├── baiban/
├── bandian/
├── famei/
├── faya/
├── hongpi/
├── qipao/
├── youwu/
└── zhengchang/
每个子文件夹对应一种花生豆类别,包含相应的图像数据。
数据增强与规范化
为了提高模型的泛化能力,训练过程中对图像数据进行了多种数据增强操作,如随机裁剪、水平翻转、旋转和颜色抖动。同时,使用ImageNet的均值和标准差对图像进行了归一化处理,与预训练模型的输入要求保持一致。
# utils/dataLoader.py
train_transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.ToTensor(),
transforms.Normalize(*stats)
])
validation_transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(*stats)
])
test_transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(*stats)
])
模型架构
本项目采用了EfficientNetB0作为基础模型。EfficientNet系列通过系统性地平衡网络的宽度、深度和分辨率,在模型性能和计算效率之间取得了优异的平衡。具体来说:
- 预训练权重:使用在ImageNet上预训练的权重,帮助模型在较小的数据集上快速收敛。
- 冻结特征提取部分:根据需要,可以选择冻结模型的特征提取层,仅训练最后的分类器,适用于数据量较小的情况。
- 分类器设计:在原有分类器前添加了Dropout层,减少过拟合风险。
# utils/model.py
class EfficientNetB0(nn.Module):
def __init__(self, num_classes, pretrained=True, freeze_features=False):
super(EfficientNetB0, self).__init__()
if pretrained:
self.model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
else:
self.model = models.efficientnet_b0(weights=None)
if freeze_features:
for param in self.model.features.parameters():
param.requires_grad = False
in_features = self.model.classifier[1].in_features
self.model.classifier = nn.Sequential(
nn.Dropout(p=0.4, inplace=True),
nn.Linear(in_features, num_classes)
)
def forward(self, x):
return self.model(x)
训练过程
训练脚本 (train.py
)
训练脚本负责模型的训练与验证,包括数据加载、模型初始化、训练循环、学习率调度、模型保存和训练曲线绘制等功能。
关键功能包括:
- 训练与验证循环:每个epoch包括训练阶段和验证阶段,记录损失与准确率。
- 优化与调度:使用Adam优化器和
ReduceLROnPlateau
学习率调度器,根据验证损失动态调整学习率。 - 模型保存:保存验证集准确率最高的模型,并定期自动保存模型检查点。
- 断点续训:支持从保存的检查点继续训练,避免重复计算。
- 训练曲线绘制:训练结束后,生成并保存训练与验证的准确率和损失曲线。
# train.py
import torch
import torch.nn as nn
from utils.dataLoader import load_data
from utils.model import EfficientNetB0
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
import os
def accuracy(predictions, labels):
pred = torch.argmax(predictions, dim=1)
correct = (pred == labels).sum().item()
return correct
def train(net, start_epoch, epochs, train_loader, validation_loader, device, criterion, optimizer, scheduler, model_path, auto_save):
# 初始化
train_acc_list, validation_acc_list = [], []
train_loss_list, validation_loss_list = [], []
best_validation_acc = 0
net = net.to(device)
if start_epoch > 0:
print(f"从 epoch {start_epoch} 开始训练。")
for epoch in range(start_epoch, epochs):
# 训练阶段
net.train()
train_correct, train_loss, total = 0, 0, 0
with tqdm(train_loader, ncols=100, colour='green', desc=f"Train Epoch {epoch+1}/{epochs}") as pbar:
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = net(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
train_correct += accuracy(outputs, labels)
total += labels.size(0)
pbar.set_postfix({'loss': f"{train_loss / total:.4f}", 'acc': f"{train_correct / total:.4f}"})
train_acc = train_correct / total
train_loss = train_loss / total
train_acc_list.append(train_acc)
train_loss_list.append(train_loss)
# 验证阶段
net.eval()
validation_correct, validation_loss, total_validation = 0, 0, 0
with torch.no_grad():
with tqdm(validation_loader, ncols=100, colour='blue', desc=f"Validation Epoch {epoch+1}/{epochs}") as pbar:
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
outputs = net(images)
loss = criterion(outputs, labels)
validation_loss += loss.item() * images.size(0)
validation_correct += accuracy(outputs, labels)
total_validation += labels.size(0)
pbar.set_postfix({'loss': f"{validation_loss / total_validation:.4f}", 'acc': f"{validation_correct / total_validation:.4f}"})
validation_acc = validation_correct / total_validation
validation_loss = validation_loss / total_validation
validation_acc_list.append(validation_acc)
validation_loss_list.append(validation_loss)
# 更新学习率
scheduler.step(validation_loss)
# 保存最佳模型
if validation_acc > best_validation_acc:
best_validation_acc = validation_acc
checkpoint = {
'epoch': epoch,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'validation_acc': best_validation_acc
}
torch.save(checkpoint, model_path)
print(f"保存最佳模型,验证准确率: {best_validation_acc:.4f}")
# 自动保存模型
if (epoch + 1) % auto_save == 0:
save_path = model_path.replace('.pth', f'_epoch{epoch+1}.pth')
checkpoint = {
'epoch': epoch,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'validation_acc': best_validation_acc
}
torch.save(checkpoint, save_path)
print(f"自动保存模型到 {save_path}")
# 绘制训练曲线
def plot_training_curves(train_acc_list, validation_acc_list, train_loss_list, validation_loss_list, epochs):
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, epochs+1), train_acc_list, 'bo-', label="训练准确率")
plt.plot(range(1, epochs+1), validation_acc_list, 'ro-', label="验证准确率")
plt.title("训练准确率 vs 验证准确率")
plt.xlabel("轮次")
plt.ylabel("准确率")
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(range(1, epochs+1), train_loss_list, 'bo-', label="训练损失")
plt.plot(range(1, epochs+1), validation_loss_list, 'ro-', label="验证损失")
plt.title("训练损失 vs 验证损失")
plt.xlabel("轮次")
plt.ylabel("损失")
plt.legend()
os.makedirs('logs', exist_ok=True)
plt.savefig('logs/training_curve.png')
plt.show()
plot_training_curves(train_acc_list, validation_acc_list, train_loss_list, validation_loss_list, epochs)
if __name__ == '__main__':
batch_size = 64
image_size = 224
classes_num = 8
num_epochs = 100
auto_save = 10
lr = 1e-4
weight_decay = 1e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
classify = {'baiban': 0, 'bandian': 1, 'famei': 2, 'faya': 3, 'hongpi': 4, 'qipao': 5, 'youwu': 6, 'zhengchang': 7}
train_loader, validation_loader, test_loader = load_data(batch_size, image_size, classify)
net = EfficientNetB0(classes_num, pretrained=True, freeze_features=False)
model_path = 'model_weights/EfficientNetB0.pth'
os.makedirs('model_weights', exist_ok=True)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
# 检查点续训
start_epoch = 0
best_validation_acc = 0
if os.path.exists(model_path):
try:
checkpoint = torch.load(model_path, map_location=device)
required_keys = ['model_state_dict', 'optimizer_state_dict', 'scheduler_state_dict', 'epoch', 'validation_acc']
if all(key in checkpoint for key in required_keys):
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
start_epoch = checkpoint['epoch'] + 1
best_validation_acc = checkpoint['validation_acc']
print(f"从 epoch {checkpoint['epoch']} 继续训练,最佳验证准确率: {best_validation_acc:.4f}")
else:
print(f"检查点文件缺少必要的键,开始从头训练。")
except Exception as e:
print(f"加载检查点时发生错误: {e}")
print("开始从头训练。")
print("训练开始")
time_start = time.time()
train(net, start_epoch, num_epochs, train_loader, validation_loader, device=device,
criterion=criterion, optimizer=optimizer, scheduler=scheduler,
model_path=model_path, auto_save=auto_save)
time_end = time.time()
seconds = time_end - time_start
m, s = divmod(seconds, 60)
h, m = divmod(m, 60)
print("训练结束")
print("本次训练时长为:%02d:%02d:%02d" % (h, m, s))
主要特点:
- 进度条可视化:使用
tqdm
库实时展示训练和验证进度。 - 断点续训:支持从上一次中断的epoch继续训练,确保训练过程的连续性。
- 自动保存:定期保存模型检查点,防止意外中断导致的训练损失。
- 训练曲线:生成并保存训练与验证的准确率和损失曲线,便于后续分析与调优。
模型评估
评估脚本 (evaluate.py
)
评估脚本用于在测试集上评估训练好的模型性能,计算准确率和损失,并将结果保存到文件中。
# evaluate.py
import torch
import torch.nn as nn
from utils.dataLoader import load_data
from utils.model import EfficientNetB0
from tqdm import tqdm
import os
def accuracy(predictions, labels):
pred = torch.argmax(predictions, dim=1)
correct = (pred == labels).sum().item()
return correct
def evaluate(net, test_loader, device, criterion, output_path):
net.eval()
test_correct, test_loss, total_test = 0, 0, 0
with torch.no_grad():
with tqdm(test_loader, ncols=100, colour='blue', desc=f"Evaluating on Test Set") as pbar:
for images, labels in pbar:
images, labels = images.to(device), labels.to(device)
outputs = net(images)
loss = criterion(outputs, labels)
test_loss += loss.item() * images.size(0)
test_correct += accuracy(outputs, labels)
total_test += labels.size(0)
pbar.set_postfix({'loss': f"{test_loss / total_test:.4f}", 'acc': f"{test_correct / total_test:.4f}"})
test_acc = test_correct / total_test
test_loss = test_loss / total_test
result = f"测试集准确率: {test_acc:.4f}, 测试集损失: {test_loss:.4f}"
print(result)
# 保存结果到文件
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'a') as f:
f.write(result + '\n')
if __name__ == '__main__':
batch_size = 64
image_size = 224
classes_num = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'
classify = {'baiban': 0, 'bandian': 1, 'famei': 2, 'faya': 3, 'hongpi': 4, 'qipao': 5, 'youwu': 6, 'zhengchang': 7}
_, _, test_loader = load_data(batch_size, image_size, classify)
net = EfficientNetB0(classes_num, pretrained=False)
model_path = 'model_weights/EfficientNetB0.pth'
if not os.path.exists(model_path):
print(f"模型权重文件 {model_path} 不存在,请先训练模型。")
exit()
net.load_state_dict(torch.load(model_path, map_location=device))
net.to(device)
criterion = nn.CrossEntropyLoss()
evaluation_output_path = 'outputs/evaluation_results.txt'
# 清空之前的评估结果
if os.path.exists(evaluation_output_path):
os.remove(evaluation_output_path)
print("评估开始")
evaluate(net, test_loader, device=device, criterion=criterion, output_path=evaluation_output_path)
print("评估结束")
评估流程:
- 加载模型:从保存的权重文件中加载训练好的模型。
- 模型评估:在测试集上计算模型的准确率和损失。
- 结果保存:将评估结果保存到指定的输出文件中,便于后续查看与分析。
预测与应用
预测脚本 (predict.py
)
预测脚本用于对新图像进行花生豆分类,并在图像上标注分类结果和边框。
# predict.py
import os
import cv2
import numpy as np
import torch
from PIL import Image
from utils.model import EfficientNetB0
from torchvision import transforms
def delet_contours(contours, delete_list):
delta = 0
for i in range(len(delete_list)):
del contours[delete_list[i] - delta]
delta += 1
return contours
def main():
input_path = 'data/pic'
output_dir = 'outputs/predicted_images'
os.makedirs(output_dir, exist_ok=True)
image_files = os.listdir(input_path)
classify = {0: 'baiban', 1: 'bandian', 2: 'famei', 3: 'faya', 4: 'hongpi', 5: 'qipao', 6: 'youwu', 7: 'zhengchang'}
# 与训练时相同的预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), # ImageNet均值
std=(0.229, 0.224, 0.225)) # ImageNet标准差
])
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = EfficientNetB0(8, pretrained=False)
model_path = 'model_weights/EfficientNetB0.pth'
if not os.path.exists(model_path):
print(f"模型权重文件 {model_path} 不存在,请先训练模型。")
return
net.load_state_dict(torch.load(model_path, map_location=device))
net.to(device)
net.eval()
min_size = 30
max_size = 400
for img_name in image_files:
img_path = os.path.join(input_path, img_name)
img = cv2.imread(img_path)
if img is None:
print(f"无法读取图像: {img_path}")
continue
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) # 转换到HSV颜色空间
# 根据HSV颜色范围进行掩膜操作(根据实际情况调整颜色范围)
lower_blue = np.array([100, 100, 8])
upper_blue = np.array([255, 255, 255])
mask = cv2.inRange(hsv, lower_blue, upper_blue) # 创建掩膜
result = cv2.bitwise_and(img, img, mask=mask) # 应用掩膜
result = result.astype(np.uint8)
# 转换为灰度图并二值化
gray = cv2.cvtColor(result, cv2.COLOR_BGR2GRAY)
_, binary_image = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
# 查找轮廓
contours, _ = cv2.findContours(binary_image, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
contours = list(contours)
# 过滤轮廓
delete_list = []
for idx, contour in enumerate(contours):
perimeter = cv2.arcLength(contour, True)
if perimeter < min_size or perimeter > max_size:
delete_list.append(idx)
contours = delet_contours(contours, delete_list)
# 对每个轮廓进行分类
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
crop = img[y:y+h, x:x+w]
if crop.size == 0:
continue
crop_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
crop_tensor = transform(crop_pil).unsqueeze(0).to(device)
with torch.no_grad():
output = net(crop_tensor)
pred = torch.argmax(output, dim=1).item()
label = classify[pred]
# 标注图像
cv2.putText(img, label, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36,255,12), 2)
cv2.rectangle(img, (x, y), (x + w, y + h), (0, 0, 255), 2)
# 保存结果图像到outputs/predicted_images/
output_image_path = os.path.join(output_dir, f"{os.path.splitext(img_name)[0]}_predicted.jpg")
cv2.imwrite(output_image_path, img)
print(f"保存预测结果到 {output_image_path}")
print("所有图像的预测和标注已完成并保存到 ./outputs/predicted_images/")
if __name__ == '__main__':
main()
预测流程:
- 图像预处理:将输入图像转换到HSV颜色空间,应用颜色掩膜提取花生豆区域。
- 轮廓检测与过滤:查找并过滤不符合大小要求的轮廓,确保只处理有效的花生豆区域。
- 分类与标注:对每个有效轮廓进行裁剪、预处理,并使用训练好的模型进行分类。在图像上标注分类结果和边框。
- 结果保存:将标注后的图像保存到指定的输出目录。
预测结果
项目成果
通过本项目,我们成功构建了一个能够高效、准确地检测和分类花生豆的深度学习模型。主要成果包括:
- 高准确率:模型在测试集上达到了令人满意的分类准确率。
- 自动化检测:实现了对新图像的自动检测与分类,大大提高了检测效率。
- 可视化结果:通过图像标注,直观展示了分类结果,便于用户理解和应用。
训练与验证的准确率和损失曲线示例
开源的程序及数据集
git clone https://gitee.com/songaoxiangsoar/peanut-bean-testing.git
结论与未来工作
本项目展示了基于深度学习的花生豆检测与分类的可行性与有效性。通过采用预训练的EfficientNetB0模型,并结合数据增强与优化策略,模型在花生豆分类任务中表现出色。
未来的工作方向包括:
- 模型优化:尝试更深更复杂的模型,如EfficientNetB7,以进一步提升分类性能。
- 数据扩展:收集更多多样化的花生豆图像,增强模型的泛化能力。
- 实时检测:优化模型推理速度,实现实时花生豆检测与分类。
- 部署应用:将模型集成到移动设备或嵌入式系统,便于现场检测与应用。
通过持续的优化与扩展,我们相信这一系统将在农业生产中发挥更大的价值,助力智能农业的发展。
感谢阅读本博客!如果您对本项目有任何疑问或建议,欢迎在下方留言交流。