一、说明
在不断发展的人工智能 (AI) 领域,医学成像是一个正在经历深刻变革的领域。乘着这一变革浪潮,Facebook 的(现为 Meta)研究小组开发了一种突破性的模型架构,称为 SegmentAnything (SAM)。SAM 的卓越之处在于它能够为图像中的不同对象生成分割掩模。这种自适应质量使其能够执行医学成像中的无数任务,从分割日常物体到照亮医学图像中的特定结构。
代码片段可以在我的 GitHub 页面上找到。
礼萨·卡兰塔 ( Reza Kalantar)设计的人偶
二、深入研究 SAM
SAM 针对特定医学成像任务的微调过程是一个多步骤的过程。这是一个细分:
- 数据加载和预处理:第一步包括处理医学成像数据,这些数据通常以 DICOM 或 NIfTI 等格式存储。这一步离不开 pydicom 或 nibabel 等库。然后对图像进行预处理,包括重新定向、标准化像素强度以及将图像和蒙版转换为模型友好格式等任务。
- 边界框提示创建:边界框提示是 SAM 分割的指路明灯。边界框的设计必须能够松散地封装您想要分割的结构。有趣的是,SAM 可以处理多个边界框,从而允许一次进行多对象分割。
- 模型和处理器准备:这涉及加载预先训练的 SAM 模型和相关处理器。后者负责准备模型的输入和提示。
- 模型微调:这一关键步骤需要运行训练循环、计算损失函数(模型输出与实际掩模的比较)、反向传播梯度以及更新模型的权重。
- 模型评估:模型训练完毕后,就可以在验证集上评估其性能,以衡量其在未见过的数据上的表现。Dice 系数或交并集 (IoU) 等指标在这里会派上用场。
- 推论:最后一步涉及使用经过训练的模型分割新的医学图像。此过程包括准备图像和边界框提示,将它们输入模型,并对输出进行后处理以产生最终的分割掩模。
在本文中,我将指导您使用 Goolge Colab 微调 SAM,以根据 CT 扫描分割肺部。我们还将介绍预处理医学图像并将其转换为 2D 切片的必要步骤。
三、从 Kaggle 数据集开始
首先,您需要安装 Kaggle 库:
!pip install -q kaggle
接下来,在根目录中创建一个名为“.kaggle”的目录:
!mkdir -p ~/.kaggle
然后,上传您的 Kaggle API 令牌,可从 Kaggle 网站获取:
from google.colab import files
files.upload() # 上传您的 Kaggle.json API 令牌
上传令牌后,将其放置在“.kaggle”目录中:
!cp kaggle.json ~/.kaggle/
现在,您已准备好下载数据集。在本教程中,我们将使用“finding-lungs-in-ct-data”数据集:
!kaggle datasets download -d kmader/finding-lungs-in-ct-data
最后,解压下载的数据集:
!unzip -q /content/finding-lungs-in-ct-data.zip
四、预处理数据
在开始处理数据之前,我们需要安装并导入一些必要的库。其中包括用于 PyTorch 中的医学图像处理和训练的 Monai 和 SimpleITK,以及用于 Transformer 的 HuggingFace 库:
!pip install -q monai
!pip install -q SimpleITK
!pip install -q git+https://github.com/huggingface/transformers.git
import os
import glob
import monai
import torch
import numpy as np
from PIL import Image
from tqdm import tqdm
import SimpleITK as sitk
from statistics import mean
from torch.optim import Adam
from natsort import natsorted
import matplotlib.pyplot as plt
from transformers import SamModel
import matplotlib.patches as patches
from transformers import SamProcessor
from IPython.display import clear_output
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import threshold, normalize
%matplotlib inline
from monai.transforms import (
EnsureChannelFirstd,
EnsureTyped,
Compose,
CropForegroundd,
CopyItemsd,
LoadImaged,
CenterSpatialCropd,
Invertd,
OneOf,
Orientationd,
MapTransform,
NormalizeIntensityd,
RandSpatialCropSamplesd,
CenterSpatialCropd,
RandSpatialCropd,
SpatialPadd,
ScaleIntensityRanged,
Spacingd,
RepeatChanneld,
ToTensord,
)
该数据集包括 4 名患者的 3D .nii.gz 体积和轮廓。我们首先将这些数据分成 2 个用于训练,1 个用于验证,1 个用于测试,并将 2D 轴向切片保存到相关目录:
data_dir = '/content/3d_images'
images = sorted(
glob.glob(os.path.join(data_dir, "IMG*.nii.gz")))
labels = sorted(
glob.glob(os.path.join(data_dir, "MASK*.nii.gz")))
print('No. of images:', len(images), ' labels:', len(labels))
No. of images: 4 labels: 4
base_dir = '/content'
datasets = ['train', 'val', 'test']
data_types = ['2d_images', '2d_masks']
# Create directories
dir_paths = {}
for dataset in datasets:
for data_type in data_types:
# Construct the directory path
dir_path = os.path.join(base_dir, f'{dataset}_{data_type}')
dir_paths[f'{dataset}_{data_type}'] = dir_path
# Create the directory
os.makedirs(dir_path, exist_ok=True)
# Assuming first 2 patients for training, next 1 for validation and last 1 for testing
for idx, (img_path, mask_path) in enumerate(zip(images, labels)):
# Load the 3D image and mask
img = sitk.ReadImage(img_path)
mask = sitk.ReadImage(mask_path)
print('processing patient', idx, img.GetSize(), mask.GetSize())
# Get the mask data as numpy array
mask_data = sitk.GetArrayFromImage(mask)
# Select appropriate directories
if idx < 2: # Training
img_dir = dir_paths['train_2d_images']
mask_dir = dir_paths['train_2d_masks']
elif idx == 2: # Validation
img_dir = dir_paths['val_2d_images']
mask_dir = dir_paths['val_2d_masks']
else: # Testing
img_dir = dir_paths['test_2d_images']
mask_dir = dir_paths['test_2d_masks']
# Iterate over the axial slices
for i in range(img.GetSize()[0]):
# If the mask slice is not empty, save the image and mask slices
if np.any(mask_data[i, :, :]):
# Prepare the new ITK images
img_slice = img[i, :, :]
mask_slice = mask[i, :, :]
# Define the output paths
img_slice_path = os.path.join(img_dir, f"{os.path.basename(img_path).replace('.nii.gz', '')}_{i}.nii.gz")
mask_slice_path = os.path.join(mask_dir, f"{os.path.basename(mask_path).replace('.nii.gz', '')}_{i}.nii.gz")
# Save the slices as NIfTI files
sitk.WriteImage(img_slice, img_slice_path)
sitk.WriteImage(mask_slice, mask_slice_path)
processing patient 0 (325, 512, 512) (325, 512, 512)
processing patient 1 (465, 512, 512) (465, 512, 512)
processing patient 2 (301, 512, 512) (301, 512, 512)
processing patient 3 (117, 512, 512) (117, 512, 512)
在给定的代码片段中,我们正在初始化一个字典,data_paths
来存储图像和标签文件的路径。
这涉及遍历每个数据集类别(训练、验证和测试)的目录,并且对于每个数据类型(图像和掩模),我们构建目录路径。然后我们将目录中以“.nii.gz”结尾的所有文件路径收集到一个列表中。
data_paths
每个列表都使用结合了数据集类型和数据类型的键存储在字典中。
# Initialize dictionary for storing image and label paths
data_paths = {}
# Create directories and print the number of images and masks in each
for dataset in datasets:
for data_type in data_types:
# Construct the directory path
dir_path = os.path.join(base_dir, f'{dataset}_{data_type}')
# Find images and labels in the directory
files = sorted(glob.glob(os.path.join(dir_path, "*.nii.gz")))
# Store the image and label paths in the dictionary
data_paths[f'{dataset}_{data_type.split("_")[1]}'] = files
print('Number of training images', len(data_paths['train_images']))
print('Number of validation images', len(data_paths['val_images']))
print('Number of test images', len(data_paths['test_images']))
Number of training images 655
Number of validation images 265
Number of test images 49
给定的代码片段正在创建用于图像预处理的实例SamProcessor
。SamProcessor
是 Hugging Face 变换器库的一部分,用于处理与带有注意力机制的序列到序列 (SAM) 模型一起使用的图像。
我们使用 Facebook 的预训练模型(特别是“sam-vit-base”模型)对其进行初始化。该处理器将用于适当地格式化我们的图像,以便输入到 SAM 模型中:
# create an instance of the processor for image preprocessing
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
print(processor)
SamProcessor:
- image_processor: SamImageProcessor {
"do_convert_rgb": true,
"do_normalize": true,
"do_pad": true,
"do_rescale": true,
"do_resize": true,
"image_mean": [
0.485,
0.456,
0.406
],
"image_processor_type": "SamImageProcessor",
"image_std": [
0.229,
0.224,
0.225
],
"pad_size": {
"height": 1024,
"width": 1024
},
"processor_class": "SamProcessor",
"resample": 2,
"rescale_factor": 0.00392156862745098,
"size": {
"longest_edge": 1024
}
}
该get_bounding_box
函数旨在为给定的分割图创建边界框坐标。这些坐标源自地图内已识别的轮廓,并使用随机选择的填充进行调整以实现可变性。如果不存在轮廓,则将边界框设置为图像大小。
该类SAMDataset
创建适合我们的应用程序的自定义数据集。它对我们的数据应用多种转换,例如加载图像、确保正确的方向、标准化强度以及将它们裁剪为特定大小。该类还通过将图像转换为预期的输入格式、生成边界框并排列分割掩码来准备图像并提示模型:
def get_bounding_box(ground_truth_map):
'''
This function creates varying bounding box coordinates based on the segmentation contours as prompt for the SAM model
The padding is random int values between 5 and 20 pixels
'''
if len(np.unique(ground_truth_map)) > 1:
# get bounding box from mask
y_indices, x_indices = np.where(ground_truth_map > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
# add perturbation to bounding box coordinates
H, W = ground_truth_map.shape
x_min = max(0, x_min - np.random.randint(5, 20))
x_max = min(W, x_max + np.random.randint(5, 20))
y_min = max(0, y_min - np.random.randint(5, 20))
y_max = min(H, y_max + np.random.randint(5, 20))
bbox = [x_min, y_min, x_max, y_max]
return bbox
else:
return [0, 0, 256, 256] # if there is no mask in the array, set bbox to image size
class SAMDataset(Dataset):
def __init__(self, image_paths, mask_paths, processor):
self.image_paths = image_paths
self.mask_paths = mask_paths
self.processor = processor
self.transforms = transforms = Compose([
# load .nii or .nii.gz files
LoadImaged(keys=['img', 'label']),
# add channel id to match PyTorch configurations
EnsureChannelFirstd(keys=['img', 'label']),
# reorient images for consistency and visualization
Orientationd(keys=['img', 'label'], axcodes='RA'),
# resample all training images to a fixed spacing
Spacingd(keys=['img', 'label'], pixdim=(1.5, 1.5), mode=("bilinear", "nearest")),
# rescale image and label dimensions to 256x256
CenterSpatialCropd(keys=['img', 'label'], roi_size=(256,256)),
# scale intensities to 0 and 255 to match the expected input intensity range
ScaleIntensityRanged(keys=['img'], a_min=-1000, a_max=2000,
b_min=0.0, b_max=255.0, clip=True),
ScaleIntensityRanged(keys=['label'], a_min=0, a_max=255,
b_min=0.0, b_max=1.0, clip=True),
SpatialPadd(keys=["img", "label"], spatial_size=(256,256))
# RepeatChanneld(keys=['img'], repeats=3, allow_missing_keys=True)
])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
mask_path = self.mask_paths[idx]
# create a dict of images and labels to apply Monai's dictionary transforms
data_dict = self.transforms({'img': image_path, 'label': mask_path})
# squeeze extra dimensions
image = data_dict['img'].squeeze()
ground_truth_mask = data_dict['label'].squeeze()
# convert to int type for huggingface's models expected inputs
image = image.astype(np.uint8)
# convert the grayscale array to RGB (3 channels)
array_rgb = np.dstack((image, image, image))
# convert to PIL image to match the expected input of processor
image_rgb = Image.fromarray(array_rgb)
# get bounding box prompt (returns xmin, ymin, xmax, ymax)
# in this dataset, the contours are -1 so we change them to 1 for label and 0 for background
ground_truth_mask[ground_truth_mask < 0] = 1
prompt = get_bounding_box(ground_truth_mask)
# prepare image and prompt for the model
inputs = self.processor(image_rgb, input_boxes=[[prompt]], return_tensors="pt")
# remove batch dimension which the processor adds by default
inputs = {k: v.squeeze(0) for k, v in inputs.items()}
# add ground truth segmentation (ground truth image size is 256x256)
inputs["ground_truth_mask"] = torch.from_numpy(ground_truth_mask.astype(np.int8))
return inputs
此代码为训练和验证数据集创建数据加载器。SAMDataset
使用图像和蒙版路径创建对象并进行SamProcessor
预处理。然后,PyTorch 的DataLoader
函数用于对数据进行批处理和洗牌,以便在训练期间将数据高效且随机地输入到模型中。
# create train and validation dataloaders
train_dataset = SAMDataset(image_paths=data_paths['train_images'], mask_paths=data_paths['train_masks'], processor=processor)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_dataset = SAMDataset(image_paths=data_paths['val_images'], mask_paths=data_paths['val_masks'], processor=processor)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True)
现在,我们可以可视化处理后的数据:
example = train_dataset[50]
for k,v in example.items():
print(k,v.shape)
xmin, ymin, xmax, ymax = get_bounding_box(example['ground_truth_mask'])
fig, axs = plt.subplots(1, 2)
axs[0].imshow(example['pixel_values'][1], cmap='gray')
axs[0].axis('off')
axs[1].imshow(example['ground_truth_mask'], cmap='copper')
# create a Rectangle patch for the bounding box
rect = patches.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, linewidth=1, edgecolor='r', facecolor='none')
# add the patch to the second Axes
axs[1].add_patch(rect)
axs[1].axis('off')
plt.tight_layout()
plt.show()
pixel_values torch.Size([3, 1024, 1024])
original_sizes torch.Size([2])
reshaped_input_sizes torch.Size([2])
input_boxes torch.Size([1, 4])
ground_truth_mask torch.Size([256, 256])
正面胸部 CT 样本和相应的肺面罩
五、训练
现在我们的数据加载器已准备就绪,我们可以开始配置模型进行微调。我们通过冻结预训练 SAM 模型中的编码器权重来保留它们:
# load the pretrained weights for finetuning
model = SamModel.from_pretrained("facebook/sam-vit-base")
# make sure we only compute gradients for mask decoder (encoder weights are frozen)
for name, param in model.named_parameters():
if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
param.requires_grad_(False)
最后,我们可以开始训练我们的模型:
# define training loop
num_epochs = 100
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# define optimizer
optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)
# define segmentation loss with sigmoid activation applied to predictions from the model
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
# track mean train and validation losses
mean_train_losses, mean_val_losses = [], []
# create an artibarily large starting validation loss value
best_val_loss = 100.0
best_val_epoch = 0
# set model to train mode for gradient updating
model.train()
for epoch in range(num_epochs):
# create temporary list to record training losses
epoch_losses = []
for i, batch in enumerate(tqdm(train_dataloader)):
# forward pass
outputs = model(pixel_values=batch["pixel_values"].to(device),
input_boxes=batch["input_boxes"].to(device),
multimask_output=False)
# compute loss
predicted_masks = outputs.pred_masks.squeeze(1)
ground_truth_masks = batch["ground_truth_mask"].float().to(device)
loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))
# backward pass (compute gradients of parameters w.r.t. loss)
optimizer.zero_grad()
loss.backward()
# optimize
optimizer.step()
epoch_losses.append(loss.item())
# visualize training predictions every 50 iterations
if i % 50 == 0:
# clear jupyter cell output
clear_output(wait=True)
fig, axs = plt.subplots(1, 3)
xmin, ymin, xmax, ymax = get_bounding_box(batch['ground_truth_mask'][0])
rect = patches.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, linewidth=1, edgecolor='r', facecolor='none')
axs[0].set_title('input image')
axs[0].imshow(batch["pixel_values"][0,1], cmap='gray')
axs[0].axis('off')
axs[1].set_title('ground truth mask')
axs[1].imshow(batch['ground_truth_mask'][0], cmap='copper')
axs[1].add_patch(rect)
axs[1].axis('off')
# apply sigmoid
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
medsam_seg_prob = medsam_seg_prob.detach().cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)
axs[2].set_title('predicted mask')
axs[2].imshow(medsam_seg, cmap='copper')
axs[2].axis('off')
plt.tight_layout()
plt.show()
# create temporary list to record validation losses
val_losses = []
# set model to eval mode for validation
with torch.no_grad():
for val_batch in tqdm(val_dataloader):
# forward pass
outputs = model(pixel_values=val_batch["pixel_values"].to(device),
input_boxes=val_batch["input_boxes"].to(device),
multimask_output=False)
# calculate val loss
predicted_val_masks = outputs.pred_masks.squeeze(1)
ground_truth_masks = batch["ground_truth_mask"].float().to(device)
val_loss = seg_loss(predicted_val_masks, ground_truth_masks.unsqueeze(1))
val_losses.append(val_loss.item())
# visualize the last validation prediction
fig, axs = plt.subplots(1, 3)
xmin, ymin, xmax, ymax = get_bounding_box(val_batch['ground_truth_mask'][0])
rect = patches.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin, linewidth=1, edgecolor='r', facecolor='none')
axs[0].set_title('input image')
axs[0].imshow(val_batch["pixel_values"][0,1], cmap='gray')
axs[0].axis('off')
axs[1].set_title('ground truth mask')
axs[1].imshow(val_batch['ground_truth_mask'][0], cmap='copper')
axs[1].add_patch(rect)
axs[1].axis('off')
# apply sigmoid
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
medsam_seg_prob = medsam_seg_prob.detach().cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)
axs[2].set_title('predicted mask')
axs[2].imshow(medsam_seg, cmap='copper')
axs[2].axis('off')
plt.tight_layout()
plt.show()
# save the best weights and record the best performing epoch
if mean(val_losses) < best_val_loss:
torch.save(model.state_dict(), f"best_weights.pth")
print(f"Model Was Saved! Current Best val loss {best_val_loss}")
best_val_loss = mean(val_losses)
best_val_epoch = epoch
else:
print("Model Was Not Saved!")
print(f'EPOCH: {epoch}')
print(f'Mean loss: {mean(epoch_losses)}')
mean_train_losses.append(mean(epoch_losses))
mean_val_losses.append(mean(val_losses))
训练期间显示分割进度
六、推理
训练完成后,我们使用最佳权重从测试数据中预测分割掩模:
# create test dataloader
test_dataset = SAMDataset(image_paths=data_paths['test_images'], mask_paths=data_paths['test_masks'], processor=processor)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
# Iteratire through test images
with torch.no_grad():
for batch in tqdm(test_dataloader):
# forward pass
outputs = model(pixel_values=batch["pixel_values"].cuda(),
input_boxes=batch["input_boxes"].cuda(),
multimask_output=False)
# compute loss
predicted_masks = outputs.pred_masks.squeeze(1)
ground_truth_masks = batch["ground_truth_mask"].float().cuda()
# loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))
# apply sigmoid
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# convert soft mask to hard mask
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)
plt.figure(figsize=(12,4))
plt.subplot(1,3,1)
plt.imshow(batch["pixel_values"][0,1], cmap='gray')
plt.axis('off')
plt.subplot(1,3,2)
plt.imshow(batch["ground_truth_mask"][0], cmap='copper')
plt.axis('off')
plt.subplot(1,3,3)
plt.imshow(medsam_seg, cmap='copper')
plt.axis('off')
plt.tight_layout()
plt.show()
测试微调 SAM 模型的预测
恭喜!您已成功使用边界框提示微调 SAM 模型,以在 CT 扫描中进行肺部分割。快乐编码!
参考资料:
请访问此处、Github或LinkedIn。礼萨·卡兰塔尔