卫星图片的Classification_model

news2024/10/6 22:21:00

Tensorflow版本:2.6.0
使用的是CNN神经网络,网络结构在最后给出
飞机和湖泊的卫星图片二分类网络
数据集请点击链接:https://www.kaggle.com/datasets/yo7oyo/lake-plane-binaryclass
数据集的构成:airplane: 700 张, lake 700 张

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import glob
%matplotlib inline
# 获取数据路径
# 'D:\\datasets\\dset\\dataset1\\卫星图像识别数据\\2_class\\airplane\\airplane_427.jpg 路径格式
img_path = glob.glob(r'D:\datasets\dset\dataset1\卫星图像识别数据\2_class/*/*.jpg')
import random
random.shuffle(img_path)
img_path[1].split('\\')[-2]

output: 'airplane'

#得到标签值label
#path: 
dict_img = {'lake':1, 'airplane':0}
dict_label = dict((v, k) for k, v in dict_img.items())
label = [dict_img.get(img.split('\\')[-2]) for img in img_path]
print(dict_label)

读取图片:

"""
读取: x = tf.io.read_file(path) 读取二进制数据
解码: x = tf.image.decode_jpeg(x) 返回Tensor张量 支持多种图片解码
类型转化 tf.cast(x, tf.float64)
"""
def load_image(path):
    # 读取图片
    img = tf.io.read_file(path)
    # 解码数据
    img = tf.image.decode_jpeg(img, channels=3)    # 转换图片大小 统一大小为 256 X 256
    img = tf.image.resize(img, [256, 256])
    # 数据转化
    img = tf.cast(img, tf.float32)
    # 归一化
    img /= 255.0
    return img
#%%
p, l = img_path[20], label[20]
img_tensor = load_image(p)
plt.title(dict_label.get(l))
plt.imshow(img_tensor)
# 创建dataset
ds_img = tf.data.Dataset.from_tensor_slices(img_path)
ds_img = ds_img.map(load_image)
# 标签的dataset
ds_label = tf.data.Dataset.from_tensor_slices(label)

print(ds_img, '\n', ds_label)
# 使用zip函数将数据进行合并
datas = tf.data.Dataset.zip((ds_img, ds_label))
print(datas)

all_count = len(img_path)
BATCH_SIZE = 16
# 划分训练集和测试集
train_ds = datas.take(int(all_count * 0.8))
test_ds = datas.skip(int(all_count * 0.8))
# 设置训练集和测试集
"""
测试集无需乱序 和 重复
"""
train_ds = train_ds.repeat().shuffle(100).batch(BATCH_SIZE)
test_ds = test_ds.batch(BATCH_SIZE)
print(train_ds, test_ds)

构建模型

model = tf.keras.Sequential([
    layers.Conv2D(64, (3, 3), input_shape=(256, 256, 3), activation='relu'),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPool2D(),
    # layers.Dropout(0.5),
    layers.Conv2D(128, (3, 3), activation='relu'),
    layers.Conv2D(128, (3, 3), activation='relu'),
    layers.MaxPool2D(),
    # layers.Dropout(0.5),
    layers.Conv2D(256, (3, 3), activation='relu'),
    layers.Conv2D(256, (3, 3), activation='relu'),
    layers.MaxPool2D(),
    # layers.Dropout(0.5),
    layers.Conv2D(512, (3, 3), activation='relu'),
    layers.Conv2D(512, (3, 3), activation='relu'),
    layers.MaxPool2D(),
    layers.Conv2D(512, (3, 3), activation='relu'),
    layers.Conv2D(512, (3, 3), activation='relu'),
    layers.Conv2D(512, (3, 3), activation='relu'),
    layers.MaxPool2D(),
    layers.GlobalAveragePooling2D(),
    layers.Dense(1024, activation='relu'),
    layers.Dense(256, activation='relu'),
    layers.Dense(1, activation='sigmoid')
])

model.summary()

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=['acc']
)

