Python图像识别实战(五):卷积神经网络CNN模型图像二分类预测结果评价(附源码和实现效果)

news2024/11/22 15:00:11

前面我介绍了可视化的一些方法以及机器学习在预测方面的应用,分为分类问题(预测值是离散型)和回归问题(预测值是连续型)(具体见之前的文章)。

从本期开始,我将做一个关于图像识别的系列文章,让读者慢慢理解python进行图像识别的过程、原理和方法,每一篇文章从实现功能、实现代码、实现效果三个方面进行展示。

实现功能:

卷积神经网络CNN模型图像二分类预测结果评价

实现代码:

import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
from collections import Counter
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import roc_auc_score
import itertools
from pylab import mpl
import seaborn as sns

class Solution():
    #==================读取图片=================================
    def read_image(self,paths):
        os.listdir(paths)
        filelist = []
        for root, dirs, files in os.walk(paths):
            for file in files:
                if os.path.splitext(file)[1] == ".png":
                    filelist.append(os.path.join(root, file))

        print(filelist)
        return filelist

    #==================图片数据转化为数组==========================
    def im_array(self,paths):
        M=[]
        for filename in paths:
            im=Image.open(filename)
            im_L=im.convert("L")                #模式L
            Core=im_L.getdata()
            arr1=np.array(Core,dtype='float32')/255.0
            list_img=arr1.tolist()
            M.extend(list_img)
        return M

    def CNN_model(self,train_images, train_lables):
        # ============构建卷积神经网络并保存=========================
        model = models.Sequential()
        model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(128, 128, 1)))  # 过滤器个数,卷积核尺寸,激活函数,输入形状
        model.add(layers.MaxPooling2D((2, 2)))  # 池化层
        model.add(layers.Conv2D(64, (3, 3), activation='relu'))
        model.add(layers.MaxPooling2D((2, 2)))
        model.add(layers.Conv2D(64, (3, 3), activation='relu'))
        model.add(layers.Flatten())  # 降维
        model.add(layers.Dense(64, activation='relu'))  # 全连接层
        model.add(layers.Dense(2, activation='softmax'))  # 注意这里参数,我只有两类图片,所以是2.
        model.summary()  # 显示模型的架构
        model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
        return model

