python基于卷积神经网络实现自定义数据集训练与测试

news2024/11/17 14:26:06

样本取自岩心照片,识别岩心是最基础的地质工作,如果用机器来划分岩心类型则会大大削减工作量。

下面叙述中0指代Anhydrite_rock(膏岩),1指代Limestone(灰岩),2指代Gray Anhydrite_rock(灰质膏岩)。

原本自定义训练集与测试集是这样的:

训练集x_train: 

标签是这样的y_train:

 测试集x_test:

标签是这样的y_test:

但是由于图片像素为3456*5184,电脑内存不足,所以只能统一修i该为下面(256*256): 

训练集: 

  测试集:

两个数据集的标签没有更改。

#导入库
import os
import cv2
import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from torchvision.io import read_image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import tensorflow.keras as ka
import datetime
import tensorflow as tf
import os
import PySide2
from tensorflow.keras.layers import Conv2D,BatchNormalization,Activation,MaxPooling2D,Dropout,Flatten,Dense
from tensorflow.keras import Model
import tensorflow as tf
'''
gpus = tf.config.list_physical_devices("GPU")

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 'cuda' 这里如果没有指定具体的卡号,系统默认cuda:0
device = torch.device('cuda:0') 		# 使用2号卡
'''

'''
if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
'''

'''加载数据集'''
#创建自定义数据集类,参考可见:http://t.csdn.cn/gkVNC
class Custom_Dataset(Dataset):
    #函数,设置图像集路径索引、图像标签文件读取
    def __init__(self, img_dir, img_label_dir, transform=None):
        super().__init__()
        self.img_dir = img_dir
        self.img_labels = pd.read_csv(img_label_dir)
        self.transform = transform
        
    #函数,设置数据集长度
    def __len__(self):
        return len(self.img_labels)
    
    #函数,设置指定图像读取、指定图像标签索引
    def __getitem__(self, index):
        #'所在文件路径+指定图像名'
        img_path = os.path.join(self.img_dir + self.img_labels.iloc[index, 1])
        #读指定图像
        #image = cv2.imdecode(np.fromfile(img_path,dtype=np.uint8),-1)
        image=plt.imread(img_path)
        #height,width = image.shape[0],image.shape[1]  #获取原图像的垂直方向尺寸和水平方向尺寸。
        #image = image.resize((height//4,width//4))
        
        #'指定图像标签'
        label = self.img_labels.iloc[index, 0]
        return image, label

#画图函数
def tensorToimg(img_tensor):
    img=img_tensor
    plt.imshow(img)
    #python3.X必须加下行
    plt.show()

#标签指示含义
label_dic = {0: '膏岩', 1: '灰岩', 2: '灰质膏岩'}

#图像集及标签路径
label_path = "C:/Users/yeahamen/AppData/Local/Programs/Python/Python310/train_label.csv"
img_root_path = "C:/Users/yeahamen/Desktop/custom_dataset/train_revise/"

test_image_path="C:/Users/yeahamen/Desktop/custom_dataset/test_revise/"
test_label_path="C:/Users/yeahamen/AppData/Local/Programs/Python/Python310/test_label.csv"
#加载图像集与标签路径到函数
#实例化类
dataset = Custom_Dataset(img_root_path, label_path)
dataset_test = Custom_Dataset(test_image_path,test_label_path)

'''查看指定图像(18)'''
#索引指定位置的图像及标签
image, label = dataset.__getitem__(18)
#展示图片及其形状(tensor)
print('单张图片(18)形状:',image.shape)
print('单张图片(18)标签:',label_dic[label])
#tensorToimg(image)


#批量输出
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

#查看图像的形状
for imgs, labels in dataloader:
    print('一批训练为1张图片(随机)形状:',imgs.shape)
    #一批图像形状:torch.Size([5, 3456, 5184, 3])
    print('一批训练为1张图片(随机)标签:',labels)
    #标签:tensor([3, 2, 3, 3, 1])
    break
    #仅需要查看一批

'''查看自定义数据集'''
showimages=[]
showlabels=[]
#把图片信息依次加载到列表
for imgs, labels in dataloader:
    c = torch.squeeze(imgs, 0)#减去一维数据形成图片固定三参数
    d = torch.squeeze(labels,0)
    showimages.append(c)
    showlabels.append(d)
