计算机视觉-01-基于SegNet和U-Net的遥感图像语义分割(包含代码和数据)

news2025/2/24 21:02:44

文章目录

    • 0. 数据下载
    • 1、介绍
      • 1.1 任务介绍
      • 1.2 数据集介
        • 1.2.1 数据集介绍
        • 1.2.2 数据集处理步骤
      • 1.3 数据处理
      • 1.4 卷积神经网络
        • 1.4.1 SegNet
          • 1.4.1.1 定义SegNet网络
          • 1.4.1.2 读入数据集
          • 1.4.1.3 定义训练过程
          • 1.4.1.4 预测遥感图像
        • 1.4.2 U-Net网络
          • 1.4.2.1 定义U-Net网络
          • 1.4.2.2 读入数据集
          • 1.4.2.3 预测思路
      • 1.5 模型融合
    • 0. 数据下载

0. 数据下载

关注公众号:『AI学习星球
回复:遥感图像语义分割 即可获取数据下载。
论文辅导算法学习可以通过公众号滴滴我
在这里插入图片描述

1、介绍

1.1 任务介绍

基于SegNet和U-Net进行对遥感图像语义分割。

1.2 数据集介

1.2.1 数据集介绍

首先介绍一下数据,我们这次采用的数据集是CCF大数据比赛提供的数据(2015年中国南方某城市的高清遥感图像),这是一个小数据集,里面包含了5张带标注的大尺寸RGB遥感图像(尺寸范围从3000×3000到6000×6000),里面一共标注了4类物体,植被(标记1)、建筑(标记2)、水体(标记3)、道路(标记4)以及其他(标记0)。其中,耕地、林地、草地均归为植被类,为了更好地观察标注情况,我们将其中三幅训练图片可视化如下:蓝色-水体,黄色-房屋,绿色-植被,棕色-马路。
在这里插入图片描述

1.2.2 数据集处理步骤

现在说一说我们的数据处理的步骤。我们现在拥有的是5张大尺寸的遥感图像,我们不能直接把这些图像送入网络进行训练,因为内存承受不了而且他们的尺寸也各不相同。因此,我们首先将他们做随机切割,即随机生成x,y坐标,然后抠出该坐标下256*256的小图,并做以下数据增强操作:

  1. 原图和label图都需要旋转:90度,180度,270度
  2. 原图和label图都需要做沿y轴的镜像操作
  3. 原图做模糊操作
  4. 原图做光照调整操作
  5. 原图做增加噪声操作(高斯噪声,椒盐噪声)

1.3 数据处理

这里我没有采用Keras自带的数据增广函数,而是自己使用opencv编写了相应的增强函数。

img_w = 256  
img_h = 256  
 
image_sets = ['1.png','2.png','3.png','4.png','5.png']
 
def gamma_transform(img, gamma):
    gamma_table = [np.power(x / 255.0, gamma) * 255.0 for x in range(256)]
    gamma_table = np.round(np.array(gamma_table)).astype(np.uint8)
    return cv2.LUT(img, gamma_table)
 
def random_gamma_transform(img, gamma_vari):
    log_gamma_vari = np.log(gamma_vari)
    alpha = np.random.uniform(-log_gamma_vari, log_gamma_vari)
    gamma = np.exp(alpha)
    return gamma_transform(img, gamma)
    
 
def rotate(xb,yb,angle):
    M_rotate = cv2.getRotationMatrix2D((img_w/2, img_h/2), angle, 1)
    xb = cv2.warpAffine(xb, M_rotate, (img_w, img_h))
    yb = cv2.warpAffine(yb, M_rotate, (img_w, img_h))
    return xb,yb
    
def blur(img):
    img = cv2.blur(img, (3, 3));
    return img
 
def add_noise(img):
    for i in range(200): #添加点噪声
        temp_x = np.random.randint(0,img.shape[0])
        temp_y = np.random.randint(0,img.shape[1])
        img[temp_x][temp_y] = 255
    return img
    
    
