fairseq (Facebook AI Research) 包

news2024/10/5 8:08:16

0. Abstract

最近在看一个用 RNNs 网络做 Translation 任务的程序, 关于数据处理部分, 主要用到工具包 sentencepiecefairseq, 前者主要是对文本进行分词处理, 后者则是对已分词的文本进行二进制化快速加载. 包越方便使用, 就说明包装得越狠, 也就越令人一头雾水, 本文简要记录学习过程.

1. Installation

pip install sentencepiece
pip install fairseq

注意, Windows 下安装 fairseq 需要 Visual C++ 构建工具, 否则会报错:

error: Microsoft Visual C++ 14.0 or greater is required. Get it with "Microsoft C++ Build Tools": https://visualstudio.microsoft.com/visual-cpp-build-tools/"

去官网下载 Microsoft C++ Build Tools 并安装就好了, 安装时须选择以下组件:

[注]: 不建议在 Windows 下跑这些程序, 很多包在 Windows 下问题太多, 甚至不支持. Microsoft C++ Build Tools 会占用 C 盘好几个 G (即使安装在其他盘).

小结: 建议在 Linux 系统下玩这些程序.

2. 用 sentencepiece 分词

翻译任务的数据集一般是这样的:


以中英文翻译任务为例, 有 {train/test}.{en/zh} 共 4 个文件, 文件中每一行代表一段文字, 中英文件内相同的行是彼此对应的.

2.1 字符全角转半角

这些数据内容是原始的 (raw), 首先要做的是进行数据清理: 字符编码统一化, 去除过长或过短的文本. “字符编码统一化” 值得记录一下:

def _str_q2b(ustring: str):
	"""
	把字串全形轉半形, 參考來源: https://ithelp.ithome.com.tw/articles/10233122;
	全形(全角): 中文标点符号, 有时中文输入法输入英文字母也是全形, 类似 word 中输入英文, 字体是宋体
	东亚文字一般是全形, 而西文字母和数字是半形(且对应有全形值). 
	那么为了使相同的字符在后端表示统一, 则把 65281 <= unicode <= 65374 范围内的全形字符转换到半形. 
	"""
	rchars = []
	for uchar in ustring:
		unicode = ord(uchar)  # unicode 码
		if unicode == 12288:  # 全形空格直接轉換
unicode = 32
		elif 65281 <= unicode <= 65374:  # 全形字元(除空格)根據關係轉化
unicode -= 65248
		rchars.append(chr(unicode))
	return ''.join(rchars)

关于全角和半角:

print('A', 'a', '0')
print(
	chr(ord('A') + 65248),
	chr(ord('a') + 65248),
	chr(ord('0') + 65248)
)

会输出:

小结: Unicode 字符有全半角之分, 数据预处理时需先进行全角转半角, 统一化.

2.2 训练分词模型
spm.SentencePieceTrainer.train(
	input=','.join([
		f'{self._path_clean}/train.{self._lang_src}',
		f'{self._path_clean}/train.{self._lang_tgt}'
	]),
	# will get 2 files: model_prefix.model, model_prefix.vocab
	model_prefix=self._path_tokens / f'spm{vocab_size}',
	vocab_size=vocab_size,
	character_coverage=1,
	model_type='unigram',  # 'bpe' 也可
	input_sentence_size=1e6,
	shuffle_input_sentence=True,
	normalization_rule_name='nmt_nfkc_cf'
)

这段代码展示了如何使用 sentencepiece 库中的 SentencePieceTrainer.train() 方法来训练一个 SentencePiece 分词模型, 特别适用于机器翻译任务, 其中涉及到两种不同的语言: [来自通义千问]

  • input: 指定了用于训练模型的文本文件. 在这个例子中, 使用的是源语言和目标语言的训练数据, 它们被拼接成一个字符串并用逗号分隔;
  • model_prefix: 定义输出模型文件的前缀. 模型训练后会生成两个文件: 一个是 .model 模型文件, 另一个是 .vocab词汇表文件;
  • vocab_size: 指定了词汇表的大小, 即模型将学习的词汇单元数量. 较大的词汇表可能会提供更好的表达能力, 但也可能导致过拟合或增加计算成本;
  • character_coverage: 设置为 1 表示模型应该覆盖输入文本中的所有字符. 较小的值(如 0.995)意味着模型将忽略最不常见的 0.5% 的字符, 这可能有助于减少词汇表的大小;
  • model_type: 设置为 ‘unigram’ 意味着模型将使用 Unigram 算法进行训练. 另一种选择是 ‘bpe’, 即 Byte-Pair Encoding, 也是一个常用的子词分割算法;
  • input_sentence_size: 控制着用于训练的句子数量. 1e6 表示一百万句. 如果数据集非常大, 你可能想要抽样一定数量的句子以加快训练速度;
  • shuffle_input_sentence: 如果设置为 True, 则在训练过程中输入句子会被随机打乱, 有助于提高模型的泛化能力;
  • normalization_rule_name: 这个参数确定了如何规范化输入文本. 'nmt_nfkc_cf' 是一个预定义的规则, 通常用于机器翻译任务, 它会执行一些文本清洗操作, 如转换为半角字符等.

呃!!! 前面还特意搞了一下字符转换, 可能多虑了. 那就看一下在 normalization_rule_name='nmt_nfkc_cf' 的情况下不进行半角转换, 模型会不会自动转换字符吧:

input=','.join([
	f'{self._path_raw}/train.{self._lang_src}',  # 使用 raw 中的数据进行分词
	f'{self._path_raw}/train.{self._lang_tgt}'
]),

直接使用 raw 文件夹下的文本文件进行分词, 果然发现 spm.SentencePieceTrainer.train() 进行了半角转换:

没骗我们, 逗点变成英文的了, 但引号没有, 因为中文引号和英文引号不是全角和半角的关系.

小结: 使用 sentencepiece 库中的 SentencePieceTrainer.train() 方法可训练一个 SentencePiece 分词模型, 输入train.{en/zh} 文件, 输出为两个文件: model_prefix.model, model_prefix.vocab.

2.3 加载分词模型进行分词
spm_model = spm.SentencePieceProcessor()
spm_model.load(str(self._path_tokens / f'spm{vocab_size}.model'))
for file_name, lang in itertools.product(['train', 'valid', 'test'], [self._lang_src, self._lang_tgt]):
	out_path = self._path_tokens / f'{file_name}.{lang}'
	if not re_tokenize and out_path.exists():
		print(f'{out_path} exists. skipping spm_encode.')
	else:
		with (
			open(self._path_split / f'{file_name}.{lang}', 'r') as in_file,
			open(out_path, 'w') as out_file
		):
			for line in in_file:
				tokens = spm_model.encode(line.strip(), out_type=str)  # 编码结果是 str
				print(' '.join(tokens), file=out_file)

创建一个 SentencePieceProcessor() 实例, 并加载刚才训练所的的分词模型, 然后用 spm_model.encode(...) 对文本文件进行编码, 也就是分词.

3. fairseq 的使用

fairseq 的使用包括很多内容, 包括: 数据的二进制化、数据的加载、模型的定义与训练等.

3.1 数据的二进制化

经过 sentencepiece 的分词处理, 得到了分词后的文本. 下一步要对文本执行二进制化, 以便于快速加载到内存中. 注意不是写 python 代码, 而是执行命令行:

python -m fairseq_cli.preprocess \
	--source-lang {src_lang}\
	--target-lang {tgt_lang}\
	--trainpref {prefix/'train'}\
	--validpref {prefix/'valid'}\
	--testpref {prefix/'test'}\
	--destdir {binpath}\
	--joined-dictionary\
	--workers 2

此命令执行几个关键任务: [来自通义千问]

  1. 数据准备: 从指定的文件(‘train’, ‘valid’, ‘test’ 前缀)读取源语言和目标语言数据, 并将二进制格式的结果写入目的地目录(binpath);
  2. 字典创建: 通过指定 --joined-dictionary, 它为源语言和目标语言创建一个统一的字典. 这对于低资源场景或者当希望模型学习两种语言之间相似单词的映射时是有益的;

参数解析:

  • --source-lang {src_lang}: 源语言;
  • --target-lang {tgt_lang}: 目标语言;
  • --trainpref {prefix/'train'}: 训练数据文件的前缀; 实际的文件名将是 {prefix}/train.{src_lang}{prefix}/train.{tgt_lang}, 分别对应源语言和目标语言的文件;
  • --validpref {prefix/'valid'}: 与上述相同, 但针对验证数据;
  • --testpref {prefix/'test'}: 与上述相同, 但针对测试数据;
  • --destdir {binpath}: 数据以二进制格式保存的目录;
  • --joined-dictionary: 是否为源语言和目标语言创建共享词汇表.

运行此命令后, 应该能在 {binpath} 目录中找到多个二进制文件, 包括词典已标记好的数据, 这些数据已经准备好用于训练神经机器翻译模型.

=== dict.{en/zh}.txt ===
两个词典文件是特殊的, 可能是词典内容往往不会很多, 所以没必要用二进制格式存储. 注意, 词典的根据仅仅是 train.{en/zh} 中的文本内容, 而不包括 {valid/test}.{en/zh}. 其内容是这样的:

统计了每个 token 的词频. 由于使用了 --joined-dictionary, 两个词典的内容是一样的. 倘若不使用这个参数, 则会将中英文的词典分开.

=== preprocess.log ===
preprocess.log 文件是数据预处理过程的日志文件, 里面会记录配置参数, 以及处理结果的数据统计:

可以看到有少量词被 <unk> 替代了, 但我们不知道替代的机制是什么, 也不打算追究了.

== 二进制文件 ==
二进制文件命名的前缀都是 {train/valid/test}.en-zh.{en/zh}, 后缀是 {bin/idx}, 以下解释来自通义千问:
bin 目录下的 .bin.idx 文件是预处理阶段生成的二进制数据文件, 它们用于加速模型的训练过程. 这些文件由 fairseq-preprocess 工具生成, 当处理大规模文本数据时, 使用二进制格式可以显著减少数据加载和处理的时间.

=== .bin 文件 ===
.bin 文件包含了经过 Tokenization[?]NumericalizationPadding 处理后的序列数据. 每个 .bin 文件通常对应于一个特定的数据集(如训练集、验证集或测试集), 并且分为源语言和目标语言两部分. 例如, 对于英语到德语的翻译任务, 你可能会看到 train.en-de.en.bintrain.en-de.de.bin 这样的文件名.

.bin 文件中, 通常包含了以下信息:

  • 每个样本的 Token ID 序列;
  • 对应的长度信息;
  • 可能还包括特殊标记, 如开始标记(BOS)和结束标记(EOS).

=== .idx 文件 ===
.idx 文件是一个索引文件, 它为对应的 .bin 文件提供快速随机访问的能力. 这是因为直接从二进制文件中随机读取数据通常是低效的, 尤其是当数据量很大时, .idx 文件存储了每个样本在 .bin 文件中的起始位置和长度信息, 使得模型在训练时可以直接跳转到所需的样本位置, 无需从头开始顺序读取整个文件.

还不知道有没有可以直接写在程序中的 python 代码 API.
上面说 “经过 Tokenization[?]、… 和 … 处理后”, 也不知道 fairseq_cli.preprocess是否真的执行了分词操作.
将在附录中解答这两个问题.

小结: 使用 fairseq_cli.preprocess 命令可对数据进行二进制化, 可节省存储空间以及加快数据的加载速度, 同时提供数据的随机读取; 在过程中, 会对文本进行分词, 创建词典.

3.2 数据的加载

这种二进制文件让人讨厌的地方就在于你没法打开文件看一看.

任务配置及 task 对象的创建 → 用 task 对象 load 数据 → 遍历数据.

3.2.1 任务配置及 task 对象的创建

经过 fairseq_cli.preprocess 的处理, bin 目录中的各种文件就可以被 fairseq 中的数据加载工具加载了. 首先需要创建一个 task:

task_cfg = translation.TranslationConfig(
	data=str(self._path_bin),
	source_lang=self._lang_src,
	target_lang=self._lang_tgt,
	train_subset='train',
	required_seq_len_multiple=8,
	dataset_impl='mmap',
	upsample_primary=1
)
task = translation.TranslationTask.setup_task(task_cfg)

TranslationConfig 是一个用于封装翻译任务相关配置的类, 它允许你设定训练数据处理的多个方面. 下面详细解释一下这段代码中的参数:

  • data=str(self._path_bin): 数据目录的路径. 这个目录应该包含预处理后的数据, 如 .bin.idx 文件;
  • source_lang=self._lang_src: 源语言代码, 这将决定模型将哪种语言的文本作为输入. 例如, 如果源语言是英语, 则设置为 "en";
  • target_lang=self._lang_tgt: 目标语言代码, 这将决定模型将哪种语言的文本作为输出. 例如, 如果目标语言是汉语, 则设置为 "zh";
  • train_subset='train': 用于训练的子集名称. 在预处理阶段, 数据通常会被划分为训练、验证和测试三个子集, 这里选择的是 "train" 子集;
  • required_seq_len_multiple=8: 设定序列长度需要是某个正整数的倍数. 在某些情况下, 特别是涉及到卷积神经网络(CNN)或注意力机制时, 要求输入序列长度是特定数值的倍数以优化计算效率. 这里设置为 8, 意味着序列长度必须会被填充为 8 的倍数;
  • dataset_impl='mmap': 数据集的实现方式. "mmap" 表示使用内存映射文件的方式加载数据, 这种方式可以提高数据读取的速度, 尤其是在处理大量数据时;
  • upsample_primary=1: 如果你使用的是多个训练数据集, upsample_primary 参数用于控制主数据集(primary dataset)的上采样倍数. 设置为 1 意味着不进行上采样, 即主数据集的样本数量保持不变;

上采样(Upsampling)主要用于增加数据集中的样本数量. 上采样通常有以下两种主要的应用场景:

  • 处理不平衡数据集: 当数据集中某一类别的样本数量远少于其他类别时, 少数类别的模型训练效果可能会受到影响, 因为模型容易偏向于多数类别. 上采样可以通过复制少数类别的样本或合成新的样本(如通过 SMOTE 算法)来平衡数据集, 从而改进模型在少数类别上的预测能力.
  • 图像和序列数据的生成: 在生成对抗网络(GANs)或自编码器(Autoencoders)等模型中, 上采样层用于从较低分辨率或较短序列恢复高分辨率图像或长序列. 例如, 超分辨率重建就是一种典型的上采样应用, 旨在从低分辨率图像生成高分辨率图像.