#依次画出图片
def show_image(nrow, ncol, sharex, sharey):
    fig, axs = plt.subplots(nrow, ncol, sharex=sharex, sharey=sharey, figsize=(10, 10))
    for i in range(0,nrow):
        for j in range(0,ncol):
            axs[i,j].imshow(showimages[i*4+j])
            axs[i,j].set_title('Label={}'.format(showlabels[i*4+j]))
    plt.show()
    plt.tight_layout()
#给定参数
#show_image(2, 4, False, False)



'''创建训练集与测试集'''
dataloader_train = DataLoader(dataset, batch_size=30, shuffle=True)
for imgs, labels in dataloader_train:
#    imgs = imgs.copy()
#    height,width = imgs.shape[1],imgs.shape[2]  #获取原图像的水平方向尺寸和垂直方向尺寸。
#    imgs = cv2.resize(imgs,(width/4,height/4),interpolation=cv2.INTER_CUBIC)
    x_train=imgs
    y_train=labels
print('训练集图像形状:',x_train.shape)
print('训练集标签形状:',y_train.shape)
dataloader_test = DataLoader(dataset_test, batch_size=5, shuffle=True)
for imgs, labels in dataloader_test:
    x_test=imgs
    y_test=labels
print('测试集图像形状:',x_test.shape)
print('测试集标签形状:',y_test.shape)
'''
print(x_train.shape[0])
print(x_train.shape[1])
print(x_train.shape[2])
print(x_train.shape[3])
'''
#x_train_d = x_train.squeeze(labels,dim=3)
#x_test_d = x_test.squeeze(labels,dim=3)
#以行为单位将二维数组拉成一维的向量
#x_train = x_train.reshape(x_train_d.shape[0],3456,5184,1).astype('float32')
#x_test = x_test.reshape(x_test_d.shape[0],3456,5184,1).astype('float32')
#x_train = x_train.flatten(1).dtype('float32')
#x_test = x_test.flatten(1).dtype('float32')

#转变数据类型
x_train,x_test = tf.cast(x_train/255.0,tf.float32),tf.cast(x_test/255.0,tf.float32)
y_train,y_test = tf.cast(y_train,tf.int16),tf.cast(y_test,tf.int16)


#参考:http://t.csdn.cn/eRQX2
print('注意:',x_train.shape)
#归一化灰度值
x_train = x_train/255
x_test = x_test/255

#标签转为独热编码,注意:如果标签不是从0开始,独热编码会增加1位(即0)
y_train = ka.utils.to_categorical(y_train)
y_test = ka.utils.to_categorical(y_test)
print('独热后训练集标签形状:',y_train.shape)
print('独热后测试集标签形状:',y_test.shape)
#获取测试集特征数
num_classes = y_test.shape[1]

'''CNN模型'''
#输入3456*5184*3
model = ka.Sequential([ka.layers.Conv2D(filters = 32,kernel_size=(5,5),input_shape=(256,256,3),data_format="channels_last",activation='relu'),
                       #卷积3456*5184*32、卷积层;参量依次为:卷积核个数、卷积核尺寸、单个像素点尺寸、使用ReLu激活函数、解释可见:http://t.csdn.cn/6s3dz
                       ka.layers.MaxPooling2D(pool_size=(4,4),strides = None,padding='VALID'),
                       #池化1—864*1296*32、最大池化层,池化核尺寸4*4、步长默认为4、无填充、解释可见:http://t.csdn.cn/sES2u
                       ka.layers.MaxPooling2D(pool_size=(2,2),strides = None,padding='VALID'),
                       #池化2—432*648*32再加一个最大池化层,池化核尺寸为2*2、步长默认为2、无填充
                       ka.layers.Dropout(0.2),
                       #模型正则化防止过拟合, 只会在训练时才会起作用,随机设定输入的值x的某一维=0,这个概率为输入的百分之20,即丢掉1/5神经元不激活
                       #在模型预测时,不生效,所有神经元均保留也就是不进行dropout。解释可见:http://t.csdn.cn/RXbmS、http://t.csdn.cn/zAIuJ
                       ka.layers.Flatten(),
                       #拉平432*648*32=8957952;拉平池化层为一个向量
                       ka.layers.BatchNormalization(),
                       #批标准化层,提高模型准确率
                       ka.layers.Dense(10,activation='relu'),
                       #全连接层1,10个神经元,激活函数为ReLu
                       ka.layers.Dense(num_classes,activation='softmax')])
                       #全连接层2,3个神经元(对应标签0-2),激活函数为softmax,作用是把神经网络的输出转化为概率,参考可见:http://t.csdn.cn/bcWgu;http://t.csdn.cn/A1Jyn