def data_augment(xb,yb):
    if np.random.random() < 0.25:
        xb,yb = rotate(xb,yb,90)
    if np.random.random() < 0.25:
        xb,yb = rotate(xb,yb,180)
    if np.random.random() < 0.25:
        xb,yb = rotate(xb,yb,270)
    if np.random.random() < 0.25:
        xb = cv2.flip(xb, 1)  # flipcode > 0:沿y轴翻转
        yb = cv2.flip(yb, 1)
        
    if np.random.random() < 0.25:
        xb = random_gamma_transform(xb,1.0)
        
    if np.random.random() < 0.25:
        xb = blur(xb)
    
    if np.random.random() < 0.2:
        xb = add_noise(xb)
        
    return xb,yb
 
def creat_dataset(image_num = 100000, mode = 'original'):
    print('creating dataset...')
    image_each = image_num / len(image_sets)
    g_count = 0
    for i in tqdm(range(len(image_sets))):
        count = 0
        src_img = cv2.imread('./data/src/' + image_sets[i])  # 3 channels
        label_img = cv2.imread('./data/label/' + image_sets[i],cv2.IMREAD_GRAYSCALE)  # single channel
        X_height,X_width,_ = src_img.shape
        while count < image_each:
            random_width = random.randint(0, X_width - img_w - 1)
            random_height = random.randint(0, X_height - img_h - 1)
            src_roi = src_img[random_height: random_height + img_h, random_width: random_width + img_w,:]
            label_roi = label_img[random_height: random_height + img_h, random_width: random_width + img_w]
            if mode == 'augment':
                src_roi,label_roi = data_augment(src_roi,label_roi)
            
            visualize = np.zeros((256,256)).astype(np.uint8)
            visualize = label_roi *50
            
            cv2.imwrite(('./aug/train/visualize/%d.png' % g_count),visualize)
            cv2.imwrite(('./aug/train/src/%d.png' % g_count),src_roi)
            cv2.imwrite(('./aug/train/label/%d.png' % g_count),label_roi)
            count += 1 
            g_count += 1

经过上面数据增强操作后,我们得到了较大的训练集:100000张256*256的图片。
在这里插入图片描述

1.4 卷积神经网络

面对这类图像语义分割的任务,我们可以选取的经典网络有很多,比如FCN,U-Net,SegNet,DeepLab,RefineNet,Mask Rcnn,Hed Net这些都是非常经典而且在很多比赛都广泛采用的网络架构。所以我们就可以从中选取一两个经典网络作为我们这个分割任务的解决方案。我们根据我们小组的情况,选取了U-Net和SegNet作为我们的主体网络进行实验。

1.4.1 SegNet

SegNet已经出来好几年了,这不是一个最新、效果最好的语义分割网络,但是它胜在网络结构清晰易懂,训练快速坑少,所以我们也采取它来做同样的任务。SegNet网络结构是编码器-解码器的结构,非常优雅,值得注意的是,SegNet做语义分割时通常在末端加入CRF模块做后处理,旨在进一步精修边缘的分割结果。
在这里插入图片描述

1.4.1.1 定义SegNet网络
def SegNet():  
    model = Sequential()  
    #encoder  
    model.add(Conv2D(64,(3,3),strides=(1,1),input_shape=(3,img_w,img_h),padding='same',activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(64,(3,3),strides=(1,1),padding='same',activation='relu'))  
    model.add(BatchNormalization())  
    model.add(MaxPooling2D(pool_size=(2,2)))  
    #(128,128)  
    model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(MaxPooling2D(pool_size=(2, 2)))  
    #(64,64)  
    model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(MaxPooling2D(pool_size=(2, 2)))  
    #(32,32)  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(MaxPooling2D(pool_size=(2, 2)))  
    #(16,16)  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(MaxPooling2D(pool_size=(2, 2)))  
    #(8,8)  
    #decoder  
    model.add(UpSampling2D(size=(2,2)))  
    #(16,16)  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(UpSampling2D(size=(2, 2)))  
    #(32,32)  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(UpSampling2D(size=(2, 2)))  
    #(64,64)  
    model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(UpSampling2D(size=(2, 2)))  
    #(128,128)  
    model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(UpSampling2D(size=(2, 2)))  
    #(256,256)  
    model.add(Conv2D(64, (3, 3), strides=(1, 1), input_shape=(3,img_w, img_h), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(64, (3, 3), strides=(1, 1), padding='same', activation='relu'))  
    model.add(BatchNormalization())  
    model.add(Conv2D(n_label, (1, 1), strides=(1, 1), padding='same'))  
    model.add(Reshape((n_label,img_w*img_h)))  
    #axis=1和axis=2互换位置,等同于np.swapaxes(layer,1,2)  
    model.add(Permute((2,1)))  
    model.add(Activation('softmax'))  
    model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])  
    model.summary()  
    return model  
