基于深度学习的遥感图像分类识别系统,使用PyTorch框架实现

news2024/9/21 16:28:45

取5个场景

['海滩', '灌木丛', '沙漠', '森林', '草地']

划分数据集 train:val:test = 7:2:1

环境依赖

pytorch==1.1 or 1.0

tensorboard==1.8

tensorboardX

pillow

注意调低batch_size参数特别是像我这样的渣渣显卡

使用方法

只需要指明数据集路径参数即可,就可以得到最终模型以及log、tensorboard_log了

train:开始训练

python train_resnet.py 

数据集文件夹应为 之下的目录结构应如下:

your data_dir/
       |->train
       |->val
       |->test
       |->ClassnameID.txt

#运行生成result.txt

python infer.py

cpu跑的验证集:

correct:697/704=0.9901

trian_resnet.py
# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from tensorboardX import SummaryWriter

import numpy as np
import time
import datetime
import argparse
import os
import os.path as osp
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from rs_dataset import RSDataset
from get_logger import get_logger
from res_network import Resnet50
# gpu or not
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def parse_args():
    parse = argparse.ArgumentParser()
    parse.add_argument('--epoch',type=int,default=10)       #迭代次数
    parse.add_argument('--schedule_step',type=int,default=2)

    parse.add_argument('--batch_size',type=int,default=64)           #根据自己的口袋,调大小   batch_size
    parse.add_argument('--test_batch_size',type=int,default=32)
    parse.add_argument('--num_workers', type=int, default=2)        #default = 8 线程数

    parse.add_argument('--eval_fre',type=int,default=2)
    parse.add_argument('--msg_fre',type=int,default=10)
    parse.add_argument('--save_fre',type=int,default=2)
    parse.add_argument('--name',type=str,default='res50_baseline', help='unique out file name of this task include log/model_out/tensorboard log')

    #local  :数据集路径
    parse.add_argument('--data_dir',type=str,default='D:/RSdata_dir/gra_data_dir/')            #/mnt/rssrai_cls/
    parse.add_argument('--log_dir',type=str, default='./logs')
    parse.add_argument('--tensorboard_dir',type=str,default='./tensorboard')
    parse.add_argument('--model_out_dir',type=str,default='./model_out')
    parse.add_argument('--model_out_name',type=str,default='final_model.pth')
    parse.add_argument('--seed',type=int,default=5,help='random seed')
    return parse.parse_args()

def evalute(net,val_loader,writer,epoch,logger):
    logger.info('------------after epo {}, eval...-----------'.format(epoch))
    total=0
    correct=0
    net.eval()
    with torch.no_grad():
        for img,lb in val_loader:
            img, lb = img.to(device), lb.to(device)
            outputs = net(img)
            outputs = F.softmax(outputs,dim=1)
            predicted = torch.max(outputs,dim=1)[1]
            total += lb.size()[0]

            correct += (predicted == lb).sum().cpu().item()
    logger.info('correct:{}/{}={:.4f}'.format(correct,total,correct*1./total,epoch))
    #tensorboard-acc
    writer.add_scalar('acc',correct*1./total,epoch)
    net.train()

