一,数据集介绍:
数据预处理:
把数据处理成相同大小:
数据集:
PathMNIST:结直肠癌组织学切片;ChestMNIST:胸部CT数据集,来源于NIH-ChestXray14 dataset;DermaMNIST:色素沉着皮肤病变的多源皮肤镜图像等
具体看一个数据集:
"pathmnist": {
"description": "PathMNIST: A dataset based on a prior study for predicting survival from colorectal cancer histology slides, which provides a dataset NCT-CRC-HE-100K of 100,000 non-overlapping image patches from hematoxylin & eosin stained histological images, and a test dataset CRC-VAL-HE-7K of 7,180 image patches from a different clinical center. 9 types of tissues are involved, resulting a multi-class classification task. We resize the source images of 3 x 224 x 224 into 3 x 28 x 28, and split NCT-CRC-HE-100K into training and valiation set with a ratio of 9:1.",
"url": "https://zenodo.org/record/4269852/files/pathmnist.npz?download=1",
"MD5": "a8b06965200029087d5bd730944a56c1",
"task": "multi-class",
"label": {
"0": "adipose",
"1": "background",
"2": "debris",
"3": "lymphocytes",
"4": "mucus",
"5": "smooth muscle",
"6": "normal colon mucosa",
"7": "cancer-associated stroma",
"8": "colorectal adenocarcinoma epithelium"
}
二,基本网络结构
基本的resnet结构:
为了解决传统cnn随着深度的增加,学习效果变差的缺陷,采用了残差网络:
核心模块:
效果如下:
三,构建模型
基本代码单元:dataset.py models.py evalustor.py train.py
首先看一下模型代码models.py:
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
#print(x.shape)
out = F.relu(self.bn1(self.conv1(x)))
#print(out.shape)
out = self.bn2(self.conv2(out))
#print(out.shape)
out += self.shortcut(x)
#print(out.shape)
out = F.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion *
planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, in_channels=1, num_classes=2):
super(ResNet, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
#print(x.shape)
out = F.relu(self.bn1(self.conv1(x)))
#print(out.shape)
out = self.layer1(out) #特征图个数没变,输入是64 输出也是64 在shortcut中不需要调整
#print(out.shape)
out = self.layer2(out)
#print(out.shape)
out = self.layer3(out)
#print(out.shape)
out = self.layer4(out)
#print(out.shape)
out = F.avg_pool2d(out, 4)
#print(out.shape)
out = out.view(out.size(0), -1)
#print(out.shape)
out = self.linear(out)
#print(out.shape)
return out
def ResNet18(in_channels, num_classes):
return ResNet(BasicBlock, [2, 2, 2, 2], in_channels=in_channels, num_classes=num_classes)
def ResNet50(in_channels, num_classes):
return ResNet(Bottleneck, [3, 4, 6, 3], in_channels=in_channels, num_classes=num_classes)
数据处理:dataset.py
import os
import json
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
INFO = "medmnist/medmnist.json"
class MedMNIST(Dataset):
flag = ...
def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
''' dataset
:param split: 'train', 'val' or 'test', select subset
:param transform: data transformation
:param target_transform: target transformation
'''
with open(INFO, 'r') as f:
self.info = json.load(f)[self.flag]
self.root = root
if download:
self.download()
if not os.path.exists(os.path.join(self.root, "{}.npz".format(self.flag))):
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
npz_file = np.load(os.path.join(self.root, "{}.npz".format(self.flag)))
self.split = split
self.transform = transform
self.target_transform = target_transform
if self.split == 'train':
self.img = npz_file['train_images']
self.label = npz_file['train_labels']
elif self.split == 'val':
self.img = npz_file['val_images']
self.label = npz_file['val_labels']
elif self.split == 'test':
self.img = npz_file['test_images']
self.label = npz_file['test_labels']
def __getitem__(self, index):
img, target = self.img[index], self.label[index].astype(int)
img = Image.fromarray(np.uint8(img))
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return self.img.shape[0]
def __repr__(self):
'''Adapted from torchvision.
'''
_repr_indent = 4
head = "Dataset " + self.__class__.__name__
body = ["Number of datapoints: {}".format(self.__len__())]
body.append("Root location: {}".format(self.root))
body.append("Split: {}".format(self.split))
body.append("Task: {}".format(self.info["task"]))
body.append("Number of channels: {}".format(self.info["n_channels"]))
body.append("Meaning of labels: {}".format(self.info["label"]))
body.append("Number of samples: {}".format(self.info["n_samples"]))
body.append("Description: {}".format(self.info["description"]))
body.append("License: {}".format(self.info["license"]))
if hasattr(self, "transforms") and self.transforms is not None:
body += [repr(self.transforms)]
lines = [head] + [" " * _repr_indent + line for line in body]
return '\n'.join(lines)
def download(self):
try:
from torchvision.datasets.utils import download_url
download_url(url=self.info["url"], root=self.root,
filename="{}.npz".format(self.flag), md5=self.info["MD5"])
except:
raise RuntimeError('Something went wrong when downloading! ' +
'Go to the homepage to download manually. ' +
'https://github.com/MedMNIST/MedMNIST')
class PathMNIST(MedMNIST):
flag = "pathmnist"
class OCTMNIST(MedMNIST):
flag = "octmnist"
class PneumoniaMNIST(MedMNIST):
flag = "pneumoniamnist"
class ChestMNIST(MedMNIST):
flag = "chestmnist"
class DermaMNIST(MedMNIST):
flag = "dermamnist"
class RetinaMNIST(MedMNIST):
flag = "retinamnist"
class BreastMNIST(MedMNIST):
flag = "breastmnist"
class OrganMNISTAxial(MedMNIST):
flag = "organmnist_axial"
class OrganMNISTCoronal(MedMNIST):
flag = "organmnist_coronal"
class OrganMNISTSagittal(MedMNIST):
flag = "organmnist_sagittal"
训练模型train.py
import os
import argparse
import json
from tqdm import trange
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from medmnist.models import ResNet18, ResNet50
from medmnist.dataset import INFO, PathMNIST, ChestMNIST, DermaMNIST, OCTMNIST, PneumoniaMNIST, RetinaMNIST, \
BreastMNIST, OrganMNISTAxial, OrganMNISTCoronal, OrganMNISTSagittal
from medmnist.evaluator import getAUC, getACC, save_results
def main(flag, input_root, output_root, end_epoch, download):
''' main function
:param flag: name of subset
'''
dataclass = {
"pathmnist": PathMNIST,
"chestmnist": ChestMNIST,
"dermamnist": DermaMNIST,
"octmnist": OCTMNIST,
"pneumoniamnist": PneumoniaMNIST,
"retinamnist": RetinaMNIST,
"breastmnist": BreastMNIST,
"organmnist_axial": OrganMNISTAxial,
"organmnist_coronal": OrganMNISTCoronal,
"organmnist_sagittal": OrganMNISTSagittal,
}
with open(INFO, 'r') as f:
info = json.load(f)
task = info[flag]['task']
n_channels = info[flag]['n_channels']
n_classes = len(info[flag]['label'])
start_epoch = 0
lr = 0.001
batch_size = 128
val_auc_list = []
dir_path = os.path.join(output_root, '%s_checkpoints' % (flag))
if not os.path.exists(dir_path):
os.makedirs(dir_path)
print('==> Preparing data...')
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.5], std=[.5])
])
val_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.5], std=[.5])
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.5], std=[.5])
])
train_dataset = dataclass[flag](root=input_root, split='train', transform=train_transform, download=download)
train_loader = data.DataLoader(
dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_dataset = dataclass[flag](root=input_root, split='val', transform=val_transform, download=download)
val_loader = data.DataLoader(
dataset=val_dataset, batch_size=batch_size, shuffle=True)
test_dataset = dataclass[flag](root=input_root, split='test', transform=test_transform, download=download)
test_loader = data.DataLoader(
dataset=test_dataset, batch_size=batch_size, shuffle=True)
print('==> Building and training model...')
print(torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ResNet18(in_channels=n_channels, num_classes=n_classes).to(device)
if task == "multi-label, binary-class":
criterion = nn.BCEWithLogitsLoss()
else:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
for epoch in trange(start_epoch, end_epoch):
train(model, optimizer, criterion, train_loader, device, task)
val(model, val_loader, device, val_auc_list, task, dir_path, epoch)
auc_list = np.array(val_auc_list)
index = auc_list.argmax()
print('epoch %s is the best model' % (index))
print('==> Testing model...')
restore_model_path = os.path.join(dir_path, 'ckpt_%d_auc_%.5f.pth' % (index, auc_list[index]))
model.load_state_dict(torch.load(restore_model_path)['net'])
test(model, 'train', train_loader, device, flag, task, output_root=output_root)
test(model, 'val', val_loader, device, flag, task, output_root=output_root)
test(model, 'test', test_loader, device, flag, task, output_root=output_root)
def train(model, optimizer, criterion, train_loader, device, task):
''' training function
:param model: the model to train
:param optimizer: optimizer used in training
:param criterion: loss function
:param train_loader: DataLoader of training set
:param device: cpu or cuda
:param task: task of current dataset, binary-class/multi-class/multi-label, binary-class
'''
model.train()
for batch_idx, (inputs, targets) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(inputs.to(device))
if task == 'multi-label, binary-class':
targets = targets.to(torch.float32).to(device)
loss = criterion(outputs, targets)
else:
targets = targets.squeeze().long().to(device)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
def val(model, val_loader, device, val_auc_list, task, dir_path, epoch):
''' validation function
:param model: the model to validate
:param val_loader: DataLoader of validation set
:param device: cpu or cuda
:param val_auc_list: the list to save AUC score of each epoch
:param task: task of current dataset, binary-class/multi-class/multi-label, binary-class
:param dir_path: where to save model
:param epoch: current epoch
'''
model.eval()
y_true = torch.tensor([]).to(device)
y_score = torch.tensor([]).to(device)
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(val_loader):
outputs = model(inputs.to(device))
if task == 'multi-label, binary-class':
targets = targets.to(torch.float32).to(device)
m = nn.Sigmoid()
outputs = m(outputs).to(device)
else:
targets = targets.squeeze().long().to(device)
m = nn.Softmax(dim=1)
outputs = m(outputs).to(device)
targets = targets.float().resize_(len(targets), 1)
y_true = torch.cat((y_true, targets), 0)
y_score = torch.cat((y_score, outputs), 0)
y_true = y_true.cpu().numpy()
y_score = y_score.detach().cpu().numpy()
auc = getAUC(y_true, y_score, task)
val_auc_list.append(auc)
state = {
'net': model.state_dict(),
'auc': auc,
'epoch': epoch,
}
path = os.path.join(dir_path, 'ckpt_%d_auc_%.5f.pth' % (epoch, auc))
torch.save(state, path)
def test(model, split, data_loader, device, flag, task, output_root=None):
''' testing function
:param model: the model to test
:param split: the data to test, 'train/val/test'
:param data_loader: DataLoader of data
:param device: cpu or cuda
:param flag: subset name
:param task: task of current dataset, binary-class/multi-class/multi-label, binary-class
'''
model.eval()
y_true = torch.tensor([]).to(device)
y_score = torch.tensor([]).to(device)
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(data_loader):
outputs = model(inputs.to(device))
if task == 'multi-label, binary-class':
targets = targets.to(torch.float32).to(device)
m = nn.Sigmoid()
outputs = m(outputs).to(device)
else:
targets = targets.squeeze().long().to(device)
m = nn.Softmax(dim=1)
outputs = m(outputs).to(device)
targets = targets.float().resize_(len(targets), 1)
y_true = torch.cat((y_true, targets), 0)
y_score = torch.cat((y_score, outputs), 0)
y_true = y_true.cpu().numpy()
y_score = y_score.detach().cpu().numpy()
auc = getAUC(y_true, y_score, task)
acc = getACC(y_true, y_score, task)
print('%s AUC: %.5f ACC: %.5f' % (split, auc, acc))
if output_root is not None:
output_dir = os.path.join(output_root, flag)
if not os.path.exists(output_dir):
os.mkdir(output_dir)
output_path = os.path.join(output_dir, '%s.csv' % (split))
save_results(y_true, y_score, output_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RUN Baseline model of MedMNIST')
#参数设定,数据集选取,输入路径,输出路径
"""
--data_name pathmnist
--input_root ./medmnist
--output_root ./output
"""
parser.add_argument('--data_name', default='pathmnist', help='subset of MedMNIST', type=str)
parser.add_argument('--input_root', default='./medmnist', help='input root, the source of dataset files', type=str)
parser.add_argument('--output_root', default='./output', help='output root, where to save models and results',
type=str)
parser.add_argument('--num_epoch', default=10, help='num of epochs of training', type=int)
parser.add_argument('--download', default=False, help='whether download the dataset or not', type=bool)
args = parser.parse_args()
data_name = args.data_name.lower()
input_root = args.input_root
output_root = args.output_root
end_epoch = args.num_epoch
download = args.download
main(data_name, input_root, output_root, end_epoch=end_epoch, download=download)
注意:执行main函数时需要配置参数:
模型评估AUC和ACC:evaluator.py
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
import numpy as np
import pandas as pd
def getAUC(y_true, y_score, task):
'''AUC metric.
:param y_true: the ground truth labels, shape: (n_samples, n_classes) for multi-label, and (n_samples,) for other tasks
:param y_score: the predicted score of each class, shape: (n_samples, n_classes)
:param task: the task of current dataset
'''
if task == 'binary-class':
y_score = y_score[:,-1]
return roc_auc_score(y_true, y_score)
elif task == 'multi-label, binary-class':
auc = 0
for i in range(y_score.shape[1]):
label_auc = roc_auc_score(y_true[:, i], y_score[:, i])
auc += label_auc
return auc / y_score.shape[1]
else:
auc = 0
zero = np.zeros_like(y_true)
one = np.ones_like(y_true)
for i in range(y_score.shape[1]):
y_true_binary = np.where(y_true == i, one, zero)
y_score_binary = y_score[:, i]
auc += roc_auc_score(y_true_binary, y_score_binary)
return auc / y_score.shape[1]
def getACC(y_true, y_score, task, threshold=0.5):
'''Accuracy metric.
:param y_true: the ground truth labels, shape: (n_samples, n_classes) for multi-label, and (n_samples,) for other tasks
:param y_score: the predicted score of each class, shape: (n_samples, n_classes)
:param task: the task of current dataset
:param threshold: the threshold for multilabel and binary-class tasks
'''
if task == 'multi-label, binary-class':
zero = np.zeros_like(y_score)
one = np.ones_like(y_score)
y_pre = np.where(y_score < threshold, zero, one)
acc = 0
for label in range(y_true.shape[1]):
label_acc = accuracy_score(y_true[:, label], y_pre[:, label])
acc += label_acc
return acc / y_true.shape[1]
elif task == 'binary-class':
y_pre = np.zeros_like(y_true)
for i in range(y_score.shape[0]):
y_pre[i] = (y_score[i][-1] > threshold)
return accuracy_score(y_true, y_pre)
else:
y_pre = np.zeros_like(y_true)
for i in range(y_score.shape[0]):
y_pre[i] = np.argmax(y_score[i])
return accuracy_score(y_true, y_pre)
def save_results(y_true, y_score, outputpath):
'''Save ground truth and scores
:param y_true: the ground truth labels, shape: (n_samples, n_classes) for multi-label, and (n_samples,) for other tasks
:param y_score: the predicted score of each class, shape: (n_samples, n_classes)
:param outputpath: path to save the result csv
'''
idx = []
idx.append('id')
for i in range(y_true.shape[1]):
idx.append('true_%s' % (i))
for i in range(y_score.shape[1]):
idx.append('score_%s' % (i))
df = pd.DataFrame(columns=idx)
for id in range(y_score.shape[0]):
dic = {}
dic['id'] = id
for i in range(y_true.shape[1]):
dic['true_%s' % (i)] = y_true[id][i]
for i in range(y_score.shape[1]):
dic['score_%s' % (i)] = y_score[id][i]
df_insert = pd.DataFrame(dic, index = [0])
df = df.append(df_insert, ignore_index=True)
df.to_csv(outputpath, sep=',', index=False, header=True, encoding="utf_8_sig")