上采样可以通过多种方式进行, 包括但不限于:

  • 重复样本: 直接复制数据集中的少数类别样本, 增加其在数据集中的比例.
    数据增强: 通过对现有样本进行变换(如旋转、翻转、缩放等)来合成新的样本, 这种方法常用于图像数据.
  • 插值: 在信号处理中, 上采样可以通过插入零值并在相邻点之间进行插值来实现, 从而提高信号的分辨率.
  • 基于模型的方法: 如 GANs 和 Autoencoders 可以生成新的样本, 这种方法通常用于图像和序列数据的上采样.

这些配置通常会在训练脚本中被传入到 fairseq-train相关的训练函数中, 以启动模型的训练过程.

点开 TranslationConfig, 查看源码, 可以看到里面还有很多其他的设置, 包括:

left_pad_source: bool = field(
	default=True, metadata={"help": "pad the source on the left"}
)  # 默认情况下是在左侧 pad source 句子
left_pad_target: bool = field(  # 而target 句子不 pad
	default=False, metadata={"help": "pad the target on the left"}
)
max_source_positions: int = field(
	default=1024, metadata={"help": "max number of tokens in the source sequence"}
)
max_target_positions: int = field(
	default=1024, metadata={"help": "max number of tokens in the target sequence"}
)

task = translation.TranslationTask.setup_task(task_cfg) 是根据 task_cfg 创建了一个 TranslationTask 实例对象, 其中 setup_task(...) 是工厂方法:

@classmethod
def setup_task(cls, cfg: TranslationConfig, **kwargs):
	"""
	Setup the task (e.g., load dictionaries).
	Args:
		args (argparse.Namespace): parsed command-line arguments
	"""
	paths = utils.split_paths(cfg.data)
	...
	# load dictionaries
	src_dict = cls.load_dictionary(
		os.path.join(paths[0], "dict.{}.txt".format(cfg.source_lang))
	)
	tgt_dict = cls.load_dictionary(
		os.path.join(paths[0], "dict.{}.txt".format(cfg.target_lang))
	)
	...
	return cls(cfg, src_dict, tgt_dict)

首先加载了 dictionary, 然后连同 cfg 一起作为参数, 创建了一个 TranslationTask 实例对象, 其属性包括(继承自父类FairseqTask的属性和两个 dictionary):


小结: 根据 TranslationConfig 创建 TranslationTask 实例对象是一切的开始, 里面封装了很多操作, 如 数据集的加载与管理, 模型输入输出格式的定义等.

3.2.2 用 task 对象 load 数据

值得注意的是, path 不是直接来源于 cfg.data, 而是经过处理的:

paths = utils.split_paths(cfg.data)
>>>
def split_paths(paths: str, separator=os.pathsep) -> List[str]:
	return (
		paths.split(separator) if "://" not in paths else paths.split(MANIFOLD_PATH_SEP)
	)

os.pathsep:, 这让我想起来 Ubuntu 中一些环境变量配置时的 :, 我估摸着这就是多个路径的分隔符. 所以, cfg.data 有可能是多个路径. 这也呼应了前面的:

def load_dataset(self, split, epoch=1, combine=False, **kwargs):
	"""
	Load a given dataset split.
	Args:
		split (str): name of the split (e.g., train, valid, test)
	"""
	paths = utils.split_paths(self.cfg.data)
	assert len(paths) > 0
	if split != self.cfg.train_subset:
		# if not training data set, use the first shard for valid and test
		paths = paths[:1]
	data_path = paths[(epoch - 1) % len(paths)]

	# infer langcode
	src, tgt = self.cfg.source_lang, self.cfg.target_lang

	self.datasets[split] = load_langpair_dataset(
		data_path,
		split,
		src, self.src_dict,
		tgt, self.tgt_dict,
		combine=combine,
		dataset_impl=self.cfg.dataset_impl,
		...
	)

先解除一个疑惑: 训练文件的命名已经是 train.xxx, 且二进制预处理时也使用了参数 --trainpref, 文件名也是 fairseq 包生成的, 为何还要在 TranslationConfig 中使用 train_subset='train' 设置训练文件名?
load_dataset(...) 中也许有了答案:

if split != self.cfg.train_subset:
	# if not training data set, use the first shard for valid and test
	paths = paths[:1]

这里用于判断 split 是否是 train_subset, 如果不是, 则去掉其他数据路径, 只保留第一个路径. 而后面的一句:

data_path = paths[(epoch - 1) % len(paths)]  # path[(epoch-1) % 1] = path[0]

data_path 就只取第一个路径. 所以猜测:

  • cfg.data 是由 : 分隔的多个数据集路径;
  • 训练数据集可能存在于多个路径中, 而 valid/test 数据则只存在于第一个路径中, 即 primary 数据集;
  • split == self.cfg.train_subset', 即读取训练集, 则读取数据的路径epoch 滚动.

那为什么不直接 if split != 'train' 呢? 呃! 咱也不知道, 可能还能设置 train_subset='valid'?

好! 接下来使用 load_langpair_dataset(...) 函数加载数据集:

def load_langpair_dataset(...):
	...
	src_datasets = []
	tgt_datasets = []
	# 迭代 k, 意思是文件夹内可能存在 {train/valid/test}{k} 多个文件, 需要 append 到 [] 中
	for k in itertools.count():  # 无限迭代器
		split_k = split + (str(k) if k > 0 else "")  # train, train1, train2, ...
		# prefix = ./bin/train.en-zh.
		if split_exists(split_k, src, tgt, src, data_path):
			prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt))
		elif split_exists(split_k, tgt, src, src, data_path):
			prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src))
		else:
			if k > 0:  # 如果 ./bin/train1.en-zh.en.{bin/idx} 和 ./bin/train1.zh-en.en.{bin/idx} 都没找到
				break
			else:
				raise FileNotFoundError(
					"Dataset not found: {} ({})".format(split, data_path)
				)
		src_dataset = data_utils.load_indexed_dataset(
			prefix + src, src_dict, dataset_impl
		)
		...
		src_datasets.append(src_dataset)

		tgt_dataset = data_utils.load_indexed_dataset(
			prefix + tgt, tgt_dict, dataset_impl
		)
		if tgt_dataset is not None:
			tgt_datasets.append(tgt_dataset)
		...
		if not combine:
			break
	...
	# >>> 合并多个文件中的文本 >>>
	if len(src_datasets) == 1:
		src_dataset = src_datasets[0]
		tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None
	else:
		sample_ratios = [1] * len(src_datasets)
		sample_ratios[0] = upsample_primary  # 增加主文件的采样权重
		src_dataset = ConcatDataset(src_datasets, sample_ratios)
		if len(tgt_datasets) > 0:
			tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
		else:
			tgt_dataset = None
	...
	eos = None
	...
	align_dataset = None
	...
	tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None
	return LanguagePairDataset(
		src_dataset,
		src_dataset.sizes,
		src_dict,
		tgt_dataset,
		tgt_dataset_sizes,
		tgt_dict,
		...cfg.xxx...
	)

这段代码假设 bin 文件夹内存在 {train/valid/test}{k} 多个文件, 读到多个 {src/tgt}_datasets = [] 中, 并通过 ConcatDataset(src_datasets, sample_ratios) 合并到一个 dataset 中, 最后通过 LanguagePairDataset(...) 创建语言对数据集. 所以这里我们要继续看三个东西:

  • data_utils.load_indexed_dataset(...);
  • ConcatDataset(xxx_datasets, sample_ratios);
  • LanguagePairDataset(src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, ...).