'''
model = Transformer()	#例子中,采用Transformer模型
model.to(device)
# 只有Tensor类型的数据可以放入GPU中
# 可以一个个【batch_size】进行转换
x_train = x_train.to(device)
y_train = y_train.to(device)
'''
model.summary()
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
startdate = datetime.datetime.now()
#训练轮数epochs=n,即训练n轮
model.fit(x_train,y_train,validation_data=(x_test,y_test),epochs=100,batch_size=1,verbose=2)
#训练样本、训练标签、指定验证数据为测试集、训练轮数、显示每一轮训练进程,参考可见:http://t.csdn.cn/oE46K
#获取训练结束时间
enndate=datetime.datetime.now()
print("训练用时:"+str(enndate-startdate))

程序运行结果是这样的:

 显然由于样本过少,模型训练精度并不高,3轮训练达到0.4;如果有时间再进一步增加样本数量并完善。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/492999.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深度学习-第T6周——好莱坞明星识别

深度学习-第T6周——好莱坞明星识别 深度学习-第T6周——好莱坞明星识别一、前言二、我的环境三、前期工作1、导入数据集2、查看图片数目3、查看数据 四、数据预处理1、 加载数据1、设置图片格式2、划分训练集3、划分验证集4、查看标签 2、数据可视化3、检查数据4、配置数据集 …

Flutter学习之旅 - 页面布局Stack层叠组件

文章目录 StackPositioned定位布局浮动导航(StackPositioned)FlutterMediaQuery获取屏幕宽度和高度StackAlign Stack Stack意思是堆的意思,我们可以用Stack结合Align或者Stack结合Positioned来实现页面的定位布局 属性说明alignment配置所有元素显示位置children子组…

23.Lambda表达式

Lambda表达式 一、Lambda表达式背景 Lambda 表达式(lambda expression)是一个匿名函数,Lambda表达式基于数学中的λ演算得名,直接对应于其中的lambda抽象(lambda abstraction),是一个匿名函数,即没有函数名的函数。Lambda表达式…

2023-05-05 背包问题

背包问题 1 01背包和完全背包问题 01背包问题 有N件物品和一个容量为V的背包,第i件物品的体积是v[i]、价值是w[i],每种物品只可以使用一次,求将哪些物品放入背包可以使得价值总和最大。这里的w是weight即权重的意思 这是最基础的背包问题&a…

【飞书ChatGPT机器人】飞书接入ChatGPT,打造智能问答助手

文章目录 前言环境列表视频教程1.飞书设置2.克隆feishu-chatgpt项目3.配置config.yaml文件4.运行feishu-chatgpt项目5.安装cpolar内网穿透6.固定公网地址7.机器人权限配置8.创建版本9.创建测试企业10. 机器人测试 转载自远控源码文章:飞书接入ChatGPT - 将ChatGPT集…

Ubuntu 如何查看 CPU 架构、系统信息、内核版本、版本代号?

Ubuntu 查看 CPU 架构、系统信息、内核版本、版本代号等相关信息有很多方式,本文介绍几种常用的命令。 x86 架构与 ARM 架构的 CPU 架构不同,如果回显为 aarch64 表示为 ARM 架构,如果回显为 x86_64 表示为 x86 架构,参考《CPU 架…

Prometheus快速入门

Prometheus快速入门 环境准备 三台主机,配置好主机名 各自配置好主机名 # hostnamectl set-hostname --static server.cluster.com ... 三台都互相绑定IP与主机名 # vim /etc/hosts 192.168.126.143 server.cluster.com 192.168.126.142 agent.clu…

归并排序(看了就会)

目录 概念1. 基本思想2. 实现逻辑3. 复杂度分析4、代码 概念 归并排序,是创建在归并操作上的一种有效的排序算法。算法是采用分治法(Divide and Conquer)的一个非常典型的应用,且各层分治递归可以同时进行。归并排序思路简单&…

