【我的创作纪念日】使用pix2pixgan实现barts2020数据集的处理(完整版本)

news2024/11/22 6:35:29

使用pix2pixgan (pytorch)实现T1 -> T2的基本代码

使用 https://github.com/eriklindernoren/PyTorch-GAN/ 这里面的pix2pixgan代码进行实现。

进去之后我们需要重新处理数据集,并且源代码里面先训练的生成器,后训练鉴别器。

一般情况下,先训练判别器而后训练生成器是因为这种训练顺序在理论和实践上更加稳定和有效。我们需要改变顺序以及一些代码:

以下是一些原因:

  1. 判别器的任务相对简单:判别器的任务是将真实样本与生成样本区分开来。这相对于生成器而言是一个相对简单的分类任务,因为它只需要区分两种类型的样本。通过先训练判别器,我们可以确保其具有足够的能力来准确识别真实和生成的样本。
  2. 生成器依赖于判别器的反馈:生成器的目标是生成逼真的样本,以尽可能地欺骗判别器。通过先训练判别器,我们可以得到关于生成样本质量的反馈信息。生成器可以根据判别器的反馈进行调整,并逐渐提高生成样本的质量。
  3. 训练稳定性:在GAN的早期训练阶段,生成器产生的样本可能会非常不真实。如果首先训练生成器,那么判别器可能会很容易辨别这些低质量的生成样本,导致梯度更新不稳定。通过先训练判别器,我们可以使生成器更好地适应判别器的反馈,从而增加训练的稳定性。
  4. 避免模式崩溃:在GAN训练过程中,存在模式坍塌的问题,即生成器只学会生成少数几种样本而不是整个数据分布。通过先训练判别器,我们可以提供更多样本的多样性,帮助生成器避免陷入模式崩溃现象。

尽管先训练鉴别器再训练生成器是一种常见的做法,但并不意味着这是唯一正确的方式。根据特定的问题和数据集,有时候也可以尝试其他训练策略,例如逆向训练(先训练生成器)。选择何种顺序取决于具体情况和实验结果。

数据集使用的是BraTs2020数据集,他的介绍和处理方法在我的知识链接里面。目前使用的是个人电脑的GPU跑的。然后数据也只取了前200个训练集,并且20%分出来作为测试集。

并且我们在训练的时候,每隔一定的batch使用matplotlib将T1,生成的T1,真实的T2进行展示,并且将生成器和鉴别器的loss进行展示。

通过比较可以发现使用了逐像素的L1 LOSS可以让生成的结果更好。

在这里插入图片描述

训练10个epoch时的结果图:

在这里插入图片描述

此时的测试结果:

PSNR mean: 21.1621928375993 PSNR std: 1.1501189362634836
NMSE mean: 0.14920212 NMSE std: 0.03501928
SSIM mean: 0.5401535398016223 SSIM std: 0.019281408927679166

代码:

dataloader.py

# dataloader for fine-tuning
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
import torch.utils.data as data
import numpy as np
from PIL import ImageEnhance, Image
import random
import os

def cv_random_flip(img, label):
    # left right flip
    flip_flag = random.randint(0, 2)
    if flip_flag == 1:
        img = np.flip(img, 0).copy()
        label = np.flip(label, 0).copy()
    if flip_flag == 2:
        img = np.flip(img, 1).copy()
        label = np.flip(label, 1).copy()
    return img, label

