CNN实现卫星图像分类(tensorflow)

news2025/1/23 6:04:49

使用的数据集卫星图像有两类,airplane和lake,每个类别样本量各700张,大小为256*256,RGB三通道彩色卫星影像。搭建深度卷积神经网络,实现卫星影像二分类。
数据链接百度网盘地址,提取码: cq47

1、查看tensorflow版本

import tensorflow as tf

print('Tensorflow Version:{}'.format(tf.__version__))
print(tf.config.list_physical_devices())

在这里插入图片描述

2、加载并显示训练数据

从文件夹中获取所有数据路径

import glob
import random

all_image_path = glob.glob('./data/air_lake_dataset/*/*.jpg')  # glob相比于pathlib更简洁
random.shuffle(all_image_path)

读取并处理图像

def load_and_preprocess_image(path):
    img_raw = tf.io.read_file(path)
    img_tensor = tf.image.decode_jpeg(img_raw,channels=3)
    img_tensor = tf.image.resize(img_tensor,[256,256])
    img_tensor = tf.cast(img_tensor,tf.float32)
    img_tensor = img_tensor/255
    return img_tensor

处理标签

label_to_index = {'airplane':0,'lake':1}
index_to_label = dict((v,k) for k,v in label_to_index.items())
labels = [label_to_index.get(img.split('/')[3]) for img in all_image_path]

显示卫星影像

import matplotlib.pyplot as plt

def plot_images_lables(all_image_path,labels,start_idx,num=5):
    fig = plt.gcf()
    fig.set_size_inches(12,14)
    images = [load_and_preprocess_image(img_path) for img_path in all_image_path[start_idx:start_idx+5]]
    for i in range(num):
        ax = plt.subplot(1,num,1+i)
        ax.imshow(images[i])
        title = 'label=' + index_to_label.get(labels[start_idx+i])
        ax.set_title(title,fontsize=10)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()

plot_images_lables(all_image_path,labels,0,5)

在这里插入图片描述

4、使用tf.data.Dataset制作训练/测试数据

制作 Dataset

img_ds = tf.data.Dataset.from_tensor_slices(all_image_path)
img_ds = img_ds.map(load_and_preprocess_image)
label_ds = tf.data.Dataset.from_tensor_slices(labels)
img_label_ds = tf.data.Dataset.zip((img_ds,label_ds))

训练集、测试集划分

test_count = int(len(labels)*0.2) 
train_count = len(labels) - test_count

train_ds = img_label_ds.skip(test_count)
test_ds = img_label_ds.take(test_count)

分批次加载数据

BATCH_SIZE = 16
train_ds = train_ds.repeat().shuffle(100).batch(BATCH_SIZE)
test_ds = test_ds.repeat().batch(BATCH_SIZE)

5、CNN模型构建

from keras.layers import Input,Dense,Dropout
from keras.layers import Conv2D,MaxPool2D,GlobalAvgPool2D

model = tf.keras.Sequential([
    Input(shape=(256,256,3)),
    Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'),  # 增加filter个数,增加模型拟合能力
    Conv2D(filters=64,kernel_size=(3,3),activation='relu',padding='same'),
    MaxPool2D(),  # 默认2*2. 池化层扩大视野
    Dropout(0.2),  # 防止过拟合
    Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'),
    Conv2D(filters=128,kernel_size=(3,3),activation='relu',padding='same'),
    MaxPool2D(),
    Dropout(0.2),
    Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same'),
    Conv2D(filters=256,kernel_size=(3,3),activation='relu',padding='same'),
    MaxPool2D(),
    Dropout(0.2),
    Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same'),
    Conv2D(filters=512,kernel_size=(3,3),activation='relu',padding='same'),
    GlobalAvgPool2D(),  # 全局平均池化
    Dense(1024,activation='relu'),
    Dense(256,activation='relu'),
    Dense(1,activation='sigmoid') 
])

model.summary()

在这里插入图片描述

6、模型编译与训练

model.compile(optimizer=tf.keras.optimizers.Adam(0.0001),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),  # 已经使用sigmoid激活过了
              metrics=['acc'])

steps_per_epoch = train_count//BATCH_SIZE
val_step = test_count//BATCH_SIZE