智头条|欧盟达成《人工智能法》协议,全球前沿科技齐聚AWE 2023

行业动态 华为云联手多方推进数字化,软通动力深度参与 华为云宣布启动“‘百城万企’应用现代化中国行”,旨在推动应用现代化进程、助力数字中国高质量落地。软通动力是该行动的参与者之一,共同探索符合区域特点、产业趋势、政企现状的数字化…

Python进阶(Linux操作系统)

一,操作系统 1.1,Linux系统基础操作 1.2,linux进程与线程 1.2.1并发,并行 (1)并发:在一段时间内交替的执行多个任务:对于单核CPU处理多任务,操作系统轮流让让各个任务…

BasicVSR++代码解读(总体介绍)

本文代码主要来自于OpenMMLab提供的MMEditing开源工具箱中的BasicVSR代码。第一部分的解读主要是针对每一个部分是在做什么提供一个解释,便于后续细读每一个块的细节代码。 (1)导入库     basicvsr_plusplus_net中主要继承了torch,mmcv,m…

信号的产生——线性调频函数

信号的产生——线性调频函数 产生线性调频扫频信号函数chirp的调用格式如下: (1)y chirp(t,f0, t1,f1) 功能:产生一个线性(频率随时间线性变化)信号,其时间轴设置由数组t定义。时刻0的瞬间频…

SpringBoot的配置文件、日志文件

一、配置文件( .properties、.yml) 1、.properties 配置文件 1.1、格式 1.2、基本语法 1.2.1、如:一般配置(以键值的形式配置的,key 和 value 之间是以“”连接的。) 1.2.2、如:自定义配置&a…

tcc-transaction 源码分析

tcc介绍 tcc介绍查看我之前的文章: https://caicongyang.blog.csdn.net/article/details/119721282?spm1001.2014.3001.5502 tcc-transaction 介绍: http://mengyun.org/zh-cn/index.html 本文基于2.x 最新版本:https://github.com/changmingxie/tcc…

以京东为例,分析优惠价格叠加规则

一、平行优惠计算原则 1、什么是“平行式门槛计算规则”? 平行式门槛计算规则,即每一层级优惠都直接根据商品的单品基准价来计算是否符合门槛,店铺/平台促销、优惠券类优惠之间是并列关系,只要单品基准价或单品基准价总和&#x…

c++(类和对象中)

【本节目标】 1. 类的6个默认成员函数 2. 构造函数 3. 析构函数 4. 拷贝构造函数 5. 赋值运算符重载 6. const成员函数 7. 取地址及const取地址操作符重载 目录 1、类的6个默认成员函数 2、构造函数 2.1概念 2.2特性 3.析构函数 3.1概念 3.2特性 4.拷贝构造函数…

Kafka生产者

一、生产者发送流程 在消息发送的过程中,涉及到了两个线程——main 线程和 Sender 线程。在 main 线程中创建了一个双端队列 RecordAccumulator。main 线程将消息发送给 RecordAccumulator,Sender 线程不断从 RecordAccumulator 中拉取消息发送到 Kafka …

网络应用基础 ——(2023新星计划文章二)

一,TCP/UDP报头 1.1TCP报文头部详解 Source port:源端口号与Destination port目标端口号: 在TCP(传输控制协议)协议中,源端口和目标端口字段用于标识通信会话的端点。 (1)源端口是一个16位字段…

LeetCode 1206. 实现跳表

不使用任何库函数,设计一个跳表。 跳表是在 O(log(n)) 时间内完成增加、删除、搜索操作的数据结构。跳表相比于树堆与红黑树,其功能与性能相当,并且跳表的代码长度相较下更短,其设计思想与链表相似。 例如,一个跳表包…

3.数据查询(实验报告)

目录 一﹑实验目的 二﹑实验平台 三﹑实验内容和步骤 四﹑命令(代码)清单 五﹑调试和测试清单 一﹑实验目的 掌握使用Transact-SQL的SELECT语句进行基本查询的方法;掌握使用SELECT语句进行条件查询的方法;掌握SELECT语句的GROUP BY、ORDER BY以及UN…