steps_pre_epoch = int(all_count*0.8) // BATCH_SIZE
val_steps = int(all_count * 0.2) // BATCH_SIZE

history = model.fit(
    train_ds,
    epochs=10,
    steps_per_epoch=steps_pre_epoch,
    validation_data=test_ds,
    validation_steps=val_steps
)

在这里插入图片描述

plt.plot(history.epoch, history.history.get('loss'), label='loss')
plt.plot(history.epoch, history.history.get('val_loss'), label='val_loss')
plt.legend()

在这里插入图片描述

plt.plot(history.epoch, history.history.get('acc'), label='acc')
plt.plot(history.epoch, history.history.get('val_acc'), label='val_acc')
plt.legend()

在这里插入图片描述

# 使用模型进行预测
plt.figure(figsize=(20,20))
for i in range(10):
    plt.subplot(2,5,i+1)
    index = random.randint(0, 1400)
    plt.imshow(load_image(img_path[index]))
    img = load_image(img_path[index])
    plt.title(f'Predict:{dict_label.get(float(model.predict(tf.expand_dims(img, axis=0))[0]) > 0.5)} \nTrue:{dict_label.get(label[index])}')
plt.show()

在这里插入图片描述
在这里插入图片描述

加入BatchNormalization()的批量归一化,防止出现梯度消失的梯度爆炸并加快神经网络的传递和更新

构建模型

model = tf.keras.Sequential([
    layers.Conv2D(64, (3, 3), input_shape=(256, 256, 3)),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Conv2D(64, (3, 3)),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    
    layers.MaxPool2D(),
    # layers.Dropout(0.5),
    layers.Conv2D(128, (3, 3)),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Conv2D(128, (3, 3)),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.MaxPool2D(),
    # layers.Dropout(0.5),
    layers.Conv2D(256, (3, 3)),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Conv2D(256, (3, 3)),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.MaxPool2D(),
    # layers.Dropout(0.5),
    layers.Conv2D(512, (3, 3)),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Conv2D(512, (3, 3)),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.MaxPool2D(),
    layers.Conv2D(512, (3, 3)),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Conv2D(512, (3, 3)),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Conv2D(512, (3, 3)),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.MaxPool2D(),
    layers.GlobalAveragePooling2D(),
    layers.Dense(1024),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Dense(256),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.Dense(1, activation='sigmoid')
])

model.summary()
# 训练模型
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=['acc']
)
steps_pre_epoch = int(all_count*0.8) // BATCH_SIZE
val_steps = int(all_count * 0.2) // BATCH_SIZE
history = model.fit(
    train_ds,
    epochs=10,
    steps_per_epoch=steps_pre_epoch,
    validation_data=test_ds,
    validation_steps=val_steps
)

在这里插入图片描述