def main_worker(args,logger):
    try:
        writer = SummaryWriter(logdir=args.sub_tensorboard_dir)

        train_set = RSDataset(rootpth=args.data_dir,mode='train')
        train_loader = DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  drop_last=True,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=args.num_workers)

        val_set = RSDataset(rootpth=args.data_dir,mode='val')
        val_loader = DataLoader(val_set,
                                batch_size=args.test_batch_size,
                                drop_last=True,
                                shuffle=True,
                                pin_memory=True,
                                num_workers=args.num_workers)

        net = Resnet50()
        net = net.train()
        input_ = torch.randn((1,3,224,224))
        #writer.add_graph(net,input_)

        #设置cuda
        #net = net.cuda()
        net = net.to(device)

        criterion = nn.CrossEntropyLoss().to(device)

        optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) #优化器

        scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=args.schedule_step,gamma=0.3)

        loss_record = []

        iter = 0
        running_loss = []
        st = glob_st = time.time()
        total_iter = len(train_loader)*args.epoch
        for epoch in range(args.epoch):

            # 评估
            if epoch!=0 and epoch%args.eval_fre == 0:
            # if epoch%args.eval_fre == 0:
                evalute(net, val_loader, writer, epoch, logger)

            if epoch!=0 and epoch%args.save_fre == 0:
                model_out_name = osp.join(args.sub_model_out_dir,'out_{}.pth'.format(epoch))
                # 防止分布式训练保存失败
                state_dict = net.modules.state_dict() if hasattr(net, 'module') else net.state_dict()
                torch.save(state_dict,model_out_name)

            for img, lb in train_loader:
                iter += 1
                img = img.to(device)
                lb = lb.to(device)

                optimizer.zero_grad()

                outputs = net(img)
                loss = criterion(outputs,lb)

                loss.backward()
                optimizer.step()

                running_loss.append(loss.item())

                if iter%args.msg_fre ==0:
                    ed = time.time()
                    spend = ed-st
                    global_spend = ed-glob_st
                    st=ed

                    eta = int((total_iter-iter)*(global_spend/iter))
                    eta = str(datetime.timedelta(seconds=eta))
                    global_spend = str(datetime.timedelta(seconds=(int(global_spend))))

                    avg_loss = np.mean(running_loss)
                    loss_record.append(avg_loss)
                    running_loss = []

                    lr = optimizer.param_groups[0]['lr']

                    msg = '. '.join([
                        'epoch:{epoch}',
                        'iter/total_iter:{iter}/{total_iter}',
                        'lr:{lr:.5f}',
                        'loss:{loss:.4f}',
                        'spend/global_spend:{spend:.4f}/{global_spend}',
                        'eta:{eta}'
                    ]).format(
                        epoch=epoch,
                        iter=iter,
                        total_iter=total_iter,
                        lr=lr,
                        loss=avg_loss,
                        spend=spend,
                        global_spend=global_spend,
                        eta=eta
                    )
                    logger.info(msg)
                    writer.add_scalar('loss',avg_loss,iter)
                    writer.add_scalar('lr',lr,iter)

            scheduler.step()
        # 训练完最后评估一次
        evalute(net, val_loader, writer, args.epoch, logger)

        out_name = osp.join(args.sub_model_out_dir,args.model_out_name)
        torch.save(net.cpu().state_dict(),out_name)

        logger.info('-----------Done!!!----------')

    except:
        logger.exception('Exception logged')
    finally:
        writer.close()

if __name__ == '__main__':

    args = parse_args()
    #为CPU设置种子用于生成随机数,以使得结果是确定的
    torch.manual_seed(args.seed)
    torch.manual_seed(args.seed)

    # 唯一标识
    unique_name = time.strftime('%y-%m%d-%H%M%S_') + args.name
    args.unique_name = unique_name

    # 每次创建作业使用不同的tensorboard目录
    args.sub_tensorboard_dir = osp.join(args.tensorboard_dir, args.unique_name)
    # 保存模型的目录
    args.sub_model_out_dir = osp.join(args.model_out_dir, args.unique_name)

    # 创建所有用到的目录
    for sub_dir in [args.sub_tensorboard_dir,args.sub_model_out_dir,  args.log_dir]:
        if not osp.exists(sub_dir):
            os.makedirs(sub_dir)

    log_file_name = osp.join(args.log_dir,args.unique_name + '.log')
    logger = get_logger(log_file_name)

    for k, v in args.__dict__.items():
        logger.info(k)
        logger.info(v)

    main_worker(args,logger=logger)

界面gui.py

#!/bin/python

import wx
from PIL import Image
import numpy as np
import os
from tst import prediect


#文件-hello
def OnHello(event):
    wx.MessageBox("遥感图像识别")