3.2.2.1 data_utils.load_indexed_dataset(...)
src_dataset = data_utils.load_indexed_dataset(
	prefix + src, src_dict, dataset_impl
)

这里的 prefix+src = './bin/train.en-zh.en', 连同 src_dict, dataset_impl 传入 data_utils.load_indexed_dataset(...), 进行数据加载:

def load_indexed_dataset(
	path, dictionary=None, dataset_impl=None, combine=False, default="cached"
):
	import fairseq.data.indexed_dataset as indexed_dataset
	from fairseq.data.concat_dataset import ConcatDataset

	datasets = []
	for k in itertools.count():
		path_k = path + (str(k) if k > 0 else "")
		try:  # 有后缀 '.bin/.idx' 的 path 去掉后缀, 这里的 path = 
			# prefix+src = './bin/train.en-zh.en' 无, 故保持不变
			path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
		except Exception as e:
			...
		...
		dataset = indexed_dataset.make_dataset(
			path_k,
			impl=dataset_impl_k or default,
			fix_lua_indexing=True,
			dictionary=dictionary
		)
		if dataset is None:
			break  # 加载完毕
		datasets.append(dataset)
		if not combine:
			break  # not combine, 则仅使用一个数据文件

	if len(datasets) == 0:
		return None
	elif len(datasets) == 1:
		return datasets[0]
	else:
		return ConcatDataset(datasets)

这里依然假设有多个文件要加载, 由 combine 参数决定加载多个, 还是只加载头一个:

combine (bool, optional): automatically load and combine multiple datasets.
	For example, if *path* is 'data-bin/train', then we will combine
	'data-bin/train', 'data-bin/train1', ... and
	return a single ConcatDataset instance.

只可惜, 这里的调用使用了默认值 False, 也就是只有一个文件. 我们只需要看:

dataset = indexed_dataset.make_dataset(
	path_k,
	impl=dataset_impl_k or default,
	fix_lua_indexing=True,
	dictionary=dictionary
)
def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None):
	if ...:
	...
	elif impl == "mmap" and MMapIndexedDataset.exists(path):
		return MMapIndexedDataset(path)  // 我们使用的, 也是默认的
	...
	return None

它是根据 impl 决定加载方式的, 对于我们使用的 'mmap', 它不需要 fix_lua_indexing, dictionary 这两个参数, 仅仅有加载路径就行了. 下面是真正的重点, 展现了如何加载二进制文件:

class MMapIndexedDataset(torch.utils.data.Dataset):
	def __init__(self, path):
		super().__init__()
		self._path = None
		self._index = None
		self._bin_buffer = None
		self._do_init(path)

	def _do_init(self, path):
		self._path = path
		self._index = self.Index(index_file_path(self._path))  # idx 文件加载
		# >>> .bin 文件加载 >>>
		_warmup_mmap_file(data_file_path(self._path))  # 加载测试, 啥都没干
		self._bin_buffer_mmap = np.memmap(
			data_file_path(self._path), mode="r", order="C"
		)
		self._bin_buffer = memoryview(self._bin_buffer_mmap)

	@lru_cache(maxsize=8)
	def __getitem__(self, i):
		ptr, size = self._index[i]
		np_array = np.frombuffer(
			self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr
		)
		if self._index.dtype != np.int64:
			np_array = np_array.astype(np.int64)
		return torch.from_numpy(np_array)

可以看到, 构建这个 MMapIndexedDataset 主要干了两件事: (1) 加载 .idx 文件, (2) np.memmap(...) 构建了 .idx 文件的内存映射. 然后 __getitem__(self, i) 使用 np.frombuffer(...) 读取 .bin 文件, 将数据转化为 torch.long 返回. 至于 self.Index, 这个对象里有数据的一些元信息以及每条数据在 .bin 文件中的位置与长度, 详情在附录中(包含以内存映射方式读取二进制文件).

3.2.2.2 ConcatDataset(xxx_datasets, sample_ratios)

虽然此例中, 我们只有一个数据集, 但还是要了解一下多数据集是如何拼接的:

class ConcatDataset(FairseqDataset):
	def __init__(self, datasets, sample_ratios=1):
		super(ConcatDataset, self).__init__()
		self.datasets = list(datasets)
		if isinstance(sample_ratios, int):
			sample_ratios = [sample_ratios] * len(self.datasets)
		self.sample_ratios = sample_ratios
		# 根据每个数据集的采样权重, 设置整个数据集的累计下标, 也即把数据集挨个摆放, 样本下标是累计的;
		# 如果有 3 个数据集, 大小分别为: 100, 300, 200; sample_ratios = [2, 4, 3];
		# 那么, cumulative_sizes = [200, 1400, 2000]
		self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
		self.real_sizes = [len(d) for d in self.datasets]  # 各数据集的实际大小: [100, 300, 200]

	def __len__(self):
		return self.cumulative_sizes[-1]

	def __getitem__(self, idx):
		# 根据 real_sizes, cumulative_sizes 等信息, 计算出该取哪个数据集的哪个样本
		dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
		return self.datasets[dataset_idx][sample_idx]
3.2.2.3 LanguagePairDataset(src_dataset, src_dataset.sizes, src_dict, tgt_dataset, tgt_dataset_sizes, tgt_dict, ...)

其实这个 LanguagePairDataset 就没什么好说的了, 它只是将源和目标数据集放在一块, 一起加载而已. 如果非要说点什么, 无非就是它除了拼装之外, 还进行了一些额外的操作:

class LanguagePairDataset(FairseqDataset):
	"""
	A pair of torch.utils.data.Datasets.

	Args:
		left_pad_source (bool, optional): pad source tensors on the left side
			(default: True).
		left_pad_target (bool, optional): pad target tensors on the left side
			(default: False).
		input_feeding (bool, optional): create a shifted version of the targets
			to be passed into the model for teacher forcing (default: True).
		remove_eos_from_source (bool, optional): if set, removes eos from end
			of source if it's present (default: False).
		append_eos_to_target (bool, optional): if set, appends eos to end of
			target if it's absent (default: False).
		append_bos (bool, optional): if set, appends bos to the beginning of
			source/target sentence.
		num_buckets (int, optional): if set to a value greater than 0, then
			batches will be bucketed into the given number of batch shapes.
	"""

其中我们能通过 TranslationConfig 设置的参数有:

left_pad_source=self.cfg.left_pad_source,  # default=True
left_pad_target=self.cfg.left_pad_target,  # default=False
num_buckets=self.cfg.num_batch_buckets,    # default=0
pad_to_multiple=self.cfg.required_seq_len_multiple

起码在使用 task 对象加载数据的过程中, LanguagePairDataset其他参数都是默认固定好的.

其样本的样式是这样的:

def __getitem__(self, index):
	tgt_item = self.tgt[index] if self.tgt is not None else None
	src_item = self.src[index]
	
	if self.append_eos_to_target:
		...
	if self.append_bos:
		...
	if self.remove_eos_from_source:
		...
	example = {
		"id": index,
		"source": src_item,
		"target": tgt_item,
	}
	...
	return example

小结: 至此, 漫长的数据加载之路就结束了, 太过漫长, 以至于我们"不识大体". 画一张图来"识一下大体"吧:

3.2.3 数据集的遍历

加载完数据之后, 就可以遍历数据了, 我们把 valid 数据拿出来看看:

sample = task.dataset("valid")[0]  # valid 数据集的第 1 个样本
pprint.pprint(sample)
pprint.pprint(  # 将下标张量转化为 str 的句子
	"Source: " + task.source_dictionary.string(
		sample['source'],
		"sentencepiece",  # 说明预处理时的分词工具, 以便还原句子
	)
)
pprint.pprint(
	"Target: " + task.target_dictionary.string(
		sample['target'],
		"sentencepiece",
	)
)
{
	'id': 0,
	'source': tensor([22, 59, 12, 2601, 13, 452, 16, 5, 326, 13, 21, 6, 7, 2]),
	'target': tensor([51, 760, 1488, 792, 442, 1591, 1564, 10, 2])
}
'Source: we have to buy a lot of those ads .'
'Target: 我們必須買很多那些廣告 。'

可见结果如 LanguagePairDataset 中写的那样, 返回了一个形如

example = {
	"id": index,
	"source": src_item,
	"target": tgt_item,
}

的样本, 包含着样本的 id, 源语言的句子 src_item, 目标语言的句子 tgt_item. 句子是由单词对应的下标张量所表示, 还没有经过 padding 等操作, 还是比较原始的, 仅仅在末尾添加了一个 2, 从

class Dictionary:
	def __init__(
			self,
			*,  # begin keyword-only arguments
			bos="<s>",
			pad="<pad>",
			eos="</s>",
			unk="<unk>",
			extra_special_symbols=None,
	):
		self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
		self.symbols = []  # 单词列表
		self.count = []  # # 对应词频
		self.indices = {}  # 单词在 symbol 和 count 中的下标
		self.bos_index = self.add_symbol(bos)  # 首先添加四个特殊标记
		self.pad_index = self.add_symbol(pad)  # 并返回其下标
		self.eos_index = self.add_symbol(eos)
		self.unk_index = self.add_symbol(unk)

来看, 2eos="</s>", 表示句子的结束. 0 是句子的开始 <s>, 1<pad>, 3<unk>.

上面只是拿出来一个样本看看, 如何迭代呢?

for i, sample in enumerate(task.dataset("valid")):
	if i < 3:
		pprint.pprint(sample)

像普通数据集一样迭代? 自然不是, 还得组成 batch 呢! 然而并不是放到 DataLoader 里面, fairseq 里有专门的迭代器:

batch_iterator = task.get_batch_iterator(
	dataset=task.dataset(split),
	max_tokens=config.max_tokens,
	max_sentences=None,
	max_positions=utils.resolve_max_positions(task.max_positions(), config.max_tokens),
	ignore_invalid_inputs=True,
	seed=config.seed,
	num_workers=config.num_workers,
	epoch=epoch,
	disable_iterator_cache=not cached
)

似乎 batch_size 并不是由一个叫 batch_size 的参数决定, 而是由:

max_tokens=config.max_tokens,
max_sentences=None,
max_positions=utils.resolve_max_positions(task.max_positions(), config.max_tokens)

等参数决定. 输出两个batch 看看:

{
	'id': tensor([1960, 3713, 1096,  441, 2897, 2340,  493, 2309, 1383, 2517]),
	'net_input': {
		'prev_output_tokens': tensor(
			[[2, 138, ...,   10,  1, 1, 1, 1],
			 [2,   5, ...,  869,  1, 1, 1, 1],
			 ...,
			 [2, 115, ..., 2178, 10, 1, 1, 1]]
		),
		'src_lengths': tensor([30, 30, 30, 30, 30, 30, 30, 30, 30, 30]),
		'src_tokens': tensor(
			[[1, ..., 1, 17, ..., 6,    7,    2],
			 ...,
			 [1, ..., 1, 33, ..., 6,    7,    2]]
		)
	},
	'nsentences': 10,
	'ntokens': 207,
	'target': tensor(
		[[138, ...,  10, 2, 1, 1, 1, 1],
		 [  5, ..., 869, 2, 1, 1, 1, 1],
		 ...,
		 [115, ...,  10, 2, 1, 1, 1, 1]]
	)
}
{
	'id': tensor([2488, 3873, 1464,  387]),
	'net_input': {
		'prev_output_tokens': tensor(
			[[  2, 148, ..., 1170,   4, 1, ..., 1],
			 [  2,  51, ...,    8,  10, 1, ..., 1],
			 [  2,   5, ..., 4293,  10, 1, ..., 1],
			 [  2, 627, ..., 3082, 231, 1,   1, 1]]
		),
		'src_lengths': tensor([57, 57, 57, 57]),
		'src_tokens': tensor(
			[[1, 1, 1, 32, ..., 2405, 7, 2],
			 [1, 1, 1, 11, ...,    6, 7, 2],
			 [1, 1, 1, 12, ...,  643, 7, 2],
			 [1, 1, 1, 32, ..., 1456, 7, 2]]
		)
	},
	'nsentences': 4,
	'ntokens': 231,
	'target': tensor(
		[[148, ..., 1170,   4, 2, 1, ..., 1],
		 [ 51, ...,    8,  10, 2, 1, ..., 1],
		 [  5, ..., 4293,  10, 2, 1, ..., 1],
		 [627, ..., 3082, 231, 2, 1,   1, 1]]
	)
}

可以看到, 两个 batchntokensnsentences 都不一样, 说明 batch 中 sample 的数量不是由这两者决定. 来看看通义千问的解释:

max_tokens=None,            # 每个批次的最大令牌数
max_sentences=None,         # 每个批次的最大句子数
max_positions=None,         # 输入序列的最大长度

在函数 task.get_batch_iterator(...) 内确实能找到:

# filter examples that are too large
if max_positions is not None:
	indices = self.filter_indices_by_size(
		indices, dataset, max_positions, ignore_invalid_inputs
	)
# create mini-batches with given size constraints
batch_sampler = dataset.batch_by_size(
	indices,
	max_tokens=max_tokens,
	max_sentences=max_sentences,
	required_batch_size_multiple=required_batch_size_multiple
)

所以, max_position 是用来过滤掉过长的句子的. 实际上模型也有这个最大句长限制:

# 解决最大位置限制
max_positions = utils.resolve_max_positions(
	model.max_positions(),
	task.max_positions(),
	default_max_positions=(1024, 1024)
)

模型数据用户设置三者决定, 默认都是 None, 可能就是没有限制的意思. 具体细节就不追究了.

max_tokensmax_sentences 共同决定了 batch 的大小. 每个batch 所含的 tokens 数量不会超过 max_tokens, 句子数量也不会超过 max_sentences. (注意, 计数不包括填充.)

再来看一下 batch 的内部, 字典内有 'id', 'net_input', 'nsentences', 'ntokens', 'target' 等五部分组成, 真正输入到网络里的是 'net_input'(后面再说). 值得注意是:

  • 数据都是经过填充的, 一个 batch 内的 src_tokens, 'target''prev_output_tokens' 都是各自在内部对其的, 这样才能在表示在一个二维张量里. 默认情况下, src 是在左侧填充, tgt 是在右侧填充;
  • 真正输入到解码器的不是 'target' 而是 batch['net_input']['prev_output_tokens'], 其是把 'target'2 移到开端;
  • 一个 batch 内的句子长度都是接近的, 这似乎是由于句子已经按长度排好序了. 设置 shuffle=False, 会发现 batch[nsentences] 越来越小, 说明句子是越来越长的;
  • 填充句子时, 会将句子张量的长度 batch['src_tokens'].shape[-1] 等填充到 cfg.required_seq_len_multiple 的倍数, 据说是为了内存的高效利用.
