卷积神经网络(CNN):乳腺癌识别.ipynb

news2024/12/27 14:56:06

文章目录

  • 一、前言
  • 一、设置GPU
  • 二、导入数据
    • 1. 导入数据
    • 2. 检查数据
    • 3. 配置数据集
    • 4. 数据可视化
  • 三、构建模型
  • 四、编译
  • 五、训练模型
  • 六、评估模型
    • 1. Accuracy与Loss图
    • 2. 混淆矩阵
    • 3. 各项指标评估

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别
  • 卷积神经网络(AlexNet)鸟类识别
  • 卷积神经网络(CNN)识别验证码

来自专栏:机器学习与深度学习算法推荐

一、设置GPU

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
    
import matplotlib.pyplot as plt
import os,PIL,pathlib
import numpy as np
import pandas as pd
import warnings
from tensorflow import keras

warnings.filterwarnings("ignore")             #忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False    # 用来正常显示负号

二、导入数据

1. 导入数据

import pathlib

data_dir = "./32-data"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)
图片总数为: 13403
batch_size = 16
img_height = 50
img_width  = 50
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 13403 files belonging to 2 classes.
Using 10723 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 13403 files belonging to 2 classes.
Using 2680 files for validation.
class_names = train_ds.class_names
print(class_names)
['0', '1']

2. 检查数据

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break
(16, 50, 50, 3)
(16,)

3. 配置数据集

AUTOTUNE = tf.data.AUTOTUNE

def train_preprocessing(image,label):
    return (image/255.0,label)

train_ds = (
    train_ds.cache()
    .shuffle(1000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

val_ds = (
    val_ds.cache()
    .shuffle(1000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

4. 数据可视化

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")

class_names = ["乳腺癌细胞","正常细胞"]

for images, labels in train_ds.take(1):
    for i in range(15):
        plt.subplot(4, 5, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)

        # 显示图片
        plt.imshow(images[i])
        # 显示标签
        plt.xlabel(class_names[labels[i]-1])

plt.show()

在这里插入图片描述

三、构建模型

import tensorflow as tf

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(filters=16,kernel_size=(3,3),padding="same",activation="relu",input_shape=[img_width, img_height, 3]),
    tf.keras.layers.Conv2D(filters=16,kernel_size=(3,3),padding="same",activation="relu"),

    tf.keras.layers.MaxPooling2D((2,2)),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Conv2D(filters=16,kernel_size=(3,3),padding="same",activation="relu"),
    tf.keras.layers.MaxPooling2D((2,2)),
    tf.keras.layers.Conv2D(filters=16,kernel_size=(3,3),padding="same",activation="relu"),
    tf.keras.layers.MaxPooling2D((2,2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(2, activation="softmax")
])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 50, 50, 16)        448       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 50, 50, 16)        2320      
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 25, 25, 16)        0         
_________________________________________________________________
dropout (Dropout)            (None, 25, 25, 16)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 25, 25, 16)        2320      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 12, 12, 16)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 12, 12, 16)        2320      
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 6, 6, 16)          0         
_________________________________________________________________
flatten (Flatten)            (None, 576)               0         
_________________________________________________________________
dense (Dense)                (None, 2)                 1154      
=================================================================
Total params: 8,562
Trainable params: 8,562
Non-trainable params: 0
_________________________________________________________________

四、编译

model.compile(optimizer="adam",
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])

五、训练模型

from tensorflow.keras.callbacks import ModelCheckpoint, Callback, EarlyStopping, ReduceLROnPlateau, LearningRateScheduler

NO_EPOCHS = 100
PATIENCE  = 5
VERBOSE   = 1

# 设置动态学习率
annealer = LearningRateScheduler(lambda x: 1e-3 * 0.99 ** (x+NO_EPOCHS))

# 设置早停
earlystopper = EarlyStopping(monitor='loss', patience=PATIENCE, verbose=VERBOSE)

# 
checkpointer = ModelCheckpoint('best_model.h5',
                                monitor='val_accuracy',
                                verbose=VERBOSE,
                                save_best_only=True,
                                save_weights_only=True)
train_model  = model.fit(train_ds,
                  epochs=NO_EPOCHS,
                  verbose=1,
                  validation_data=val_ds,
                  callbacks=[earlystopper, checkpointer, annealer])

六、评估模型

1. Accuracy与Loss图

acc = train_model.history['accuracy']
val_acc = train_model.history['val_accuracy']

loss = train_model.history['loss']
val_loss = train_model.history['val_loss']

epochs_range = range(len(acc))

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

