【深度学习实战】kaggle 自动驾驶的假场景分类

news2025/1/19 15:51:00

本次分享我在kaggle中参与竞赛的历程,这个版本是我的第一版,使用的是vgg。欢迎大家进行建议和交流。

概述

  • 判断自动驾驶场景是真是假,训练神经网络或使用任何算法来分类驾驶场景的图像是真实的还是虚假的。

  • 图像采用 RGB 格式并以 JPEG 格式压缩。

  • 标签显示 (1) 真实和 (0) 虚假

  • 二元分类

数据集描述

文件
train.csv - 训练集标签
Sample_submission.csv - 正确格式的示例提交文件
Train/- 训练图像
Test/ - 测试图像

模型思路

由于是要进行图像的二分类任务,因此考虑使用迁移学习,将vgg16中的卷积层和卷积层的参数完全迁移过来,不包括顶部的全连接层,自己设计适合该任务的头部结构,然后加以训练,绘制图像查看训练结果。

vgg16简介

VGG16 是由牛津大学视觉几何组(VGG)在2014年提出的卷积神经网络(CNN)。它由16个层组成,其中包含13个卷积层和3个全连接层。其特点是使用3x3的小卷积核和2x2的最大池化层,网络深度较深,有效提取图像特征。VGG16在图像分类任务中表现优异,尤其是在ImageNet挑战中取得了良好成绩。尽管计算量大、参数众多,但它因其简单而高效的结构,仍广泛应用于迁移学习和其他计算机视觉任务中。

源码+解析

  1. 第一步,导入所需的库。
import os
import cv2
import numpy as np
import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Dropout, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.applications.vgg16 import preprocess_input
  1. 加载文件
# 路径和文件
data_file = '/kaggle/input/cidaut-ai-fake-scene-classification-2024/train.csv'
image_test = '/kaggle/input/cidaut-ai-fake-scene-classification-2024/Test/'
image_train = '/kaggle/input/cidaut-ai-fake-scene-classification-2024/Train/'

# 加载标签数据
df = pd.read_csv(data_file)
df['image_path'] = df['image'].apply(lambda x: os.path.join(image_train, x))

n_classes = df['label'].nunique()

df.head()  # 显示数据的前几行,检查路径和标签

输出

	image	label	image_path
0	1.jpg	editada	/kaggle/input/cidaut-ai-fake-scene-classificat...
1	2.jpg	real	/kaggle/input/cidaut-ai-fake-scene-classificat...
2	3.jpg	real	/kaggle/input/cidaut-ai-fake-scene-classificat...
3	6.jpg	editada	/kaggle/input/cidaut-ai-fake-scene-classificat...
4	8.jpg	real	/kaggle/input/cidaut-ai-fake-scene-classificat...

原始train.csv文件只有前两列,image 和label 列,为了方便读取图像文件,新添加了一列image_path用来记录图像文件的具体路径。

# 初始化空列表 x 用于存储图像
x = []

# 遍历每一行读取图像
for index, row in df.iterrows():
    image_path = row['image_path']  # 获取图像路径
    img = cv2.imread(image_path)  # 使用 cv2 读取图像
    
    if img is not None:
        img_resized = cv2.resize(img, (256, 256))  # 调整图像尺寸为 (256, 256)
        x.append(img_resized)  # 将读取的图像添加到列表 x 中
    else:
        print(f"图像 {row['image_path']} 读取失败")  # 打印失败的路径

# x 列表现在包含了所有读取的图像
print(f"总共有 {len(x)} 张图像被读取")

输出

总共有 720 张图像被读取

通过输出结果,可以看到图像被正确的读取了。并且将图像的大小调整为vgg所能用的256*256的尺寸,存放在变量x中。

  1. 第三步,进行数据处理
# 将图像转换为 NumPy 数组
x = np.array(x)

# 标签映射并进行 one-hot 编码
y = df['label'].map({'real': 1, 'editada': 0})
y = np.array(y)
y = to_categorical(y, num_classes=2)  # 二分类

# 分割训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

# 检查转换后的结果
print(f"x_train.shape: {x_train.shape}")
print(f"y_train.shape: {y_train.shape}")
print(f"x_test.shape: {x_test.shape}")
print(f"y_test.shape: {y_test.shape}")

输出

x_train.shape: (576, 256, 256, 3)
y_train.shape: (576, 2)
x_test.shape: (144, 256, 256, 3)
y_test.shape: (144, 2)

