昇思MindSpore学习笔记6-01LLM原理和实践--FCN图像语义分割

news2024/9/22 19:26:54

摘要:

        记录MindSpore AI框架使用FCN全卷积网络理解图像进行图像语议分割的过程、步骤和方法。包括环境准备、下载数据集、数据集加载和预处理、构建网络、训练准备、模型训练、模型评估、模型推理等。

一、

1.语义分割

图像语义分割

semantic segmentation

        图像处理

        机器视觉

                图像理解

        AI领域重要分支

        应用

                人脸识别

                物体检测

                医学影像

                卫星图像分析

                自动驾驶感知

        目的

                图像每个像素点分类

                输出与输入大小相同的图像

                输出图像的每个像素对应了输入图像每个像素的类别

        图像领域语义

                图像的内容

                对图片意思的理解

实例

2.FCN全卷积网络

Fully Convolutional Networks

图像语义分割框架

        2015年UC Berkeley提出

        端到端(end to end)像素级(pixel level)预测全卷积网络

全卷积神经网络主要使用三种技术:

1.卷积化Convolutional

VGG-16

        FCN的backbone

        输入224*224RGB图像

                固定大小的输入

                丢弃了空间坐标

                产生非空间输出

        输出1000个预测值

卷积层

        输出二维矩阵

        生成输入图片映射的heatmap

2.上采样Upsample

卷积过程

        卷积操作

        池化操作

特征图尺寸变小

上采样操作

        得到原图大小的稠密图像预测

双线性插值参数

初始化上采样逆卷积参数

反向传播学习非线性上采样

3.跳跃结构Skip Layer

将深层的全局信息与浅层的局部信息相结合

                             底层stride 32的预测FCN-32s    2倍上采样

融合(相加)  pool4层stride 16的预测FCN-16s    2倍上采样

融合(相加)  pool3层stride 8的预测FCN-8s

特点:

(1)不含全连接层(fc)的全卷积(fully conv)网络,可适应任意尺寸输入。

(2)增大数据尺寸的反卷积(deconv)层,能够输出精细的结果。

(3)结合不同深度层结果的跳级(skip)结构,同时确保鲁棒性和精确性。

二、环境准备

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore

输出:

Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 

三、数据处理

1.下载数据集

from download import download
​
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar"
​
download(url, "./dataset", kind="tar", replace=True)

输出:

Creating data folder...
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar (537.2 MB)

file_sizes: 100%|█████████████████████████████| 563M/563M [00:03<00:00, 160MB/s]
Extracting tar file...
Successfully downloaded / unzipped to ./dataset
'./dataset'

2.数据预处理

PASCAL VOC 2012数据集图像分辨率不一致

        标准化处理

3.数据加载

混合PASCAL VOC 2012数据集SDB数据集

import numpy as np
import cv2
import mindspore.dataset as ds
​
class SegDataset:
    def __init__(self,
                 image_mean,
                 image_std,
                 data_file='',
                 batch_size=32,
                 crop_size=512,
                 max_scale=2.0,
                 min_scale=0.5,
                 ignore_label=255,
                 num_classes=21,
                 num_readers=2,
                 num_parallel_calls=4):
​
        self.data_file = data_file
        self.batch_size = batch_size
        self.crop_size = crop_size
        self.image_mean = np.array(image_mean, dtype=np.float32)
        self.image_std = np.array(image_std, dtype=np.float32)
        self.max_scale = max_scale
        self.min_scale = min_scale
        self.ignore_label = ignore_label
        self.num_classes = num_classes
        self.num_readers = num_readers
        self.num_parallel_calls = num_parallel_calls
        max_scale > min_scale
​
    def preprocess_dataset(self, image, label):
        image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)
        label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
        sc = np.random.uniform(self.min_scale, self.max_scale)
        new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])
        image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
        label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
​
        image_out = (image_out - self.image_mean) / self.image_std
        out_h, out_w = max(new_h, self.crop_size), max(new_w, self.crop_size)
        pad_h, pad_w = out_h - new_h, out_w - new_w
        if pad_h > 0 or pad_w > 0:
            image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
            label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label)
        offset_h = np.random.randint(0, out_h - self.crop_size + 1)
        offset_w = np.random.randint(0, out_w - self.crop_size + 1)
        image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :]
        label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w+self.crop_size]
        if np.random.uniform(0.0, 1.0) > 0.5:
            image_out = image_out[:, ::-1, :]
            label_out = label_out[:, ::-1]
        image_out = image_out.transpose((2, 0, 1))
        image_out = image_out.copy()
        label_out = label_out.copy()
        label_out = label_out.astype("int32")
        return image_out, label_out