3.3 Seq2Seq 模型
class Seq2Seq(FairseqEncoderDecoderModel):
	def __init__(self, args, encoder, decoder):
		super().__init__(encoder, decoder)
		self.args = args

	def forward(
			self,
			src_tokens,
			src_lengths,
			prev_output_tokens,
			return_all_hiddens: bool = True
	):
		"""
		Run the forward pass for an encoder-decoder model.
		"""
		encoder_out = self.encoder(
			src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens
		)
		logits, extra = self.decoder(
			prev_output_tokens,
			encoder_out=encoder_out,
			src_lengths=src_lengths,
			return_all_hiddens=return_all_hiddens
		)
		return logits, extra

这是翻译任务的模型架构, 由一个 encoder 和一个 decoder 组成, 继承 fairseq 中的 FairseqEncoderDecoderModel, 并实现其抽象方法 forward 方法. 可以看到, 其所需参数恰好跟 batch_iterator = task.get_batch_iterator(...) 迭代器所返回的样本中 batch['net_input'] 内容是一致的. 然后, 数据经过 encoderdecoder 的处理, 得到了 logits, 我们知道, 它就是输入到损失函数的东西.

下面, 我们分别详细地看一下 encoderdecoder 是什么样的.

3.3.1 FairseqEncoder
class RNNEncoder(FairseqEncoder):
	def __init__(self, args, dictionary, embed_tokens):
		super().__init__(dictionary)
		self.embed_tokens = embed_tokens

		self.embed_dim = args.encoder_embed_dim
		self.hidden_dim = args.encoder_ffn_embed_dim
		self.num_layers = args.encoder_layers

		self.dropout_in_module = nn.Dropout(args.dropout)
		self.rnn = nn.GRU(
			self.embed_dim,
			self.hidden_dim,
			self.num_layers,
			dropout=args.dropout,
			batch_first=False,  # True 时, 输入和输出的形状应为 (batch, seq, feature), 而不是默认的 (seq, batch, feature)
			bidirectional=True
		)
		self.dropout_out_module = nn.Dropout(args.dropout)
		self.padding_idx = dictionary.pad()

	def combine_bidir(self, outs, bsz: int):
		out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous()
		return out.view(self.num_layers, bsz, -1)

	def forward(self, src_tokens, src_lengths=None, **kwargs):
		bsz, seqlen = src_tokens.size()  # batch_size 在前

		# get embeddings
		x = self.embed_tokens(src_tokens)  # (batch_size, seq_len, embd_dim)
		x = self.dropout_in_module(x)  # 输入前就进行 dropout

		# B x T x C -> T x B x C
		x = x.transpose(0, 1)  # 换了 (seq_len, batch_size, embd_dim)

		# 過雙向 RNN
		h0 = x.new_zeros(2 * self.num_layers, bsz, self.hidden_dim)
		x, final_hiddens = self.rnn(x, h0)
		outputs = self.dropout_out_module(x)  # 输出 dropout
		# outputs = [seq_len, batch_size, hid_dim * directions] 是最上層 RNN 的輸出
		#  hidden = [num_layers * directions, batch_size, hid_dim]

		# 因為 Encoder 是雙向的 RNN,所以需要將同一層兩個方向的 hidden_state 接在一起
		final_hiddens = self.combine_bidir(final_hiddens, bsz)
		# hidden = [num_layers x batch x num_directions*hidden]

		encoder_padding_mask = src_tokens.eq(self.padding_idx).t()
		return (
			outputs,  # seq_len x batch x hidden
			final_hiddens,  # num_layers x batch x num_directions*hidden
			encoder_padding_mask  # seq_len x batch
		)

	def reorder_encoder_out(self, encoder_out, new_order):
		# 這個 beam search 時會用到,意義並不是很重要
		return (
			encoder_out[0].index_select(1, new_order),
			encoder_out[1].index_select(1, new_order),
			encoder_out[2].index_select(1, new_order),
		)

附录

1. fairseq_cli.preprocess 对应的 Python API

对于 fairseq_cli.preprocess 命令行, 想要把该功能融合进程序中, 得把命令行内容写进一个 shell 脚本, 这里假设脚本文件名为 'tobin.sh', 然后调用 python 的 subprocess.run(...) 函数以执行 shell 脚本:

#!/bin/bash
python -m fairseq_cli.preprocess \
	--source-lang $1\
	--target-lang $2\
	--trainpref $3\
	--validpref $4\
	--testpref $5\
	--destdir $6\
	--joined-dictionary\
	--workers 4
result = subprocess.run(
	[
		'bash', './tobin.sh',
		f'{self._lang_src}',
		f'{self._lang_tgt}',
		f'{self._path_tokens}/train',
		f'{self._path_tokens}/valid',
		f'{self._path_tokens}/test',
		f'{self._path_bin}'
	],
	capture_output=True, text=True
)
if result.returncode == 0:
	print(result.stdout)
else:
	print(result.stderr)

那有没有对应的 python API 可调用呢? 答案是肯定的, 我们找到了 fairseq_cli.preprocess.py 文件, 里面是数据处理代码:

def cli_main():
	parser = options.get_preprocessing_parser()
	args = parser.parse_args()
	main(args)

操作主要在 main(args) 函数中, 其中 args 中充满了各种配置, 来自于命令行, 那我们自然可以自己构建一个 args, 然后调用 main(args). 只可惜, 我们不知道里面到底需要哪些配置, 因为 args = parser.parse_args() 中可是有很多默认配置的, 要是盲目地这么干:

args = {  # 设置参数
	'source_lang': 'en',
	'target_lang': 'zh',
	'trainpref': './data/tokens/train',
	'validpref': './data/tokens/valid',
	'testpref': './data/tokens/test',
	'destdir': './data/bin',
	'joined_dictionary': True,
	'workers': 2
}
args = argparse.Namespace(**args)  # 转换命名空间
os.makedirs(args.destdir, exist_ok=True)  # 创建输出目录
preprocess.main(args)  # 调用预处理函数

则:

AttributeError: 'Namespace' object has no attribute 'dataset_impl'

那么我们可以这么干:

import argparse
import os

from fairseq import options
from fairseq_cli import preprocess

parser = options.get_preprocessing_parser()
args = vars(parser.parse_args()).update({  # 设置参数
	'source_lang': 'en',
	'target_lang': 'zh',
	'trainpref': './data/tokens/train',
	'validpref': './data/tokens/valid',
	'testpref': './data/tokens/test',
	'destdir': './data/bin',
	'joined_dictionary': True,
	'workers': 2
})
args = argparse.Namespace(**args)  # 转换命名空间
os.makedirs(args.destdir, exist_ok=True)  # 创建输出目录

preprocess.main(args)  # 调用预处理函数

拿到默认配置, 然后用设置我们的参数.

2. fairseq_cli.preprocess 预处理时 ‘Tokenization’ 了吗?

为了验证这个问题, 我们直接将未经 sentencepiece 处理过的数据fairseq_cli.preprocess 进行二进制化:

args.update({  # 设置参数
	...
	'trainpref': f'{self._path_split}/train',
	'validpref': f'{self._path_split}/valid',
	'testpref': f'{self._path_split}/test',
	...
})

会得到:

[en] Dictionary: 846872 types
[en] ./data/split/train.en: 389051 sents, 7972723 tokens, 0.0% replaced (by <unk>)
[en] Dictionary: 846872 types
[en] ./data/split/valid.en: 3929 sents, 80476 tokens, 0.588% replaced (by <unk>)
[en] Dictionary: 846872 types
[en] ./data/split/test.en: 1000 sents, 20242 tokens, 0.558% replaced (by <unk>)
[zh] Dictionary: 846872 types
[zh] ./data/split/train.zh: 389051 sents, 2011056 tokens, 0.0% replaced (by <unk>)
[zh] Dictionary: 846872 types
[zh] ./data/split/valid.zh: 3929 sents, 20342 tokens, 36.4% replaced (by <unk>)
[zh] Dictionary: 846872 types
[zh] ./data/split/test.zh: 1000 sents, 5903 tokens, 35.7% replaced (by <unk>)

天呐, 得到了 846904 个词单元? 点击 dict.en.txt 文件, 提示文件太大无法打开. 更令人震惊的是, {valid/test}.zh 文件中 35% 以上的 tokens 被 replaced (by <unk>).

怎么办? 问了 Kimi: fairseq_cli.preprocess 分词时如何设置词单元数量?

答:
--nwordssrc 10000 \
--nwordstgt 10000 \

好吧, 试试:

args.update({  # 设置参数
	...
	'nwordssrc': 10000,
	'nwordstgt': 10000,
	...
})

输出:

[en] Dictionary: 10000 types
[en] ./data/split/train.en: 389051 sents, 7972723 tokens, 4.82% replaced (by <unk>)
[en] Dictionary: 10000 types
[en] ./data/split/valid.en: 3929 sents, 80476 tokens, 5.01% replaced (by <unk>)
[en] Dictionary: 10000 types
[en] ./data/split/test.en: 1000 sents, 20242 tokens, 4.54% replaced (by <unk>)
[zh] Dictionary: 10000 types
[zh] ./data/split/train.zh: 389051 sents, 2011056 tokens, 39.4% replaced (by <unk>)
[zh] Dictionary: 10000 types
[zh] ./data/split/valid.zh: 3929 sents, 20342 tokens, 39.5% replaced (by <unk>)
[zh] Dictionary: 10000 types
[zh] ./data/split/test.zh: 1000 sents, 5903 tokens, 38.4% replaced (by <unk>)

(⊙o⊙)?更多的 tokens 被 replaced (by <unk>) 了. 打开 dict.en.txt 文件, 发现里面是比较符合逻辑的整个单词或汉语词语. 那应该是分词算法的问题, 样本量太小, 以至于学习到的 tokens 太少, 导致 {valid/test}.zh 文件中大量 tokens 被 replaced (by <unk>).[英语还好]. 于是想起来 sentencepiece 的分词适合东亚语言, 也适合英语, 它分词的结果是字母片段, 汉字短语.

问 Kimi: 能设置设置分词算法吗?

 --tokenizer moses  # 指定分词算法, 例如 moses, nltk, space 等

好吧, 试试:

args.update({  # 设置参数
	...
	'nwordssrc': 10000,
	'nwordstgt': 10000,
	'tokenizer': 'moses'
	...
})

输出:

[en] Dictionary: 10000 types
[en] ./data/split/train.en: 389051 sents, 7972723 tokens, 4.82% replaced (by <unk>)
[en] Dictionary: 10000 types
[en] ./data/split/valid.en: 3929 sents, 80476 tokens, 5.01% replaced (by <unk>)
[en] Dictionary: 10000 types
[en] ./data/split/test.en: 1000 sents, 20242 tokens, 4.54% replaced (by <unk>)
[zh] Dictionary: 10000 types
[zh] ./data/split/train.zh: 389051 sents, 2011056 tokens, 39.4% replaced (by <unk>)
[zh] Dictionary: 10000 types
[zh] ./data/split/valid.zh: 3929 sents, 20342 tokens, 39.5% replaced (by <unk>)
[zh] Dictionary: 10000 types
[zh] ./data/split/test.zh: 1000 sents, 5903 tokens, 38.4% replaced (by <unk>)

很遗憾, 情况没有好转, 试试其他算法 nltk, space 也都没好转. 当使用 'tokenizer': 'sentencepiece' 时, 告诉我只支持 moses, nltk, space 三种分词工具. 汉语的话, 还是老老实实用 sentencepiece.

3. 二进制文件的读取过程

首先看 .idx 文件的读取:

class Index:
	def __init__(self, path):
		with open(path, "rb") as stream:
			magic_test = stream.read(9)
			...
			version = struct.unpack("<Q", stream.read(8))
			...
			(dtype_code,) = struct.unpack("<B", stream.read(1))
			...
			self._len = struct.unpack("<Q", stream.read(8))[0]
			offset = stream.tell()
		...
		self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
		self._bin_buffer = memoryview(self._bin_buffer_mmap)
		self._sizes = np.frombuffer(
			self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
		)
		self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len, offset=offset + self._sizes.nbytes)

	@lru_cache(maxsize=8)
	def __getitem__(self, i):
		return self._pointers[i], self._sizes[i]

首先以 'rb' 模式打开了 '.idx' 文件, 依次读取了 9, 8, 1, 8 个字节的内容, 分别是 magic_test 用于检查存储方式是否符合本类的读取方式; version 记录存储文件版本, dtype_code 是文件中数据的类型; self._len 是文件的数据长度. offset = stream.tell() 是数据开始的位置.

好! 在读取了文件的元信息之后, 后面的四行代码我们求助通义千问:

这段代码展示了如何使用 NumPymemmap 功能结合 memoryviewNumPy 的数组操作, 高效地从磁盘上的二进制文件中读取和处理数据. 下面是对每一步的详细解释:

=== 创建内存映射 ===

self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")

创建了一个内存映射文件 self._bin_buffer_mmap, 将磁盘上的二进制文件 path 映射到内存中, 允许像访问内存中的数组那样访问文件内容. mode="r" 表示只读模式, order="C" 指定数据的布局为 C 风格的行优先顺序.

=== 创建 memoryview ===

self._bin_buffer = memoryview(self._bin_buffer_mmap)

这一行创建了一个 memoryview 对象 self._bin_buffer, 它提供了对 self._bin_buffer_mmap低级内存视图, 允许对内存映射文件进行高效、低开销的切片操作, 且不会复制数据, 这对于处理大型数据集特别有用.

=== buffer 中读取数据 ===
接下来的两行代码使用 np.frombuffer 方法self._bin_buffer 中读取数据:

self._sizes = np.frombuffer(
	self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
)
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len, offset=offset + self._sizes.nbytes)

self._sizes 读取了数据集中每个样本的大小信息, 存储为 np.int32 类型的数组. count=self._len 表示读取 self._len 个这样的元素, offset=offset 指定从 self._bin_buffer 的哪个位置开始读取.

self._pointers 读取了指向每个样本起始位置的指针, 存储为 np.int64 类型的数组. 读取的起点是从 self._sizes 数组的末尾开始, 即 offset + self._sizes.nbytes.

根据上面通义千问的解释, 我们可以认为, self._pointers, self._sizes 分别是存储了数据位置和大小的两个 numpy 数组, 只不过它们是由内存映射实现的. 然后就可以 return self._pointers[i], self._sizes[i] 作为 __getitem__ 的结果了.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1857529.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

虚拟机拖拽文档造成缓存过大

查看文件夹大小&#xff1a;du -h --max-depth1 缓存位置&#xff1a;~/.cache/vmware/drag_and_drop 删除&#xff1a;rm -fr ~/.cache/vmware/drag_and_drop 释放了3GB