#关于
def OnAbout(event):
    """Display an About Dialog"""
    wx.MessageBox("这是一个识别单张图像的界面-罗",
                  "关于 Hello World",
                  wx.OK | wx.ICON_INFORMATION)


class HelloFrame(wx.Frame):

    def __init__(self,*args,**kw):
        super(HelloFrame,self).__init__(*args,**kw)

        pnl = wx.Panel(self)

        self.pnl = pnl

        st = wx.StaticText(pnl, label="选择一张图像进行识别", pos=(200, 0))
        font = st.GetFont()
        font.PointSize += 10
        font = font.Bold()
        st.SetFont(font)

        # 选择图像文件按钮
        btn = wx.Button(pnl, -1, "select",pos=(4,4))
        btn.SetBackgroundColour("#0a74f7")
        #事件
        btn.Bind(wx.EVT_BUTTON, self.OnSelect)


        self.makeMenuBar()

        self.CreateStatusBar()
        self.SetStatusText("欢迎来到图像识别系统")

    #菜单栏
    def makeMenuBar(self):
        fileMenu = wx.Menu()
        helloItem = fileMenu.Append(-1, "&Hello...\tCtrl-H",
                                    "Help string shown in status bar for this menu item")
        fileMenu.AppendSeparator()

        exitItem = fileMenu.Append(wx.ID_EXIT)
        helpMenu = wx.Menu()
        aboutItem = helpMenu.Append(wx.ID_ABOUT)

        menuBar = wx.MenuBar()
        menuBar.Append(fileMenu, "&File")
        menuBar.Append(helpMenu, "Help")

        self.SetMenuBar(menuBar)

        self.Bind(wx.EVT_MENU, OnHello, helloItem)
        self.Bind(wx.EVT_MENU, self.OnExit, exitItem)
        self.Bind(wx.EVT_MENU, OnAbout, aboutItem)
    #退出
    def OnExit(self, event):
        self.Close(True)

    #select按钮设置
    def OnSelect(self, event):
        wildcard = "image source(*.jpg)|*.jpg|" \
                   "Compile Python(*.pyc)|*.pyc|" \
                   "All file(*.*)|*.*"
        dialog = wx.FileDialog(None, "Choose a file", os.getcwd(),
                               "", wildcard, wx.ID_OPEN)
        if dialog.ShowModal() == wx.ID_OK:

            print(dialog.GetPath())
            img = Image.open(dialog.GetPath())
            imag = img.resize([128, 128])
            image = np.array(img)

            self.initimage(name= dialog.GetPath())

            #从tst.py获取 结果
            result = prediect(image)


            result_text = wx.StaticText(self.pnl, label='', pos=(600, 400), size=(150,50))
            result_text.SetLabel(result)

            font = result_text.GetFont()
            font.PointSize += 8
            result_text.SetFont(font)
            self.initimage(name= dialog.GetPath())

# 生成图片控件
    def initimage(self, name):
        imageShow = wx.Image(name, wx.BITMAP_TYPE_ANY)
        sb = wx.StaticBitmap(self.pnl, -1, imageShow.ConvertToBitmap(), pos=(0,30), size=(600,400))
        return sb


if __name__ == '__main__':

    app = wx.App()
    frm = HelloFrame(None, title='老罗的识别器', size=(1000,600))
    frm.Show()
    app.MainLoop()

 

影像分类


打开ENVI,选择主菜单->Classificatio->Unsupervised->IsoData或者K-mean。如选择IsoData,在选择文件时,可以设置空间或光谱裁剪区。如选择Can-tmr.ing,按默认设置,之后跳出参数设置,如ISODATA非监督分类结果。

 类别定义
在display中显示原始影像,在display->overlay->classification,选择ISODATA分类结果,如图所示,在Interactive Class Tool面板中,可以选择 各个分类结果显示。如图


Interactive Class Tool面板中,选择Option->Edit class colors/names。 通过目视或者其他方式识别分类结果,填写相应的类型名称和颜色。