​
    def get_dataset(self):
        ds.config.set_numa_enable(True)
        dataset = ds.MindDataset(self.data_file, columns_list=["data", "label"],
                                 shuffle=True, num_parallel_workers=self.num_readers)
        transforms_list = self.preprocess_dataset
        dataset = dataset.map(operations=transforms_list, input_columns=["data", "label"],
                              output_columns=["data", "label"],
                              num_parallel_workers=self.num_parallel_calls)
        dataset = dataset.shuffle(buffer_size=self.batch_size * 10)
        dataset = dataset.batch(self.batch_size, drop_remainder=True)
        return dataset
​
​
# 定义创建数据集的参数
IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"
​
# 定义模型训练参数
train_batch_size = 4
crop_size = 512
min_scale = 0.5
max_scale = 2.0
ignore_label = 255
num_classes = 21
​
# 实例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,
                     image_std=IMAGE_STD,
                     data_file=DATA_FILE,
                     batch_size=train_batch_size,
                     crop_size=crop_size,
                     max_scale=max_scale,
                     min_scale=min_scale,
                     ignore_label=ignore_label,
                     num_classes=num_classes,
                     num_readers=2,
                     num_parallel_calls=4)
​
dataset = dataset.get_dataset()

4.训练集可视化

import numpy as np
import matplotlib.pyplot as plt
​
plt.figure(figsize=(16, 8))
​
# 对训练集中的数据进行展示
for i in range(1, 9):
    plt.subplot(2, 4, i)
    show_data = next(dataset.create_dict_iterator())
    show_images = show_data["data"].asnumpy()
    show_images = np.clip(show_images, 0, 1)
# 将图片转换HWC格式后进行展示
    plt.imshow(show_images[0].transpose(1, 2, 0))
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0)
plt.show()

输出:

四、网络构建

FCN网络流程

        输入图像image

        pool1池化

                尺寸变为原始尺寸的1/2

        pool2池化

                尺寸变为原始尺寸的1/4

        pool3池化

                尺寸变为原始尺寸的1/8

        pool4池化

                尺寸变为原始尺寸的1/16

        pool5池化

                尺寸变为原始尺寸的1/32

        conv6-7卷积

                输出尺寸原图的1/32

        FCN-32s

                反卷积扩大到原始尺寸

        FCN-16s

                融合

                        conv7反卷积尺寸扩大两倍至原图的1/16

                        pool4特征图

                反卷积扩大到原始尺寸

        FCN-8s

                融合

                        conv7反卷积尺寸扩大4倍

                        pool4特征图反卷积扩大2倍

                        pool3特征图

                反卷积扩大到原始尺寸

构建FCN-8s网络代码:

import mindspore.nn as nn
​
class FCN8s(nn.Cell):
    def __init__(self, n_class):
        super().__init__()
        self.n_class = n_class
        self.conv1 = nn.SequentialCell(
            nn.Conv2d(in_channels=3, out_channels=64,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.SequentialCell(
            nn.Conv2d(in_channels=64, out_channels=128,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.SequentialCell(
            nn.Conv2d(in_channels=128, out_channels=256,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv4 = nn.SequentialCell(
            nn.Conv2d(in_channels=256, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv5 = nn.SequentialCell(
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv6 = nn.SequentialCell(
            nn.Conv2d(in_channels=512, out_channels=4096,
                      kernel_size=7, weight_init='xavier_uniform'),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )
        self.conv7 = nn.SequentialCell(
            nn.Conv2d(in_channels=4096, out_channels=4096,
                      kernel_size=1, weight_init='xavier_uniform'),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )
        self.score_fr = nn.Conv2d(in_channels=4096, out_channels=self.n_class,
                                  kernel_size=1, weight_init='xavier_uniform')
        self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
                                           kernel_size=4, stride=2, weight_init='xavier_uniform')
        self.score_pool4 = nn.Conv2d(in_channels=512, out_channels=self.n_class,
                                     kernel_size=1, weight_init='xavier_uniform')
        self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
                                                kernel_size=4, stride=2, weight_init='xavier_uniform')
        self.score_pool3 = nn.Conv2d(in_channels=256, out_channels=self.n_class,
                                     kernel_size=1, weight_init='xavier_uniform')
        self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
                                           kernel_size=16, stride=8, weight_init='xavier_uniform')
​
    def construct(self, x):
        x1 = self.conv1(x)
        p1 = self.pool1(x1)
        x2 = self.conv2(p1)
        p2 = self.pool2(x2)
        x3 = self.conv3(p2)
        p3 = self.pool3(x3)
        x4 = self.conv4(p3)
        p4 = self.pool4(x4)
        x5 = self.conv5(p4)
        p5 = self.pool5(x5)
        x6 = self.conv6(p5)
        x7 = self.conv7(x6)
        sf = self.score_fr(x7)
        u2 = self.upscore2(sf)
        s4 = self.score_pool4(p4)
        f4 = s4 + u2
        u4 = self.upscore_pool4(f4)
        s3 = self.score_pool3(p3)
        f3 = s3 + u4
        out = self.upscore8(f3)
        return out

五、训练准备

1.导入VGG-16部分预训练权重

from download import download
from mindspore import load_checkpoint, load_param_into_net
​
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt"
download(url, "fcn8s_vgg16_pretrain.ckpt", replace=True)
def load_vgg16():
    ckpt_vgg16 = "fcn8s_vgg16_pretrain.ckpt"
    param_vgg = load_checkpoint(ckpt_vgg16)
    load_param_into_net(net, param_vgg)

输出:

Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt (513.2 MB)

file_sizes: 100%|█████████████████████████████| 538M/538M [00:03<00:00, 179MB/s]
Successfully downloaded file to fcn8s_vgg16_pretrain.ckpt

2.损失函数

交叉熵损失函数

mindspore.nn.CrossEntropyLoss()

计算FCN网络输出与mask之间的交叉熵损失

3.自定义评价指标 Metrics

用于评估模型效果

设共有 K+1个类

        从L_0 到L_{ki}

        其中包含一个空类或背景

P_{ij}表示本属于i类但被预测为j类的像素数量

P_{ii}表示真正的数量

P_{ij}P_{ji}则分别被解释为假正和假负

Pixel Accuracy

PA像素精度

        标记正确的像素占总像素的比例。

PA=\frac{\sum_{i=0}^{k}P_{ii}}{\sum_{i=0}^{k}\sum_{j=0}^{k}P_{ij}}

Mean Pixel Accuracy

MPA均像素精度

计算每个类内正确分类像素数的比例

求所有类的平均

MPA=\frac{1}{K+1}\sum \sum_{i=0}^{k}\frac{P_{ii}}{\sum_{j=0}^{k}P_{ij}}

Mean Intersection over Union

MloU均交并比

        语义分割的标准度量

                计算两个集合的交集和并集之比

                        交集为真实值(ground truth)

                        并集为预测值(predicted segmentation)

                两者之比:正真数 (intersection) /(真正+假负+假正(并集))

                在每个类上计算loU

                平均

MIoU=\frac{1}{K+1} \sum_{i=0}^{k}\frac{p_{ii}}{\sum_{j=0}^{k}p_{ij}+{\sum_{j=0}^{k}p_{ji}}-p_{ii}}

Frequency Weighted Intersection over Union

FWIoU频权交井比

根据每个类出现的频率设置权重

FWIoU=\frac{1}{\sum_{i=0}^{k}\sum_{j=0}^{k}p_{ij}} \sum_{i=0}^{k}\frac{p_{ii}}{\sum_{j=0}^{k}p_{ij}+{\sum_{j=0}^{k}p_{ji}}-p_{ii}}

import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.train as train
​
class PixelAccuracy(train.Metric):
    def __init__(self, num_class=21):
        super(PixelAccuracy, self).__init__()
        self.num_class = num_class
​
    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix
​
    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)
​
    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)
​
    def eval(self):
        pixel_accuracy = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
        return pixel_accuracy
​
​
class PixelAccuracyClass(train.Metric):
    def __init__(self, num_class=21):
        super(PixelAccuracyClass, self).__init__()
        self.num_class = num_class
​
    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix
​
    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)
​
    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)
​
    def eval(self):
        mean_pixel_accuracy = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
        mean_pixel_accuracy = np.nanmean(mean_pixel_accuracy)
        return mean_pixel_accuracy
​
​
class MeanIntersectionOverUnion(train.Metric):
    def __init__(self, num_class=21):
        super(MeanIntersectionOverUnion, self).__init__()
        self.num_class = num_class
​
    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix
​
    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)
​
    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)
​
    def eval(self):
        mean_iou = np.diag(self.confusion_matrix) / (
            np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
            np.diag(self.confusion_matrix))
        mean_iou = np.nanmean(mean_iou)
        return mean_iou
