小罗碎碎念
这期推文算是hover net
系列的一个补充文档,把几个非常重要的脚本拿出来单独做了一个分析,感兴趣的自取。
- extract_patches.py
- config.py
- dataset.py
- opt.py
- run_infer.py
一、extract_patches.py
1-1:加载和处理图像数据集
注意
dataset
属于自建函数,所以一定要保证这个文件与你的代码执行文件位于同一路径下,否则代码会报错。如果对于dataset
代码细节不清楚的,可以跳转至本文第三部分。
import re
import glob
import os
import tqdm
import pathlib
import numpy as np
from misc.patch_extractor import PatchExtractor
from misc.utils import rm_n_mkdir
from dataset import get_dataset
下面是对每一行的详细解释:
import re
: 导入正则表达式模块,用于字符串的搜索和匹配。import glob
: 导入glob模块,用于从目录中匹配符合特定规则的文件路径列表。import os
: 导入os模块,提供了一种方便的方式来使用操作系统相关的功能。import tqdm
: 导入tqdm模块,用于在循环中显示进度条,提高用户体验。import pathlib
: 导入pathlib模块,用于处理文件路径。import numpy as np
: 导入numpy模块,并将其重命名为np,用于进行科学计算。from misc.patch_extractor import PatchExtractor
: 从misc.patch_extractor
模块中导入PatchExtractor
类,用于提取图像中的小块(patch)。from misc.utils import rm_n_mkdir
: 从misc.utils
模块中导入rm_n_mkdir
函数,用于删除旧目录并创建新目录。from dataset import get_dataset
: 从dataset
模块中导入get_dataset
函数,用于获取图像数据集。
这个脚本的主要功能是加载和处理图像数据集,具体包括提取图像中的小块、删除旧目录并创建新目录等操作。
1-2:定义一个处理图像数据集的流程
if __name__ == "__main__":
# Determines whether to extract type map (only applicable to datasets with class labels).
type_classification = True
win_size = [540, 540]
step_size = [164, 164]
extract_type = "mirror" # Choose 'mirror' or 'valid'. 'mirror'- use padding at borders. 'valid'- only extract from valid regions.
# Name of dataset - use Kumar, CPM17 or CoNSeP.
# This used to get the specific dataset img and ann loading scheme from dataset.py
dataset_name = "consep"
save_root = "dataset/training_data/%s/" % dataset_name
# a dictionary to specify where the dataset path should be
dataset_info = {
"train": {
"img": (".png", "dataset/CoNSeP/Train/Images/"),
"ann": (".mat", "dataset/CoNSeP/Train/Labels/"),
},
"valid": {
"img": (".png", "dataset/CoNSeP/Test/Images/"),
"ann": (".mat", "dataset/CoNSeP/Test/Labels/"),
},
}
patterning = lambda x: re.sub("([\[\]])", "[\\1]", x)
parser = get_dataset(dataset_name)
xtractor = PatchExtractor(win_size, step_size)
for split_name, split_desc in dataset_info.items():
img_ext, img_dir = split_desc["img"]
ann_ext, ann_dir = split_desc["ann"]
out_dir = "%s/%s/%s/%dx%d_%dx%d/" % (
save_root,
dataset_name,
split_name,
win_size[0],
win_size[1],
step_size[0],
step_size[1],
)
file_list = glob.glob(patterning("%s/*%s" % (ann_dir, ann_ext)))
file_list.sort() # ensure same ordering across platform
rm_n_mkdir(out_dir)
pbar_format = "Process File: |{bar}| {n_fmt}/{total_fmt}[{elapsed}<{remaining},{rate_fmt}]"
pbarx = tqdm.tqdm(
total=len(file_list), bar_format=pbar_format, ascii=True, position=0
)
for file_idx, file_path in enumerate(file_list):
base_name = pathlib.Path(file_path).stem
img = parser.load_img("%s/%s%s" % (img_dir, base_name, img_ext))
ann = parser.load_ann(
"%s/%s%s" % (ann_dir, base_name, ann_ext), type_classification
)
# *
img = np.concatenate([img, ann], axis=-1)
sub_patches = xtractor.extract(img, extract_type)
pbar_format = "Extracting : |{bar}| {n_fmt}/{total_fmt}[{elapsed}<{remaining},{rate_fmt}]"
pbar = tqdm.tqdm(
total=len(sub_patches),
leave=False,
bar_format=pbar_format,
ascii=True,
position=1,
)
for idx, patch in enumerate(sub_patches):
np.save("{0}/{1}_{2:03d}.npy".format(out_dir, base_name, idx), patch)
pbar.update()
pbar.close()
# *
pbarx.update()
pbarx.close()
下面是对代码的详细解释:
if __name__ == "__main__":
: 这是一个Python常用的条件语句,用于检查当前脚本是否作为主程序运行。如果是,则执行下面的代码块。type_classification = True
: 设置一个标志,指示是否进行类型分类。这通常用于有类别标签的数据集。win_size = [540, 540]
: 定义提取图像块(patch)的窗口大小为540x540像素。step_size = [164, 164]
: 定义在图像上移动窗口以提取块时的步长为164x164像素。extract_type = "mirror"
: 定义提取块时的边界处理方式为"mirror",即使用镜像填充。另一个选项是"valid",只从有效区域提取。dataset_name = "consep"
: 设置数据集的名称为"consep"。save_root = "dataset/training_data/%s/" % dataset_name
: 定义保存提取块的根目录。dataset_info
: 一个字典,包含了数据集的路径信息,包括训练集和验证集的图像和标注路径。patterning = lambda x: re.sub("([\[\]])", "[\\1]", x)
: 定义一个函数,用于处理文件路径中的特殊字符,以便与glob
模块兼容。parser = get_dataset(dataset_name)
: 获取数据集的解析器,用于加载和解析图像和标注。xtractor = PatchExtractor(win_size, step_size)
: 创建一个PatchExtractor
实例,用于提取图像块。for split_name, split_desc in dataset_info.items():
: 遍历dataset_info
字典中的每个条目,每个条目代表数据集的一个分割(如训练集或验证集)。img_ext, img_dir = split_desc["img"]
: 从分割描述中获取图像的扩展名和目录。ann_ext, ann_dir = split_desc["ann"]
: 从分割描述中获取标注的扩展名和目录。out_dir = "%s/%s/%s/%dx%d_%dx%d/" % (...)
: 构建输出目录的路径,包括数据集名称、分割名称和图像块的大小和步长。file_list = glob.glob(patterning("%s/*%s" % (ann_dir, ann_ext)))
: 获取所有标注文件的列表。file_list.sort()
: 对文件列表进行排序,以确保在不同平台上的一致性。rm_n_mkdir(out_dir)
: 删除旧目录(如果存在)并创建新目录。pbarx = tqdm.tqdm(...)
: 创建一个进度条,用于显示处理文件的进度。for file_idx, file_path in enumerate(file_list):
: 遍历所有标注文件。base_name = pathlib.Path(file_path).stem
: 获取文件的基本名称(不包括路径和扩展名)。img = parser.load_img(...)
: 加载图像。ann = parser.load_ann(...)
: 加载标注。img = np.concatenate([img, ann], axis=-1)
: 将图像和标注沿最后一个轴(通常是通道轴)合并。sub_patches = xtractor.extract(img, extract_type)
: 使用PatchExtractor
提取图像块。pbar = tqdm.tqdm(...)
: 创建一个进度条,用于显示提取块的进度。for idx, patch in enumerate(sub_patches):
: 遍历所有提取的块。np.save("{0}/{1}_{2:03d}.npy".format(out_dir, base_name, idx), patch)
: 将每个块保存为.npy
文件。pbar.update()
: 更新进度条。pbar.close()
: 关闭进度条。pbarx.update()
: 更新文件处理进度条。pbarx.close()
: 关闭文件处理进度条。
整个脚本的目的是遍历一个图像数据集,提取图像块,并将这些块保存到磁盘上。这个过程中使用了进度条来提供反馈,确保用户知道脚本的运行状态。
二、config.py
2-1:导入库
import importlib
import random
import cv2
import numpy as np
from dataset import get_dataset
下面是对每一行的详细解释:
import importlib
: 导入importlib
模块,它提供了导入Python模块的函数。这允许在运行时动态导入模块,而不需要在编写代码时就知道所有模块的名字。import random
: 导入random
模块,它提供了生成随机数的函数。这些函数可以用于随机化数据集、打乱数据顺序等。import cv2
: 导入cv2
模块,即OpenCV库,它是一个强大的计算机视觉库,提供了很多图像处理和图像分析的函数。import numpy as np
: 导入numpy
模块,并将其重命名为np
。numpy
是一个用于科学计算的Python库,它提供了多维数组对象和一系列用于操作这些数组的函数。from dataset import get_dataset
: 从dataset
模块中导入get_dataset
函数。这个函数用于获取或创建一个数据集,以便进行图像处理或机器学习任务。
总的来说,这段代码导入了用于图像处理、计算机视觉和科学计算的常用模块和函数,以及一个特定于数据集的函数。这些工具将在后续的代码中用于处理和分析图像数据。
2-2:定义一个名为Config
的类
这个类的设计用于存储和管理配置信息,是为了在运行实验或训练模型时提供灵活的配置选项。
class Config(object):
"""Configuration file."""
def __init__(self):
self.seed = 10
self.logging = True
# turn on debug flag to trace some parallel processing problems more easily
self.debug = False
model_name = "hovernet"
model_mode = "original" # choose either `original` or `fast`
if model_mode not in ["original", "fast"]:
raise Exception("Must use either `original` or `fast` as model mode")
nr_type = 5 # number of nuclear types (including background)
# whether to predict the nuclear type, availability depending on dataset!
self.type_classification = True
# shape information -
# below config is for original mode.
# If original model mode is used, use [270,270] and [80,80] for act_shape and out_shape respectively
# If fast model mode is used, use [256,256] and [164,164] for act_shape and out_shape respectively
aug_shape = [540, 540] # patch shape used during augmentation (larger patch may have less border artefacts)
act_shape = [270, 270] # patch shape used as input to network - central crop performed after augmentation
out_shape = [80, 80] # patch shape at output of network
if model_mode == "original":
if act_shape != [270,270] or out_shape != [80,80]:
raise Exception("If using `original` mode, input shape must be [270,270] and output shape must be [80,80]")
if model_mode == "fast":
if act_shape != [256,256] or out_shape != [164,164]:
raise Exception("If using `fast` mode, input shape must be [256,256] and output shape must be [164,164]")
self.dataset_name = "consep" # extracts dataset info from dataset.py
self.log_dir = "logs/" # where checkpoints will be saved
# paths to training and validation patches
self.train_dir_list = [
"train_patches_path"
]
self.valid_dir_list = [
"valid_patches_path"
]
self.shape_info = {
"train": {"input_shape": act_shape, "mask_shape": out_shape,},
"valid": {"input_shape": act_shape, "mask_shape": out_shape,},
}
# * parsing config to the running state and set up associated variables
self.dataset = get_dataset(self.dataset_name)
module = importlib.import_module(
"models.%s.opt" % model_name
)
self.model_config = module.get_config(nr_type, model_mode)
下面是对代码的详细解释:
class Config(object):
: 定义一个名为Config
的类,继承自object
类。"""Configuration file."""
: 类的文档字符串,简要描述了这个类的作用。def __init__(self):
: 定义类的构造函数。self.seed = 10
: 设置一个种子值,用于初始化随机数生成器,确保实验的可重复性。self.logging = True
: 设置一个标志,指示是否启用日志记录。self.debug = False
: 设置一个标志,指示是否启用调试模式,用于追踪并行处理中的问题。model_name = "hovernet"
: 设置模型名称为"hovernet"。model_mode = "original"
: 设置模型模式为"original",另一个选项是"fast"。如果model_mode
不是"original"或"fast",则抛出异常,因为只接受这两个值。
nr_type = 5
: 设置核类型的数量,包括背景类型。self.type_classification = True
: 设置一个标志,指示是否进行类型分类。aug_shape = [540, 540]
: 设置数据增强时使用的图像块的大小。act_shape = [270, 270]
: 设置作为网络输入的图像块的大小。out_shape = [80, 80]
: 设置网络输出图像块的大小。- 根据模型模式检查输入和输出形状是否正确,如果不正确,则抛出异常。
self.dataset_name = "consep"
: 设置数据集名称。self.log_dir = "logs/"
: 设置日志文件的保存目录。self.train_dir_list
和self.valid_dir_list
: 设置训练和验证图像块目录的列表。self.shape_info
: 一个字典,包含了训练和验证的输入和输出形状信息。self.dataset = get_dataset(self.dataset_name)
: 获取数据集配置。module = importlib.import_module("models.%s.opt" % model_name)
: 动态导入与模型名称对应的配置模块。self.model_config = module.get_config(nr_type, model_mode)
: 从导入的模块中获取模型配置。
这个Config
类的实例将用于存储和访问配置信息,这些信息将在训练模型和处理数据时使用。通过这种方式,可以轻松地修改配置,而无需更改代码的其他部分。
三、dataset.py
3-1:导入常用的模块
import glob
import cv2
import numpy as np
import scipy.io as sio
下面是对每一行的详细解释:
import glob
: 导入glob
模块,它提供了一个函数用于从目录中匹配符合特定规则的文件路径列表。这对于批量处理文件非常有用。import cv2
: 导入cv2
模块,即OpenCV库,它是一个强大的计算机视觉库,提供了很多图像处理和图像分析的函数,如读取、显示、保存图像,图像滤波,特征检测等。import numpy as np
: 导入numpy
模块,并将其重命名为np
。numpy
是一个用于科学计算的Python库,它提供了多维数组对象和一系列用于操作这些数组的函数,对于图像处理和机器学习非常重要。import scipy.io as sio
: 导入scipy.io
模块,并将其重命名为sio
。scipy.io
是SciPy库的一部分,它提供了读写多种数据格式的函数,特别是matlab
格式的文件。这对于加载和处理MATLAB数据文件非常有用。
总的来说,这段代码导入了用于文件操作、图像处理、科学计算和数据分析的常用模块。这些工具将在后续的代码中用于处理和分析图像数据,以及加载和保存不同格式的数据文件。
3-2:__AbstractDataset
class __AbstractDataset(object):
"""Abstract class for interface of subsequent classes.
Main idea is to encapsulate how each dataset should parse
their images and annotations.
"""
def load_img(self, path):
raise NotImplementedError
def load_ann(self, path, with_type=False):
raise NotImplementedError
这段代码定义了一个名为__AbstractDataset
的抽象类,它旨在为后续的类提供一个接口。这个类的主要思想是将每个数据集应该如何解析它们的图像和标注封装起来。
通过定义抽象方法,它确保了所有继承自这个类的具体数据集类都实现了这些方法。
class __AbstractDataset(object):
: 定义一个名为__AbstractDataset
的类,继承自object
类。在Python中,类名前加上两个下划线表示这是一个特殊方法或属性,通常意味着它是不应该被直接实例化的抽象类。def load_img(self, path):
: 定义一个名为load_img
的抽象方法,它接受一个参数path
,表示图像文件的路径。raise NotImplementedError
: 在load_img
方法中,直接抛出NotImplementedError
异常。这表示这个方法必须在子类中被重写。def load_ann(self, path, with_type=False):
: 定义一个名为load_ann
的抽象方法,它接受两个参数path
和with_type
,分别表示标注文件的路径和一个标志,指示是否加载类型信息。raise NotImplementedError
: 在load_ann
方法中,直接抛出NotImplementedError
异常。这表示这个方法必须在子类中被重写。
这个抽象类的目的是为数据集的处理提供一个通用的接口。任何继承自__AbstractDataset
的类都需要实现load_img
和load_ann
方法,以确保它们能够正确地加载和处理图像和标注数据。这样,即使不同的数据集可能有不同的数据格式和解析方式,只要它们实现了这个接口,就可以使用相同的方法来加载和处理数据。
3-3:__Kumar
class __Kumar(__AbstractDataset):
"""Defines the Kumar dataset as originally introduced in:
Kumar, Neeraj, Ruchika Verma, Sanuj Sharma, Surabhi Bhargava, Abhishek Vahadane,
and Amit Sethi. "A dataset and a technique for generalized nuclear segmentation for
computational pathology." IEEE transactions on medical imaging 36, no. 7 (2017): 1550-1560.
"""
def load_img(self, path):
return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
def load_ann(self, path, with_type=False):
# assumes that ann is HxW
assert not with_type, "Not support"
ann_inst = sio.loadmat(path)["inst_map"]
ann_inst = ann_inst.astype("int32")
ann = np.expand_dims(ann_inst, -1)
return ann
这段代码定义了一个名为__Kumar
的类,它继承自__AbstractDataset
抽象类。__Kumar
类是专门为Kumar数据集设计的,该数据集最初是在2017年的IEEE Transactions on Medical Imaging期刊上介绍的,用于计算机病理学中的广义核分割。
class __Kumar(__AbstractDataset):
: 定义一个名为__Kumar
的类,继承自__AbstractDataset
。这意味着它必须实现__AbstractDataset
中的抽象方法load_img
和load_ann
。def load_img(self, path):
: 实现了load_img
方法,用于加载图像。return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
: 使用OpenCV的imread
函数读取图像,并使用cvtColor
函数将图像从BGR颜色空间转换为RGB颜色空间。def load_ann(self, path, with_type=False):
: 实现了load_ann
方法,用于加载标注。assert not with_type, "Not support"
: 断言with_type
为False
,因为该方法不支持加载类型信息。ann_inst = sio.loadmat(path)["inst_map"]
: 使用scipy.io
的loadmat
函数加载MATLAB格式的标注文件,并提取实例映射(inst_map
)。ann_inst = ann_inst.astype("int32")
: 将标注数据转换为int32
类型。ann = np.expand_dims(ann_inst, -1)
: 在标注数据的最后一个轴上增加一个维度,使其形状与图像数据兼容。return ann
: 返回处理后的标注数据。
__Kumar
类的目的是为Kumar数据集提供一个具体的实现,使其能够加载和处理图像和标注数据。这样,其他代码可以通过这个类来访问和利用Kumar数据集,而不必担心数据集的具体细节。
3-4:__CPM17
class __CPM17(__AbstractDataset):
"""Defines the CPM 2017 dataset as originally introduced in:
Vu, Quoc Dang, Simon Graham, Tahsin Kurc, Minh Nguyen Nhat To, Muhammad Shaban,
Talha Qaiser, Navid Alemi Koohbanani et al. "Methods for segmentation and classification
of digital microscopy tissue images." Frontiers in bioengineering and biotechnology 7 (2019).
"""
def load_img(self, path):
return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
def load_ann(self, path, with_type=False):
assert not with_type, "Not support"
# assumes that ann is HxW
ann_inst = sio.loadmat(path)["inst_map"]
ann_inst = ann_inst.astype("int32")
ann = np.expand_dims(ann_inst, -1)
return ann
这段代码定义了一个名为__CPM17
的类,它同样继承自__AbstractDataset
抽象类。
__CPM17
类是专门为CPM 2017数据集设计的,该数据集最初是在2019年的Frontiers in Bioengineering and Biotechnology期刊上介绍的,用于数字显微镜组织图像的分割和分类。
class __CPM17(__AbstractDataset):
: 定义一个名为__CPM17
的类,继承自__AbstractDataset
。这意味着它必须实现__AbstractDataset
中的抽象方法load_img
和load_ann
。def load_img(self, path):
: 实现了load_img
方法,用于加载图像。return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
: 使用OpenCV的imread
函数读取图像,并使用cvtColor
函数将图像从BGR颜色空间转换为RGB颜色空间。def load_ann(self, path, with_type=False):
: 实现了load_ann
方法,用于加载标注。assert not with_type, "Not support"
: 断言with_type
为False
,因为该方法不支持加载类型信息。ann_inst = sio.loadmat(path)["inst_map"]
: 使用scipy.io
的loadmat
函数加载MATLAB格式的标注文件,并提取实例映射(inst_map
)。ann_inst = ann_inst.astype("int32")
: 将标注数据转换为int32
类型。ann = np.expand_dims(ann_inst, -1)
: 在标注数据的最后一个轴上增加一个维度,使其形状与图像数据兼容。return ann
: 返回处理后的标注数据。
__CPM17
类的目的是为CPM 2017数据集提供一个具体的实现,使其能够加载和处理图像和标注数据。这样,其他代码可以通过这个类来访问和利用CPM 2017数据集,而不必担心数据集的具体细节。
3-5:__CoNSeP
class __CoNSeP(__AbstractDataset):
"""Defines the CoNSeP dataset as originally introduced in:
Graham, Simon, Quoc Dang Vu, Shan E. Ahmed Raza, Ayesha Azam, Yee Wah Tsang, Jin Tae Kwak,
and Nasir Rajpoot. "Hover-Net: Simultaneous segmentation and classification of nuclei in
multi-tissue histology images." Medical Image Analysis 58 (2019): 101563
"""
def load_img(self, path):
return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
def load_ann(self, path, with_type=False):
# assumes that ann is HxW
ann_inst = sio.loadmat(path)["inst_map"]
if with_type:
ann_type = sio.loadmat(path)["type_map"]
# merge classes for CoNSeP (in paper we only utilise 3 nuclei classes and background)
# If own dataset is used, then the below may need to be modified
ann_type[(ann_type == 3) | (ann_type == 4)] = 3
ann_type[(ann_type == 5) | (ann_type == 6) | (ann_type == 7)] = 4
ann = np.dstack([ann_inst, ann_type])
ann = ann.astype("int32")
else:
ann = np.expand_dims(ann_inst, -1)
ann = ann.astype("int32")
return ann
这段代码定义了一个名为__CoNSeP
的类,它同样继承自__AbstractDataset
抽象类。
__CoNSeP
类是专门为CoNSeP数据集设计的,该数据集最初是在2019年的Medical Image Analysis期刊上介绍的,用于多组织组织学图像中核的同步分割和分类。
class __CoNSeP(__AbstractDataset):
: 定义一个名为__CoNSeP
的类,继承自__AbstractDataset
。这意味着它必须实现__AbstractDataset
中的抽象方法load_img
和load_ann
。def load_img(self, path):
: 实现了load_img
方法,用于加载图像。return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
: 使用OpenCV的imread
函数读取图像,并使用cvtColor
函数将图像从BGR颜色空间转换为RGB颜色空间。def load_ann(self, path, with_type=False):
: 实现了load_ann
方法,用于加载标注。这个方法有一个参数with_type
,指示是否加载类型信息。ann_inst = sio.loadmat(path)["inst_map"]
: 使用scipy.io
的loadmat
函数加载MATLAB格式的标注文件,并提取实例映射(inst_map
)。if with_type:
: 如果with_type
为True
,则还加载类型映射。ann_type = sio.loadmat(path)["type_map"]
: 加载类型映射(type_map
)。ann_type[(ann_type == 3) | (ann_type == 4)] = 3
: 将类型映射中的一些类合并,这里将类别3和4合并为类别3。ann_type[(ann_type == 5) | (ann_type == 6) | (ann_type == 7)] = 4
: 将类别5、6和7合并为类别4。ann = np.dstack([ann_inst, ann_type])
: 将实例映射和类型映射合并成一个三维数组。ann = ann.astype("int32")
: 将标注数据转换为int32
类型。else:
: 如果with_type
为False
,则只加载实例映射。ann = np.expand_dims(ann_inst, -1)
: 在实例映射的最后一个轴上增加一个维度。ann = ann.astype("int32")
: 将标注数据转换为int32
类型。return ann
: 返回处理后的标注数据。
__CoNSeP
类的目的是为CoNSeP数据集提供一个具体的实现,使其能够加载和处理图像和标注数据。如果需要,它还可以加载类型信息,用于核的分类任务。这样,其他代码可以通过这个类来访问和利用CoNSeP数据集,而不必担心数据集的具体细节。
3-6:get_dataset
这段代码定义了一个名为get_dataset
的函数,用于根据提供的数据集名称返回一个预定义的数据集对象。
def get_dataset(name):
"""Return a pre-defined dataset object associated with `name`."""
name_dict = {
"kumar": lambda: __Kumar(),
"cpm17": lambda: __CPM17(),
"consep": lambda: __CoNSeP(),
}
if name.lower() in name_dict:
return name_dict[name]()
else:
assert False, "Unknown dataset `%s`" % name
这个函数通过一个字典来映射数据集名称到相应的数据集类,并创建该类的实例。
def get_dataset(name):
: 定义了一个名为get_dataset
的函数,它接受一个参数name
,表示要获取的数据集名称。name_dict = {
: 定义了一个字典name_dict
,用于存储数据集名称和对应的函数(即数据集类的实例化函数)。"kumar": lambda: __Kumar(),
: 在字典中,"kumar"
键映射到一个匿名函数,该函数返回一个__Kumar
类的实例。"cpm17": lambda: __CPM17(),
:"cpm17"
键映射到一个匿名函数,该函数返回一个__CPM17
类的实例。"consep": lambda: __CoNSeP(),
:"consep"
键映射到一个匿名函数,该函数返回一个__CoNSeP
类的实例。}
: 结束字典定义。if name.lower() in name_dict:
: 检查name
是否为字典name_dict
中的键,这里将名称转换为小写以进行比较。return name_dict[name]()
: 如果name
是字典中的键,则返回对应的函数(即数据集类的实例化函数)的返回值,即数据集类的实例。else:
: 如果name
不是字典中的键,则执行以下代码。assert False, "Unknown dataset
%s"
: 抛出一个断言错误,错误信息为“Unknown dataset%s
”,其中%s
会被name
的值替换。
这个get_dataset
函数的作用是提供一个统一的方法来获取不同数据集的实例,而不需要知道具体的数据集类名。它通过字典映射,将数据集名称转换为相应的数据集类实例。如果提供的数据集名称不在字典中,则抛出一个错误。
四、opt.py
4-1:配置和初始化
import torch.optim as optim
from run_utils.callbacks.base import (
AccumulateRawOutput,
PeriodicSaver,
ProcessAccumulatedRawOutput,
ScalarMovingAverage,
ScheduleLr,
TrackLr,
VisualizeOutput,
TriggerEngine,
)
from run_utils.callbacks.logging import LoggingEpochOutput, LoggingGradient
from run_utils.engine import Events
from .targets import gen_targets, prep_sample
from .net_desc import create_model
from .run_desc import proc_valid_step_output, train_step, valid_step, viz_step_output
这段代码是Python脚本的一部分,主要用于配置和初始化一个深度学习训练流程。它导入了必要的模块和类,并定义了一些函数,这些函数在训练过程中被用来处理数据、构建模型、执行训练和验证步骤等。
import torch.optim as optim
: 导入torch.optim
模块,它是PyTorch库的一部分,提供了一系列优化算法,如SGD、Adam、RMSprop等。from run_utils.callbacks.base import (
: 导入run_utils.callbacks.base
模块中的多个类,这些类提供了各种回调功能,用于在训练过程中进行监控、保存、可视化等操作。AccumulateRawOutput,
: 导入AccumulateRawOutput
类,用于累积和处理原始输出。PeriodicSaver,
: 导入PeriodicSaver
类,用于定期保存模型和日志。ProcessAccumulatedRawOutput,
: 导入ProcessAccumulatedRawOutput
类,用于处理累积的原始输出。ScalarMovingAverage,
: 导入ScalarMovingAverage
类,用于计算标量值的移动平均。ScheduleLr,
: 导入ScheduleLr
类,用于动态调整学习率。TrackLr,
: 导入TrackLr
类,用于跟踪学习率。VisualizeOutput,
: 导入VisualizeOutput
类,用于可视化输出。TriggerEngine,
: 导入TriggerEngine
类,用于触发操作。
from run_utils.callbacks.logging import LoggingEpochOutput, LoggingGradient
: 导入run_utils.callbacks.logging
模块中的两个类,用于日志记录。LoggingEpochOutput,
: 导入LoggingEpochOutput
类,用于记录每个epoch的输出。LoggingGradient,
: 导入LoggingGradient
类,用于记录梯度信息。from run_utils.engine import Events
: 导入run_utils.engine
模块中的Events
类,它定义了训练过程中的一些事件,如开始、结束、保存等。from .targets import gen_targets, prep_sample
: 导入targets
模块中的两个函数,用于生成目标和预处理样本。from .net_desc import create_model
: 导入net_desc
模块中的create_model
函数,用于创建模型。from .run_desc import proc_valid_step_output, train_step, valid_step, viz_step_output
: 导入run_desc
模块中的四个函数,分别用于处理验证步骤的输出、执行训练步骤、执行验证步骤和可视化步骤的输出。
这个脚本通过导入必要的模块和类,以及定义相关的函数,为深度学习训练流程提供了必要的组件和工具。这些组件和工具将在后续的代码中用于构建模型、处理数据、执行训练和验证步骤等。
4-2:训练和验证过程中所需的各种参数和设置
def get_config(nr_type, mode):
return {
# ------------------------------------------------------------------
# ! All phases have the same number of run engine
# phases are run sequentially from index 0 to N
"phase_list": [
{
"run_info": {
# may need more dynamic for each network
"net": {
"desc": lambda: create_model(
input_ch=3, nr_types=nr_type,
freeze=True, mode=mode
),
"optimizer": [
optim.Adam,
{ # should match keyword for parameters within the optimizer
"lr": 1.0e-4, # initial learning rate,
"betas": (0.9, 0.999),
},
],
# learning rate scheduler
"lr_scheduler": lambda x: optim.lr_scheduler.StepLR(x, 25),
"extra_info": {
"loss": {
"np": {"bce": 1, "dice": 1},
"hv": {"mse": 1, "msge": 1},
"tp": {"bce": 1, "dice": 1},
},
},
# path to load, -1 to auto load checkpoint from previous phase,
# None to start from scratch
"pretrained": "../pretrained/ImageNet-ResNet50-Preact_pytorch.tar",
# 'pretrained': None,
},
},
"target_info": {"gen": (gen_targets, {}), "viz": (prep_sample, {})},
"batch_size": {"train": 16, "valid": 16,}, # engine name : value
"nr_epochs": 50,
},
{
"run_info": {
# may need more dynamic for each network
"net": {
"desc": lambda: create_model(
input_ch=3, nr_types=nr_type,
freeze=False, mode=mode
),
"optimizer": [
optim.Adam,
{ # should match keyword for parameters within the optimizer
"lr": 1.0e-4, # initial learning rate,
"betas": (0.9, 0.999),
},
],
# learning rate scheduler
"lr_scheduler": lambda x: optim.lr_scheduler.StepLR(x, 25),
"extra_info": {
"loss": {
"np": {"bce": 1, "dice": 1},
"hv": {"mse": 1, "msge": 1},
"tp": {"bce": 1, "dice": 1},
},
},
# path to load, -1 to auto load checkpoint from previous phase,
# None to start from scratch
"pretrained": -1,
},
},
"target_info": {"gen": (gen_targets, {}), "viz": (prep_sample, {})},
"batch_size": {"train": 4, "valid": 8,}, # batch size per gpu
"nr_epochs": 50,
},
],
# ------------------------------------------------------------------
# TODO: dynamically for dataset plugin selection and processing also?
# all enclosed engine shares the same neural networks
# as the on at the outer calling it
"run_engine": {
"train": {
# TODO: align here, file path or what? what about CV?
"dataset": "", # whats about compound dataset ?
"nr_procs": 16, # number of threads for dataloader
"run_step": train_step, # TODO: function name or function variable ?
"reset_per_run": False,
# callbacks are run according to the list order of the event
"callbacks": {
Events.STEP_COMPLETED: [
# LoggingGradient(), # TODO: very slow, may be due to back forth of tensor/numpy ?
ScalarMovingAverage(),
],
Events.EPOCH_COMPLETED: [
TrackLr(),
PeriodicSaver(),
VisualizeOutput(viz_step_output),
LoggingEpochOutput(),
TriggerEngine("valid"),
ScheduleLr(),
],
},
},
"valid": {
"dataset": "", # whats about compound dataset ?
"nr_procs": 8, # number of threads for dataloader
"run_step": valid_step,
"reset_per_run": True, # * to stop aggregating output etc. from last run
# callbacks are run according to the list order of the event
"callbacks": {
Events.STEP_COMPLETED: [AccumulateRawOutput(),],
Events.EPOCH_COMPLETED: [
# TODO: is there way to preload these ?
ProcessAccumulatedRawOutput(
lambda a: proc_valid_step_output(a, nr_types=nr_type)
),
LoggingEpochOutput(),
],
},
},
},
}
这段代码定义了一个名为get_config
的函数,用于生成一个配置字典,该字典包含了训练和验证过程中所需的各种参数和设置。这个配置字典将用于定义神经网络、优化器、学习率调度器、损失函数、回调函数等,以指导整个训练过程。
def get_config(nr_type, mode):
: 定义了一个名为get_config
的函数,它接受两个参数nr_type
和mode
。nr_type
表示核类型的数量(包括背景),而mode
表示模型的模式(如"original"或"fast")。return {
: 开始返回一个字典,该字典包含了所有配置信息。# ! All phases have the same number of run engine
: 注释,说明所有阶段都使用相同的运行引擎。# phases are run sequentially from index 0 to N
: 注释,说明阶段是按顺序从索引0到N依次运行的。"phase_list": [
: 定义了一个名为phase_list
的列表,用于存储各个阶段的配置信息。
-
{
: 开始第一个阶段的配置字典。"run_info": {
: 定义了一个名为run_info
的字典,包含了与运行相关的信息。"net": {
: 在run_info
字典中,定义了一个名为net
的字典,包含了神经网络的配置信息。"desc": lambda: create_model(...),
: 定义了一个匿名函数,用于创建神经网络模型。该函数使用了create_model
函数,它根据输入通道数、核类型数、是否冻结权重和模型模式来创建模型。"optimizer": [...],
: 定义了优化器的配置,包括优化器类型和参数。"lr_scheduler": lambda x: optim.lr_scheduler.StepLR(x, 25),
: 定义了学习率调度器的配置,使用StepLR
调度器,每25个epoch调整一次学习率。"extra_info": {...},
: 定义了一些额外的信息,如损失函数的配置。"pretrained": "../pretrained/ImageNet-ResNet50-Preact_pytorch.tar",
: 定义了预训练模型的路径,用于初始化权重。},
: 结束net
字典的定义。
-
"target_info": {"gen": (gen_targets, {}), "viz": (prep_sample, {})},
: 定义了目标信息的配置,包括生成目标和预处理样本的函数和参数。 -
"batch_size": {"train": 16, "valid": 16,},
: 定义了训练和验证的批量大小。 -
"nr_epochs": 50,
: 定义了训练的epoch数量。 -
},
: 结束第一个阶段的配置字典。
{
: 开始第二个阶段的配置字典。"run_info": {...},
: 重复第一个阶段的run_info
字典的定义,但有一些变化,如不冻结权重和调整批量大小。},
: 结束第二个阶段的配置字典。],
: 结束phase_list
列表的定义。
"run_engine": {
: 定义了一个名为run_engine
的字典,包含了运行引擎的配置信息。"train": {
: 在run_engine
字典中,定义了一个名为train
的字典,包含了训练引擎的配置信息。"dataset": "",
: 定义了训练数据集的路径,但这里留空。"nr_procs": 16,
: 定义了数据加载器使用的线程数量。"run_step": train_step,
: 定义了run_step
,即执行训练步骤的函数,为train_step
。"reset_per_run": False,
: 定义了一个标志,指示在每个运行周期中是否重置数据加载器。"callbacks": {
: 定义了回调函数的配置,这些回调函数将在特定的事件发生时执行。Events.STEP_COMPLETED: [...],
: 定义了在训练步骤完成后执行的回调函数列表,这里使用ScalarMovingAverage
。Events.EPOCH_COMPLETED: [...],
: 定义了在每个epoch完成后执行的回调函数列表,包括TrackLr
、PeriodicSaver
、VisualizeOutput
、LoggingEpochOutput
、TriggerEngine
(触发验证)和ScheduleLr
。},
: 结束callbacks
字典的定义。
},
: 结束train
字典的定义。
"valid": {
: 在run_engine
字典中,定义了一个名为valid
的字典,包含了验证引擎的配置信息。"dataset": "",
: 定义了验证数据集的路径,但这里留空。"nr_procs": 8,
: 定义了数据加载器使用的线程数量。"run_step": valid_step,
: 定义了执行验证步骤的函数,为valid_step
。"reset_per_run": True,
: 定义了一个标志,指示在每个运行周期中是否重置数据加载器。"callbacks": {
: 定义了回调函数的配置,这些回调函数将在特定的事件发生时执行。Events.STEP_COMPLETED: [...],
: 定义了在验证步骤完成后执行的回调函数列表,这里使用AccumulateRawOutput
。Events.EPOCH_COMPLETED: [...],
: 定义了在每个epoch完成后执行的回调函数列表,包括ProcessAccumulatedRawOutput
和LoggingEpochOutput
。
},
: 结束callbacks
字典的定义。},
: 结束valid
字典的定义。},
: 结束run_engine
字典的定义。}
: 结束返回的字典的定义。
这个get_config
函数的作用是为训练和验证过程生成一个详细的配置字典。这个字典包含了神经网络、优化器、学习率调度器、损失函数、回调函数等所有必要的信息,以指导整个训练过程。通过这个函数,可以方便地配置和定制训练过程,以适应不同的数据集和模型。
五、run_infer.py
5-1:定义脚本的用法、选项和命令模式
"""run_infer.py
Usage:
run_infer.py [options] [--help] <command> [<args>...]
run_infer.py --version
run_infer.py (-h | --help)
Options:
-h --help Show this string.
--version Show version.
--gpu=<id> GPU list. [default: 0]
--nr_types=<n> Number of nuclei types to predict. [default: 0]
--type_info_path=<path> Path to a json define mapping between type id, type name,
and expected overlaid color. [default: '']
--model_path=<path> Path to saved checkpoint.
--model_mode=<mode> Original HoVer-Net or the reduced version used PanNuke and MoNuSAC,
'original' or 'fast'. [default: fast]
--nr_inference_workers=<n> Number of workers during inference. [default: 8]
--nr_post_proc_workers=<n> Number of workers during post-processing. [default: 16]
--batch_size=<n> Batch size per 1 GPU. [default: 32]
Two command mode are `tile` and `wsi` to enter corresponding inference mode
tile run the inference on tile
wsi run the inference on wsi
Use `run_infer.py <command> --help` to show their options and usage.
"""
这段代码是Python脚本run_infer.py
的文档字符串,它定义了脚本的用法、选项和命令模式。
"""run_infer.py
: 文档字符串的开始。Usage:
: 显示脚本的用法。run_infer.py [options] [--help] <command> [<args>...]
: 第一个用法,展示了如何运行脚本,包括可用的选项和命令。run_infer.py --version
: 显示脚本的版本。run_infer.py (-h | --help)
: 显示脚本的帮助信息。
Options:
: 列出可用的选项。-h --help Show this string.
: 显示帮助信息。--version Show version.
: 显示脚本的版本。--gpu=<id> GPU list. [default: 0]
: 指定用于推理的GPU ID。--nr_types=<n> Number of nuclei types to predict. [default: 0]
: 指定要预测的核类型数量。--type_info_path=<path> Path to a json define mapping between type id, type name,
: 指定一个JSON文件路径,该文件定义了类型ID、类型名称和预期覆盖的颜色之间的映射。--model_path=<path> Path to saved checkpoint.
: 指定保存的检查点(checkpoint)文件的路径。--model_mode=<mode> Original HoVer-Net or the reduced version used PanNuke and MoNuSAC,
: 指定模型模式,可以是原始的HoVer-Net或用于PanNuke和MoNuSAC的简化版本。--nr_inference_workers=<n> Number of workers during inference. [default: 8]
: 指定推理过程中使用的工人数量。--nr_post_proc_workers=<n> Number of workers during post-processing. [default: 16]
: 指定后处理过程中使用的工人数量。--batch_size=<n> Batch size per 1 GPU. [default: 32]
: 指定每个GPU的批量大小。
- Two command mode are
tile
andwsi
to enter corresponding inference mode: 描述了两种命令模式,tile
和wsi
,用于进入相应的推理模式。tile run the inference on tile
: 描述了tile
模式,用于在切片上运行推理。wsi run the inference on wsi
: 描述了wsi
模式,用于在WSI(Whole Slide Image,全切片图像)上运行推理。
- Use
run_infer.py <command> --help
to show their options and usage.: 提供了如何显示每个命令选项和用法的说明。
这个文档字符串为用户提供了运行脚本的详细指导,包括可用的选项、命令模式以及如何获取每个命令的帮助信息。
5-2:描述处理切片(tiles)的参数和选项
tile_cli = """
Arguments for processing tiles.
usage:
tile (--input_dir=<path>) (--output_dir=<path>) \
[--draw_dot] [--save_qupath] [--save_raw_map] [--mem_usage=<n>]
options:
--input_dir=<path> Path to input data directory. Assumes the files are not nested within directory.
--output_dir=<path> Path to output directory..
--mem_usage=<n> Declare how much memory (physical + swap) should be used for caching.
By default it will load as many tiles as possible till reaching the
declared limit. [default: 0.2]
--draw_dot To draw nuclei centroid on overlay. [default: False]
--save_qupath To optionally output QuPath v0.2.3 compatible format. [default: False]
--save_raw_map To save raw prediction or not. [default: False]
"""
这段代码是一个命令行接口(CLI)的文档字符串,用于描述处理切片(tiles)的参数和选项,用于处理和分析切片上的核(nuclei)。
tile_cli = """
: 文档字符串的开始。Arguments for processing tiles.
: 描述了这段文档字符串的主题,即处理切片的参数。usage:
: 显示命令的用法。tile (--input_dir=<path>) (--output_dir=<path>) [--draw_dot] [--save_qupath] [--save_raw_map] [--mem_usage=<n>]
: 展示了如何使用命令,包括可选的参数和它们的默认值。
options:
: 列出可用的选项。--input_dir=<path> Path to input data directory. Assumes the files are not nested within directory.
: 指定输入数据目录的路径。假设文件不在目录中嵌套。--output_dir=<path> Path to output directory..
: 指定输出目录的路径。--mem_usage=<n> Declare how much memory (physical + swap) should be used for caching.
: 指定用于缓存的内存量(物理内存 + 交换内存)。默认情况下,它将加载尽可能多的切片,直到达到声明的极限。--draw_dot To draw nuclei centroid on overlay. [default: False]
: 指定是否在叠加图上绘制核质心。--save_qupath To optionally output QuPath v0.2.3 compatible format. [default: False]
: 指定是否以QuPath v0.2.3兼容的格式输出。--save_raw_map To save raw prediction or not. [default: False]
: 指定是否保存原始预测。
这个文档字符串为用户提供了使用tile
命令行工具处理切片的详细指导,包括可用的选项、它们的默认值以及如何指定输入和输出目录。通过这个文档字符串,用户可以了解如何使用这个工具来处理医学图像数据。
5-3:描述处理全切片图像的参数和选项
wsi_cli = """
Arguments for processing wsi
usage:
wsi (--input_dir=<path>) (--output_dir=<path>) [--proc_mag=<n>]\
[--cache_path=<path>] [--input_mask_dir=<path>] \
[--ambiguous_size=<n>] [--chunk_shape=<n>] [--tile_shape=<n>] \
[--save_thumb] [--save_mask]
options:
--input_dir=<path> Path to input data directory. Assumes the files are not nested within directory.
--output_dir=<path> Path to output directory.
--cache_path=<path> Path for cache. Should be placed on SSD with at least 100GB. [default: cache]
--mask_dir=<path> Path to directory containing tissue masks.
Should have the same name as corresponding WSIs. [default: '']
--proc_mag=<n> Magnification level (objective power) used for WSI processing. [default: 40]
--ambiguous_size=<int> Define ambiguous region along tiling grid to perform re-post processing. [default: 128]
--chunk_shape=<n> Shape of chunk for processing. [default: 10000]
--tile_shape=<n> Shape of tiles for processing. [default: 2048]
--save_thumb To save thumb. [default: False]
--save_mask To save mask. [default: False]
"""
这段代码是一个命令行接口(CLI)的文档字符串,用于描述处理全切片图像(WSI,Whole Slide Image)的参数和选项,用于处理和分析全切片图像上的核(nuclei)。
wsi_cli = """
: 文档字符串的开始。Arguments for processing wsi
: 描述了这段文档字符串的主题,即处理全切片图像的参数。usage:
: 显示命令的用法。wsi (--input_dir=<path>) (--output_dir=<path>) [--proc_mag=<n>] [--cache_path=<path>] [--input_mask_dir=<path>] [--ambiguous_size=<n>] [--chunk_shape=<n>] [--tile_shape=<n>] [--save_thumb] [--save_mask]
: 展示了如何使用命令,包括可选的参数和它们的默认值。
options:
: 列出可用的选项。--input_dir=<path> Path to input data directory. Assumes the files are not nested within directory.
: 指定输入数据目录的路径。假设文件不在目录中嵌套。--output_dir=<path> Path to output directory.
: 指定输出目录的路径。--cache_path=<path> Path for cache. Should be placed on SSD with at least 100GB. [default: cache]
: 指定用于缓存的路径。应该放在至少100GB的SSD上。--mask_dir=<path> Path to directory containing tissue masks. Should have the same name as corresponding WSIs. [default: '']
: 指定包含组织掩码的目录路径。应该与相应的WSI具有相同的名称。--proc_mag=<n> Magnification level (objective power) used for WSI processing. [default: 40]
: 指定用于WSI处理的放大倍数(目标功率)。--ambiguous_size=<int> Define ambiguous region along tiling grid to perform re-post processing. [default: 128]
: 指定沿切分网格的模糊区域的大小,用于重新进行后处理。--chunk_shape=<n> Shape of chunk for processing. [default: 10000]
: 指定用于处理的块的形状。--tile_shape=<n> Shape of tiles for processing. [default: 2048]
: 指定用于处理的切片的形状。--save_thumb To save thumb. [default: False]
: 指定是否保存缩略图。--save_mask To save mask. [default: False]
: 指定是否保存掩码。
这个文档字符串为用户提供了使用wsi
命令行工具处理全切片图像的详细指导,包括可用的选项、它们的默认值以及如何指定输入和输出目录。通过这个文档字符串,用户可以了解如何使用这个工具来处理医学图像数据。
5-4:导入常用的模块和函数
import torch
import logging
import os
import copy
from misc.utils import log_info
from docopt import docopt
这段代码是一个Python脚本的导入部分,它导入了一些常用的模块和函数,以便在后续的代码中使用。
import torch
: 导入torch
模块,它是PyTorch库的一部分,提供了张量(Tensor)操作、自动求导、深度神经网络等。import logging
: 导入logging
模块,它提供了日志记录的功能。import os
: 导入os
模块,它提供了与操作系统交互的功能,如文件操作、环境变量等。import copy
: 导入copy
模块,它提供了复制Python对象的功能。from misc.utils import log_info
: 从misc.utils
模块中导入log_info
函数。这个函数可能用于日志记录。from docopt import docopt
: 从docopt
模块中导入docopt
函数。docopt
是一个命令行解析器,它可以根据命令行参数生成一个包含命令行选项和参数的字典。
总的来说,这段代码导入了用于科学计算、日志记录、文件操作、对象复制和命令行解析的常用模块和函数。这些工具将在后续的代码中用于处理和分析数据、记录日志、操作文件以及解析命令行参数。
5-5:定义处理全切片图像(WSI)或切片(Tile)的推理流程
if __name__ == '__main__':
sub_cli_dict = {'tile' : tile_cli, 'wsi' : wsi_cli}
args = docopt(__doc__, help=False, options_first=True,
version='HoVer-Net Pytorch Inference v1.0')
sub_cmd = args.pop('<command>')
sub_cmd_args = args.pop('<args>')
# ! TODO: where to save logging
logging.basicConfig(
level=logging.INFO,
format='|%(asctime)s.%(msecs)03d| [%(levelname)s] %(message)s',datefmt='%Y-%m-%d|%H:%M:%S',
handlers=[
logging.FileHandler("debug.log"),
logging.StreamHandler()
]
)
if args['--help'] and sub_cmd is not None:
if sub_cmd in sub_cli_dict:
print(sub_cli_dict[sub_cmd])
else:
print(__doc__)
exit()
if args['--help'] or sub_cmd is None:
print(__doc__)
exit()
sub_args = docopt(sub_cli_dict[sub_cmd], argv=sub_cmd_args, help=True)
args.pop('--version')
gpu_list = args.pop('--gpu')
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
nr_gpus = torch.cuda.device_count()
log_info('Detect #GPUS: %d' % nr_gpus)
args = {k.replace('--', '') : v for k, v in args.items()}
sub_args = {k.replace('--', '') : v for k, v in sub_args.items()}
if args['model_path'] == None:
raise Exception('A model path must be supplied as an argument with --model_path.')
nr_types = int(args['nr_types']) if int(args['nr_types']) > 0 else None
method_args = {
'method' : {
'model_args' : {
'nr_types' : nr_types,
'mode' : args['model_mode'],
},
'model_path' : args['model_path'],
},
'type_info_path' : None if args['type_info_path'] == '' \
else args['type_info_path'],
}
# ***
run_args = {
'batch_size' : int(args['batch_size']) * nr_gpus,
'nr_inference_workers' : int(args['nr_inference_workers']),
'nr_post_proc_workers' : int(args['nr_post_proc_workers']),
}
if args['model_mode'] == 'fast':
run_args['patch_input_shape'] = 256
run_args['patch_output_shape'] = 164
else:
run_args['patch_input_shape'] = 270
run_args['patch_output_shape'] = 80
if sub_cmd == 'tile':
run_args.update({
'input_dir' : sub_args['input_dir'],
'output_dir' : sub_args['output_dir'],
'mem_usage' : float(sub_args['mem_usage']),
'draw_dot' : sub_args['draw_dot'],
'save_qupath' : sub_args['save_qupath'],
'save_raw_map': sub_args['save_raw_map'],
})
if sub_cmd == 'wsi':
run_args.update({
'input_dir' : sub_args['input_dir'],
'output_dir' : sub_args['output_dir'],
'input_mask_dir' : sub_args['input_mask_dir'],
'cache_path' : sub_args['cache_path'],
'proc_mag' : int(sub_args['proc_mag']),
'ambiguous_size' : int(sub_args['ambiguous_size']),
'chunk_shape' : int(sub_args['chunk_shape']),
'tile_shape' : int(sub_args['tile_shape']),
'save_thumb' : sub_args['save_thumb'],
'save_mask' : sub_args['save_mask'],
})
# ***
if sub_cmd == 'tile':
from infer.tile import InferManager
infer = InferManager(**method_args)
infer.process_file_list(run_args)
else:
from infer.wsi import InferManager
infer = InferManager(**method_args)
infer.process_wsi_list(run_args)
这段代码是一个Python脚本的主体部分,它定义了一个处理全切片图像(WSI)或切片(Tile)的推理流程。
if __name__ == '__main__':
: 这是一个Python常用的条件语句,用于检查当前脚本是否作为主程序运行。如果是,则执行下面的代码块。sub_cli_dict = {'tile' : tile_cli, 'wsi' : wsi_cli}
: 定义了一个字典sub_cli_dict
,其中包含两个键值对。键是命令模式(‘tile’ 或 ‘wsi’),值是相应的命令行接口(CLI)文档字符串。args = docopt(__doc__, help=False, options_first=True, version='HoVer-Net Pytorch Inference v1.0')
: 使用docopt
函数解析命令行参数。__doc__
是脚本的文档字符串,它包含了所有命令行选项和参数的描述。help=False
和options_first=True
是传递给docopt
的参数,用于控制命令行解析的行为。sub_cmd = args.pop('<command>')
: 从解析后的参数中提取<command>
参数,即用户输入的命令模式。sub_cmd_args = args.pop('<args>')
: 从解析后的参数中提取<args>
参数,即用户输入的命令模式参数。logging.basicConfig(...)
: 设置日志记录的基本配置,包括日志级别、日志格式、日期格式以及日志处理器。这里使用了文件和标准输出两种日志处理器。
if args['--help'] and sub_cmd is not None:
: 如果用户输入了--help
选项并且指定了命令模式,则打印相应的命令模式CLI文档字符串。if args['--help'] or sub_cmd is None:
: 如果用户输入了--help
选项或者没有指定命令模式,则打印整个脚本的CLI文档字符串。
sub_args = docopt(sub_cli_dict[sub_cmd], argv=sub_cmd_args, help=True)
: 使用docopt
函数解析命令行参数,这次是基于命令模式指定的CLI文档字符串。args.pop('--version')
: 从解析后的参数中移除--version
选项,因为它在后续的代码中不再需要。gpu_list = args.pop('--gpu')
: 从解析后的参数中提取--gpu
选项,即用户指定的GPU列表。os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
: 将--gpu
选项的值设置为环境变量CUDA_VISIBLE_DEVICES
,这样PyTorch可以识别并使用指定的GPU。nr_gpus = torch.cuda.device_count()
: 获取当前可用的GPU数量。log_info('Detect #GPUS: %d' % nr_gpus)
: 打印当前检测到的GPU数量。args = {k.replace('--', '') : v for k, v in args.items()}
: 转换args
字典的键,移除前缀--
,以便与sub_args
字典的键保持一致。sub_args = {k.replace('--', '') : v for k, v in sub_args.items()}
: 同样,转换sub_args
字典的键。if args['model_path'] == None:
: 如果--model_path
选项没有指定,则抛出一个异常。nr_types = int(args['nr_types']) if int(args['nr_types']) > 0 else None
: 转换--nr_types
选项的值,并检查它是否大于0。
method_args = {...}
: 定义了一个字典method_args
,包含了模型参数和类型信息路径。run_args = {...}
: 定义了一个字典run_args
,包含了推理过程中的各种参数,如批量大小、推理工人数量、后处理工人数量等。
if args['model_mode'] == 'fast':
: 如果--model_mode
选项设置为’fast’,则更新run_args
字典中的切片大小。if sub_cmd == 'tile':
: 如果命令模式为’tile’,则更新run_args
字典中的切片输入和输出目录、内存使用量、绘制点、保存QuPath格式和保存原始映射等参数。if sub_cmd == 'wsi':
: 如果命令模式为’wsi’,则更新run_args
字典中的全切片图像输入和输出目录、掩码目录、缓存路径、处理放大倍数、模糊区域大小、块形状、切片形状、保存缩略图和保存掩码等参数。if sub_cmd == 'tile':
: 如果命令模式为’tile’,则导入infer.tile
模块,并创建一个InferManager
实例,使用method_args
和run_args
参数进行文件列表的处理。else:
: 如果命令模式为’wsi’,则导入infer.wsi
模块,并创建一个InferManager
实例,使用method_args
和run_args
参数进行全切片图像列表的处理。
整个脚本的目的是根据用户输入的命令行参数,选择相应的命令模式(‘tile’ 或 ‘wsi’),并使用相应的参数来运行推理过程。这个过程中,它还处理了日志记录、GPU设置、模型路径和类型信息路径的验证、推理参数的更新以及根据命令模式调用相应的推理管理器来处理文件或全切片图像列表。