if __name__=='__main__':
    Object1=Solution()
    # =================数据读取===============
    path1="D:\DCTDV2\dataset\\train\\"
    test1 = "D:\DCTDV2\dataset\\test\\"

    pathDir = os.listdir(path1)
    print(pathDir)
    pathDir=pathDir[1:5]
    print(pathDir)

    for a in pathDir:
        path2=path1+a
        test2=test1+a
        filelist_1=Object1.read_image(path1+"Norm")
        filelist_2=Object1.read_image(path2)
        filelist_all=filelist_1+filelist_2
        M=Object1.im_array(filelist_all)
        train_images=np.array(M).reshape(len(filelist_all),128,128)#输出验证一下(400, 128, 128)
        label=[0]*len(filelist_1)+[1]*len(filelist_2)
        train_lables=np.array(label)        #数据标签
        train_images = train_images[..., np.newaxis]        #数据图片
        print(train_images.shape)#输出验证一下(400, 128, 128, 1)
        print(train_lables.shape)

        # ===================准备测试数据==================
        filelist_1T = Object1.read_image(test1+"Norm")
        filelist_2T = Object1.read_image(test2)
        filelist_allT = filelist_1T + filelist_2T
        print(filelist_allT)
        N = Object1.im_array(filelist_allT)
        dict_label = {0: 'norm', 1: 'IgaK'}
        test_images = np.array(N).reshape(len(filelist_allT), 128, 128)
        label = [0] * len(filelist_1T) + [1] * len(filelist_2T)
        test_lables = np.array(label)  # 数据标签
        test_images = test_images[..., np.newaxis]  # 数据图片
        print(test_images.shape)  # 输出验证一下(100, 128, 128, 1)
        print(test_lables.shape)

        # #===================训练模型=============
        model=Object1.CNN_model(train_images, train_lables)
        CnnModel=model.fit(train_images, train_lables, epochs=20)
        # model.save('D:\电池条带V2\model\my_model.h5')  # 保存为h5模型
        # tf.keras.models.save_model(model,"F:\python\moxing\model")#这样是pb模型
        print("模型保存成功!")

        # history列表
        print(CnnModel.history.keys())
        font = {'family': 'Times New Roman','size': 12,}
        sns.set(font_scale=1.2)

        plt.plot(CnnModel.history['loss'])
        plt.title('model loss')
        plt.ylabel('loss')
        plt.xlabel('epoch')
        plt.savefig('D:\\DCTDV2\\result\\V1\\loss' + "\\" + '%s.tif' % a,bbox_inches='tight',dpi=600)
        plt.show()

        plt.plot(CnnModel.history['accuracy'])
        plt.title('model accuracy')
        plt.ylabel('accuracy')
        plt.xlabel('epoch')
        plt.savefig('D:\\DCTDV2\\result\\V1\\accuracy' + "\\" + '%s.tif' % a,bbox_inches='tight',dpi=600)
        plt.show()

        # #===================预测图像=============
        predict_label=[]
        prob_label=[]
        for i in test_images:
            i=np.array([i])
            predictions_single=model.predict(i)
            print(np.argmax(predictions_single))
            out_c_1 = np.array(predictions_single)[:, 1]
            prob_label.extend(out_c_1)
            predict_label.append(np.argmax(predictions_single))

        print(prob_label)
        print(predict_label)
        print(list(test_lables))
        count = Counter(predict_label)
        print(count)

        TP = FP = FN = TN = 0
        for i in range(len(predict_label)):
            if predict_label[i]==1:
                if list(test_lables)[i]==1:
                    TP=TP+1
                elif list(test_lables)[i]==0:
                    FP=FP+1
            elif predict_label[i]==0:
                if list(test_lables)[i]==1:
                    FN=FN+1
                elif list(test_lables)[i]==0:
                    TN=TN+1

        print(TP,FP,FN,TN)
        deathc_recall=TP/(TP+FN)
        savec_recall=TN/(FP+TN)
        print(deathc_recall)
        print(savec_recall)

        print("--------------------")

        cm = np.arange(4).reshape(2, 2)
        cm[0, 0] = TN
        cm[0, 1] = FP
        cm[1, 0] = FN
        cm[1, 1] = TP
        classes = [0, 1]

        plt.figure()
        plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
        plt.title('Confusion matrix')
        tick_marks = np.arange(len(classes))
        plt.xticks(tick_marks, classes, rotation=0)
        plt.yticks(tick_marks, classes)
        thresh = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, cm[i, j], horizontalalignment="center", color="red" if cm[i, j] > thresh else "black")
            plt.tight_layout()
            plt.ylabel('True label')
            plt.xlabel('Predicted label')
        plt.savefig('D:\\DCTDV2\\result\\V1\\cm' + "\\" + '%s.tif' % a,bbox_inches='tight',dpi=600)
        plt.show()

        fpr, tpr, thresholds = roc_curve(list(test_lables), prob_label, pos_label=1)
        Auc_score = roc_auc_score(list(test_lables), predict_label)
        Auc = auc(fpr, tpr)
        print(Auc_score, Auc)
        plt.plot(fpr, tpr, 'b', label='AUC = %0.2f' % Auc)  # 生成ROC曲线
        plt.legend(loc='lower right')
        plt.plot([0, 1], [0, 1], 'r--')
        plt.xlim([0, 1])
        plt.ylim([0, 1])
        plt.ylabel('True positive rate')
        plt.xlabel('False positive rate')
        plt.savefig('D:\\DCTDV2\\result\\V1\\roc' + "\\" + '%s.tif' % a,bbox_inches='tight',dpi=600)
        plt.show()

        plt.figure()
        precision, recall, thresholds = precision_recall_curve(list(test_lables), predict_label)
        plt.title('Precision/Recall Curve')  # give plot a title
        plt.xlabel('Recall')  # make axis labels
        plt.ylabel('Precision')
        plt.plot(precision, recall)
        plt.savefig('D:\\DCTDV2\\result\\V1\\pr' + "\\" + '%s.tif' % a,bbox_inches='tight',dpi=600)
        plt.show()