​
​
class FrequencyWeightedIntersectionOverUnion(train.Metric):
    def __init__(self, num_class=21):
        super(FrequencyWeightedIntersectionOverUnion, self).__init__()
        self.num_class = num_class
​
    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix
​
    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)
​
    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)
​
    def eval(self):
        freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
        iu = np.diag(self.confusion_matrix) / (
            np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
            np.diag(self.confusion_matrix))
​
        frequency_weighted_iou = (freq[freq > 0] * iu[freq > 0]).sum()
        return frequency_weighted_iou

六、模型训练

导入VGG-16预训练参数

实例化损失函数、优化器

Model接口编译网络

训练FCN-8s网络

import mindspore
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Model
​
device_target = "Ascend"
mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target=device_target)
​
train_batch_size = 4
num_classes = 21
# 初始化模型结构
net = FCN8s(n_class=21)
# 导入vgg16预训练参数
load_vgg16()
# 计算学习率
min_lr = 0.0005
base_lr = 0.05
train_epochs = 1
iters_per_epoch = dataset.get_dataset_size()
total_step = iters_per_epoch * train_epochs
​
lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
                                            base_lr,
                                            total_step,
                                            iters_per_epoch,
                                            decay_epoch=2)
lr = Tensor(lr_scheduler[-1])
​
# 定义损失函数
loss = nn.CrossEntropyLoss(ignore_index=255)
# 定义优化器
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001)
# 定义loss_scale
scale_factor = 4
scale_window = 3000
loss_scale_manager = ms.amp.DynamicLossScaleManager(scale_factor, scale_window)
# 初始化模型
if device_target == "Ascend":
    model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
else:
    model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
​
# 设置ckpt文件保存的参数
time_callback = TimeMonitor(data_size=iters_per_epoch)
loss_callback = LossMonitor()
callbacks = [time_callback, loss_callback]
save_steps = 330
keep_checkpoint_max = 5
config_ckpt = CheckpointConfig(save_checkpoint_steps=10,
                               keep_checkpoint_max=keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="FCN8s",
                                directory="./ckpt",
                                config=config_ckpt)
callbacks.append(ckpt_callback)
model.train(train_epochs, dataset, callbacks=callbacks)

输出:

epoch: 1 step: 1, loss is 3.0504844
epoch: 1 step: 2, loss is 3.017057
epoch: 1 step: 3, loss is 2.9523003
epoch: 1 step: 4, loss is 2.9488814
epoch: 1 step: 5, loss is 2.666231
epoch: 1 step: 6, loss is 2.7145326
epoch: 1 step: 7, loss is 1.796408
epoch: 1 step: 8, loss is 1.5167583
epoch: 1 step: 9, loss is 1.6862022
epoch: 1 step: 10, loss is 2.4622822
......
epoch: 1 step: 1141, loss is 1.70966
epoch: 1 step: 1142, loss is 1.434751
epoch: 1 step: 1143, loss is 2.406475
Train epoch time: 762889.258 ms, per step time: 667.445 ms

七、模型评估

IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"
​
# 下载已训练好的权重文件
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt"
download(url, "FCN8s.ckpt", replace=True)
net = FCN8s(n_class=num_classes)
​
ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
​
if device_target == "Ascend":
    model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
else:
    model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
​
# 实例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,
                     image_std=IMAGE_STD,
                     data_file=DATA_FILE,
                     batch_size=train_batch_size,
                     crop_size=crop_size,
                     max_scale=max_scale,
                     min_scale=min_scale,
                     ignore_label=ignore_label,
                     num_classes=num_classes,
                     num_readers=2,
                     num_parallel_calls=4)
dataset_eval = dataset.get_dataset()
model.eval(dataset_eval)

输出:

Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt (1.00 GB)

file_sizes: 100%|██████████████████████████| 1.08G/1.08G [00:10<00:00, 99.7MB/s]
Successfully downloaded file to FCN8s.ckpt
/
{'pixel accuracy': 0.9734831394168291,
 'mean pixel accuracy': 0.9423324801371116,
 'mean IoU': 0.8961453779807752,
 'frequency weighted IoU': 0.9488883312345654}

八、模型推理

使用训练的网络对模型推理结果进行展示。