这里是为了将原始的图像转换为numpy数组,并且将标签进行独热编码,(对分类的标签一定要进行独热编码,转换为矩阵形式),并且切分数据集。

  1. 第四步,设计模型结构
from tensorflow.keras.regularizers import l2
# 加载预训练的VGG16卷积基(不包括顶部的全连接层)
vgg16_model = VGG16(include_top=False, weights='imagenet', input_shape=(256, 256, 3))

# 冻结VGG16的卷积层
for layer in vgg16_model.layers:
    layer.trainable = False

# 创建一个新的模型
model_fine_tuning = Sequential()

# 将VGG16的卷积基添加到新模型中
model_fine_tuning.add(vgg16_model)  # 添加VGG16卷积基
model_fine_tuning.add(Flatten())  # 将卷积特征图展平

# 添加新的全连接层并进行正则化
model_fine_tuning.add(Dense(512, activation='relu', kernel_regularizer=l2(0.01)))  # L2正则化
model_fine_tuning.add(Dropout(0.3))  # Dropout层,减少过拟合
model_fine_tuning.add(Dense(256, activation='relu', kernel_regularizer=l2(0.01)))  # 较小的全连接层
model_fine_tuning.add(Dropout(0.3) ) # 再次使用Dropout层

# 输出层
model_fine_tuning.add(Dense(2, activation='softmax'))  # 对于二分类问题,使用softmax

# 查看模型架构
model_fine_tuning.summary()

输出:

Layer (type)Output ShapeParam #
vgg16 (Functional)(None, 8, 8, 512)14,714,688
flatten (Flatten)(None, 32768)0
dense (Dense)(None, 512)16,777,728
dropout (Dropout)(None, 512)0
dense_1 (Dense)(None, 256)131,328
dropout_1 (Dropout)(None, 256)0
dense_2 (Dense)(None, 2)514

这里实现了一个基于预训练VGG16模型的迁移学习框架,用于图像分类任务。首先,加载了预训练的VGG16卷积基(不包括全连接层),并通过设置include_top=False来只使用卷积部分,从而利用其在ImageNet数据集上学到的特征。接着,冻结VGG16的卷积层,即通过将trainable属性设为False,使得这些层在训练过程中不进行更新。接下来,创建了一个新的Sequential模型,并将VGG16的卷积基添加进去,随后使用Flatten层将卷积特征图展平,为全连接层准备输入。为了增加模型的表达能力,添加了两个全连接层,每个层都应用了ReLU激活函数,并使用L2正则化来防止过拟合。为了进一步减少过拟合,模型还在每个全连接层后添加了Dropout层,丢弃30%的神经元。最后,输出层是一个具有两个神经元的全连接层,采用softmax激活函数,用于处理二分类问题。model_fine_tuning.summary()方法输出模型架构,帮助查看各层的结构和参数。通过这种方式,模型能够利用VGG16的预训练卷积基进行特征提取,并通过新添加的全连接层进行分类。

  1. 第五步,编译并训练模型
# 编译模型
model_fine_tuning.compile(loss='binary_crossentropy', 
                          optimizer=Adam(), 
                          metrics=['accuracy'])

datagen = ImageDataGenerator(
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    preprocessing_function=preprocess_input)  # 使用VGG16的预处理函数

# 对原始图像进行增强,并进行训练
history = model_fine_tuning.fit(datagen.flow(x_train, y_train, batch_size=32),
                                epochs=20,
                                validation_data=(x_test, y_test),
                                callbacks=[ModelCheckpoint('best_model.keras', save_best_only=True),
                                           EarlyStopping(patience=5)])

这里主要完成了对已经构建的模型(model_fine_tuning)的编译与训练过程。

  • 首先,使用compile()方法对模型进行编译,指定损失函数为binary_crossentropy,适用于二分类问题,同时选择Adam优化器,这是一种自适应学习率的优化算法,能够有效提升训练性能。在编译时,还通过metrics=['accuracy']设置了准确率作为评估指标。
  • 接着,创建了一个ImageDataGenerator对象用于数据增强,它包含多种图像变换方式,如旋转、平移、剪切、缩放、水平翻转等,这些操作可以增加数据多样性,减少过拟合,提升模型的泛化能力。
  • 此外,preprocessing_function=preprocess_input使用了VGG16预训练模型的标准预处理函数,确保输入图像的像素范围符合VGG16的训练要求。
  • 随后,通过fit()方法开始训练模型,训练数据通过datagen.flow()进行增强和批量生成,训练将在20个周期(epochs)内进行。在训练过程中,还设置了两个回调函数:ModelCheckpoint,用于保存最好的模型权重文件(best_model.keras),并且只保存验证集上表现最好的模型;
  • EarlyStopping,用于在验证集准确率不再提升时提前停止训练,patience=5表示如果5个周期内没有改进,则停止训练。这样,通过数据增强和回调函数的配合,能够有效提高训练的效果和模型的稳定性。