实现效果:

由于数据为非公开数据,仅展示几个图像的效果,有问题可以后台联系我。

 

 

 

本人读研期间发表5篇SCI数据挖掘相关论文,现在在某研究院从事数据挖掘相关工作,对数据挖掘有一定的认知和理解,会不定期分享一些关于python机器学习、深度学习、数据挖掘基础知识与案例。

致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。

关注V订阅号:数据杂坛可在后台联系我获取相关数据集和源码,送有关数据分析、数据挖掘、机器学习、深度学习相关的电子书籍。
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/126973.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

BOOT进程控制模式与故障排错

1. BOOT reboot and shutdown—使用systemctl 命令。 systemctl poweroff–关机 systemctl reboot --重启 systemctl halt 禁用CPU 在7版本中使用systemctl 工具。 选择systemd target graphical.target 桌面图形模式 multi-user.target 多用户模式–命令行 rescue.target 救援…

Linux驱动开发基础__总线设备驱动模型

目录 1 驱动编写的三种方法 1.1 传统写法 1.2 总线设备驱动模型 1.3 设备树 2 在 Linux 中实现“分离”:Bus/Dev/Drv 模型 3 匹配原则 4 函数调用关系 1 驱动编写的三种方法 1.1 传统写法 1.2 总线设备驱动模型 引入platform_device、platform_driver&…

二叉数题型2