H = model.fit(train_ds,
             epochs=10,
             steps_per_epoch=steps_per_epoch,
             validation_data=test_ds,
             validation_steps=val_step,
             verbose=1)

在这里插入图片描述

7、模型评估

import matplotlib.pyplot as plt

fig = plt.gcf()
fig.set_size_inches(12,4)
plt.subplot(1,2,1)
plt.plot(H.epoch, H.history['loss'], label='loss')
plt.plot(H.epoch, H.history['val_loss'], label='val_loss')
plt.legend()
plt.title('loss')

plt.subplot(1,2,2)
plt.plot(H.epoch, H.history['acc'], label='acc')
plt.plot(H.epoch, H.history['val_acc'], label='val_acc')
plt.legend()
plt.title('acc')
plt.show()

在这里插入图片描述

8、模型预测

def pred_img(img_path):
    img = load_and_preprocess_image(img_path)
    img = tf.expand_dims(img, axis=0)
    pred = model.predict(img)
    pred = index_to_label.get((pred>0.5).astype('int')[0][0])
    return pred
    
img_path = './data/air_lake_dataset/airplane/airplane_240.jpg'
pred = pred_img(img_path)
img_tensor = load_and_preprocess_image(img_path)
plt.imshow(img_tensor)
title = 'label=' + img_path.split('/')[3].strip() + ', pred=' + pred
plt.title(title)
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1644378.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【一刷《剑指Offer》】面试题 14:调整数组顺序使奇数位于偶数前面

力扣对应题目链接:LCR 139. 训练计划 I - 力扣(LeetCode) 牛客对应题目链接:调整数组顺序使奇数位于偶数前面(二)_牛客题霸_牛客网 (nowcoder.com) 核心考点:数组操作,排序思想的扩展使用。 一、《剑指Off…

LAME及 iOS 编译

文章目录 关于 LAME编译 for iOS 关于 LAME 官网:https://lame.sourceforge.io LAME是根据LGPL许可的高质量MPEG音频层III(MP3)编码器。 LAME的开发始于1998年年中左右。Mike Cheng 最开始将它作为针对8hz-MP3编码器源的补丁。在其他人提出…

docker-本地私有仓库、harbor私有仓库部署与管理

一、本地私有仓库: 1、本地私有仓库简介: docker本地仓库,存放镜像,本地的机器上传和下载,pull/push。 使用私有仓库有许多优点: 节省网络带宽,针对于每个镜像不用每个人都去中央仓库上面去下…

实现 Trie (前缀树) - LeetCode 热题 54

大家好!我是曾续缘💜 今天是《LeetCode 热题 100》系列 发车第 54 天 图论第 4 题 ❤️点赞 👍 收藏 ⭐再看,养成习惯 实现 Trie (前缀树) Trie(发音类似 "try")或者说 前缀树 是一种树形数据结构…

C#知识|上位机项目主窗体设计思路及流程(实例)

哈喽,你好啊,我是雷工! 昨天练习了登录窗体的设计实现,今天练习上位机项目主窗体的设计实现。 01 主窗体效果展示 02 实现步骤 2.1、添加主窗体 添加窗体,名称:FrmMain.cs 2.2、窗体属性设置 将FrmMain窗体属性FormBorderStyle设置为None,无边框; 将FrmMain窗体属性…

神经网络中的算法优化(皮毛讲解)

抛砖引玉 在深度学习中,优化算法是训练神经网络时至关重要的一部分。 优化算法的目标是最小化(或最大化)一个损失函数,通常通过调整神经网络的参数来实现。 这个过程可以通过梯度下降法来完成,其中梯度指的是损失函数…

Windows查找JDK的安装路径

如果很久之前安装了JDK,或者在别人的电脑上,想要快速指导JDK 的安装路径,可以通过啥方式指导JDK的安装路径是在哪里呢? 一、确认是否安装了JDK 首先我们打开命令行,如果输入 java -version 如果显示这种,…

IBM收购HashiCorp:开源工具的未来与“好软件的坟墓”

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

VBA 创建透视表,录制宏,自动化报表

目录 一. 数据准备二. 需求三. 准备好报表模板四. 执行统计操作,录制宏4.1 根据数据源创建透视表4.2 填充数据到报表4.3 结束宏录制 五. 执行录制好的宏,自动化报表 一. 数据准备 ⏹数据源1 姓名学科成绩丁志敏语文91李平平语文81王刚语文64张伊语文50…