import cv2
import matplotlib.pyplot as plt
​
net = FCN8s(n_class=num_classes)
# 设置超参
ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
eval_batch_size = 4
img_lst = []
mask_lst = []
res_lst = []
# 推理效果展示(上方为输入图片,下方为推理效果图片)
plt.figure(figsize=(8, 5))
show_data = next(dataset_eval.create_dict_iterator())
show_images = show_data["data"].asnumpy()
mask_images = show_data["label"].reshape([4, 512, 512])
show_images = np.clip(show_images, 0, 1)
for i in range(eval_batch_size):
    img_lst.append(show_images[i])
    mask_lst.append(mask_images[i])
res = net(show_data["data"]).asnumpy().argmax(axis=1)
for i in range(eval_batch_size):
    plt.subplot(2, 4, i + 1)
    plt.imshow(img_lst[i].transpose(1, 2, 0))
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
    plt.subplot(2, 4, i + 5)
    plt.imshow(res[i])
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()

输出:

九、总结

FCN

        使用全卷积层

        通过学习让图片实现端到端分割。

        优点:

                输入接受任意大小的图像

                高效,避免了由于使用像素块而带来的重复存储和计算卷积的问题。

        待改进之处:

                结果不够精细。比较模糊和平滑,边界处细节不敏感。

                像素分类,没有考虑像素之间的关系(如不连续性和相似性)

                忽略空间规整(spatial regularization)步骤,缺乏空间一致性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1907798.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ARCGIS PRO 要素标注

一、普通模式 1、标注&#xff1a;名称和面积&#xff08;无分数线&#xff09; 语言&#xff1a;Arcade $feature.QLR \nRound($feature.Shape_Area,2) 语言&#xff1a;vbscript [QLR] & " " & Round([Shape_Area],2) 2、标注&#xff1a;名称…

Leetcode 295.数据流的中位数

295.数据流的中位数 问题描述 中位数是有序整数列表中的中间值。如果列表的大小是偶数&#xff0c;则没有中间值&#xff0c;中位数是两个中间值的平均值。 例如 arr [2,3,4] 的中位数是 3 。例如 arr [2,3] 的中位数是 (2 3) / 2 2.5 。 实现 MedianFinder 类: Media…

redis运维:sentinel模式如何查看所有从节点

1. 连接到sentinel redis-cli -h sentinel_host -p sentinel_port如&#xff1a; redis-cli -h {域名} -p 200182. 发现Redis主服务器 连接到哨兵后&#xff0c;我们可以使用SENTINEL get-master-addr-by-name命令来获取当前的Redis主服务器的地址。 SENTINEL get-master-a…

STM32对数码管显示的控制

1、在项目开发过程中会遇到STM32控制的数码管显示应用&#xff0c;这里以四位共阴极数码管显示控制为例讲解&#xff1b;这里采用的控制芯片为STM32F103RCT6。 2、首先要确定数码管的段选的8个引脚连接的单片机的引脚是哪8个&#xff0c;然后确认位选的4个引脚连接的单片机的4…

SpringBoot 启动流程六

SpringBoot启动流程六 这句话是创建一个上下文对象 就是最终返回的那个上下文 我们这个creatApplicationContext方法 是调用的这个方法 传入一个类型 我们通过打断点的方式 就可以看到context里面的东西 加载容器对象 当我们把依赖改成starter-web时 这个容器对象会进行…

【网络安全】对称加密算法

文章目录 非对称加密对称加密&#xff1a;Des 加密3 DES 加密Des 加密java ApiAES 加密算法AES 加密过程AES 密钥拓展 非对称加密 非对称加密算法需要两个密钥&#xff1a;公开密钥&#xff08;publickey:简称公钥&#xff09;和私有密钥&#xff08;privatekey:简称私钥&…

一键配置PCL环境+VTK环境(最简单的方法)

系列文章目录 1. Windows系统下5分钟配置好PCL&#xff08;debug和release&#xff09; 2. PCL1.11.0Qt5.14.2VTK8.2VS2019 环境配置&#xff08;超详细&#xff09; 文章目录 系列文章目录前言一、下载解压文件二、双击运行Setup.bat三、测试视频四、所需文件 前言 之前写过…

Android 开发中 C++ 和Java 日志调试

在 C 中添加堆栈日志 先在 Android.bp 中 添加 ‘libutilscallstack’ shared_libs:["liblog"," libutilscallstack"]在想要打印堆栈的代码中添加 #include <utils/CallStack.h> using android::CallStack;// 在函数中添加 int VisualizerLib_Crea…

20240708 每日AI必读资讯