目录 二叉搜索树的众数 二叉树的最近公共祖先 修剪二叉树 二叉搜索树的众数 问题描述: 给你一个含重复值的二叉搜索树(BST)的根节点 root ,找出并返回 BST 中的所有 众数(即,出现频率最高的元素&#…

PROJ 9.1.1源码下载编译(Win10+VS2022)

目录PROJ什么是PROJPROJ下载方式资源结构编译PROJ打包编译成功的库PROJ 什么是PROJ Proj是一个免费的GIS工具。 它专注于地图投影的表达,以及转换。采用一种非常简单明了的投影表达PROJ,比其它的投影定义简单,但很明显。很容易就能看到各种…

无人机倾斜摄影测量技术的优势有哪些?

传统的地理信息获取工作一般是通过人工测量的方式进行,但这样的测量方式具有工作强度大、成本高等问题。随着现代科技的不断发展,测绘行业对地理信息数据的准确性、时效性要求也越来越高,人工成本和时间成本也为行业带来了巨大的压力。因此&a…

GIT回退到指定版本的两种方法(reset/revert)

实现多人合作程序开发的过程中,我们有时会出现错误提交的情况,此时我们希望能撤销提交操作,让程序回到提交前的样子,本文总结了两种解决方法:reset、revert。 命令特点reset该命令会强行覆盖当前版本和要回退的版本之…

ArcGIS基础实验操作100例--实验15设置字段属性域

本实验专栏来自于汤国安教授《地理信息系统基础实验操作100例》一书 实验平台:ArcGIS 10.6 实验数据:请访问实验1(传送门) 基础编辑篇--实验15 设置字段属性域 目录 一、实验背景 二、实验数据 三、实验步骤 (1&a…

如何用Sonic云真机打王者

使用Sonic进行跨网段部署,助力海外业务的公司进行专项检测。提供定时任务充分利用无人值守时间回归UI测试,省时省力。自研随机事件测试与UI遍历测试,支持打通Jenkins的DevOps流程,Sonic提供图像识别,后续还会添加poco控…

ECS-弹性容器服务 - Part 2

68-ECS-弹性容器服务 - Part 2 Hello大家好,我们今天继续ECS的内容。 Service load balancing 之前的课时讨论过,在ECS集群上创建的ECS服务支持AWS负载均衡器,而应用程序负载均衡器和ECS服务通常是一个很好的搭配,因为应用程序负…

Docker 基础概念介绍

一 什么是 docker ? Docker 是一个开源的应用容器引擎,让开发者可以打包他们的应用以及依赖包到一个可移植的镜像中,然后发布到任何流行的 Linux或Windows操作系统的机器上,也可以实现虚拟化。容器是完全使用沙箱机制,…

【nowcoder】笔试强训Day13

目录 一、选择题 二、编程题 2.1参数解析 2.2跳石板 一、选择题 1.一个关系数据库文件中的各条记录 () 。 A. 前后顺序不能任意颠倒,一定要按照输入的顺序排列 B. 前后顺序可以任意颠倒,不影响库中的数据关系 C. 前后顺序…

前端面试题之计算机网络篇--HTTP协议

HTTP协议 1. GET和POST的请求的区别 GET和POST方法 GET和POST方法都是HTTP中的方法 什么是 HTTP? 超文本传输协议(Hypertext Transfer Protocol,缩写 HTTP)旨在启用客户端和服务器之间的通信。 HTTP 充当客户端和服务器之间的…

Android进阶——Javac编译解析

Javac编译器 1.Javac的源码与调试 Javac的源码下载地址:Javac的源码下载地址,在Myeclipse中新建项目Compiler_javac,把源码复制到项目中。 Javac的源码目录: 从Sun Javac的代码来看,编译过程大致可以分为3个过程&…

测试工程师正遭「革命」 AI将改写测试模式

文章目录❤️‍🔥 软件测试的现状❣️ 功能测试的短板❣️ 过于的依赖工具❤️‍🔥 测试行业的两极分化❤️‍🔥 纯功能测试人员应该如何破局❣️ 龙测 AI TestOps 云平台❣️ AI TestOps 亮相 TICA❣️ AI TestOps 所实现的混合模型解决方案…

相关系数(皮尔逊pearson相关系数和斯皮尔曼spearman等级相关系数)

目录 总体皮尔逊Person相关系数: 样本皮尔逊Person相关系数: 两点总结: 假设检验:(可结合概率论课本假设检验部分) 皮尔逊相关系数假设检验: 更好的方法:p值判断方法 皮尔逊相…

lua调用c动态库实例

简介 Lua 是一种轻量小巧的脚本语言,用标准C语言编写并以源代码形式开放, 其设计目的是为了嵌入应用程序中,从而为应用程序提供灵活的扩展和定制功能。 特点 轻量级: 它用标准C语言编写并以源代码形式开放,编译后仅仅一百余K&a…

STM32/51单片机实训day4——RFID数据读取|RC522|串口数据收发、可模拟RFID (三) 仿真

目录 1 任务指导 2 实验步骤 3 串口调试 4 USART配置 5 fputs函数重写 内 容:能够读取RFID卡S50的ID——编程实现串口数据收发 学 时:3学时 知识点:电路图设计、USART配置 重点: USART配置 难点:USART配置 时…

赶快升级吧!PHP8比PHP5快41倍,比PHP7快3倍

本文得出的结论,归结于仅运行纯CPU任务的脚本的基准测试结果,不需要I/O操作的任务,例如访问文件、网络或数据库连接。 这些是纯 CPU 基准测试。它们并未涵盖 PHP 性能的所有方面,并且它们可能无法代表实际情况。然而,结…

用Python标准库统计CSDN阅读量

urllib基础 一般做爬虫其实很少有推荐urllib的,但urllib乃是Python标准库成员,在要求比较简单的情况下,采用urllib还是比较方便的。 作为爬虫入门必学包,urllib最常用的函数一定是urllib.request中的urlopen。其返回对象是HTTPR…

ES学习路程(二)

关于ES第一篇是在Linux安装,为了方便我在windows搭建一套ES和kibana版本(7.15.0) 第一步:下载安装ES在windows 官网下载相应版本的es和kibana: https://www.elastic.co/cn/downloads/past-releases/elasticsearch-7…