卷积神经网络(CNN):艺术作品识别

news2024/12/28 21:43:04

文章目录

  • 一、前言
  • 一、设置GPU
  • 二、导入数据
    • 1. 导入数据
    • 2. 检查数据
    • 3. 配置数据集
    • 4. 数据可视化
  • 三、构建模型
  • 四、编译
  • 五、训练模型
  • 六、评估模型
    • 1. Accuracy与Loss图
    • 2. 混淆矩阵
    • 3. 各项指标评估

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别
  • 卷积神经网络(AlexNet)鸟类识别
  • 卷积神经网络(CNN)识别验证码

来自专栏:机器学习与深度学习算法推荐

一、设置GPU

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
    
import matplotlib.pyplot as plt
import os,PIL,pathlib
import numpy as np
import pandas as pd
import warnings
from tensorflow import keras

warnings.filterwarnings("ignore")#忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

二、导入数据

1. 导入数据

import pathlib

data_dir = "./27-data/"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)
图片总数为: 3776
batch_size = 16
img_height = 224
img_width  = 224
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 3776 files belonging to 10 classes.
Using 3021 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 3776 files belonging to 10 classes.
Using 755 files for validation.
class_names = train_ds.class_names
print(class_names)
['Alfred_Sisley', 'Edgar_Degas', 'Francisco_Goya', 'Marc_Chagall', 'Pablo_Picasso', 'Paul_Gauguin', 'Peter_Paul_Rubens', 'Rembrandt', 'Titian', 'Vincent_van_Gogh']

2. 检查数据

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break
(16, 224, 224, 3)
(16,)

3. 配置数据集

AUTOTUNE = tf.data.AUTOTUNE

def train_preprocessing(image,label):
    return (image/255.0,label)