&#x1f916;破解ChatGPT惊人耗电&#xff01;DeepMind新算法训练提效13倍&#xff0c;能耗暴降10倍 - 谷歌DeepMind研究团队提出了一种加快AI训练的新方法——多模态对比学习与联合示例选择&#xff08;JEST&#xff09;&#xff0c;大大减少了所需的计算资源和时间。 - JE…

python基础篇(9):模块

1 模块简介 Python 模块(Module)&#xff0c;是一个 Python 文件&#xff0c;以 .py 结尾. 模块能定义函数&#xff0c;类和变量&#xff0c;模块里也能包含可执行的代码. 模块的作用: python中有很多各种不同的模块, 每一个模块都可以帮助我们快速的实现一些功能, 比如实现…

实在智能荣获WAIC 2024机器之心重量级奖项——AI隐形冠军TOP 10!

近日&#xff0c;世界人工智能大会&#xff08;WAIC 2024&#xff09;如火如荼召开&#xff0c;自2018年首届举办以来&#xff0c;WAIC已成为全球AI领域最具影响力的国际盛会之一。本届WAIC再度集聚了来自世界各地的政府代表、顶尖科学家、行业领袖和创新企业等&#xff0c;共同…

Redis存储原理与数据模型

Redis存储结构 存储转换 redis-value编码 string int&#xff1a;字符串长度小于等于20切能转成整数raw&#xff1a;字符串长度大于44embstr&#xff1a;字符串长度小于等于44 list quicklist&#xff08;双向链表&#xff09;ziplist&#xff08;压缩链表&#xff09; hash …

7.8作业

一、思维导图 二、 1】按值修改 2】按值查找&#xff0c;返回当前节点的地址 &#xff08;先不考虑重复&#xff0c;如果有重复&#xff0c;返回第一个&#xff09; 3】反转 4】销毁链表 //按值修改 int value_change(linklistptr H,datatype e,int value) {if(HNULL||empty(H…

自动化测试及生成测试报告

Linux安装Selenium进行自动化测试 首先需要安装python、Chrome&ChromeDirver ChromeDriver与Chrome对应版本 #查看chrome版本google-chrome --version或者在浏览器搜索chrome://version/ChromeDriver下载地址这里下载114版本 wget https://chromedriver.storage.googleap…

数据库图形化管理界面应用 Navicat Premium 使用教程

经同学介绍的一个把数据库可视化的软件Navicat Premium&#xff0c;很好用&#xff0c;在这里分享一下&#xff0c;需要的同学可以去了解看看 一&#xff1a;下载并解压 链接&#xff1a;https://pan.baidu.com/s/1ZcDH6m7EAurAp_QmXWx81A 提取码&#xff1a;e5f6 解压到合…

Windows下载、配置Java JDK开发环境的方法

本文介绍在Windows电脑中&#xff0c;安装JDK&#xff08;Java Development Kit&#xff09;&#xff0c;也就是Java开发工具包的详细方法。 JDK是Java软件开发的基础&#xff0c;由Oracle公司提供&#xff0c;用于构建在Java平台上运行的应用程序与组件等&#xff1b;其已经包…

CnosDB:深入理解时序数据修复函数

CnosDB是一个专注于时序数据处理的数据库。CnosDB针对时序数据的特点设计并实现了三个强大的数据修复函数&#xff1a; timestamp_repair – 对时间戳列进行有效修复&#xff0c;支持插入、删除、不变等操作。value_repair – 对值列进行智能修复&#xff0c;根据时间戳间隔和…

Django 新增数据 save()方法

1&#xff0c;添加模型 Test/app11/models.py from django.db import modelsclass Book(models.Model):title models.CharField(max_length100)author models.CharField(max_length100)publication_date models.DateField()price models.DecimalField(max_digits5, decim…

数据分析与挖掘实战案例-电商产品评论数据情感分析

数据分析与挖掘实战案例-电商产品评论数据情感分析 文章目录 数据分析与挖掘实战案例-电商产品评论数据情感分析1. 背景与挖掘目标2. 分析方法与过程2.1 评论预处理1. 评论去重2. 数据清洗 2.2 评论分词1. 分词、词性标注、去除停用词2. 提取含名词的评论3. 绘制词云查看分词效…

OS-HACKNOS-2.1

确定靶机IP地址 扫描靶机开放端口信息 目录扫描 访问后发现个邮箱地址 尝试爆破二级目录 确定为wordpress站 利用wpscan进行漏洞扫描 #扫描所有插件 wpscan --url http://192.168.0.2/tsweb -e ap 发现存在漏洞插件 cat /usr/share/exploitdb/exploits/php/webapps/46537.txt…