def randomCrop(image, label):
    border = 30
    image_width = image.size[0]
    image_height = image.size[1]
    crop_win_width = np.random.randint(image_width - border, image_width)
    crop_win_height = np.random.randint(image_height - border, image_height)
    random_region = (
        (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
        (image_height + crop_win_height) >> 1)
    return image.crop(random_region), label.crop(random_region)

def randomRotation(image, label):
    rotate = random.randint(0, 1)
    if rotate == 1:
        rotate_time = random.randint(1, 3)
        image = np.rot90(image, rotate_time).copy()
        label = np.rot90(label, rotate_time).copy()
    return image, label

def colorEnhance(image):
    bright_intensity = random.randint(7, 13) / 10.0
    image = ImageEnhance.Brightness(image).enhance(bright_intensity)
    contrast_intensity = random.randint(4, 11) / 10.0
    image = ImageEnhance.Contrast(image).enhance(contrast_intensity)
    color_intensity = random.randint(7, 13) / 10.0
    image = ImageEnhance.Color(image).enhance(color_intensity)
    sharp_intensity = random.randint(7, 13) / 10.0
    image = ImageEnhance.Sharpness(image).enhance(sharp_intensity)
    return image

def randomGaussian(img, mean=0.002, sigma=0.002):

    def gaussianNoisy(im, mean=mean, sigma=sigma):
        for _i in range(len(im)):
            im[_i] += random.gauss(mean, sigma)
        return im

    flag = random.randint(0, 3)
    if flag == 1:
        width, height = img.shape
        img = gaussianNoisy(img[:].flatten(), mean, sigma)
        img = img.reshape([width, height])

    return img


def randomPeper(img):
    flag = random.randint(0, 3)
    if flag == 1:
        noiseNum = int(0.0015 * img.shape[0] * img.shape[1])
        for i in range(noiseNum):
            randX = random.randint(0, img.shape[0] - 1)
            randY = random.randint(0, img.shape[1] - 1)
            if random.randint(0, 1) == 0:
                img[randX, randY] = 0
            else:
                img[randX, randY] = 1
    return img


class BraTS_Train_Dataset(data.Dataset):
    def __init__(self, source_modal, target_modal, img_size,
                 image_root, data_rate, sort=False, argument=False, random=False):

        self.source = source_modal
        self.target = target_modal
        self.modal_list = ['t1', 't2']
        self.image_root = image_root
        self.data_rate = data_rate
        self.images = [self.image_root + f for f in os.listdir(self.image_root) if f.endswith('.npy')]
        self.images.sort(key=lambda x: int(x.split(image_root)[1].split(".npy")[0]))
        self.img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(img_size)
        ])
        self.gt_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(img_size, Image.NEAREST)
        ])
        self.sort = sort
        self.argument = argument
        self.random = random
        self.subject_num = len(self.images) // 60
        if self.random == True:
            subject = np.arange(self.subject_num)
            np.random.shuffle(subject)
            self.LUT = []
            for i in subject:
                for j in range(60):
                    self.LUT.append(i * 60 + j)
        # print('slice number:', self.__len__())

    def __getitem__(self, index):
        if self.random == True:
            index = self.LUT[index]
        npy = np.load(self.images[index])
        img = npy[self.modal_list.index(self.source), :, :]
        gt = npy[self.modal_list.index(self.target), :, :]
        
        if self.argument == True:
            img, gt = cv_random_flip(img, gt)
            img, gt = randomRotation(img, gt)
            img = img * 255
            img = Image.fromarray(img.astype(np.uint8))
            img = colorEnhance(img)
            img = img.convert('L')

        img = self.img_transform(img)
        gt = self.img_transform(gt)
        return img, gt

    def __len__(self):
        return int(len(self.images) * self.data_rate)

def get_loader(batchsize, shuffle, pin_memory=True, source_modal='t1', target_modal='t2',
               img_size=256, img_root='data/train/', data_rate=0.1, num_workers=8, sort=False, argument=False,
               random=False):
    dataset = BraTS_Train_Dataset(source_modal=source_modal, target_modal=target_modal,
                                  img_size=img_size, image_root=img_root, data_rate=data_rate, sort=sort,
                                  argument=argument, random=random)
    data_loader = data.DataLoader(dataset=dataset, batch_size=batchsize, shuffle=shuffle,
                                  pin_memory=pin_memory, num_workers=num_workers)
    return data_loader