如图

示为最终结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2113957.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MCU4.逻辑门电路的符号

1.与运算 C语言符号:&(按位与)和&&(逻辑与) 逻辑门电路的符号: 2.或运算 符号:|(按位或)和||(逻辑或) 逻辑门电路的符号: 3.非运算 C语言符号:!(按位非) 逻辑门电路的符号: 4.同或运算 相同为真(0⊙01,1⊙11),否则为假(0⊙10,1⊙00) 符号:⊙(按位同或) 图…

网络学习-eNSP配置ACL

AR1路由器配置 <Huawei>system-view Enter system view, return user view with CtrlZ. [Huawei]undo info-center enable Info: Information center is disabled. [Huawei]interface gigabitethernet 0/0/0 [Huawei-GigabitEthernet0/0/0]ip address 192.168.2.254 24 …

头脑风暴必备:四款在线思维导图工具详解

在快节奏的现代生活中&#xff0c;工作和学习常常需要我们去挖掘新的思维与灵感&#xff1b;在这个过程中&#xff0c;在线思维导图工具无疑是我们的重要伙伴&#xff1b;今天&#xff0c;我们将详细介绍四款在工作和学习中常用的在线思维导图工具给大家&#xff01;&#xff0…

网络安全(sql注入)

这里写目录标题 一. information_schema.tables 和 information_schema.schemata是information_schema数据库中的两张表1. information_schema.schemata2. information_schema.tables 二. 判断注入类型1. 判断数字型还是字符型注入2. 判断注入闭合是""还是 三. 判断表…

数据结构(邓俊辉)学习笔记】排序 5——选取:通用算法

文章目录 1. 尝试2. quickSelect3.linearSelect&#xff1a;算法4. linearSelect&#xff1a;性能分析5. linearSelect&#xff1a;性能分析B6. linearSelect&#xff1a;性能分析C 1. 尝试 在讨论过众数以及特殊情况下中位数的计算方法以后&#xff0c;接下来针对一般性的选取…

「大数据分析」图形可视化,如何选择大数据可视化图形?

​图形可视化技术&#xff0c;在大数据分析中&#xff0c;是一个非常重要的关键部分。我们前期通过数据获取&#xff0c;数据处理&#xff0c;数据分析&#xff0c;得出结果&#xff0c;这些过程都是比较抽象的。如果是非数据分析专业人员&#xff0c;很难清楚我们这些工作&…

【网络安全】服务基础第二阶段——第三节:Linux系统管理基础----Linux用户与组管理

目录 一、用户与组管理命令 1.1 用户分类与UID范围 1.2 用户管理命令 1.2.1 useradd 1.2.2 groupadd 1.2.3 usermod 1.2.4 userdel 1.3 组管理命令 1.3.1 groupdel 1.3.2 查看密码文件 /etc/shadow 1.3.4 passwd 1.4 Linux密码暴力破解 二、权限管理 2.1 文件与目…

RISC-V (十)任务同步和锁

并发与同步 并发&#xff1a;指多个控制流同时执行。 多处理器多任务。一般在多处理器架构下内存是共享的。 单处理器多任务&#xff0c;通过调度器&#xff0c;一会调度这个任务&#xff0c;一会调度下个任务。 共享一个处 理器一个内存。…

C语言指针详解-包过系列(一)目录版

C语言指针详解-包过系列&#xff08;一&#xff09;目录版 1.内存和地址1.1内存1.2 深入理解编址 2.指针变量和地址2.1 取地址操作符&#xff08;&&#xff09;2.2 指针变量和解引用操作符&#xff08;*&#xff09;2.2.1 指针变量2.2.2 指针变量各部分理解2.2.3 解引用操作…

重启顺风车的背后,是高德难掩的“野心”

