LeNet实验 四分类 与 四分类变为多个二分类

news2024/11/26 10:41:49

 

目录

1. 划分二分类

2. 训练独立的二分类模型

3. 二分类预测结果代码

4. 二分类预测结果

5 改进训练模型

6 优化后 预测结果代码

 7 优化后预测结果

8 训练四分类模型 

9 预测结果代码

10 四分类结果识别


1. 划分二分类

可以根据不同的类别进行多个划分,以实现NonDemented为例,划分为NonDemented和Demented两类,不属于NonDemented的全都属于Demented

2. 训练独立的二分类模型

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from 文件准备 import data_dir

# 数据生成器
train_datagen = ImageDataGenerator(
    rescale=1./255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    validation_split=0.2  # 20%用于验证
)

train_generator = train_datagen.flow_from_directory(
    data_dir,
    target_size=(28, 28),
    batch_size=32,
    class_mode='binary',
    subset='training'
)

validation_generator = train_datagen.flow_from_directory(
    data_dir,
    target_size=(28, 28),
    batch_size=32,
    class_mode='binary',
    subset='validation'
)

# 构建LeNet-5模型
model = models.Sequential()
model.add(layers.Conv2D(6, (5, 5), activation='relu', input_shape=(28, 28, 3), padding='same'))
model.add(layers.AveragePooling2D((2, 2)))
model.add(layers.Conv2D(16, (5, 5), activation='relu', padding='same'))
model.add(layers.AveragePooling2D((2, 2)))
model.add(layers.Conv2D(120, (5, 5), activation='relu', padding='same'))
model.add(layers.Flatten())
model.add(layers.Dense(84, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))

# 编译模型
model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])

# 训练模型
model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // train_generator.batch_size,
    epochs=10,
    validation_data=validation_generator,
    validation_steps=validation_generator.samples // validation_generator.batch_size
)

# 保存模型
model.save('lenet_binary_classification_model.h5')

3. 预测结果代码

import tensorflow as tf
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt

# 加载模型
model = tf.keras.models.load_model('lenet_binary_classification_model.h5')

# 预处理图像
def preprocess_image(img_path):
    img = image.load_img(img_path, target_size=(28, 28))
    img_array = image.img_to_array(img) / 255.0
    img_array = np.expand_dims(img_array, axis=0)
    return img_array

# 预测图像
img_path = 'D:\Pycharm_workspace\LeNet实验_二分类\Demented\moderateDem24.jpg'  # 测试图像路径
img_array = preprocess_image(img_path)
prediction = model.predict(img_array)
predicted_class = 'Demented' if prediction[0][0] > 0.5 else 'NonDemented'

print(f'The predicted class is: {predicted_class}')

# 显示图像
img = image.load_img(img_path, target_size=(28, 28))
plt.imshow(img)
plt.title(f'Predicted: {predicted_class}')
plt.show()

4. 预测结果

Demented结果

 NonDemented结果没有。。。。。。

竟然全都没有。。。。因为预测的全部都是Demented

疯狂找原因中

猜测是像素太低使得训练的模型准确率太低

于是重新训练

5 改进训练模型

进行重新训练

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt

# 定义LeNet模型
def create_lenet_model(input_shape):
    model = Sequential([
        Conv2D(6, (5, 5), activation='relu', input_shape=input_shape, padding='same'),
        MaxPooling2D((2, 2), strides=2),
        Conv2D(16, (5, 5), activation='relu'),
        MaxPooling2D((2, 2), strides=2),
        Flatten(),
        Dense(120, activation='relu'),
        Dense(84, activation='relu'),
        Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

# 数据增强和数据生成器
train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

# 训练数据生成器
train_generator = train_datagen.flow_from_directory(
    'D:\Pycharm_workspace\LeNet实验_二分类\image',
    target_size=(176, 208),
    batch_size=32,
    class_mode='binary',
    subset='training'
)

# 验证数据生成器
validation_generator = train_datagen.flow_from_directory(
    'D:\Pycharm_workspace\LeNet实验_二分类\image',
    target_size=(176, 208),
    batch_size=32,
    class_mode='binary',
    subset='validation'
)

# 创建并训练模型
input_shape = (176, 208, 3)
model = create_lenet_model(input_shape)
history = model.fit(train_generator, epochs=10, validation_data=validation_generator)

# 保存模型
model.save('dementia_classification_model.h5')

# 绘制训练和验证损失
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('训练和验证损失')
plt.xlabel('时期')
plt.ylabel('损失')
plt.legend()
plt.show()

# 绘制训练和验证准确率
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('训练和验证准确率')
plt.xlabel('时期')
plt.ylabel('准确率')
plt.legend()
plt.show()

 这里还有图形画loss与准确率但是我忘记保存了,就用控制台的输出

 可以看到loss值非常小而且准确率是100

6 优化后 预测结果代码

import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt
import os

# 加载模型
model = load_model('dementia_classification_model.h5')

# 定义类别标签
class_labels = ['Demented', 'NonDemented']


# 预测函数
def predict_image(img_path):
    img = image.load_img(img_path, target_size=(176, 208))
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array /= 255.0

    prediction = model.predict(img_array)
    predicted_class = class_labels[int(prediction[0] > 0.5)]

    # 显示图像和预测结果
    plt.imshow(image.load_img(img_path))
    plt.title(f'Predicted: {predicted_class}')
    plt.axis('off')
    plt.show()


# 预测并展示结果
img_path = r'D:\Pycharm_workspace\LeNet实验_二分类\image\NonDemented\nonDem1.jpg'  # 替换为你的图片路径
predict_image(img_path)

 7 优化后预测结果

 图片与预测结果对应上了(右侧是图片链接可以看到是Dem的类型)

 NonDem的也是对应上了

 就此训练完成

 

8 训练四分类模型 

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt

# 定义LeNet模型
def create_lenet_model(input_shape):
    model = Sequential([
        Conv2D(6, (5, 5), activation='relu', input_shape=input_shape, padding='same'),
        MaxPooling2D((2, 2), strides=2),
        Conv2D(16, (5, 5), activation='relu'),
        MaxPooling2D((2, 2), strides=2),
        Flatten(),
        Dense(120, activation='relu'),
        Dense(84, activation='relu'),
        Dense(4, activation='softmax')
    ])
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

# 数据增强和数据生成器
train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

# 训练数据生成器
train_generator = train_datagen.flow_from_directory(
    'D:\Pycharm_workspace\LeNet实验_四分类\image',
    target_size=(176, 208),
    batch_size=32,
    class_mode='categorical',
    subset='training'
)

# 验证数据生成器
validation_generator = train_datagen.flow_from_directory(
    'D:\Pycharm_workspace\LeNet实验_四分类\image',
    target_size=(176, 208),
    batch_size=32,
    class_mode='categorical',
    subset='validation'
)

# 创建并训练模型
input_shape = (176, 208, 3)
model = create_lenet_model(input_shape)
history = model.fit(train_generator, epochs=10, validation_data=validation_generator)

# 保存模型
model.save('dementia_classification_model.h5')

# 绘制训练和验证损失
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('训练和验证损失')
plt.xlabel('时期')
plt.ylabel('损失')
plt.legend()
plt.show()

# 绘制训练和验证准确率
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('训练和验证准确率')
plt.xlabel('时期')
plt.ylabel('准确率')
plt.legend()
plt.show()

 

loss值与准确率的变化图

可以看到才第四轮准确率就已经很高了 

9 预测结果代码

import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt

# 加载模型
model = load_model('dementia_classification_model.h5')

# 定义类别标签
class_labels = ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented']

# 预测函数
def predict_image(img_path):
    img = image.load_img(img_path, target_size=(176, 208))
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array /= 255.0

    prediction = model.predict(img_array)
    predicted_class = class_labels[np.argmax(prediction)]

    # 显示图像和预测结果
    plt.imshow(image.load_img(img_path))
    plt.title(f'Predicted: {predicted_class}')
    plt.axis('off')
    plt.show()

# 预测并展示结果
img_path = r'D:\Pycharm_workspace\LeNet实验_四分类\image\VeryMildDemented\verymildDem0.jpg'  # 你的图片路径
predict_image(img_path)

 

10 四分类结果识别

1 MildDem成功识别(右侧有图片名称)

2 ModerateDem 成功识别

3 NonDem成功识别

 4 VeryMildDem成功识别

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1938619.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【数据分享】2013-2022年我国省市县三级的逐月SO2数据(excel\shp格式\免费获取)

空气质量数据是在我们日常研究中经常使用的数据!之前我们给大家分享了2000——2022年的省市县三级的逐月PM2.5数据和2013-2022年的省市县三级的逐月CO数据(均可查看之前的文章获悉详情)! 本次我们分享的是我国2013——2022年的省…

Langchain-Chatchat-Ubuntu服务器本地安装部署笔记

Langchain-Chatchat(原Langchain-ChatGLM)基于 Langchain 与 ChatGLM 等语言模型的本地知识库问答 | Langchain-Chatchat (formerly langchain-ChatGLM), local knowledge based LLM (like ChatGLM) QA app with langchain。 开源网址:https:…

基于NeRF的路面重建算法——RoME / EMIE-MAP / RoGS

基于NeRF的路面重建算法——RoME / EMIE-MAP / RoGS 1. RoMe1.1 Mesh Initialization / Waypoint Sampling1.2 Optimization1.3 Experiments 2. EMIE-MAP2.1 Road Surface Representation based on Explicit mesh and Implicit Encoding2.2 Optimizing Strategies2.3 Experimen…

基于面向对象和递归的拦截器设计模式

1 定义 拦截器模式(Interceptor Pattern),是指提供一种通用的扩展机制,可以在业务操作前后提供一些切面的(Cross-Cutting)的操作。这些切面操作通常是和业务无关的,比如日志记录、性能统计、安…

SciPy版本与Python和NumPy各个版本的兼容性

但是现在我用Scipy1.13.1,Python3.10,NumPy2.0.0,使用Scipy时会报错,将NumPy 版本降低为1.26.4以后,就没有报错了。

C++ | Leetcode C++题解之第268题丢失的数字

题目&#xff1a; 题解&#xff1a; class Solution { public:int missingNumber(vector<int>& nums) {int n nums.size();int total n * (n 1) / 2;int arrSum 0;for (int i 0; i < n; i) {arrSum nums[i];}return total - arrSum;} };

【MySQL】一些业务场景常见的查询,比如实现多表字段同步,递归查询等

目录 快速加注释多表关联查询更新多个字段循环查询子级方法1&#xff1a;递归查询方法2&#xff1a;循环查询 快速加注释 使用ALTER TABLE语句可以修改表结构&#xff0c;包括添加注释。以下是添加注释的语法&#xff1a; ALTER TABLE 表名 MODIFY COLUMN 列名 列类型 COMMEN…

【开源库学习】libodb库学习(三)

4 查询数据库 如果我们不知道我们正在寻找的对象的标识符&#xff0c;我们可以使用查询在数据库中搜索符合特定条件的对象。ODB查询功能是可选的&#xff0c;我们需要使用--generate-query ODB编译器选项显式请求生成必要的数据库支持代码。 ODB提供了一个灵活的查询API&#x…

LeetCode 热题 HOT 100 (001/100)【宇宙最简单版】

【链表】 No. 0160 相交链表 【简单】&#x1f449;力扣对应题目指路 希望对你有帮助呀&#xff01;&#xff01;&#x1f49c;&#x1f49c; 如有更好理解的思路&#xff0c;欢迎大家留言补充 ~ 一起加油叭 &#x1f4a6; ⭐题目描述&#xff1a;两个单链表的头节点 headA 和 …

51单片机嵌入式开发:13、STC89C52RC 之 RS232与电脑通讯

STC89C52RC 之 RS232与电脑通讯 第十三节课&#xff0c;RS232与电脑通讯1 概述2 Uart介绍2.1 概述2.2 STC89C52UART介绍2.3 STC89C52 UART寄存器介绍2.4 STC89C52 UART操作 3 C51 UART总结 第十三节课&#xff0c;RS232与电脑通讯 1 概述 RS232&#xff08;Recommended Stand…

huawei USG6001v1学习----NAT和智能选路

目录 1.NAT的分类 2.智能选路 1.就近选路 2.策略路由 3.智能选路 NAT:&#xff08;Network Address Translation&#xff0c;网络地址转换&#xff09; 指网络地址转换&#xff0c;1994年提出的。NAT是用于在本地网络中使用私有地址&#xff0c;在连接互联网时转而使用全局…

Java | Leetcode Java题解之第263题丑数

题目&#xff1a; 题解&#xff1a; class Solution {public boolean isUgly(int n) {if (n < 0) {return false;}int[] factors {2, 3, 5};for (int factor : factors) {while (n % factor 0) {n / factor;}}return n 1;} }

数学建模--优劣解距离法TOPSIS

目录 简介 TOPSIS法的基本步骤 延伸 优劣解距离法&#xff08;TOPSIS&#xff09;的历史发展和应用领域有哪些&#xff1f; 历史发展 应用领域 如何准确计算TOPSIS中的理想解&#xff08;PIS&#xff09;和负理想解&#xff08;NIS&#xff09;&#xff1f; TOPSIS方法在…

<数据集>手势识别数据集<目标检测>

数据集格式&#xff1a;VOCYOLO格式 图片数量&#xff1a;2400张 标注数量(xml文件个数)&#xff1a;2400 标注数量(txt文件个数)&#xff1a;2400 标注类别数&#xff1a;5 标注类别名称&#xff1a;[fist, no_gesture, like, ok, palm] 序号类别名称图片数框数1fist597…

Qt中在pro中实现一些宏定义

在pro文件中利用 DEFINES 定义一些宏定义供工程整体使用。&#xff08;和在cpp/h文件文件中定义使用有点类似&#xff09;可以利用pro的中的宏定义实现一些全局的判断 pro中实现 #自定义一个变量 DEFINES "PI\"3.1415926\"" #自定义宏 DEFINES "T…

XLua原理(一)

项目中活动都是用xlua开发的&#xff0c;项目周更热修也是用xlua的hotfix特性来做的。现研究底层原理&#xff0c;对于项目性能有个更好的把控。 本文认为看到该文章的人已具备使用xlua开发的能力&#xff0c;只研究介绍下xlua的底层实现原理。 一.lua和c#交互原理 概括&…

香橙派AIpro部署边缘端夜莺监控

文章目录 硬件信息硬件简介技术路线硬件参数到手实拍接口详情图应用场景相关资源香橙派官方昇腾论坛 开箱使用准备工作上电准备启动设备开发板状态 连接设备方式一、显示器直连方式二、Micro Usb 数据线串口连接方式三、Micro Usb 数据线方式网络直连方式四、Micro Usb数据线方…

R语言画散点图-饼图-折线图-柱状图-箱线图-直方图-等高线图-曲线图-热力图-雷达图-韦恩图(二D)

R语言画散点图-饼图-折线图-柱状图-箱线图-直方图-等高线图-曲线图-热力图-雷达图-韦恩图&#xff08;二D&#xff09; 散点图示例解析效果 饼图示例解析效果 折线图示例解析效果 柱状图示例解析效果 箱线图示例解析效果 直方图示例解析效果 等高线图使用filled.contour函数示例…

Pixel6 GKI 内核编译

前言 前段时间写了一篇关于pixel4 Android内核编译编译内核的流程。 但是随着Android版本的提升Google开始推崇GKI方式发内核模式,这种模式可以方便供应商剥离内核和驱动的捆绑性&#xff0c;官方抽象出一部分接口(GKI)提供给产生使用极大便利和解耦开发复杂性。 在pixel4 And…

python-爬虫实例(1):获取京东商品评论

目录 前言 道路千万条&#xff0c;安全第一条 爬虫不谨慎&#xff0c;亲人两行泪 获取京东商品评论信息 一、实例示范 二、爬虫四步走 1.UA伪装 2.获取Url 3.发送请求 4获取响应数据进行解析并保存 总结 前言 道路千万条&#xff0c;安全第一条 爬虫不谨慎&#xff0c;亲…