未加入批量归一化的网络结构

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 254, 254, 64)      1792      
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 252, 252, 64)      36928     
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 126, 126, 64)      0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 124, 124, 128)     73856     
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 122, 122, 128)     147584    
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 61, 61, 128)       0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 59, 59, 256)       295168    
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 57, 57, 256)       590080    
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 28, 28, 256)       0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 26, 26, 512)       1180160   
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 24, 24, 512)       2359808   
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 12, 12, 512)       0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 10, 10, 512)       2359808   
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 8, 8, 512)         2359808   
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 6, 6, 512)         2359808   
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 3, 3, 512)         0         
_________________________________________________________________
global_average_pooling2d (Gl (None, 512)               0         
_________________________________________________________________
dense (Dense)                (None, 1024)              525312    
_________________________________________________________________
dense_1 (Dense)              (None, 256)               262400    
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 257       
=================================================================
Total params: 12,552,769
Trainable params: 12,552,769
Non-trainable params: 0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/747929.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

着眼未来砥砺前行,知了汇智携西南交大学生走进企业参观学习

随着数字化转型推进的深入,企业对数字化人才的需求量大幅增长,人才需求结构也发生显著在变化。为加强学生与企业的接触,拓展专业视野,对接行业需求,激发学生对所学专业的兴趣,明确自己学习的目标&#xff0…

NC19 连续子数组的最大和

import java.util.*; public class Solution {public int FindGreatestSumOfSubArray(int[] array) {//记录到下标i为止的最大连续子数组和int[] dp new int[array.length]; dp[0] array[0];int maxsum dp[0];for(int i 1; i < array.length; i){//状态转移&#xff1a;…

优雅实现垂直SeekBar:不继承Seekbar、不自定义View

目录 0 前言 关于自定义View 1 实现竖直SeekBar 1.1 XML布局解析 1.1.1 套一层FrameLayout 1.1.2 SeekBar去除左右间距 1.1.3 SeekBar高度无法设置 1.1.4 SeekBar背景设置 1.1.5 底部View尺寸和距底部距离不硬编码 1.2 自定义样式属性与主题 1.2.1 自定义样式属性 …

应急管理大屏助力暴雨天气下的水灾防范

随着气候变化和城市化进程的加剧&#xff0c;暴雨天气引发的水灾风险日益凸显。在面对这种自然灾害时&#xff0c;如何高效、及时地应对、减轻损失成为了当务之急。水灾应急管理平台的可视化大屏为相关部门和决策者提供了实时、全面的信息展示和决策支持&#xff0c;大大提升了…

每天5个好用的实用工具链接分享(第1弹)

每天5个好用的实用工具链接分享&#xff08;第1弹&#xff09; 1、免费PPT模板网站2、科研狗租用GPU跑模型网站3、在线正则测试网站4、免费数据集下载网站5、在线curl命令转代码网站6、号外 1、免费PPT模板网站 【链接】&#xff1a;https://www.ypppt.com/ 【网站名】&#x…

性能测试工具 Jmeter 做 Http 接口测试 :编写自定义函数

目录 一、 前言 二、 编写自定义函数的步骤 1. 新建一个工程&#xff0c;导入 jmeter jar 包。 2. 新建 package&#xff1a;stressTest.functions 3. 新建一个类继承 AbstractFunction&#xff0c;重写以下方法&#xff1a; 4. 打包 5. 将打出来的 jar 包拷贝至 jmeter…

学习记录——BiSeNetV1、BiSeNetV2、BiSeNetV3、PIDNet、CMNeXt

BiSeNetV1 BiSeNetV1为了在不影响速度的情况下&#xff0c;同时收集到空间信息和语义信息&#xff0c;设计了两条路&#xff1a; Spatial Path: 用了三层stride为 2 的卷积&#xff0c;卷积BNRELU模块。最后提取了相当于原图像 1/8 的输出特征图。由于它利用了较大尺度的特征图…

怎样把手机录音转换成文字免费?分享3个免费方法给给大家!

将手机录音转换为文字可以提高工作和学习效率&#xff0c;但很多人不知道如何实现。在本文中&#xff0c;我将分享三个免费的方法来帮助您将手机录音转换为文字&#xff0c;分别是使用记灵在线工具&#xff08;网页&#xff09;、微信和剪映。无论您是需要转录会议记录、课堂笔…

界面控件DevExtreme UI组件——增强的API功能

虽然DevExtreme刚刚发布了v23.1&#xff0c;但今天我们仍然要继续总结一下之前的主要更新&#xff08;v22.2&#xff09;中发布的一些与DevExtreme API相关的重要特性。 DevExtreme拥有高性能的HTML5 / JavaScript小部件集合&#xff0c;使您可以利用现代Web开发堆栈&#xff…

C语言 — 指针进阶篇(上)

前言 指针基础篇回顾可以详见&#xff1a; 指针基础篇&#xff08;1&#xff09;指针基础篇&#xff08;2&#xff09; 指针进阶篇分为上下两篇,上篇介绍1 — 4&#xff0c;下篇介绍5 — 6 字符指针数组指针指针数组数组传参和指针传参函数指针函数指针数组指向函数指针数组的…

Leetcode-每日一题【147.对链表进行插入排序】

题目 给定单个链表的头 head &#xff0c;使用 插入排序 对链表进行排序&#xff0c;并返回 排序后链表的头 。 插入排序 算法的步骤: 插入排序是迭代的&#xff0c;每次只移动一个元素&#xff0c;直到所有元素可以形成一个有序的输出列表。 每次迭代中&#xff0c;插入排序…

MySQL分区表详解

目录 分区表介绍 分区适用场景 分区方式 分区策略 常见分区命令 1. 分区表介绍 MySQL 数据库中的数据是以文件的形势存在磁盘上的&#xff0c;默认放在 /var/lib/mysql/ 目录下面&#xff0c;我们可以通过 show variables like %datadir%; 命令来查看&#xff1a; 我们进入…

a标签form表单,转发,重定向

a标签需要写项目名,form表单需要写项目名,转发写请求路径,重定向需要写项目名 // window.location.href"请求路径" 需要写项目名 // window.location"请求路径" 需要写项目名 // document.location.href"请求路径" 需要写项目名 …

业务流程图怎么画?这几种绘制方法看一看

业务流程图怎么画&#xff1f;流程图提供了对业务流程的清晰概述&#xff0c;帮助人们理解工作流程中涉及的活动、决策和步骤。它定义了任务的顺序和依赖关系&#xff0c;使人们能够更好地了解整个流程。通过绘制流程图&#xff0c;可以更容易地识别出潜在的问题、瓶颈和延迟。…

SOPC之NiosⅡ系统(一)

1. 基础概念 1.1 CPU软核与硬核 简单来说 CPU硬核就是在FPGA上的一颗硬件结构固定并且用户不能对其结构进行任何更改、只能进行编程控制的芯片。 CPU软核则是FPGA上本来不存在这样的硬件结构&#xff0c;用户可根据硬件描述语言利用NIOS Ⅱ软核搭建出一个CPU。 1.2 SOC和S…

面试题大杂烩-记不住

1、分库分表图啥 分库是为了解决单库io连接数的瓶颈 分表是为了解决单表效率瓶颈 2、分表后如何limit分页 如果是根据xxx字段进行分表的话 那么shardingjdbc会根据字段进行笛卡尔积计算 去对应表里面执行sql到内存中计算&#xff0c;比如根据用户id进行hash算法进行查表&…

如何选择软文发布平台?软文发布平台的分类、特点及推广策略

在现今的市场竞争中&#xff0c;软文作为一种重要的推广方式&#xff0c;受到了越来越多企业的关注和运用。而软文发布平台&#xff0c;则是软文营销过程中不可或缺的一个环节&#xff0c;不同的软文发布平台会对软文的传播效果产生重要影响。本文将从软文发布平台的分类、特点…

淘宝APP商品详情源数据接口代码分享(API接口系列可高并发)

电商平台APP商品详情源数据接口代码分享如下&#xff1a; 商品数据&#xff1a;淘宝提供了商品的基本信息&#xff0c;包括商品名称、描述、规格、价格、销量、库存等信息。此外&#xff0c;也可以通过淘宝提供的API接口来获取商品的图片、评价、物流信息等详细数据。 公共参数…

深入理解DRF中的Mixin类

DRF官网地址&#xff1a; Home - Django REST framework Generic views - Django REST framework 一、Mixin类介绍 1.1 Mixin类介绍 Mixin类是一种常见的设计模式&#xff0c;在多个类之间共享功能或行为时非常有用。 一个Mixin类通常包含一组方法或属性&#xff0c;可以被…

unity 使用vrtk4的插件 打包htv vive VR客户端包,手柄不生效

背景&#xff1a; 目的&#xff1a;u3d使用vrtk开发pico应用(vrtk是为了到时候无缝衔接后续要买的htc vive pro 2) 先导入了tilia importer&#xff08;也就是vrtk4.0&#xff0c;根据教程模块化使用功能&#xff09;和pico官网下的“PICO Unity IntegrationSDK-214-20230302…