2. 混淆矩阵

from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd

# 定义一个绘制混淆矩阵图的函数
def plot_cm(labels, predictions):
    
    # 生成混淆矩阵
    conf_numpy = confusion_matrix(labels, predictions)
    # 将矩阵转化为 DataFrame
    conf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  
    
    plt.figure(figsize=(8,7))
    
    sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")
    
    plt.title('混淆矩阵',fontsize=15)
    plt.ylabel('真实值',fontsize=14)
    plt.xlabel('预测值',fontsize=14)
val_pre   = []
val_label = []

for images, labels in val_ds:#这里可以取部分验证数据(.take(1))生成混淆矩阵
    for image, label in zip(images, labels):
        # 需要给图片增加一个维度
        img_array = tf.expand_dims(image, 0) 
        # 使用模型预测图片中的人物
        prediction = model.predict(img_array)

        val_pre.append(class_names[np.argmax(prediction)])
        val_label.append(class_names[label])
plot_cm(val_label, val_pre)

3. 各项指标评估

from sklearn import metrics

def test_accuracy_report(model):
    print(metrics.classification_report(val_label, val_pre, target_names=class_names)) 
    score = model.evaluate(val_ds, verbose=0)
    print('Loss function: %s, accuracy:' % score[0], score[1])
    
test_accuracy_report(model)
             precision    recall  f1-score   support

       乳腺癌细胞       0.92      0.90      0.91      1339
        正常细胞       0.91      0.92      0.91      1341

    accuracy                           0.91      2680
   macro avg       0.91      0.91      0.91      2680
weighted avg       0.91      0.91      0.91      2680

Loss function: 0.22688131034374237, accuracy: 0.9138059616088867

pport

   乳腺癌细胞       0.92      0.90      0.91      1339
    正常细胞       0.91      0.92      0.91      1341

accuracy                           0.91      2680

macro avg 0.91 0.91 0.91 2680
weighted avg 0.91 0.91 0.91 2680

Loss function: 0.22688131034374237, accuracy: 0.9138059616088867


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1284018.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

线程池、及Springboot线程池实践

摘要 本文介绍了线程池基本概念、线程及线程池状态、java中线程池提交task后执行流程、Executors线程池工具类、最后介绍在springboot框架下使用线程池和定时线程池,以及task取消 线程池基本 背景 线程池 线程池是一种多线程处理形式,处理过程中将任务…

HCIP——交换综合实验

一、实验拓扑图 二、实验需求 1、PC1和PC3所在接口为access,属于vlan2;PC2/4/5/6处于同一网段,其中PC2可以访问PC4/5/6;但PC4可以访问PC5,不能访问PC6 2、PC5不能访问PC6 3、PC1/3与PC2/4/5/6/不在同一网段 4、所有PC通…

【Java】类和对象之超级详细的总结!!!

文章目录 前言1. 什么是面向对象?1.2面向过程和面向对象 2.类的定义和使用2.1什么是类?2.2类的定义格式2.3类的实例化2.3.1什么是实例化2.3.2类和对象的说明 3.this引用3.1为什么会有this3.2this的含义与性质3.3this的特性 4.构造方法4.1构造方法的概念4…

数据结构第六课 -----链式二叉树的实现

作者前言 🎂 ✨✨✨✨✨✨🍧🍧🍧🍧🍧🍧🍧🎂 ​🎂 作者介绍: 🎂🎂 🎂 🎉🎉&#x1f389…

Java生成word[doc格式转docx]

引入依赖 <!-- https://mvnrepository.com/artifact/org.freemarker/freemarker --><dependency><groupId>org.freemarker</groupId><artifactId>freemarker</artifactId><version>2.3.32</version></dependency> doc…

针对net core 使用CSRedis 操作redis的三种连接实例方式

1、主从访问 2、哨兵模式 3、集群访问 写法一&#xff1a;写任意一个地址即可&#xff0c;其它节点在运行过程中自动增加&#xff0c;确保每个节点密码一致。如&#xff1a;Console.WriteLine("集群测试");RedisHelper.Initialization(new CSRedis.CSRedisClient(&q…

Http和WebSocket

客户端发送一次http请求&#xff0c;服务器返回一次http响应。 问题&#xff1a;如何在客户端没有发送请求的情况下&#xff0c;返回服务端的响应&#xff0c;网页可以得服务器数据&#xff1f; 1&#xff1a;http定时轮询 客户端定时发送http请求&#xff0c;eg&#…

回溯和分支算法