场景文本检测识别学习 day08(无监督的Loss Function、代理任务、特征金字塔)

无监督的Loss Function(无监督的目标函数) 根据有无标签,可以将模型的学习方法分为:无监督、有监督两种。而自监督是无监督的一种无监督的目标函数可以分为以下几种: 生成式网络的做法,衡量模型的输出和固…

【C++STL详解(六)】--------list的模拟实现

目录 前言 一、接口总览 一、节点类的模拟实现 二、迭代器类的模拟实现 迭代器的目的 list迭代器为何要写成类? 迭代器类模板参数说明 模拟实现 1.构造函数 2.*运算符重载 3.->运算符重载 4.前置 5.后置 6.前置-- 7.后置-- 8.! 9. 三、list类的…

【知识加油站】——机电产品数字孪生机理模型构建

明确一种多领域、多层次、参数化、一致性的机电一体化装备数字孪生机理模型构建准则! 关键词英文简称: 数字孪生:DT物联网:IoT网络物理系统:CPS高级架构:HLA统一建模语言:UML数控机床&#xf…

2-qt之信号与槽-简单实例讲解

前言、因实践课程讲解需求,简单介绍下qt的信号与槽。 一、了解信号与槽 怎样使用信号与槽? 概览 还记得 X-Window 上老旧的回调函数系统吗?通常它不是类型安全的并且很复杂。(使用)它(会)有很多…

精析React与Vue架构异同及React核心技术——涵盖JSX、组件、Props、State、生命周期与16.8版后Hooks深化解析

React,Facebook开源的JavaScript库,用于构建高性能用户界面。通过组件化开发,它使UI的构建、维护变得简单高效。利用虚拟DOM实现快速渲染更新,适用于单页应用、移动应用(React Native)。React极大推动了现代…

【链表】:链表的带环问题

🎁个人主页:我们的五年 🔍系列专栏:数据结构 🌷追光的人,终会万丈光芒 前言: 链表的带环问题在链表中是一类比较难的问题,它对我们的思维有一个比较高的要求,但是这一类…

51单片机入门:DS1302时钟

51单片机内部含有晶振,可以实现定时/计数功能。但是其缺点有:精度往往不高、不能掉电使用等。 我们可以通过DS1302时钟芯片来解决以上的缺点。 DS1302时钟芯片 功能:DS1302是一种低功耗实时时钟芯片,内部有自动的计时功能&#x…

Spring Boot:国际化

Spring Boot 前言国际化 前言 在 Spring MVC:视图与视图解析器 的文章中,介绍过使用 Jstl 的 fmt 标签实现国际化,Spring MVC 会把视图由 InternalResourceViewResolver 转换为 JstlView(InternalResourceView 的子类&#xff09…

【DPU系列之】如何通过带外口登录到DPU上的ARM服务器?(Bluefield2举例)

文章目录 1. 背景说明2. 详细操作步骤2.1 目标拓扑结构2.2 连接DPU带外口网线,并获取IP地址2.3 ssh登录到DPU 3. 进一步看看系统的一些信息3.1 CPU信息:8核A723.2 内存信息 16GB3.3 查看ibdev设备 3.4 使用小工具pcie2netdev查看信息3.5 查看PCIe设备信息…

【JavaEE 初阶(一)】初识线程

❣博主主页: 33的博客❣ ▶️文章专栏分类:JavaEE◀️ 🚚我的代码仓库: 33的代码仓库🚚 🫵🫵🫵关注我带你了解更多线程知识 目录 1.前言2.进程3.线程4.线程和进程的区别5.Thread创建线程5.1继承Thread创建线程5.2实现R…

【深度优先搜索 图论 树】2872. 可以被 K 整除连通块的最大数目

本文涉及知识点 深度优先搜索 图论 树 图论知识汇总 LeetCode 2872. 可以被 K 整除连通块的最大数目 给你一棵 n 个节点的无向树,节点编号为 0 到 n - 1 。给你整数 n 和一个长度为 n - 1 的二维整数数组 edges ,其中 edges[i] [ai, bi] 表示树中节点…