# if __name__=='__main__':
#     data_loader = get_loader(batchsize=1, shuffle=True, pin_memory=True, source_modal='t1',
#                              target_modal='t2', img_size=256, num_workers=8,
#                              img_root='data/train/', data_rate=0.1, argument=True, random=False)
#     length = len(data_loader)
#     print("data_loader的长度为:", length)
#     # 将 data_loader 转换为迭代器
#     data_iter = iter(data_loader)
#
#     # 获取第一批数据
#     batch = next(data_iter)
#
#     # 打印第一批数据的大小
#     print("第一批数据的大小:", batch[0].shape)  # 输入图像的张量
#     print("第一批数据的大小:", batch[1].shape)  # 目标图像的张量
#     print(batch.shape)

models.py

import torch.nn as nn
import torch.nn.functional as F
import torch


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


##############################
#           U-NET
##############################


class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_size))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super(UNetUp, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(out_size),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))

        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)

        return x


class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(GeneratorUNet, self).__init__()

        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512, dropout=0.5)
        self.down5 = UNetDown(512, 512, dropout=0.5)
        self.down6 = UNetDown(512, 512, dropout=0.5)
        self.down7 = UNetDown(512, 512, dropout=0.5)
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)

        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 512, dropout=0.5)
        self.up4 = UNetUp(1024, 512, dropout=0.5)
        self.up5 = UNetUp(1024, 256)
        self.up6 = UNetUp(512, 128)
        self.up7 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(128, out_channels, 4, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        # U-Net generator with skip connections from encoder to decoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)

        return self.final(u7)


##############################
#        Discriminator
##############################


class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels * 2, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1, bias=False)
        )

    def forward(self, img_A, img_B):
        # Concatenate image and condition image by channels to produce input
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

pix2pix.py

import argparse
import os
import numpy as np
import math
import itertools
import time
import datetime
import sys
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

from models import *
from dataloader import *