train_ds = (
    train_ds.cache()
    .shuffle(2000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

val_ds = (
    val_ds.cache()
    .shuffle(2000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

4. 数据可视化

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")

for images, labels in train_ds.take(1):
    for i in range(15):
        plt.subplot(4, 5, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)

        # 显示图片
        plt.imshow(images[i])
        # 显示标签
        plt.xlabel(class_names[labels[i]-1])

plt.show()

在这里插入图片描述

三、构建模型

from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout,BatchNormalization,Activation

# Load pre-trained model
base_model = keras.applications.ResNet50(weights='imagenet', include_top=False, input_shape=(img_width,img_height,3))

for layer in base_model.layers:
    layer.trainable = True
    
# Add layers at the end
X = base_model.output
X = Flatten()(X)

X = Dense(512, kernel_initializer='he_uniform')(X)
#X = Dropout(0.5)(X)
X = BatchNormalization()(X)
X = Activation('relu')(X)

X = Dense(16, kernel_initializer='he_uniform')(X)
#X = Dropout(0.5)(X)
X = BatchNormalization()(X)
X = Activation('relu')(X)

output = Dense(len(class_names), activation='softmax')(X)

model = Model(inputs=base_model.input, outputs=output)

四、编译

optimizer = tf.keras.optimizers.Adam(lr=1e-4)

model.compile(optimizer=optimizer,
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])

五、训练模型

from tensorflow.keras.callbacks import ModelCheckpoint, Callback, EarlyStopping, ReduceLROnPlateau, LearningRateScheduler

NO_EPOCHS = 15
PATIENCE  = 5
VERBOSE   = 1

# 设置动态学习率
# annealer = LearningRateScheduler(lambda x: 1e-3 * 0.99 ** (x+NO_EPOCHS))

# 设置早停
earlystopper = EarlyStopping(monitor='loss', patience=PATIENCE, verbose=VERBOSE)

# 
checkpointer = ModelCheckpoint('best_model.h5',
                                monitor='val_accuracy',
                                verbose=VERBOSE,
                                save_best_only=True,
                                save_weights_only=True)
train_model  = model.fit(train_ds,
                  epochs=NO_EPOCHS,
                  verbose=1,
                  validation_data=val_ds,
                  callbacks=[earlystopper, checkpointer])

六、评估模型

1. Accuracy与Loss图

acc = train_model.history['accuracy']
val_acc = train_model.history['val_accuracy']

loss = train_model.history['loss']
val_loss = train_model.history['val_loss']

epochs_range = range(len(acc))

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

2. 混淆矩阵

from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd

# 定义一个绘制混淆矩阵图的函数
def plot_cm(labels, predictions):
    
    # 生成混淆矩阵
    conf_numpy = confusion_matrix(labels, predictions)
    # 将矩阵转化为 DataFrame
    conf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  
    
    plt.figure(figsize=(8,7))
    
    sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")
    
    plt.title('混淆矩阵',fontsize=15)
    plt.ylabel('真实值',fontsize=14)
    plt.xlabel('预测值',fontsize=14)
val_pre   = []
val_label = []

for images, labels in val_ds:#这里可以取部分验证数据(.take(1))生成混淆矩阵
    for image, label in zip(images, labels):
        # 需要给图片增加一个维度
        img_array = tf.expand_dims(image, 0) 
        # 使用模型预测图片中的人物
        prediction = model.predict(img_array)

        val_pre.append(class_names[np.argmax(prediction)])
        val_label.append(class_names[label])
plot_cm(val_label, val_pre)

3. 各项指标评估

from sklearn import metrics

def test_accuracy_report(model):
    print(metrics.classification_report(val_label, val_pre, target_names=class_names)) 
    score = model.evaluate(val_ds, verbose=0)
    print('Loss function: %s, accuracy:' % score[0], score[1])
    
test_accuracy_report(model)

											precision    recall  f1-score   support

    Alfred_Sisley       0.76      0.98      0.86        53
      Edgar_Degas       0.89      0.94      0.92       132
   Francisco_Goya       0.89      0.69      0.77        70
     Marc_Chagall       0.85      0.94      0.89        48
    Pablo_Picasso       0.89      0.74      0.81        90
     Paul_Gauguin       0.94      0.84      0.89        57
Peter_Paul_Rubens       0.71      0.86      0.78        29
        Rembrandt       0.66      0.92      0.77        48
           Titian       0.90      0.72      0.80        65
 Vincent_van_Gogh       0.88      0.87      0.87       163

         accuracy                           0.85       755
        macro avg       0.84      0.85      0.84       755
     weighted avg       0.86      0.85      0.85       755

Loss function: 0.5761227011680603, accuracy: 0.8490065932273865

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1283938.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python获取指定用户csdn博客列表并查询质量分,将结果保存到excel

API接口 获取博文总数接口 usernamehougang,表示获取用户hougang的所有博文数量 https://blog.csdn.net/community/home-api/v1/get-tab-total?usernamehougang 获取博文列表接口 https://blog.csdn.net/community/home-api/v1/get-business-list 质量分接口…

在文本框中添加单位

<el-col :span"12"><el-form-item label"进度" prop"schedule":rules"[{required: true, message:进度不能为空, trigger:blur},{validator: validator.isFloatGteZero, trigger:blur}]"><el-input v-model"input…

2024搞钱方式,这些你都了解吗?

气温日渐降低&#xff0c;凛冬已至&#xff0c;年关将近&#xff0c;咱还得多多搞钱才能喜气洋洋过大年不是&#xff1f;拿满全勤搞绩效&#xff0c;累死累活KPI……为了生活咱也是付出了太多。可是咱程序员该咋办呢&#xff1f; 相信有机智的小伙伴已经脱口而出了&#xff1a…

分布式搜索引擎elasticsearch(二)

1.DSL查询文档 elasticsearch的查询依然是基于JSON风格的DSL来实现的。 1.1.DSL查询分类 Elasticsearch提供了基于JSON的DSL(Domain Specific Language)来定义查询。常见的查询类型包括: 查询所有:查询出所有数据,一般测试用。例如:match_all 全文检索(full text)查…

从0开始使用Maven

文章目录 一.Maven的介绍即相关概念1.为什么使用Maven/Maven的作用2.Maven的坐标 二.Maven的安装三.IDEA编译器配置Maven环境1.在IDEA的单个工程中配置Maven环境2.方式2&#xff1a;配置Maven全局参数 四.IDEA编译器创建Maven项目五.IDEA中的Maven项目结构六.IDEA编译器导入Mav…

关于rocketMQ踩坑的那些事

在最近&#xff0c;我所写的这个项目需要使用到rocketMQ&#xff0c;为了图方便我便使用的是Windows版本的&#xff0c;但是在使用的过程中首先是发现无法发送消息出去&#xff0c;报错信息为 org.apache.rocketmq.client.exception.MQClientException: Send [3] times, still …

做一件荒谬的事:用AI推理下一次双色球结果 v0.1

做一件荒谬的事&#xff1a;用AI推理下一次双色球结果 v0.1 引言 事情的起因是父亲被亲戚安利&#xff0c;突然喜欢上了双色球&#xff0c;连规则和开奖结果怎么看都不懂的他&#xff0c;让我研究研究这个事&#xff0c;给他选个号。他还说老家有好几个人中了几百万&#xff…

unity学习笔记18

模型文件属性简介 1.动画类型&#xff1a;一共有四种&#xff1a;无 表示没有动画&#xff0c;旧版 就表示这个模型文件里面的动画片段可以用animation组件来播放的&#xff0c;最后两个 ”泛型“和“人形”都是animator组件来播放的。区别是泛型支持所有类型的动画播放&#x…

CoreDNS实战(一)-构建高性能、插件化的DNS服务器

1 概述 在企业高可用DNS架构部署方案中我们使用的是传统老牌DNS软件Bind, 但是现在不少企业内部流行容器化部署&#xff0c;所以也可以将Bind替换为 CoreDNS &#xff0c;由于 CoreDNS 是 Kubernetes 的一个重要组件&#xff0c;稳定性不必担心&#xff0c;于此同时还可将K8S集…

微信扫码登录修改二维码的样式

默认是这个样子二维码都没有展示全 微信的了的 js 对象是这个样子&#xff0c;既然大家看到我这篇文章&#xff0c;想必里面的属性已经知道了&#xff0c;这里不做赘述。 let href data:text/css;base64,LmltcG93ZXJCb3ggLnFyY29kZSB7d2lkdGg6ODAlO21hcmdpbi10b3A6MH0uaW1wb3d…

价差后的几种方向,澳福如何操作才能盈利

在价差出现时&#xff0c;澳福认为会出现以下几种方向。 昂贵资产的贬值和便宜资产的平行升值。昂贵的资产贬值&#xff0c;而便宜的资产保持不变。昂贵资产的贬值和便宜资产的平行贬值&#xff0c;但昂贵资产的贬值速度更快&#xff0c;超过便宜资产。更贵的一对的进一步升值和…

Pycharm配置jupyter使用notebook详细指南(可换行conda环节)

本教程为事后记录&#xff0c;部分图片非实操图片。 详细记录了pycharm配置jupyter的方法&#xff0c;jupyter添加其他conda环境的方法&#xff0c;远程密码调用jupyter的方法&#xff0c;修改jupyter工作目录的方法。 文章目录 一、入门级配置1. Pycharm配置Conda自带的jupyt…

Python Opencv实践 - Yolov3目标检测

本文使用CPU来做运算&#xff0c;未使用GPU。练习项目&#xff0c;参考了网上部分资料。 如果要用TensorFlow做检测&#xff0c;可以参考这里 使用GPU运行基于pytorch的yolov3代码的准备工作_little han的博客-CSDN博客文章浏览阅读943次。记录一下自己刚拿到带独显的电脑&a…

springboot数据源配置

springboot数据源配置 数据层解决方案——持久化技术 内置持久化解决方案——jdbcTemplate 内置数据库 H2一般用于测试环境&#xff0c;配置profiels&#xff0c;只在开发阶段使用&#xff0c;让他在上线的时候不走这里就可以了 要使用内嵌的数据库H2,要先导入jar包

python提取通话记录中的时间信息

您需要安装适合中文的SpaCy模型。您可以通过运行 pip install spacypython -m spacy download zh_core_web_sm来安装和下载所需的模型。 import spacy# 加载中文模型 nlp spacy.load(zh_core_web_sm)# 示例电话记录文本 text """ Agent: 今天我们解决一下这…

语音识别从入门到精通——1-基本原理解释

文章目录 语音识别算法1. 语音识别简介1.1 **语音识别**1.1.1 自动语音识别1.1.2 应用 1.2 语音识别流程1.2.1 预处理1.2.2 语音检测和断句1.2.3 音频场景分析1.2.4 识别引擎(**语音识别的模型**)1. 传统语音识别模型2. 端到端的语音识别模型基于Transformer的ASR模型基于CNN的…

14、pytest像用参数一样使用fixture

官方实例 # content of test_fruit.py import pytestclass Fruit:def __init__(self, name):self.name nameself.cubed Falsedef cube(self):self.cubed Trueclass FruitSalad:def __init__(self, *fruit_bowl):self.fruit fruit_bowlself._cube_fruit()def _cube_fruit(s…

【从零开始学习Redis | 第六篇】爆改Setnx实现分布式锁

前言&#xff1a; 在Java后端业务中&#xff0c; 如果我们开启了均衡负载模式&#xff0c;也就是多台服务器处理前端的请求&#xff0c;就会产生一个问题&#xff1a;多台服务器就会有多个JVM&#xff0c;多个JVM就会导致服务器集群下的并发问题。我们在这里提出的解决思路是把…

Spring Security 自定义异常失效?源码分析与解决方案

&#x1f680; 作者主页&#xff1a; 有来技术 &#x1f525; 开源项目&#xff1a; youlai-mall &#x1f343; vue3-element-admin &#x1f343; youlai-boot &#x1f33a; 仓库主页&#xff1a; Gitee &#x1f4ab; Github &#x1f4ab; GitCode &#x1f496; 欢迎点赞…

python pyaudio对音频进行端点检测,检测出说话区间

python pyaudio对音频进行端点检测&#xff0c;检测出说话区间 主要采用过零率和语音能量来进行检测&#xff0c;并设置双阈值。 代码如下&#xff1a; # -*- coding: utf-8 -*- import wave import os import matplotlib.pyplot as plt import numpy as np# 判断是否变号 de…