1.4.1.2 读入数据集

这里使用的验证集大小是训练集的0.25。

def get_train_val(val_rate = 0.25):
    train_url = []    
    train_set = []
    val_set  = []
    for pic in os.listdir(filepath + 'src'):
        train_url.append(pic)
    random.shuffle(train_url)
    total_num = len(train_url)
    val_num = int(val_rate * total_num)
    for i in range(len(train_url)):
        if i < val_num:
            val_set.append(train_url[i]) 
        else:
            train_set.append(train_url[i])
    return train_set,val_set
    
# data for training  
def generateData(batch_size,data=[]):  
    #print 'generateData...'
    while True:  
        train_data = []  
        train_label = []  
        batch = 0  
        for i in (range(len(data))): 
            url = data[i]
            batch += 1 
            #print (filepath + 'src/' + url)
            #img = load_img(filepath + 'src/' + url, target_size=(img_w, img_h))  
            img = load_img(filepath + 'src/' + url)
            img = img_to_array(img) 
            # print img
            # print img.shape  
            train_data.append(img)  
            #label = load_img(filepath + 'label/' + url, target_size=(img_w, img_h),grayscale=True)
            label = load_img(filepath + 'label/' + url, grayscale=True)
            label = img_to_array(label).reshape((img_w * img_h,))  
            # print label.shape  
            train_label.append(label)  
            if batch % batch_size==0: 
                #print 'get enough bacth!\n'
                train_data = np.array(train_data)  
                train_label = np.array(train_label).flatten()  
                train_label = labelencoder.transform(train_label)  
                train_label = to_categorical(train_label, num_classes=n_label)  
                train_label = train_label.reshape((batch_size,img_w * img_h,n_label))  
                yield (train_data,train_label)  
                train_data = []  
                train_label = []  
                batch = 0  
 
# data for validation 
def generateValidData(batch_size,data=[]):  
    #print 'generateValidData...'
    while True:  
        valid_data = []  
        valid_label = []  
        batch = 0  
        for i in (range(len(data))):  
            url = data[i]
            batch += 1  
            #img = load_img(filepath + 'src/' + url, target_size=(img_w, img_h))
            img = load_img(filepath + 'src/' + url)
            #print img
            #print (filepath + 'src/' + url)
            img = img_to_array(img)  
            # print img.shape  
            valid_data.append(img)  
            #label = load_img(filepath + 'label/' + url, target_size=(img_w, img_h),grayscale=True)
            label = load_img(filepath + 'label/' + url, grayscale=True)
            label = img_to_array(label).reshape((img_w * img_h,))  
            # print label.shape  
            valid_label.append(label)  
            if batch % batch_size==0:  
                valid_data = np.array(valid_data)  
                valid_label = np.array(valid_label).flatten()  
                valid_label = labelencoder.transform(valid_label)  
                valid_label = to_categorical(valid_label, num_classes=n_label)  
                valid_label = valid_label.reshape((batch_size,img_w * img_h,n_label))  
                yield (valid_data,valid_label)  
                valid_data = []  
                valid_label = []  
                batch = 0  
1.4.1.3 定义训练过程

在这个任务上,我们把batch size定为16,epoch定为30,每次都存储最佳model(save_best_only=True),并且在训练结束时绘制loss/acc曲线,并存储起来。