Javascript中的this关键字指向

this关键字介绍 不同情况下的this 1.对象调用方法中的this 2.在全局使用this(单独使用) 3.函数中的this 4.函数严格模式下 5.事件中的this 6.构造函数中的this 7.箭头函数没有this call()、apply()、bind() 的用法 this关键字介绍 面向对象语言中 this 表示当前对象…

JavaScript中的内置对象

// 用法2&#xff1a;参数&#xff1a;指定的时间的字符串 创建一个指定的时间 // new Date(‘2018-12-12 12:00:00’) var date1 new Date(‘2018-12-12 12:00:00’); console.log(date1); // 用法3&#xff1a; 参数可以是年月日时分秒 月份从0开始 // var date2 new …

基于Java实验室课程管理系统设计和实现(源码+LW+调试文档+讲解等)

&#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN作者、博客专家、全栈领域优质创作者&#xff0c;博客之星、平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌&#x1f497; &#x1f31f;文末获取源码数据库&#x1f31f; 感兴趣的可以先收藏起来&#xff0c;…

Linux 阻塞和非阻塞 IO 实验学习

Linux 阻塞和非阻塞 IO 实验学习 IO 指的是 Input/Output&#xff0c;也就是输入/输出&#xff0c;是应用程序对驱动设备的输入/输出操作。当应用程序对设备驱动进行操作的时候&#xff0c;如果不能获取到设备资源&#xff0c;那么阻塞式 IO 就会将应用程序对应的线程挂起&…

【初阶数据结构】二叉树(附题)

目录 1.树概念及结构 1.1树的概念 1.2 树的相关概念&#xff08;树结构的相关概念命名参考自然树和人的血缘关系&#xff09; 1.3 树的表示 1.4 树在实际中的运用&#xff08;表示文件系统的目录树结构&#xff0c;初次之外网盘中使用到&#xff09; 2.二叉树概念及结构 …

Google浏览器快捷方式固定到任务栏启动被其他网页劫持

场景复现 1、Google浏览器设置启动时继续浏览上次打开的网页 2、先浏览CSDN网站&#xff0c;然后关闭Google浏览器 3、再次打开Google浏览器时&#xff0c;除了显示我们上次浏览的CSDN网页外&#xff0c;还默认打开了百度网页 解决办法 1、在Google浏览器中新建标签页&am…

Redis缓存雪崩(主从复制、哨兵模式(脑裂)、分片集群)

缓存雪崩&#xff1a; 在同一时段大量的缓存key同时失效或者Redis服务宕机&#xff0c;导致大量请求到达数据库&#xff0c;带来巨大压力。 方法一&#xff1a; 给不同key的TTL添加随机值&#xff0c;以此避免同一时间大量key失效。&#xff08;用于解决同一时间大量key过期&…

linux笔记10--编辑器之神VIM

文章目录 1. 简单介绍① 为什么叫vim② linux常见的编辑器③ 注意事项④ 其它 2. 操作模式的划分① 两种 -- 国际上普通模式(命令操作模式)插入模式 ② 三种 -- 国内普通模式如何进入与退出界面 插入模式如何进入与退出界面 命令模式如何进入与退出界面常见的命令模式 ③ 区别④…

RFID技术在人工晶体清洗台上的应用案例分析

RFID技术在人工晶体清洗台上的应用案例分析 应用背景 在医疗领域中人工晶体清洗台发挥着极为重要的作用&#xff0c;随着市场需求的持续增长、技术的不断创新、定制化趋势的加强以及环保要求的提高&#xff0c;人工晶体清洗台不免暴露出一下应用痛点需要解决。 痛点&#xff…

SAP ABAP 常用的便利小手段:大写+自动对齐

目录 一&#xff0c;字体变大写 二&#xff0c;自动对齐行 一&#xff0c;字体变大写 找到上面的ユーティリティ⇒設定、 在【ユーザ固有の設定】里选择&#xff0c;【打文字】&#xff0c;同时勾除【名称を変更しない】 二&#xff0c;自动对齐行 在页面右下角找到黄色的【…

Java如何快速实现发送模版消息?

Java如何快速实现发送模版消息&#xff1f; 这次分析模版消息&#xff1a; 公众号&#xff08;小程序同理&#xff09;登录微信公众平台&#xff0c;创建模版&#xff0c;拿到模版id, 拿到appid,appSecret&#xff0c;根据开发文档找到对应功能的api进行开发即可&#xff0c;记…

EcmaScript6全新语法特性-----EcmaScript6(1)

age : 20,language : "Eng"}// 对象也可以用结构表达式来获取对应的值const { name,age,language} person;// 这样可以将我们获取的值name变成abc这个变量// const { name:abc,age,language} person;// 字符串拓展let str "Hello,vue";// 判断是否以xxx…

ES6 逐点突破系列 -- 函数的扩展

} f() // 1 var x 1; function foo(x, y function() { x 2; }) { var x 3; y(); console.log(x); } foo() // 3 x // 1 上面代码中&#xff0c;函数foo的参数形成一个单独作用域。这个作用域里面&#xff0c;首先声明了变量x&#xff0c;然后声明了变量y&#xf…

Axios发送ajax请求

}, // 请求体参数 data: { username: ‘admin’, password: ‘admin’ } }).then(response>{ // 响应状态码 console.log(response.status); // 响应状态字符串 console.log(response.statusText); // 响应头信息 console.log(response.headers); // 响应体 c…

docker 部署的 wordpress 接入阿里云短信服务 详细实操介绍

一、阿里云短信服务配置&#xff1a; 1、登录 阿里云短信服务 完成指引短信相关配置 2、创建RAM用户 并完成授权 出于安全及规范考虑 需通过RAM 用户来完成OponApl 接口调用&#xff0c;创建成功需完成短信接口&#xff08;AliyunDysmsFullAccess、AliyunDysmsReadOnlyAccess…

量检具管理有一套

量检具是用于测量和检验产品尺寸、形状和质量的工具。有一位年轻的工程师小张&#xff0c;他负责管理工厂的量检具&#xff0c;确保它们能够准确地测量产品尺寸和质量。有一天&#xff0c;小张发现量检具出现了一些问题。他注意到一些量具的读数不准确&#xff0c;导致生产出来…

加载资源文件失败

背景 自己以前装了一个海康的深度学习算法平台&#xff0c;试用期是一个月&#xff0c;过了一个月之后&#xff0c;因为没有有效注册码或者加密狗的支持了导致无法使用&#xff0c;于是打算卸载掉&#xff0c;在卸载一个软件的时候&#xff0c;无论是使用控制面板还是软件自带的…

SpringIOC核心源码

一、Spring IOC容器源码解析 1、Spring IOC容器的核心类 &#xff08;1&#xff09;BeanFactory与ApplicationContext &#xff08;2&#xff09;默认容器DefaultListableBeanFactory a. DefaultListableBeanFactory实现的接口 b.DefaultListableBeanFactory继承的类&#…

TCP/IP 在 Linux 内核中的实现

之前出了一个python的socket编程的文章&#xff0c;里面讲的是怎么进行socket编程。最近想到TCP/IP协议的原理&#xff0c;然后查阅资料后说是在操作系统级别实现的&#xff0c;python的socket模块只是一个接口。 本文就来谈一下Linux源码里实现TCP/IP协议簇的源代码在哪里&am…