import torch.nn as nn
import torch.nn.functional as F
import torch
if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
    parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
    parser.add_argument("--dataset_name", type=str, default="basta2020", help="name of the dataset")
    parser.add_argument("--batch_size", type=int, default=2, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
    parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
    parser.add_argument("--img_height", type=int, default=256, help="size of image height")
    parser.add_argument("--img_width", type=int, default=256, help="size of image width")
    parser.add_argument("--channels", type=int, default=3, help="number of image channels")
    parser.add_argument(
        "--sample_interval", type=int, default=500, help="interval between sampling of images from generators"
    )
    parser.add_argument("--checkpoint_interval", type=int, default=10, help="interval between model checkpoints")
    opt = parser.parse_args()
    print(opt)

    os.makedirs("images/%s" % opt.dataset_name, exist_ok=True)
    os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True)

    cuda = True if torch.cuda.is_available() else False

    # Loss functions
    criterion_GAN = torch.nn.MSELoss()
    criterion_pixelwise = torch.nn.L1Loss()

    # Loss weight of L1 pixel-wise loss between translated image and real image
    lambda_pixel = 100

    # Calculate output of image discriminator (PatchGAN)
    patch = (1, opt.img_height // 2 ** 4, opt.img_width // 2 ** 4)

    # Initialize generator and discriminator
    generator = GeneratorUNet(in_channels=1, out_channels=1)
    discriminator = Discriminator(in_channels=1)

    if cuda:
        generator = generator.cuda()
        discriminator = discriminator.cuda()
        criterion_GAN.cuda()
        criterion_pixelwise.cuda()

    if opt.epoch != 0:
        # Load pretrained models
        generator.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % (opt.dataset_name, opt.epoch)))
        discriminator.load_state_dict(torch.load("saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, opt.epoch)))
    else:
        # Initialize weights
        generator.apply(weights_init_normal)
        discriminator.apply(weights_init_normal)

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

    # Configure dataloaders
    transforms_ = [
        transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]

    dataloader = get_loader(batchsize=4, shuffle=True, pin_memory=True, source_modal='t1',
                                 target_modal='t2', img_size=256, num_workers=8,
                                 img_root='data/train/', data_rate=0.1, argument=True, random=False)
    # dataloader = DataLoader(
    #     ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_),
    #     batch_size=opt.batch_size,
    #     shuffle=True,
    #     num_workers=opt.n_cpu,
    # )

    # val_dataloader = DataLoader(
    #     ImageDataset("../../data/%s" % opt.dataset_name, transforms_=transforms_, mode="val"),
    #     batch_size=10,
    #     shuffle=True,
    #     num_workers=1,
    # )

    # Tensor type
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor


    # def sample_images(batches_done):
    #     """Saves a generated sample from the validation set"""
    #     imgs = next(iter(val_dataloader))
    #     real_A = Variable(imgs["B"].type(Tensor))
    #     real_B = Variable(imgs["A"].type(Tensor))
    #     fake_B = generator(real_A)
    #     img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)
    #     save_image(img_sample, "images/%s/%s.png" % (opt.dataset_name, batches_done), nrow=5, normalize=True)


    # ----------
    #  Training
    # ----------

    prev_time = time.time()

    # 创建空列表用于保存损失值
    losses_G = []
    losses_D = []

    for epoch in range(opt.epoch, opt.n_epochs):
        for i, batch in enumerate(dataloader):

            # Model inputs
            real_A = Variable(batch[0].type(Tensor))
            real_B = Variable(batch[1].type(Tensor))
           # print(real_A == real_B)

            # Adversarial ground truths
            valid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)
            fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            # Real loss
            pred_real = discriminator(real_B, real_A)
            loss_real = criterion_GAN(pred_real, valid)

            # Fake loss
            fake_B = generator(real_A)
            pred_fake = discriminator(fake_B.detach(), real_A)
            loss_fake = criterion_GAN(pred_fake, fake)

            # Total loss
            loss_D = 0.5 * (loss_real + loss_fake)

            loss_D.backward()
            optimizer_D.step()

            # ------------------
            #  Train Generators
            # ------------------

            optimizer_G.zero_grad()

            # GAN loss

            pred_fake = discriminator(fake_B, real_A)
            loss_GAN = criterion_GAN(pred_fake, valid)
            # Pixel-wise loss
            loss_pixel = criterion_pixelwise(fake_B, real_B)

            # Total loss
            loss_G = loss_GAN + lambda_pixel * loss_pixel   # 希望生成的接近1

            loss_G.backward()

            optimizer_G.step()

            # --------------
            #  Log Progress
            # --------------

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + i
            batches_left = opt.n_epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s"
                % (
                    epoch,
                    opt.n_epochs,
                    i,
                    len(dataloader),
                    loss_D.item(),
                    loss_G.item(),
                    loss_pixel.item(),
                    loss_GAN.item(),
                    time_left,
                )
            )

            mat = [real_A, fake_B, real_B]

            if (batches_done + 1) % 200 == 0:
                plt.figure(dpi=400)
                ax = plt.subplot(131)
                for i, img in enumerate(mat):
                    ax = plt.subplot(1, 3, i + 1)  #get position
                    img = img.permute([0, 2, 3, 1])  # b c h w ->b h w c
                    if img.shape[0] != 1:   # 有多个就只取第一个
                        img = img[1]
                    img = img.squeeze(0)   # b h w c -> h w c
                    if img.shape[2] == 1:
                        img = img.repeat(1, 1, 3)  # process gray img
                    img = img.cpu()
                    ax.imshow(img.data)
                    ax.set_xticks([])
                    ax.set_yticks([])

                plt.show()

            if (batches_done + 1) % 20 ==0:
                losses_G.append(loss_G.item())
                losses_D.append(loss_D.item())

            if (batches_done + 1) % 200 == 0:  # 每20个batch添加一次损失
                # 保存损失值
                plt.figure(figsize=(10, 5))
                plt.plot(range(int((batches_done + 1) / 20)), losses_G, label="Generator Loss")
                plt.plot(range(int((batches_done + 1) / 20)), losses_D, label="Discriminator Loss")
                plt.xlabel("Epoch")
                plt.ylabel("Loss")
                plt.title("GAN Training Loss Curve")
                plt.legend()
                plt.show()


            # # If at sample interval save image
            # if batches_done % opt.sample_interval == 0:
            #     sample_images(batches_done)

        if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(generator.state_dict(), "saved_models/%s/generator_%d.pth" % (opt.dataset_name, epoch))
            torch.save(discriminator.state_dict(), "saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, epoch))

processing.py 数据预处理

import numpy as np
from matplotlib import pylab as plt
import nibabel as nib
import random
import glob
import os
from PIL import Image
import imageio


def normalize(image, mask=None, percentile_lower=0.2, percentile_upper=99.8):
    if mask is None:
        mask = image != image[0, 0, 0]
    cut_off_lower = np.percentile(image[mask != 0].ravel(), percentile_lower)
    cut_off_upper = np.percentile(image[mask != 0].ravel(), percentile_upper)
    res = np.copy(image)
    res[(res < cut_off_lower) & (mask != 0)] = cut_off_lower
    res[(res > cut_off_upper) & (mask != 0)] = cut_off_upper
    res = res / res.max()  # 0-1

    return res


def visualize(t1_data, t2_data, flair_data, t1ce_data, gt_data):
    plt.figure(figsize=(8, 8))
    plt.subplot(231)
    plt.imshow(t1_data[:, :], cmap='gray')
    plt.title('Image t1')
    plt.subplot(232)
    plt.imshow(t2_data[:, :], cmap='gray')
    plt.title('Image t2')
    plt.subplot(233)
    plt.imshow(flair_data[:, :], cmap='gray')
    plt.title('Image flair')
    plt.subplot(234)
    plt.imshow(t1ce_data[:, :], cmap='gray')
    plt.title('Image t1ce')
    plt.subplot(235)
    plt.imshow(gt_data[:, :])
    plt.title('GT')
    plt.show()


def visualize_to_gif(t1_data, t2_data, t1ce_data, flair_data):
    transversal = []
    coronal = []
    sagittal = []
    slice_num = t1_data.shape[2]
    for i in range(slice_num):
        sagittal_plane = np.concatenate((t1_data[:, :, i], t2_data[:, :, i],
                                         t1ce_data[:, :, i], flair_data[:, :, i]), axis=1)
        coronal_plane = np.concatenate((t1_data[i, :, :], t2_data[i, :, :],
                                        t1ce_data[i, :, :], flair_data[i, :, :]), axis=1)
        transversal_plane = np.concatenate((t1_data[:, i, :], t2_data[:, i, :],
                                            t1ce_data[:, i, :], flair_data[:, i, :]), axis=1)
        transversal.append(transversal_plane)
        coronal.append(coronal_plane)
        sagittal.append(sagittal_plane)
    imageio.mimsave("./transversal_plane.gif", transversal, duration=0.01)
    imageio.mimsave("./coronal_plane.gif", coronal, duration=0.01)
    imageio.mimsave("./sagittal_plane.gif", sagittal, duration=0.01)
    return


if __name__ == '__main__':

    t1_list = sorted(glob.glob(
        '../data/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/*/*t1.*'))
    t2_list = sorted(glob.glob(
        '../data/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/*/*t2.*'))


    data_len = len(t1_list)
    train_len = int(data_len * 0.8)
    test_len = data_len - train_len

    train_path = '../data/train/'
    test_path = '../data/test/'

    os.makedirs(train_path, exist_ok=True)
    os.makedirs(test_path, exist_ok=True)

    for i, (t1_path, t2_path) in enumerate(zip(t1_list, t2_list)):

        print('preprocessing the', i + 1, 'th subject')

        t1_img = nib.load(t1_path)  # (240,140,155)
        t2_img = nib.load(t2_path)

        # to numpy
        t1_data = t1_img.get_fdata()
        t2_data = t2_img.get_fdata()

        t1_data = normalize(t1_data)  # normalize to [0,1]
        t2_data = normalize(t2_data)

        tensor = np.stack([t1_data, t2_data])  # (2, 240, 240, 155)

        if i < train_len:
            for j in range(60):
                Tensor = tensor[:, 10:210, 25:225, 50 + j]
                np.save(train_path + str(60 * i + j + 1) + '.npy', Tensor)
        else:
            for j in range(60):
                Tensor = tensor[:, 10:210, 25:225, 50 + j]
                np.save(test_path + str(60 * (i - train_len) + j + 1) + '.npy', Tensor)

testutil.py

#-*- codeing = utf-8 -*-
#@Time : 2023/9/23 0023 17:21
#@Author : Tom
#@File : testutil.py.py
#@Software : PyCharm
import argparse

from math import log10, sqrt
import numpy as np
from skimage.metrics import structural_similarity as ssim

def psnr(res,gt):
    mse = np.mean((res - gt) ** 2)
    if(mse == 0):
        return 100
    max_pixel = 1
    psnr = 20 * log10(max_pixel / sqrt(mse))
    return psnr


def nmse(res,gt):
    Norm = np.linalg.norm((gt * gt),ord=2)
    if np.all(Norm == 0):
        return 0
    else:
        nmse = np.linalg.norm(((res - gt) * (res - gt)),ord=2) / Norm
    return nmse

test.py

#-*- codeing = utf-8 -*-
#@Time : 2023/9/23 0023 16:14
#@Author : Tom
#@File : test.py.py
#@Software : PyCharm

import torch
from models import *
from dataloader import *
from testutil import *



if __name__ == '__main__':

    images_save = "images_save/"
    slice_num = 4
    os.makedirs(images_save, exist_ok=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = GeneratorUNet(in_channels=1, out_channels=1)

    data_loader = get_loader(batchsize=4, shuffle=True, pin_memory=True, source_modal='t1',
                             target_modal='t2', img_size=256, num_workers=8,
                             img_root='data/test/', data_rate=1, argument=True, random=False)

    model = model.to(device)
    model.load_state_dict(torch.load("saved_models/basta2020/generator_0.pth", map_location=torch.device(device)), strict=False)

    PSNR = []
    NMSE = []
    SSIM = []

    for i, (img, gt) in enumerate(data_loader):
        batch_size = img.size()[0]
        img = img.to(device, dtype=torch.float)
        gt = gt.to(device, dtype=torch.float)

        with torch.no_grad():
            pred = model(img)

        for j in range(batch_size):
            a = pred[j]
            save_image([pred[j]], images_save + str(i * batch_size + j + 1) + '.png', normalize=True)
            print(images_save + str(i * batch_size + j + 1) + '.png')

        pred, gt = pred.cpu().detach().numpy().squeeze(), gt.cpu().detach().numpy().squeeze()

        for j in range(batch_size):
            PSNR.append(psnr(pred[j], gt[j]))
            NMSE.append(nmse(pred[j], gt[j]))
            SSIM.append(ssim(pred[j], gt[j]))

    PSNR = np.asarray(PSNR)
    NMSE = np.asarray(NMSE)
    SSIM = np.asarray(SSIM)

    PSNR = PSNR.reshape(-1, slice_num)
    NMSE = NMSE.reshape(-1, slice_num)
    SSIM = SSIM.reshape(-1, slice_num)

    PSNR = np.mean(PSNR, axis=1)
    print(PSNR.size)
    NMSE = np.mean(NMSE, axis=1)
    SSIM = np.mean(SSIM, axis=1)

    print("PSNR mean:", np.mean(PSNR), "PSNR std:", np.std(PSNR))
    print("NMSE mean:", np.mean(NMSE), "NMSE std:", np.std(NMSE))
    print("SSIM mean:", np.mean(SSIM), "SSIM std:", np.std(SSIM))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1041841.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

亚马逊要求的UL报告的产品标准是什么?如何区分

亚马逊为什么要求电子产品有UL检测报告&#xff1f; 首先&#xff0c;美国是一个对安全要求非常严格的国家&#xff0c;美国本土的所有电子产品生产企业早在很多年前就要求有相关安规检测。 其次&#xff0c;随着亚马逊在全球商业的战略地位不断提高&#xff0c;境外的电子设…

百度资源搜索平台出现:You do not have the proper credential to access this page.怎么办?

Forbidden site not allowed You do not have the proper credential to access this page. If you think this is a server error, please contact the webmaster. 如果你的百度资源平台&#xff0c;点进去出现这个提示&#xff0c;说明您的网站已经被百度清退了。 如果你的网…

队列的分类及用途

队列&#xff08;Queue&#xff09;是一种常见的数据结构&#xff0c;用于存储和管理数据元素。队列通常遵循先进先出&#xff08;FIFO&#xff0c;First-In-First-Out&#xff09;的原则&#xff0c;这意味着最早添加到队列的元素将首先被移除。队列有不同的类型和用途&#x…

VS code本地安装PlantUML

VS code本地安装PlantUML 需要条件vs code安装插件使用常见错误 需要条件 在VS Code上安装PlantUML扩展之前&#xff0c;请确保您具有以下先决条件: : Java与GraphViz(点击可直接跳转下载界面); 安装省略 vs code安装插件 vs code安装以下两个插件&#xff08;PlantUML,Grap…

易云维®智慧工厂数字化管理平台助推工业制造企业数字化转型新动能

近年来&#xff0c;我国正在积极推进工业制造企业数字化转型&#xff0c;工业制造企业数字化转型迎来了密集的利好政策&#xff0c;近期&#xff0c;国家工信部又出台系列政策&#xff0c;实施工业制造企业数字化促进工程&#xff0c;推动工业制造企业更快更好地拥抱数字经济。…

数字安全设备制造有哪几种方式?

数字安全设备制造是指制造用于保护数字信息系统和网络安全的专用设备。以下是几种常见的数字安全设备制造方式&#xff1a; 集成式安全设备制造&#xff1a;集成式安全设备制造是将多种安全功能集成到单一的硬件设备或软件平台中。这种制造方式可以大大降低设备的成本和复杂性&…

vue3 + vite3 addRoute 实现权限管理系统

vue3 vite3 addRoute 实现权限控制 1、前言2、静态路由3、动态路由4、在组建中使用路由5、注意事项 1、前言 在权限系统开发中&#xff0c;根据后端返回的菜单列表动态添加路由是非常常见的需求&#xff0c;它可以实现根据用户权限动态加载可访问的页面。本篇文章我们将重点介…

第二届全国高校计算机技能竞赛——Java赛道

第二届全国高校计算机技能竞赛——Java赛道 小赛跳高 签到题 import java.util.*; public class Main{public static void main(String []args) {Scanner sc new Scanner(System.in);double n sc.nextDouble();for(int i 0; i < 4; i) {n n * 0.9;}System.out.printf(&…

探索公共厕所的数字化治理,智慧公厕完善公共厕所智能化的治理体系

随着城市化进程的不断发展&#xff0c;公共厕所治理成为一个不容忽视的问题。如何通过数字化手段来提升公共厕所管理水平&#xff0c;成为了一个备受关注的话题。本文将以智慧公厕领先厂家广州中期科技有限公司&#xff0c;大量精品案例项目实景实图&#xff0c;探讨公共厕所数…

品牌线上假货怎么治理

随着品牌的发展&#xff0c;母婴、家电、百货等行业&#xff0c;链接量暴增&#xff0c;销售店铺也较多&#xff0c;线上仅通过图片销售的形式&#xff0c;也导致了假货链接地滋生&#xff0c;假货分两种情况&#xff0c;一种是只销售假货的店铺&#xff0c;一种是真假混卖的店…

用numpy生成18种特殊数组

文章目录 单值数组特殊矩阵范德蒙德矩阵数值范围坐标网格绘图代码 所有创建数组的函数中&#xff0c;都有一个可选参数dtype&#xff0c;表示创建的数组的数据类型。 指定维度empty, eye, identity, ones, zeros, full模仿维度empty_like, ones_like, zeros_like, full_like特…

【Linux】C语言实现对文件的加密算法

异或加密 解密方式是进行第二次加密后自动解密 #define BUF_SIZE (16384) //16k /************************************************************** 功能描述: 加密实现 输入参数: --------------------------------------------------------------- 修改作者: 修改日期…

【小尘送书-第五期】《巧用ChatGPT快速提高职场晋升力》用ChatGPT快速提升职场能力,全面促进自身职业发展

大家好&#xff0c;我是小尘&#xff0c;欢迎你的关注&#xff01;大家可以一起交流学习&#xff01;欢迎大家在CSDN后台私信我&#xff01;一起讨论学习&#xff0c;讨论如何找到满意的工作&#xff01; &#x1f468;‍&#x1f4bb;博主主页&#xff1a;小尘要自信 &#x1…

qq录屏快捷键大全,玩转录制就这么简单(干货)

“qq有录屏快捷键吗&#xff1f;有点好奇&#xff0c;现在用qq录制屏幕&#xff0c;总是得去点击屏幕录制才可以出来&#xff0c;太麻烦了&#xff0c;如果可以通过快捷键的方式打开&#xff0c;会轻松许多&#xff0c;想问问大家&#xff0c;知道qq录屏快捷键是多少吗&#xf…

#你我都是国家队#,与泸州老窖一起为中国荣耀干杯

执笔 | 姜 姜 编辑 | 古利特 代表亚洲最高水平的体育盛会已经开幕两天&#xff0c;国家队运动员们在赛场上挥洒汗水&#xff0c;国人的激情也随之升温。 为迎接这场体育盛会&#xff0c;9月13日&#xff0c;TEAM CHINA中国国家队官方微博携手泸州老窖发布了一条态度短片&am…

R语言用标准最小二乘OLS,广义相加模型GAM ,样条函数进行逻辑回归LOGISTIC分类...

原文链接&#xff1a;http://tecdat.cn/?p21379 本文我们对逻辑回归和样条曲线进行介绍&#xff08;点击文末“阅读原文”获取完整代码数据&#xff09;。 logistic回归基于以下假设&#xff1a;给定协变量x&#xff0c;Y具有伯努利分布&#xff0c; 目的是估计参数β。 回想一…

如何在Python中实现高效的数据处理与分析

在当今信息爆炸的时代&#xff0c;我们面对的数据量越来越大&#xff0c;如何高效地处理和分析数据成为了一种迫切的需求。Python作为一种强大的编程语言&#xff0c;提供了丰富的数据处理和分析库&#xff0c;帮助我们轻松应对这个挑战。本文将为您介绍如何在Python中实现高效…

【深度学习实验】卷积神经网络(二):自定义简单的二维卷积神经网络

目录 一、实验介绍 二、实验环境 1. 配置虚拟环境 2. 库版本介绍 三、实验内容 0. 导入必要的工具包 1. 二维互相关运算&#xff08;corr2d&#xff09; 2. 二维卷积层类&#xff08;Conv2D&#xff09; a. __init__&#xff08;初始化&#xff09; b. forward(前向传…

Vue2 常用用法

Vue2 常用用法 Vue 动画1. 进入、离开的过渡2. 列表的过渡3. 状态的过渡 Vue 透传Attrbute、插槽1.透传Attrbute2. 插槽 CSS布局原则flex 布局常见的问题&#xff1a;当子元素内容超出父元素时&#xff0c;不出现滚动条的问题。父元素flex:1且内容超出后的最佳解决方案&#xf…

新版首途影视视频网站源码/22套带后台版全开源+无加密源码(全新二开完整版)

源码简介&#xff1a; 首途影视是一个非常受欢迎的视频网站&#xff0c;提供各种电影、电视剧、综艺节目等内容。它是一个基于Web的视频流媒体平台&#xff0c;你可以随时随地在手机上或电脑上在线观看自己喜欢的影视作品。 新版首途影视视频网站源码/22套带后台版全开源无加…