def train(args): 
    EPOCHS = 30
    BS = 16
    model = SegNet()  
    modelcheck = ModelCheckpoint(args['model'],monitor='val_acc',save_best_only=True,mode='max')  
    callable = [modelcheck]  
    train_set,val_set = get_train_val()
    train_numb = len(train_set)  
    valid_numb = len(val_set)  
    print ("the number of train data is",train_numb)  
    print ("the number of val data is",valid_numb)
    H = model.fit_generator(generator=generateData(BS,train_set),steps_per_epoch=train_numb//BS,epochs=EPOCHS,verbose=1,  
                    validation_data=generateValidData(BS,val_set),validation_steps=valid_numb//BS,callbacks=callable,max_q_size=1)  
 
    # plot the training loss and accuracy
    plt.style.use("ggplot")
    plt.figure()
    N = EPOCHS
    plt.plot(np.arange(0, N), H.history["loss"], label="train_loss")
    plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")
    plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")
    plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")
    plt.title("Training Loss and Accuracy on SegNet Satellite Seg")
    plt.xlabel("Epoch #")
    plt.ylabel("Loss/Accuracy")
    plt.legend(loc="lower left")
    plt.savefig(args["plot"])

训练之后的结果:
在这里插入图片描述

训练loss降到0.1左右,acc可以去到0.9,但是验证集的loss和acc都没那么好,貌似存在点问题。

1.4.1.4 预测遥感图像

这里需要思考一下怎么预测整张遥感图像,我们知道,我们训练模型时选择的图片输入是256×256,所以我们预测时也要采用256×256的图片尺寸送进模型预测。现在我们要考虑一个问题,我们该怎么将这些预测好的小图重新拼接成一个大图呢?这里给出一个最基础的方案:先给大图做padding 0操作,得到一副padding过的大图,同时我们也生成一个与该图一样大的全0图A,把图像的尺寸补齐为256的倍数,然后以256为步长切割大图,依次将小图送进模型预测,预测好的小图则放在A的相应位置上,依次进行,最终得到预测好的整张大图(即A),再做图像切割,切割成原先图片的尺寸,完成整个预测流程。

def predict(args):
    # load the trained convolutional neural network
    print("[INFO] loading network...")
    model = load_model(args["model"])
    stride = args['stride']
    for n in range(len(TEST_SET)):
        path = TEST_SET[n]
        #load the image
        image = cv2.imread('./test/' + path)
        # pre-process the image for classification
        #image = image.astype("float") / 255.0
        #image = img_to_array(image)
        h,w,_ = image.shape
        padding_h = (h//stride + 1) * stride 
        padding_w = (w//stride + 1) * stride
        padding_img = np.zeros((padding_h,padding_w,3),dtype=np.uint8)
        padding_img[0:h,0:w,:] = image[:,:,:]
        padding_img = padding_img.astype("float") / 255.0
        padding_img = img_to_array(padding_img)
        print 'src:',padding_img.shape
        mask_whole = np.zeros((padding_h,padding_w),dtype=np.uint8)
        for i in range(padding_h//stride):
            for j in range(padding_w//stride):
                crop = padding_img[:3,i*stride:i*stride+image_size,j*stride:j*stride+image_size]
                _,ch,cw = crop.shape
                if ch != 256 or cw != 256:
                    print 'invalid size!'
                    continue
                    
                crop = np.expand_dims(crop, axis=0)
                #print 'crop:',crop.shape
                pred = model.predict_classes(crop,verbose=2)  
                pred = labelencoder.inverse_transform(pred[0])  
                #print (np.unique(pred))  
                pred = pred.reshape((256,256)).astype(np.uint8)
                #print 'pred:',pred.shape
                mask_whole[i*stride:i*stride+image_size,j*stride:j*stride+image_size] = pred[:,:]
 
        
        cv2.imwrite('./predict/pre'+str(n+1)+'.png',mask_whole[0:h,0:w])

预测之后的图片:
在这里插入图片描述

1.4.2 U-Net网络

对于这个语义分割任务,我们毫不犹豫地选择了U-Net作为我们的方案,原因很简单,我们参考很多类似的遥感图像分割比赛的资料,绝大多数获奖的选手使用的都是U-Net模型。在这么多的好评下,我们选择U-Net也就毫无疑问了。
U-Net有很多优点,最大卖点就是它可以在小数据集上也能train出一个好的模型,这个优点 对于我们这个任务来说真的非常适合。而且,U-Net在训练速度上也是非常快的,这对于需要短时间就得出结果的期末project来说也是非常合适。U-Net在网络架构上还是非常优雅的,整个呈现U形,故起名U-Net。这里不打算详细介绍U-Net结构,有兴趣的深究的可以看看论文。
在这里插入图片描述

1.4.2.1 定义U-Net网络

现在开始谈谈代码细节。首先我们定义一下U-Net的网络结构,这里用的deep learning框架还是Keras。
注意到,我们这里训练的模型是一个多分类模型,其实更好的做法是,
训练一个二分类模型(使用二分类的标签)
对每一类物体进行预测,得到4张预测图,再做预测图叠加,合并成一张完整的包含4类的预测图,
这个策略在效果上肯定好于一个直接4分类的模型。所以,U-Net这边我们采取的思路就是对于每一类的分类都训练一个二分类模型,最后再将每一类的预测结果组合成一个四分类的结果。
定义U-Net结构,注意这里的loss function我们选了binary_crossentropy,因为我们要训练的是二分类模型。

def unet():
    inputs = Input((3, img_w, img_h))
 
    conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)
    conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
 
    conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(pool1)
    conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
 
    conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(pool2)
    conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
 
    conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(pool3)
    conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
 
    conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(pool4)
    conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(conv5)
 
    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=1)
    conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(up6)
    conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv6)
 
    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=1)
    conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(up7)
    conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv7)
 
    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=1)
    conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(up8)
    conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv8)
 
    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=1)
    conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(up9)
    conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv9)
 
    conv10 = Conv2D(n_label, (1, 1), activation="sigmoid")(conv9)
    #conv10 = Conv2D(n_label, (1, 1), activation="softmax")(conv9)
 
    model = Model(inputs=inputs, outputs=conv10)
    model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model