以史鉴今&#xff0c;我们往往可以从今天的事情中&#xff0c;看到古人的智慧&#xff0c;也看到时代的进步。就如西汉后期文学家恒宽曾说的&#xff0c;“明者因时而变&#xff0c;知者随事而制”。 图源来自高德官方 近日&#xff0c;高德就展现了这样的智慧。在网约车市场陷…

团队比赛活动如何记分?

团队比赛时如何记分&#xff1f; 在当今快节奏的社会中&#xff0c;团队合作和竞争已成为推动个人和集体发展的重要方式。无论是在学校的体育赛事、公司的团建活动&#xff0c;还是社区的娱乐竞赛中&#xff0c;团队比赛都扮演着不可或缺的角色。然而&#xff0c;组织一场成功的…

ubuntu 安装配置 ollama ,添加open-webui

ubuntu 安装配置 ollama 下载安装 [https://ollama.com/download](https://ollama.com/download)一 安装方法 1 命令行下载安装一 安装方法 2 , 手动下载安装 二 配置模型下载路径三 运行1 启动 ollama 服务2 运行大模型 四 添加开机自启服务 ollama serve1 关闭 ollama 服务2 …

2158. 直播获奖(live)

代码 #include<bits/stdc.h> using namespace std; int main() {int n,w,a[100000],cnt[601]{0},i,j,s;cin>>n>>w;for(i0;i<n;i){scanf("%d",&a[i]);cnt[a[i]];int x(i1)*w/100;if(!x) x1;for(j600,s0;j>0;j--){scnt[j];if(s>x){cou…

离线语音通断器

很长一段时间没有更新了 这段时间丰富了离线语音通断器的款式&#xff0c;目前成熟的模块包括单路&#xff0c;三路&#xff0c;四路&#xff0c;八路&#xff0c;输出支持无源输出&#xff0c;有源输出&#xff0c;目前主要针对房车改装市场做一些定制化的客户&#xff0c;同…

快排Java

快速排序的复杂度 快排代码 package leetcode;import java.util.Arrays;public class QuickSort {public static void quickSort(int[] array, int low, int high) {if (low < high) {int pivotIndex partition(array, low, high);quickSort(array, low, pivotIndex - 1);…

winserver2012 关闭iis

1&#xff0c;打开控制面板&#xff0c;搜索管理工具 2&#xff0c;打开管理工具 3&#xff0c;双击打开&#xff1a; 3.1&#xff0c;看到服务列表&#xff1a;右键管理网站-浏览&#xff0c;可以查看web服务&#xff1b; 4&#xff0c;如何关闭&#xff1a;右键结束即可。

【Linux】探索进程优先级的奥秘,解锁进程的调度与切换

目录 进程优先级&#xff1a; 是什么&#xff1f; 为什么存在进程优先级的概念呢&#xff1f; Linux为什么调整优先级是要受限制的&#xff1f; PRI vs NICE Linux的调度与切换 概念准备&#xff1a; 那我们到底怎样完成进程的调度和切换呢&#xff1f; 区分&#xff…

2024国赛论文拿奖快对照这几点及评阅要点,勿踩雷区!(国赛最后冲刺,提高获奖概率)

↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑↑ 2024“高教社杯”全国大学生数学建模竞赛已过去第三个夜晚&#xff0c;小伙伴们都累了没…

无法用 FileZilla 传送文件的解决方案

以下内容源于学习过程中的总结&#xff0c;欢迎交流。 参考博客 &#xff08;1&#xff09;无法用FileZilla 传送文件的解决方案_filezilla could not transfer-CSDN博客 连接时的设置如下&#xff0c;可见我没有以root用户身份进行连接。 我打算把某个文件从本地PC机传送到虚…

【重学 MySQL】十四、显示表结构

【重学 MySQL】十四、显示表结构 使用DESCRIBE或DESC命令使用SHOW COLUMNS命令查询information_schema数据库使用SHOW CREATE TABLE命令总结 在MySQL中&#xff0c;查看或显示表结构是一个常见的需求&#xff0c;它可以帮助你了解表中包含哪些列、每列的数据类型、是否允许为空…