到这里,整个部分就基本完成了。

  1. 绘制损失和准确率图像
import matplotlib.pyplot as plt

# 获取训练过程中的损失和准确率数据
history_dict = history.history
loss = history_dict['loss']
accuracy = history_dict['accuracy']
val_loss = history_dict['val_loss']
val_accuracy = history_dict['val_accuracy']

# 绘制损失图
plt.figure(figsize=(12, 6))

# 损失图
plt.subplot(1, 2, 1)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Loss over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# 准确率图
plt.subplot(1, 2, 2)
plt.plot(accuracy, label='Training Accuracy')
plt.plot(val_accuracy, label='Validation Accuracy')
plt.title('Accuracy over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

# 展示图像
plt.tight_layout()
plt.show()

在这里插入图片描述
数据文件已经上传,感兴趣的小伙伴可以下载后自己尝试。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2278974.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网络编程 | UDP套接字通信及编程实现经验教程

1、UDP基础 传输层主要应用的协议模型有两种,一种是TCP协议,另外一种则是UDP协议。在上一篇博客文章中,已经对TCP协议及如何编程实现进行了详细的梳理讲解,在本文中,主要讲解与TCP一样广泛使用了另一种协议&#xff1a…

【Linux】线程全解:概念、操作、互斥与同步机制、线程池实现

🎬 个人主页:谁在夜里看海. 📖 个人专栏:《C系列》《Linux系列》《算法系列》 ⛰️ 道阻且长,行则将至 目录 📚一、线程概念 📖 回顾进程 📖 引入线程 📖 总结 &a…

掌握未来:从零开始学习生成式AI大模型!

随着人工智能技术的飞速发展,生成式AI大模型已成为当今科技领域的热点。从文本生成、图像创作到音乐创作,生成式AI大模型在各个领域展现出惊人的潜力。本文将带领大家从零开始,逐步学习生成式AI大模型,掌握未来的关键技术。 一、生…

多肽合成 -- 液相合成(liquid-phase peptide synthesis (LPPS))

液相合成的定义 液相合成(Solution Synthesis)是指在液体介质中进行的化学合成反应,是化学合成中一种基本且重要的方法。它涉及到将反应物溶解在溶剂中,在一定的温度、压力和其他反应条件下进行化学反应,从而生成所需的…

第23篇 基于ARM A9处理器用汇编语言实现中断<五>

Q:怎样修改HPS Timer 0定时器产生的中断周期? A:在上一期实验的基础上,可以修改按键中断服务程序,实现红色LED上的计数值递增的速率,主程序和其余代码文件不用修改。 实现以下功能:按下KEY0…

ChatGPT大模型极简应用开发-CH1-初识 GPT-4 和 ChatGPT

文章目录 1.1 LLM 概述1.1.1 语言模型和NLP基础1.1.2 Transformer及在LLM中的作用1.1.3 解密 GPT 模型的标记化和预测步骤 1.2 GPT 模型简史:从 GPT-1 到 GPT-41.2.1 GPT11.2.2 GPT21.2.3 GPT-31.2.4 从 GPT-3 到 InstructGPT1.2.5 GPT-3.5、Codex 和 ChatGPT1.2.6 …

2025春秋杯冬季赛 day1 crypto

文章目录 通往哈希的旅程小哈斯RSA1ez_rsa 通往哈希的旅程 根据提示推断是哈希函数,ai一下,推测大概率是一个sha1,让ai写一个爆破脚本即可 import hashlib# 给定目标 SHA-1 哈希值 target_hash "ca12fd8250972ec363a16593356abb1f3cf…

广播网络实验

1 实验内容 1、构建星性拓扑下的广播网络,实现hub各端口的数据广播,验证网络的连通性并测试网络效率 2、构建环形拓扑网络,验证该拓扑下结点广播会产生数据包环路 2 实验流程与结果分析 2.1 实验环境 ubuntu、mininet、xterm、wireshark、iperf 2.2 实验方案与结果分析…

RustDesk ID更新脚本

RustDesk ID更新脚本 此PowerShell脚本自动更新RustDesk ID和密码,并将信息安全地存储在Bitwarden中。 特点 使用以下选项更新RustDesk ID: 使用系统主机名生成一个随机的9位数输入自定义值 为RustDesk生成新的随机密码将RustDesk ID和密码安全地存储…

JavaEE之常见的锁策略

前面我们学习过线程不安全问题,我们通过给代码加锁来解决线程不安全问题,在生活中我们也知道有很多种类型的锁,同时在代码的世界当中,也对应着很多类型的锁,今天我们对锁一探究竟! 1. 常见的锁策略 注意: …

数字图像处理:实验二

任务一: 将不同像素(32、64和256)的原图像放大为像素大 小为1024*1024的图像(图像自选) 要求:1)输出一幅图,该图包含六幅子图,第一排是原图,第 二排是对应放大…