1.4.2.2 读入数据集
# data for training  
def generateData(batch_size,data=[]):  
    #print 'generateData...'
    while True:  
        train_data = []  
        train_label = []  
        batch = 0  
        for i in (range(len(data))): 
            url = data[i]
            batch += 1 
            img = load_img(filepath + 'src/' + url)
            img = img_to_array(img) 
            train_data.append(img)  
            label = load_img(filepath + 'label/' + url, grayscale=True) 
            label = img_to_array(label)
            #print label.shape  
            train_label.append(label)  
            if batch % batch_size==0: 
                #print 'get enough bacth!\n'
                train_data = np.array(train_data)  
                train_label = np.array(train_label)  
 
                yield (train_data,train_label)  
                train_data = []  
                train_label = []  
                batch = 0  
 
# data for validation 
def generateValidData(batch_size,data=[]):  
    #print 'generateValidData...'
    while True:  
        valid_data = []  
        valid_label = []  
        batch = 0  
        for i in (range(len(data))):  
            url = data[i]
            batch += 1  
            img = load_img(filepath + 'src/' + url)
            #print img
            img = img_to_array(img)  
            # print img.shape  
            valid_data.append(img)  
            label = load_img(filepath + 'label/' + url, grayscale=True)
            valid_label.append(label)  
            if batch % batch_size==0:  
                valid_data = np.array(valid_data)  
                valid_label = np.array(valid_label)  
                yield (valid_data,valid_label)  
                valid_data = []  
                valid_label = []  
                batch = 0  

训练:指定输出model名字和训练集位置。

python unet.py --model unet_buildings20.h5 --data ./unet_train/buildings/
1.4.2.3 预测思路