状态空间图 “图”——状态空间图 例子&#xff1a;农夫过河问题——“图”状态操作例子&#xff1a;n后问题、0-1背包问题、货郎问题(TSP) 用向量表示解&#xff0c;“图”由解向量扩张得到的解空间树。 ——三种图&#xff1a;n叉树、子集树、排序树 ​ 剪枝 不满住条件的…

链表【3】

文章目录 &#x1f433;23. 合并 K 个升序链表&#x1f41f;题目&#x1f42c;算法原理&#x1f420;代码实现 &#x1f437;25. K 个一组翻转链表&#x1f416;题目&#x1f43d;算法原理&#x1f367;代码实现 &#x1f433;23. 合并 K 个升序链表 &#x1f41f;题目 题目链…

Sentinel基础知识

Sentinel基础知识 资源 1、官方网址&#xff1a;https://sentinelguard.io/zh-cn/ 2、os-china: https://www.oschina.net/p/sentinel?hmsraladdin1e1 3、github: https://github.com/alibaba/Sentinel 一、软件简介 Sentinel 是面向分布式服务架构的高可用流量防护组件…

Unity 关于SetParent方法的使用情况

在设置子物体的父物体时&#xff0c;我们使用SetParent再常见不过了。 但是通常我们只是使用其中一个语法&#xff1a; public void SetParent(Transform parent);使用改方法子对象会保持原来位置&#xff0c;跟使用以下方法效果一样&#xff1a; public Transform tran; ga…

【数值计算方法(黄明游)】函数插值与曲线拟合(二):Newton插值【理论到程序】

​ 文章目录 一、近似表达方式1. 插值&#xff08;Interpolation&#xff09;2. 拟合&#xff08;Fitting&#xff09;3. 投影&#xff08;Projection&#xff09; 二、Lagrange插值1. 拉格朗日插值方法2. Lagrange插值公式a. 线性插值&#xff08;n1&#xff09;b. 抛物插值&…

UDS 诊断报文格式

文章目录 网络层目的N_PDU 格式诊断报文的分类&#xff1a;单帧、多帧 网络层目的 N_PDU(network protocol data unit)&#xff0c;即网络层协议数据单元 网络层最重要的目的就是把数据转换成符合标准的单一数据帧&#xff08;符合can总线规范的&#xff09;&#xff0c;从而…

原生横向滚动条 吸附 页面底部

效果图 /** 横向滚动条 吸附 页面底部 */ export class StickyHorizontalScrollBar {constructor(options {}) {const { el, style } optionsthis.createScrollbar(style)this.insertScrollbar(el)this.setScrollbarSize()this.onEvent()}/** 创建滚轴组件元素 */createS…

【踩坑】解决maven的编译报错Cannot connect to the Maven process. Try again later

背景 新公司新项目, 同事拷给我maven的setting配置文件, 跑项目编译发现maven报 Cannot connect to the Maven process. Try again later. If the problem persists, check the Maven Importing JDK settings and restart IntelliJ IDEA 虽然好像不影响, 项目最终还是能跑起来…

计算机组成学习-存储系统总结

复习本章时&#xff0c;思考以下问题&#xff1a; 1)存储器的层次结构主要体现在何处&#xff1f;为何要分这些层次&#xff1f;计算机如何管理这些层次&#xff1f;2)存取周期和存取时间有何区别&#xff1f;3)在虚拟存储器中&#xff0c;页面是设置得大一些好还是设置得小一…

视频剪辑转码:mp4批量转成wmv视频,高效转换格式

在视频编辑和处理的领域&#xff0c;转换格式是一项常见的任务。在某些编辑和发布工作中&#xff0c;可能需要使用WMV格式。提前将素材转换为WMV可以节省在编辑过程中的时间和精力。从MP4到WMV的批量转换&#xff0c;不仅能使视频素材在不同的平台和设备上得到更好的兼容性&…

JavaScript基础—for语句、循环嵌套、数组、操作数组、综合案例—根据数据生成柱形图、冒泡排序

版本说明 当前版本号[20231129]。 版本修改说明20231126初版20231129完善部分内容 目录 文章目录 版本说明目录JavaScript 基础第三天笔记for 语句for语句的基本使用循环嵌套倒三角九九乘法表 数组数组是什么&#xff1f;数组的基本使用定义数组和数组单元访问数组和数组索引…

centos7 设置静态ip

文章目录 设置VMware主机设置centos7 设置 设置VMware 主机设置 centos7 设置 vim /etc/sysconfig/network-scripts/ifcfg-ens33重启网络服务 service network restart检验配置是否成功 ifconfig ip addr