WEB渗透技术研究与安全防御

目录 作品简介I IntroductionII 1 网络面临的主要威胁1 1.1 技术安全1 2 分析Web渗透技术2 2.1 Web渗透技术的概念2 2.2 Web漏洞产生的原因2 2.3 注入测试3 2.3.1 注入测试的攻击流程3 2.3.2 进行一次完整的Sql注入测试4 2.3.3 Cookie注入攻击11 3 安全防御方案设计…

使用 Thermal Desktop 进行航天器热分析

介绍 将航天器保持在运行温度下的轨道上是一个具有挑战性的问题。航天器需要处理太空非常寒冷的背景温度,同时还要管理来自内部组件、地球反照率和太阳辐射的高热负荷。航天器在轨道上可以进行的各种轨道机动使解决这个问题变得更加复杂。 Thermal Desktop 是一款…

乘联会:1月汽车零售预计175万辆 环比暴跌33.6%

快科技1月18日消息,据乘联会的初步推算,2025年1月狭义乘用车零售总市场规模预计将达到约175万辆左右。与去年同期相比,这一数据呈现了-14.6%的同比下降态势;而相较于上个月,则出现了-33.6%的环比暴跌情况。 为了更清晰…

SQL 递归 ---- WITH RECURSIVE 的用法

SQL 递归 ---- WITH RECURSIVE 的用法 开发中遇到了一个需求,传递一个父类id,获取父类的信息,同时获取其所有子类的信息。 首先想到的是通过程序中去递归查,但这种方法着实孬了一点,于是想,sql能不能递归查…

【机器学习实战入门项目】使用深度学习创建您自己的表情符号

深度学习项目入门——让你更接近数据科学的梦想 表情符号或头像是表示非语言暗示的方式。这些暗示已成为在线聊天、产品评论、品牌情感等的重要组成部分。这也促使数据科学领域越来越多的研究致力于表情驱动的故事讲述。 随着计算机视觉和深度学习的进步,现在可以…

windows 搭建flutter环境,开发windows程序

环境安装配置: 下载flutter sdk https://docs.flutter.dev/get-started/install/windows 下载到本地后,随便找个地方解压,然后配置下系统环境变量 编译windows程序本地需要安装vs2019或更新的开发环境 主要就这2步安装后就可以了&#xff0…

【Linux】15.Linux进程概念(4)

文章目录 程序地址空间前景回顾C语言空间布局图:代码1代码2代码3代码4代码5代码6代码7 程序地址空间前景回顾 历史核心问题: pid_t id fork(); if(id 0) else if(id>0) 为什么一个id可以放两个值呢?之前没有仔细讲。 C语言空间布局图&am…

一文读懂服务器的HBA卡

什么是 HBA 卡 HBA 卡,全称主机总线适配器(Host Bus Adapter) ,是服务器与存储装置间的关键纽带,承担着输入 / 输出(I/O)处理及物理连接的重任。作为一种电路板或集成电路适配器,HBA…

oracle使用case when报错ORA-12704字符集不匹配原因分析及解决方法

问题概述 使用oracle的case when函数时,报错提示ORA-12704字符集不匹配,如下图,接下来分析报错原因并提出解决方法。 样例演示 现在有一个TESTTABLE表,本表包含的字段如下图所示,COL01字段是NVARCHAR2类型&#xff0…