预测单张遥感图像时我们分别使用4个模型做预测,那我们就会得到4张mask(比如下图就是我们用训练好的buildings模型预测的结果),我们现在要将这4张mask合并成1张,那么怎么合并会比较好呢?我思路是,通过观察每一类的预测结果,我们可以从直观上知道哪些类的预测比较准确,那么我们就可以给这些mask图排优先级了,比如:priority:building>water>road>vegetation,那么当遇到一个像素点,4个mask图都说是属于自己类别的标签时,我们就可以根据先前定义好的优先级,把该像素的标签定为优先级最高的标签。代码思路可以参照下面的代码:
在这里插入图片描述

def combind_all_mask():
    for mask_num in tqdm(range(3)):
        if mask_num == 0:
            final_mask = np.zeros((5142,5664),np.uint8)#生成一个全黑全0图像,图片尺寸与原图相同
        elif mask_num == 1:
            final_mask = np.zeros((2470,4011),np.uint8)
        elif mask_num == 2:
            final_mask = np.zeros((6116,3356),np.uint8)
        #final_mask = cv2.imread('final_1_8bits_predict.png',0)
        
        if mask_num == 0:
            mask_pool = mask1_pool
        elif mask_num == 1:
            mask_pool = mask2_pool
        elif mask_num == 2:
            mask_pool = mask3_pool
        final_name = img_sets[mask_num]
        for idx,name in enumerate(mask_pool):
            img = cv2.imread('./predict_mask/'+name,0)
            height,width = img.shape
            label_value = idx+1  #coressponding labels value
            for i in tqdm(range(height)):    #priority:building>water>road>vegetation
                for j in range(width):
                    if img[i,j] == 255:
                        if label_value == 2:
                            final_mask[i,j] = label_value
                        elif label_value == 3 and final_mask[i,j] != 2:
                            final_mask[i,j] = label_value
                        elif label_value == 4 and final_mask[i,j] != 2 and final_mask[i,j] != 3:
                            final_mask[i,j] = label_value
                        elif label_value == 1 and final_mask[i,j] == 0:
                            final_mask[i,j] = label_value
                        
        cv2.imwrite('./final_result/'+final_name,final_mask)           
                
                
print 'combinding mask...'
combind_all_mask() 

1.5 模型融合

集成学习的方法在比赛中经常使用,要想获得好成绩集成学习必须做得好。在这里简单谈谈思路,我们使用了两个模型,我们模型也会采取不同参数去训练和预测,那么我们就会得到很多预测MASK图,此时 我们可以采取模型融合的思路,对每张结果图的每个像素点采取投票表决的思路,对每张图相应位置的像素点的类别进行预测,票数最多的类别即为该像素点的类别。正所谓“三个臭皮匠,胜过诸葛亮”,我们这种ensemble的思路,可以很好地去掉一些明显分类错误的像素点,很大程度上改善模型的预测能力。
少数服从多数的投票表决策略代码:

import numpy as np
import cv2
import argparse
 
RESULT_PREFIXX = ['./result1/','./result2/','./result3/']
 
# each mask has 5 classes: 0~4
 
def vote_per_image(image_id):
    result_list = []
    for j in range(len(RESULT_PREFIXX)):
        im = cv2.imread(RESULT_PREFIXX[j]+str(image_id)+'.png',0)
        result_list.append(im)
        
    # each pixel
    height,width = result_list[0].shape
    vote_mask = np.zeros((height,width))
    for h in range(height):
        for w in range(width):
            record = np.zeros((1,5))
            for n in range(len(result_list)):
                mask = result_list[n]
                pixel = mask[h,w]
                #print('pix:',pixel)
                record[0,pixel]+=1
           
            label = record.argmax()
            #print(label)
            vote_mask[h,w] = label
    
    cv2.imwrite('vote_mask'+str(image_id)+'.png',vote_mask)
        
 
vote_per_image(3)

模型融合后的预测结果:
在这里插入图片描述

0. 数据下载

关注公众号:『AI学习星球
回复:遥感图像语义分割 即可获取数据下载。
论文辅导算法学习可以通过公众号滴滴我
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1285812.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

全息图着色器插件:Hologram Shaders Pro for URP, HDRP Built-in

