取5个场景
['海滩', '灌木丛', '沙漠', '森林', '草地']
划分数据集 train:val:test = 7:2:1
环境依赖
pytorch==1.1 or 1.0
tensorboard==1.8
tensorboardX
pillow
注意调低batch_size参数特别是像我这样的渣渣显卡
使用方法
只需要指明数据集路径参数即可,就可以得到最终模型以及log、tensorboard_log了
train:开始训练
python train_resnet.py
数据集文件夹应为 之下的目录结构应如下:
your data_dir/
|->train
|->val
|->test
|->ClassnameID.txt
#运行生成result.txt
python infer.py
cpu跑的验证集:
correct:697/704=0.9901
trian_resnet.py
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from tensorboardX import SummaryWriter
import numpy as np
import time
import datetime
import argparse
import os
import os.path as osp
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from rs_dataset import RSDataset
from get_logger import get_logger
from res_network import Resnet50
# gpu or not
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def parse_args():
parse = argparse.ArgumentParser()
parse.add_argument('--epoch',type=int,default=10) #迭代次数
parse.add_argument('--schedule_step',type=int,default=2)
parse.add_argument('--batch_size',type=int,default=64) #根据自己的口袋,调大小 batch_size
parse.add_argument('--test_batch_size',type=int,default=32)
parse.add_argument('--num_workers', type=int, default=2) #default = 8 线程数
parse.add_argument('--eval_fre',type=int,default=2)
parse.add_argument('--msg_fre',type=int,default=10)
parse.add_argument('--save_fre',type=int,default=2)
parse.add_argument('--name',type=str,default='res50_baseline', help='unique out file name of this task include log/model_out/tensorboard log')
#local :数据集路径
parse.add_argument('--data_dir',type=str,default='D:/RSdata_dir/gra_data_dir/') #/mnt/rssrai_cls/
parse.add_argument('--log_dir',type=str, default='./logs')
parse.add_argument('--tensorboard_dir',type=str,default='./tensorboard')
parse.add_argument('--model_out_dir',type=str,default='./model_out')
parse.add_argument('--model_out_name',type=str,default='final_model.pth')
parse.add_argument('--seed',type=int,default=5,help='random seed')
return parse.parse_args()
def evalute(net,val_loader,writer,epoch,logger):
logger.info('------------after epo {}, eval...-----------'.format(epoch))
total=0
correct=0
net.eval()
with torch.no_grad():
for img,lb in val_loader:
img, lb = img.to(device), lb.to(device)
outputs = net(img)
outputs = F.softmax(outputs,dim=1)
predicted = torch.max(outputs,dim=1)[1]
total += lb.size()[0]
correct += (predicted == lb).sum().cpu().item()
logger.info('correct:{}/{}={:.4f}'.format(correct,total,correct*1./total,epoch))
#tensorboard-acc
writer.add_scalar('acc',correct*1./total,epoch)
net.train()
def main_worker(args,logger):
try:
writer = SummaryWriter(logdir=args.sub_tensorboard_dir)
train_set = RSDataset(rootpth=args.data_dir,mode='train')
train_loader = DataLoader(train_set,
batch_size=args.batch_size,
drop_last=True,
shuffle=True,
pin_memory=True,
num_workers=args.num_workers)
val_set = RSDataset(rootpth=args.data_dir,mode='val')
val_loader = DataLoader(val_set,
batch_size=args.test_batch_size,
drop_last=True,
shuffle=True,
pin_memory=True,
num_workers=args.num_workers)
net = Resnet50()
net = net.train()
input_ = torch.randn((1,3,224,224))
#writer.add_graph(net,input_)
#设置cuda
#net = net.cuda()
net = net.to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) #优化器
scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=args.schedule_step,gamma=0.3)
loss_record = []
iter = 0
running_loss = []
st = glob_st = time.time()
total_iter = len(train_loader)*args.epoch
for epoch in range(args.epoch):
# 评估
if epoch!=0 and epoch%args.eval_fre == 0:
# if epoch%args.eval_fre == 0:
evalute(net, val_loader, writer, epoch, logger)
if epoch!=0 and epoch%args.save_fre == 0:
model_out_name = osp.join(args.sub_model_out_dir,'out_{}.pth'.format(epoch))
# 防止分布式训练保存失败
state_dict = net.modules.state_dict() if hasattr(net, 'module') else net.state_dict()
torch.save(state_dict,model_out_name)
for img, lb in train_loader:
iter += 1
img = img.to(device)
lb = lb.to(device)
optimizer.zero_grad()
outputs = net(img)
loss = criterion(outputs,lb)
loss.backward()
optimizer.step()
running_loss.append(loss.item())
if iter%args.msg_fre ==0:
ed = time.time()
spend = ed-st
global_spend = ed-glob_st
st=ed
eta = int((total_iter-iter)*(global_spend/iter))
eta = str(datetime.timedelta(seconds=eta))
global_spend = str(datetime.timedelta(seconds=(int(global_spend))))
avg_loss = np.mean(running_loss)
loss_record.append(avg_loss)
running_loss = []
lr = optimizer.param_groups[0]['lr']
msg = '. '.join([
'epoch:{epoch}',
'iter/total_iter:{iter}/{total_iter}',
'lr:{lr:.5f}',
'loss:{loss:.4f}',
'spend/global_spend:{spend:.4f}/{global_spend}',
'eta:{eta}'
]).format(
epoch=epoch,
iter=iter,
total_iter=total_iter,
lr=lr,
loss=avg_loss,
spend=spend,
global_spend=global_spend,
eta=eta
)
logger.info(msg)
writer.add_scalar('loss',avg_loss,iter)
writer.add_scalar('lr',lr,iter)
scheduler.step()
# 训练完最后评估一次
evalute(net, val_loader, writer, args.epoch, logger)
out_name = osp.join(args.sub_model_out_dir,args.model_out_name)
torch.save(net.cpu().state_dict(),out_name)
logger.info('-----------Done!!!----------')
except:
logger.exception('Exception logged')
finally:
writer.close()
if __name__ == '__main__':
args = parse_args()
#为CPU设置种子用于生成随机数,以使得结果是确定的
torch.manual_seed(args.seed)
torch.manual_seed(args.seed)
# 唯一标识
unique_name = time.strftime('%y-%m%d-%H%M%S_') + args.name
args.unique_name = unique_name
# 每次创建作业使用不同的tensorboard目录
args.sub_tensorboard_dir = osp.join(args.tensorboard_dir, args.unique_name)
# 保存模型的目录
args.sub_model_out_dir = osp.join(args.model_out_dir, args.unique_name)
# 创建所有用到的目录
for sub_dir in [args.sub_tensorboard_dir,args.sub_model_out_dir, args.log_dir]:
if not osp.exists(sub_dir):
os.makedirs(sub_dir)
log_file_name = osp.join(args.log_dir,args.unique_name + '.log')
logger = get_logger(log_file_name)
for k, v in args.__dict__.items():
logger.info(k)
logger.info(v)
main_worker(args,logger=logger)
界面gui.py
#!/bin/python
import wx
from PIL import Image
import numpy as np
import os
from tst import prediect
#文件-hello
def OnHello(event):
wx.MessageBox("遥感图像识别")
#关于
def OnAbout(event):
"""Display an About Dialog"""
wx.MessageBox("这是一个识别单张图像的界面-罗",
"关于 Hello World",
wx.OK | wx.ICON_INFORMATION)
class HelloFrame(wx.Frame):
def __init__(self,*args,**kw):
super(HelloFrame,self).__init__(*args,**kw)
pnl = wx.Panel(self)
self.pnl = pnl
st = wx.StaticText(pnl, label="选择一张图像进行识别", pos=(200, 0))
font = st.GetFont()
font.PointSize += 10
font = font.Bold()
st.SetFont(font)
# 选择图像文件按钮
btn = wx.Button(pnl, -1, "select",pos=(4,4))
btn.SetBackgroundColour("#0a74f7")
#事件
btn.Bind(wx.EVT_BUTTON, self.OnSelect)
self.makeMenuBar()
self.CreateStatusBar()
self.SetStatusText("欢迎来到图像识别系统")
#菜单栏
def makeMenuBar(self):
fileMenu = wx.Menu()
helloItem = fileMenu.Append(-1, "&Hello...\tCtrl-H",
"Help string shown in status bar for this menu item")
fileMenu.AppendSeparator()
exitItem = fileMenu.Append(wx.ID_EXIT)
helpMenu = wx.Menu()
aboutItem = helpMenu.Append(wx.ID_ABOUT)
menuBar = wx.MenuBar()
menuBar.Append(fileMenu, "&File")
menuBar.Append(helpMenu, "Help")
self.SetMenuBar(menuBar)
self.Bind(wx.EVT_MENU, OnHello, helloItem)
self.Bind(wx.EVT_MENU, self.OnExit, exitItem)
self.Bind(wx.EVT_MENU, OnAbout, aboutItem)
#退出
def OnExit(self, event):
self.Close(True)
#select按钮设置
def OnSelect(self, event):
wildcard = "image source(*.jpg)|*.jpg|" \
"Compile Python(*.pyc)|*.pyc|" \
"All file(*.*)|*.*"
dialog = wx.FileDialog(None, "Choose a file", os.getcwd(),
"", wildcard, wx.ID_OPEN)
if dialog.ShowModal() == wx.ID_OK:
print(dialog.GetPath())
img = Image.open(dialog.GetPath())
imag = img.resize([128, 128])
image = np.array(img)
self.initimage(name= dialog.GetPath())
#从tst.py获取 结果
result = prediect(image)
result_text = wx.StaticText(self.pnl, label='', pos=(600, 400), size=(150,50))
result_text.SetLabel(result)
font = result_text.GetFont()
font.PointSize += 8
result_text.SetFont(font)
self.initimage(name= dialog.GetPath())
# 生成图片控件
def initimage(self, name):
imageShow = wx.Image(name, wx.BITMAP_TYPE_ANY)
sb = wx.StaticBitmap(self.pnl, -1, imageShow.ConvertToBitmap(), pos=(0,30), size=(600,400))
return sb
if __name__ == '__main__':
app = wx.App()
frm = HelloFrame(None, title='老罗的识别器', size=(1000,600))
frm.Show()
app.MainLoop()
影像分类
打开ENVI,选择主菜单->Classificatio->Unsupervised->IsoData或者K-mean。如选择IsoData,在选择文件时,可以设置空间或光谱裁剪区。如选择Can-tmr.ing,按默认设置,之后跳出参数设置,如ISODATA非监督分类结果。
类别定义
在display中显示原始影像,在display->overlay->classification,选择ISODATA分类结果,如图所示,在Interactive Class Tool面板中,可以选择 各个分类结果显示。如图
Interactive Class Tool面板中,选择Option->Edit class colors/names。 通过目视或者其他方式识别分类结果,填写相应的类型名称和颜色。
如图
示为最终结果。