8个新的Unity全息图着色器,具有故障效果,扫描线,网格线,和更多其他效果!与所有渲染管线兼容。 软件包添加了一系列的全息图着色器到Unity。从基本的全息图与菲涅耳亮点,先进的全息图与两种故障效应,扫描线,文体点阵和网格线全息图! 特色全息效果 Basic-支持菲涅耳发光照…

计算机速成课Crash Course - 07. 中央处理器

今天开始计算机速成课Crash Course的系列讲解。 更多技术文章&#xff0c;全网首发公众号 “摸鱼IT” 锁定 -上午11点 - &#xff0c;感谢大家关注、转发、点赞&#xff01; 计算机速成课Crash Course - 07. 中央处理器 07. 中央处理器 提示下&#xff0c;这集可能是最难的一…

python之记录程序运行时长工具

python之记录程序运行时长工具 废话不多话&#xff0c;上代码 from datetime import datetime, timedelta import timestart_time datetime.now()while True:current_time datetime.now()elapsed_time current_time - start_timeformatted_time str(elapsed_time).split(…

QGIS之二十六pbf转osm转shp

效果 步骤 1、下载工具 用于转换osm.pbf–>.osm,当然也可以反过来,还支持其它格式互相转换 osmconvert64-0.8.8p.exe 链接:https://pan.baidu.com/s/1Mj-6b30f6voOkQI8QFh_rw 提取码:1111 2、国内各省下载OSM数据 http://download.openstreetmap.fr/extracts/asia/c…

AI入侵B站鬼畜区!网友辣评:不如传统“活字乱刷术”

11月27日&#xff0c;B站UP主“女孩为何穿短裙”突破传统&#xff0c;投稿一则使用AI合成语音制作的鬼畜视频&#xff0c;标志着AI视频制作正式进入B站鬼畜区。视频播放量截至目前已达167.3万&#xff0c;获得14.5万的点赞和2.8万个投币。 鬼畜视频一直以其独特之处引起关注&a…

vue el-select多选封装及使用

使用了Element UI库中的el-select和el-option组件来构建多选下拉框。同时&#xff0c;也包含了一个el-input组件用于过滤搜索选择项&#xff0c;以及el-checkbox-group和el-checkbox组件用于显示多选项。 创建组件index.vue (src/common-ui/selectMultiple/index.vue) <tem…

宝塔部署appache部署ssl证书无法访问443端口

原因&#xff1a; 不是部署方法错误&#xff0c;而是操作不当&#xff0c;原来一开始为了测试我去修改了appache默配置路径下的httpd-ssl.donf&#xff0c;此文件一般 在appche/conf/extra/目录下&#xff08;版本不同目录可能有所区别&#xff09;。 导致问题&#xff1a; 在…

java连接池 理解及解释(DBCP、druid、c3p0、HikariCP)

一、在Java开发中&#xff0c;有许多常见的数据库连接池可供选择。以下是一些常见的Java数据库连接池&#xff1a;不使用数据库连接池的特性&#xff1a; 优点&#xff1a;实现简单 缺点&#xff1a;网络 IO 较多数据库的负载较高响应时间较长及 QPS 较低应用频繁的创建连接和关…

高低压供配电智能监控系统

高低压供配电智能监控系统是一种综合运用物联网、云计算、大数据和人工智能等技术的智能化监控系统&#xff0c;用于实时监测高低压供配电设备的运行状态和电能质量&#xff0c;及时发现和处理供配电系统中存在的问题&#xff0c;提高供配电系统的安全性和可靠性。依托电易云-智…

漏洞复现--万户ezoffice wpsservlet任意文件上传

免责声明&#xff1a; 文章中涉及的漏洞均已修复&#xff0c;敏感信息均已做打码处理&#xff0c;文章仅做经验分享用途&#xff0c;切勿当真&#xff0c;未授权的攻击属于非法行为&#xff01;文章中敏感信息均已做多层打马处理。传播、利用本文章所提供的信息而造成的任何直…

A股风格因子看板 (2023.12第1期)

该因子看板跟踪A股风格因子&#xff0c;该因子主要解释沪深两市的市场收益、刻画市场风格趋势的系列风格因子&#xff0c;用以分析市场风格切换、组合风格景 露组合等。 今日为该因子跟踪第1期&#xff0c;指数组合数据截止日2023-11-30&#xff0c;要点如下 近1年A股风格因子收…

如何理解方块电阻与宽度的关系(RPSQ_VS_SI_WIDTH)

我正在「拾陆楼」和朋友们讨论有趣的话题&#xff0c;你⼀起来吧&#xff1f; 拾陆楼知识星球入口 来自星球提问: 解释如下: 单看方块电阻的公式&#xff0c;Rs&#xff1d;电阻率/厚度 在半导体制造过程中&#xff0c;由于工艺偏差&#xff0c;电阻跟金属线的density是相关的…

windows10系统下替换、修改jar中的文件并重新打包成jar文件然后运行

目录 1、jar文件简述2、问题来源3、操作步骤3.1 解压jar包3.2 替换或者更改操作3.3 重新打成jar包3.4 确认是否修改成功3.5 运行程序 附录&#xff1a;常见命令参数 1、jar文件简述 JAR 文件就是 Java Archive &#xff08; Java 档案文件&#xff09;&#xff0c;它是 Java 的…

IntelliJ IDEA 2023.2新特性详解第三弹!Docker、Kubernetes等支持!

9 Docker 在 Docker 镜像层内预览文件 现在可以在 Services&#xff08;服务&#xff09;工具窗口中轻松访问和预览 Docker 镜像层的内容。 从列表选择镜像&#xff0c;选择 Show layers&#xff08;显示层&#xff09;&#xff0c;然后点击 Analyze image for more informati…

电气自动化专业求职简历11篇

电气自动化专业求职简历下载&#xff08;可在线编辑制作&#xff09;&#xff1a;来幻主简历&#xff0c;做好简历&#xff01; 电气自动化专业简历1&#xff1a; 求职意向 求职类型&#xff1a;全职 意向岗位&#xff1a;自动化工程师 意向城市&#xff1a;广东广州…

二维码智慧门牌管理系统升级解决方案:优化制牌存疑管理

文章目录 前言一、解决方案关键特点二、解决方案的优势 前言 二维码智慧门牌管理系统在城市管理中发挥着重要作用&#xff0c;然而&#xff0c;制牌审核过程中遇到存疑地址数据是常见问题&#xff0c;需要更有效的处理方法。为此&#xff0c;我们提出了二维码智慧门牌管理系统…

编写并调试运行一个简单的 Java 应用程序,显示自己的学号、姓名、兴趣爱好等。

源代码&#xff1a; public class Main { public static void main(String[] args) { System.out.println("学号是:""0233217821"); System.out.println("姓名是:""赵港"); System.out.println("兴趣爱好是:""运动&qu…

elasticsearch 内网下如何以离线的方式上传任意的huggingFace上的NLP模型(国内避坑指南)

es自2020年的8.x版本以来&#xff0c;就提供了机器学习的能力。我们可以使用es官方提供的工具eland&#xff0c;将hugging face上的NLP模型&#xff0c;上传到es集群中。利用es的机器学习模块&#xff0c;来运维部署管理模型。配合es的管道处理&#xff0c;来更加便捷的处理数据…

软件性能测试之压力测试详解

压力测试 压力测试是一种软件测试&#xff0c;用于验证软件应用程序的稳定性和可靠性。压力测试的目标是在极其沉重的负载条件下测量软件的健壮性和错误处理能力&#xff0c;并确保软件在危急情况下不会崩溃。它甚至可以测试超出正常工作点的测试&#xff0c;并评估软件在极端条…

15:00的面试,15:06就出来了,问的问题过于变态了。。。

从小厂出来&#xff0c;没想到在另一家公司又寄了。 到这家公司开始上班&#xff0c;加班是每天必不可少的&#xff0c;看在钱给的比较多的份上&#xff0c;就不太计较了。没想到5月一纸通知&#xff0c;所有人不准加班&#xff0c;加班费不仅没有了&#xff0c;薪资还要降40%…