Extractive QA参考:
https://juejin.cn/post/7180925054559977533
https://huggingface.co/learn/nlp-course/en/chapter7/7
目录
滑动窗口例子(提取开始结束点任务)
滑动窗口代码实现
tokenize() 默认添加问题
每个滑窗添加标题和摘要
训练label跨滑窗情况token匹配
Token匹配忽略固定标题摘要
输出结果拼接
根据样本论文标号拼接预测结果
滑窗重复token位置处理
posing questions about a document and identifying the answers as spans of text in the document itself.
常见范式:做句子的二分类任务(该句是否属于摘要),将预测为“属于”的句子拼起来,组成摘要。
像BERT这样的纯编码器模型往往擅长提取诸如“谁发明了Transformer架构?”之类的事实问题的答案,但是当给出诸如“为什么天空是蓝色的?”之类的开放式问题时,则表现不佳。在这些更具挑战性的情况下,通常使用 T5 和 BART 等编码器-解码器模型以类似文本摘要的方式综合信息.
Bert在SQuAD 数据集训练的模型例子:
https://huggingface.co/huggingface-course/bert-finetuned-squad
squad 评估的输出格式包含两个字段, 每个都是列表:
Answer: {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]} |
评估时, 某些问题有几个可能的答案,并且此脚本会将预测答案与所有可接受的答案进行比较,并取得最高分数。
使用深度学习方法做抽取式摘要的经典论文:
Friendly Topic Assistant for Transformer Based Abstractive Summarization
Friendly Topic Assistant for Transformer Based Abstractive Summarization - ACL Anthology
SummaRuNNer: A Recurrent Neural Network Based Sequence Model for Extractive Summarization of Documents
Extractive Summarization using Deep Learning
Neural Extractive Summarization with Side Information
Ranking Sentences for Extractive Summarization with Reinforcement Learning
Fine-tune BERT for Extractive Summarization
Extractive Summarization of Long Documents by Combining Global and Local Context
Extractive Summarization as Text Matching
重要模型: Fine-tune BERT for Extractive Summarization(BertSum)
https://arxiv.org/abs/1903.10318
GitHub - nlpyang/BertSum: Code for paper Fine-tune BERT for Extractive Summarization
滑动窗口例子(提取开始结束点任务)
模型内部的注意力滑窗
全局注意力可以通过参数设置
用户可以通过设置张量来定义哪些代币“本地”参与,哪些代币“全局”参与 global_attention_mask 适当地在运行时。所有 Longformer 模型都采用以下逻辑 global_attention_mask:
0: the token attends “locally”,
1: the token attends “globally”.
Longformer 自注意力结合了局部(滑动窗口)和全局注意力
文档没有提到更多滑窗相关内容, 可能这种模型的创新就在于attention window, 而不是整个段落滑窗. 因此还是需要另外编写.
输入时候的滑动窗口例子
(参考上文提到的huggingface文档)
从数据集的一个样本创建多个训练特征来处理长上下文,并在它们之间有一个滑动窗口。
为了使用当前示例了解其工作原理,我们可以将长度限制为 100 个,并使用 50 个标记的滑动窗口。提醒一下,我们使用:
max_length 设置最大长度(此处为 100)
truncation="only_second" 当问题及其上下文过长时,截断上下文(位于第二个位置)
stride 设置两个连续块之间的重叠令牌数(此处为 50)
return_overflowing_tokens=True 为了让tokenizer知道我们想要溢出的token
inputs = tokenizer( question, context, max_length=100, truncation="only_second", stride=50, return_overflowing_tokens=True, ) for ids in inputs["input_ids"]: print(tokenizer.decode(ids)) |
'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basi [SEP]' '[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin [SEP]' '[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 [SEP]' '[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP]. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]' |
针对只有一半答案的情况, 数据集为我们提供了上下文中答案的开始字符,通过添加答案的长度,我们可以在上下文中找到结束字符。为了将它们映射到令牌索引,我们需要使用我们研究的偏移映射 第 6 章.我们可以让我们的标记器通过传递来返回这些 return_offsets_mapping=True:
inputs = tokenizer( question, context, max_length=100, truncation="only_second", stride=50, return_overflowing_tokens=True, return_offsets_mapping=True, ) inputs.keys() |
这样inputs里面就会包含偏移值
dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping']) |
得到列表inputs["overflow_to_sample_mapping"]
它代表滑窗来自第几个样本
如[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3]代表样本0~2每个样本分别有4个滑窗,样本3有7个滑窗
answers = raw_datasets["train"][2:6]["answers"] start_positions = [] end_positions = [] for i, offset in enumerate(inputs["offset_mapping"]): sample_idx = inputs["overflow_to_sample_mapping"][i] answer = answers[sample_idx] start_char = answer["answer_start"][0] end_char = answer["answer_start"][0] + len(answer["text"][0]) sequence_ids = inputs.sequence_ids(i) # Find the start and end of the context idx = 0 while sequence_ids[idx] != 1: idx += 1 context_start = idx while sequence_ids[idx] == 1: idx += 1 context_end = idx - 1 # If the answer is not fully inside the context, label is (0, 0) if offset[context_start][0] > start_char or offset[context_end][1] < end_char: start_positions.append(0) end_positions.append(0) else: # Otherwise it's the start and end token positions idx = context_start while idx <= context_end and offset[idx][0] <= start_char: idx += 1 start_positions.append(idx - 1) idx = context_end while idx >= context_start and offset[idx][1] >= end_char: idx -= 1 end_positions.append(idx + 1) start_positions, end_positions |
([83, 51, 19, 0, 0, 64, 27, 0, 34, 0, 0, 0, 67, 34, 0, 0, 0, 0, 0], [85, 53, 21, 0, 0, 70, 33, 0, 40, 0, 0, 0, 68, 35, 0, 0, 0, 0, 0]) |
第一个特征标签 (83, 85)-> 从 83 到 85(含)的解码token跨度(decoded span of tokens)
idx = 0 sample_idx = inputs["overflow_to_sample_mapping"][idx] answer = answers[sample_idx]["text"][0] start = start_positions[idx] end = end_positions[idx] labeled_answer = tokenizer.decode(inputs["input_ids"][idx][start : end + 1]) print(f"Theoretical answer: {answer}, labels give: {labeled_answer}") |
'Theoretical answer: the Main Building, labels give: the Main Building' |
其中的(0, 0)意味着答案不在该特征的上下文块中, 其中特征是其中一个input_ids组
滑动窗口代码实现
tokenize_and_align_labels中处理sentences时由于longformer模型及其tokenizer的最大输入长度为4096, tokenizer目前的设置将会对句子进行截断.
修改代码, 在使用tokenizer时, 用滑动窗口功能处理sentences, 及其对应的labels.
使用参数 max_length(默认值为 4096)和 stride(默认值为 256)的滑动窗口方法。这样,该函数就可以处理长度超过最大长度的输入,方法是将它们分成具有重叠的块。
训练循环:主训练循环保持不变,但它现在处理滑动窗口方法生成的标记化输入和标签。
在获得模型的输出后对输出进行拼接
注意不能在输入模型之前进行不同滑窗的拼接, 因为longformer模型的最大输入长度为4096.
修改前token_labels和inputid一样的形状 shape: torch.Size([1, 4096])
检查模型输入-还原成token
tokenizer.convert_ids_to_tokens(input_ids[0]) |
得到训练token全部是开头这一段
所以之前token划分就已经切割, 因此没有在输入模型时候超出限制
接着检查使用dataloader之前的input长度
inputs, token_labels = tokenize_and_align_labels(sentences, labels_texts, tokenizer) |
inputs 中的id shape也是: torch.Size([4096])
最后对比tokenize_and_align_labels中第一次从sentences转成token的维度
print("sentences lenth:",len(sentences[0])) print("tokenized_inputs shape:",tokenized_inputs["input_ids"][0].shape) |
sentences lenth: 42723 tokenized_inputs shape: torch.Size([4096]) |
说明需要从此处开始加滑窗
修改tokenizer增加滑窗
sentence_tokens = tokenizer( title, sentence, max_length=max_length, # truncation=True, truncation="only_second", stride=stride, return_overflowing_tokens=True, return_offsets_mapping=True, padding='max_length' ) |
NotImplementedError: return_offset_mapping is not available when using Python tokenizers.To use this feature, change your tokenizer to one deriving from transformers.PreTrainedTokenizerFast.More information on available tokenizers at https://github.com/huggingface/transformers/pull/2674 |
得到两个特殊的键, 之后可以看看怎么使用
sentence_tokens['offset_mapping'][0] |
[(0, 0), (0, 3), (3, 4), (4, 9), (9, 10), (11, 13), (14, 24), (25, 28), (28, 30), (30, 32), (33, 36), (37, 46), (47, 52), (53, 62), (62, 67), (68, 76), (77, 86), (87, 92), (93, 107), (108, 109), (109, 110), (110, 112), (112, 113), (114, 116), (117, 118), (119, 122), (122, 128), (129, 131), (132, 138), (138, 143), (144, 151), (152, 157), (158, 168), (169, 175), (176, 185), (186, 196), (197, 200), (201, 210), (210, 211), (212, 216), (217, 220), (221, 232), (233, 237), (238, 242), (243, 247), (248, 254), (255, 262), (263, 265), (266, 269), (270, 281), (282, 288), (288, 289), (290, 293), (294, 303), (304, 311), (312, 315), (316, 321), (322, 326), (327, 335), ...] |
sentence_tokens['overflow_to_sample_mapping'] |
[0, 0, 0] |
sentence lenth: 42723
sentence_tokens len: 4096
第一段还是和原来一样的开头, 但是有三段了
print("sentence_tokens len:",len(sentence_tokens["input_ids"])) |
sentence_tokens len: 3 |
可以看到第二段不再是开头, 而是接下去的内容了
sentence_tokens: ['<s>', 'Ġ(', '64', ',', '219', ')', 'ĠAverage', 'Ġreturn', 'Ġ(', 'by', 'Ġassets', ',', 'Ġwhole', 'Ġperiod', ')', 'Ġ37', '.', '16', '%', 'ĠAverage', 'Ġreturn', 'Ġ(', 'by', 'Ġcustomers', ',', 'Ġwhole', 'Ġperiod', ')', …', 'Ġwith', 'Ġover', 'Ġ5', ',', '000', 'Ġinteractions', 'Ġhappening', 'Ġev', '-', 'Ġ', 'ery', 'Ġmonth', '.', 'ĠThe', 'Ġprevious', 'Ġperiod', ',', 'Ġbetween', 'ĠJanuary', 'Ġ2018', 'Ġand', 'ĠDecember', 'Ġ2019', 'Ġonly', 'Ġreceives', 'Ġbetwe |
现在输入token维度是没有问题了,
每个元素append, 再整个torch.stack的方法得到的是列表中新的元素, 不是拼接成同一个字符串, 因此不会超过长度限制
现在改后的预测处predicted_token_class_ids更加复杂, 需要再处理一层
Traceback (most recent call last): File "d:/Projects/longformer/tests/tkn_clsfy_slide.py", line 154, in <module> prediction_strings = [get_prediction_string(prediction, predicted_inputs_id) for prediction, predicted_inputs_id in zip(predicted_token_class_ids, inputs["input_ids"])] TypeError: zip argument #1 must support iteration |
取出来了, 虽然最后一段预测的token包含padding.
tokenize() 默认添加<pad>问题
检查label, 发现tokenize的时候被加了padding.
一般来说tokenize()是不添加<pad>的. 可能是tokenizer默认参数设置的问题. 但是这个函数又没有padding参数可以传入, 直接改tokenizer又可能对后续产生影响, 目前 删除这些padding
tokenized_label = [token for token in tokenized_label if token != '<pad>']#删除pad |
现在可以正确匹配了
matched tokenized_label: ['ĠAs', 'Ġfuture', 'Ġwork', ',', 'Ġwe', 'Ġshall', 'Ġexplore', 'Ġthe', 'Ġuse', 'Ġof', 'Ġthis', 'Ġdataset', 'Ġfor', 'Ġmultiple', 'Ġtasks', ',', 'Ġnot', 'Ġonly', 'Ġincluding', 'Ġfinancial', 'Ġasset', 'Ġrecommendation', ',', 'Ġbut', 'Ġalso', 'Ġportfolio', 'Ġconstruction', 'Ġand', 'Ġoptim', 'isation', 'Ġor', 'Ġinvestor', 'Ġand', 'Ġasset', 'Ġmodelling', '.'] |
例子中每个输入都包含问题和上下文的某些部分,
可以改成每个输入都包含标题和摘要, 但这需要数据集的预处理
预测匹配了目标句子, 但是输出的预测依然有许多padding
打印出来可以看到超出的pad部分label是0, 所以label是没有问题的
matched clsfy_label: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, …, 0, 0, 0] |
查看评估中的tokenize, 去掉所有tokenize处的pad
#评估时 ref_tokens = [token for token in ref_tokens if token != '<pad>']#删除pad pred_tokens = [token for token in pred_tokens if token != '<pad>']#删除pad |
#预测token拼接时 for token, pred_class in zip(tokenized_sub_sentence, predicted_tokens_classes): if token.startswith("Ġ"): if (len(current_word) != 0) & current_word_pred: dataset_description_words.append(current_word) current_word = token[1:] current_word_pred = (pred_class == 'Dataset description') elif token != '<pad>':#删除pad current_word += token current_word_pred = current_word_pred or (pred_class == 'Dataset description') if (len(current_word) != 0)& (current_word != '<pad>') & current_word_pred: dataset_description_words.append(current_word) |
到这里所有的<pad>都处理完了
每个滑窗添加标题和摘要
例子
将问题和上下文一起传递给我们的标记器,它将正确插入特殊标记以形成一个句子
[CLS] question [SEP] context [SEP] |
context = raw_datasets["train"][0]["context"] question = raw_datasets["train"][0]["question"] inputs = tokenizer(question, context) tokenizer.decode(inputs["input_ids"]) |
'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, ' 'the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin ' … 'Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues ' 'and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]' |
得到[CLS] question [SEP] context [SEP]形式
在longformer中支持的分隔形式是</s >
例子2添加更多参数
inputs = tokenizer( question, context, max_length=100, truncation="only_second", stride=50, return_overflowing_tokens=True, ) for ids in inputs["input_ids"]: print(tokenizer.decode(ids)) |
'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basi [SEP]' … '[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP]. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]' |
在输入参数中加入title, 并且截断参数设置truncation="only_second"只针对第二个滑窗
for i, sentence in enumerate(sentences): labels_text = labels_texts[i] title = titles[i] sentence_tokens = tokenizer( title, sentence, max_length=max_length, # truncation=True, truncation="only_second", stride=stride, return_overflowing_tokens=True, return_offsets_mapping=True, padding='max_length' ) print("sentence_tokens keys:",sentence_tokens.keys()) input_ids = sentence_tokens['input_ids'] attention_mask = sentence_tokens['attention_mask'] offset_mapping = sentence_tokens['offset_mapping'] print("input_token:",tokenizer.convert_ids_to_tokens(input_ids[0])) |
sentence_tokens keys: dict_keys(['input_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping']) input_token: ['<s>', 'ĠFAR', '-', 'Trans', ':', 'ĠAn', 'ĠInvestment', 'ĠDat', 'as', 'et', 'Ġfor', 'ĠFinancial', 'ĠAsset', 'ĠRecommend', 'ation', …'ĠWe', 'Ġalso', 'Ġprovide', 'Ġa', 'Ġbench', '-', 'mark', 'ing', 'Ġcomparison', 'Ġbetween', 'Ġeleven', 'ĠFAR', 'Ġalgorithms', 'Ġover', 'Ġthe', 'Ġdata', 'Ġfor', 'Ġuse', 'Ġas', 'Ġfuture', 'Ġbas', 'elines', '.', 'ĠThe', 'Ġdataset', 'Ġcan', 'Ġbe', 'Ġdownloaded', 'Ġfrom', 'Ġhttps', '://', 'doi', '.', 'org', '/', '10', '.', '55', '25', '/', 'gl', 'a', '.', 'research', 'data', '.', '16', '58', '.', '</s>', '</s>', 'ĠFAR', '-', 'Trans', ':', 'ĠAn', 'ĠInvestment', 'ĠDat', 'as', 'et', 'Ġfor', 'ĠFinancial', 'ĠAsset', 'ĠRecommend', 'ation', 'ĠJavier', 'ĠSan', 'z', '-', 'Cruz', 'ado', '1', 'Ġ,', 'ĠNikola', 'os', 'ĠD', 'rou', 'kas', '2', 'Ġ,', 'ĠRichard', 'ĠMcC', 'read', 'ie', '1', 'Ġ1', 'University', 'Ġof', 'ĠGlasgow', 'Ġ2', 'National', 'ĠBank', 'Ġof', 'ĠGreece', 'Ġj', 'avier', '.', 'san', 'z', '-', 'cru', 'z', 'ad', 'op', 'u', 'ig', '@', 'glas', 'gow', '.', 'ac', '.', 'uk', ',', 'Ġd', 'rou', 'kas', '.', 'nik', 'ola', 'os', '@', 'n', 'bg', '.', 'gr', … |
得到结果:
Epoch 1/3, Loss: 0.35347285121679306 Epoch 2/3, Loss: 0.04643934965133667 Epoch 3/3, Loss: 0.026251467876136303 Model and tokenizer saved to 'trained_model' Prediction Strings: ['problem by', '', 'In'] d:\Users\laugo\anaconda3\envs\longformer\lib\site-packages\sklearn\metrics\_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior. _warn_prf(average, modifier, msg_start, len(result)) Precision: 0.0000, Recall: 0.0000, F1 Score: 0.0000 |
训练label跨滑窗情况token匹配
目前打token label时候是完整匹配
for k in range(len(tokenized_sentence) - label_length + 1): |
tokenized_sentence[k:k + label_length] == tokenized_label |
从遍历开始就保证了label预留长度
可以考虑改成按照tokenized_sentence的长度将tokenized_label截断后比较
分别对后半部分缺少和前半部分缺少的情况进行处理
# 处理label跨滑窗 for k in range(len(tokenized_sentence)-1): # print("tokenized_sentence:",tokenized_sentence[k:k + label_length]) end_position=min(len(tokenized_sentence)-1, k + label_length) if tokenized_sentence[k:end_position] == tokenized_label[0:end_position]:#后半部分没有的情况 print("matched tokenized_label:",tokenized_label,"\n",tokenized_sentence[k:end_position]) # end_position=min(len(tokenized_sentence)-1, k + label_length) # end_position=k + label_length token_label[k:end_position] = [1] * (end_position-k) # print("matched clsfy_label:",token_label[k:len(tokenized_sentence)-1]) |
for start in range(label_length-1):#前半部分没有的情况 # print("tokenized_sentence:",tokenized_sentence[k:k + label_length]) if tokenized_sentence[0:label_length-start] == tokenized_label[start:]:#后半部分没有的情况 print("matched tokenized_label:",tokenized_sentence[0:label_length-start],"\n",tokenized_label[start:]) # end_position=min(len(tokenized_sentence)-1, k + label_length) # end_position=k + label_length token_label[0:label_length-start] = [1] * (label_length-start) |
得到输出
Prediction Strings: ['FAR-Trans, the first public dataset for FAR, containing pricing information and retail investor transactions acquired from a large European financial institution. aim solve problem by introducing FAR-Trans, the first public dataset for FAR, containing pricing information and retail investor transactions acquired from a large European financial institution.', 'introducing FAR-Trans, the first public dataset for FAR, containing pricing information and retail investor transactions acquired from a large European financial institution.', 'FAR-Trans, the first public dataset for FAR, containing pricing information and retail investor transactions acquired from a large European financial institution.'] Precision: 1.0000, Recall: 0.9091, F1 Score: 0.9524 |
Token匹配忽略固定标题摘要
加入滑窗后, 每一个滑窗的第一段Q部分是题目摘要, 第一个滑窗的第二段A部分也是题目摘要, 出现提取出token的重复, 因此token匹配的起点改为第二段A开头.
(提取式QA的Q在每一个滑窗中, 但是答案一般不会出现在Q中, 因此不造成影响.
但是题目摘要中如果有描述, 将每次被Token匹配.
目前Token匹配的时候默认是全部匹配, 改成所有匹配从第二个</s>段开始.)
滑窗后的输入样本token(tokenized_sentence)的格式如下, 其中用special tokens 即<s>标记区分不同文本, 第一对<s></s>标记包裹的是标题摘要, 第二对</s></s>标记包裹的是全部文本, 注意第二个的开头是'</s>'
A Longformer sequence has the following format: single sequence: <s> X </s> pair of sequences: <s> A </s></s> B </s> |
例子
['<s>', 'ĠFAR', '-', 'Trans', ':', 'ĠAn', 'ĠInvestment', 'ĠDat', 'as', 'et', 'Ġfor', 'ĠFinancial', 'ĠAsset', 'ĠRecommend', 'ation', …'ĠWe', 'Ġalso', 'Ġprovide', 'Ġa', 'Ġbench', '-', 'mark', 'ing', 'Ġcomparison', 'Ġbetween', 'Ġeleven', 'ĠFAR', 'Ġalgorithms', 'Ġover', 'Ġthe', 'Ġdata', 'Ġfor', 'Ġuse', 'Ġas', 'Ġfuture', 'Ġbas', 'elines', '.', 'ĠThe', 'Ġdataset', 'Ġcan', 'Ġbe', 'Ġdownloaded', 'Ġfrom', 'Ġhttps', '://', 'doi', '.', 'org', '/', '10', '.', '55', '25', '/', 'gl', 'a', '.', 'research', 'data', '.', '16', '58', '.', '</s>', '</s>', 'ĠFAR', '-', 'Trans', ':', 'ĠAn', 'ĠInvestment', 'ĠDat', 'as', 'et', 'Ġfor', 'ĠFinancial', 'ĠAsset', 'ĠRecommend', 'ation', 'ĠJavier', 'ĠSan', 'z', '-', 'Cruz', 'ado', '1', 'Ġ,', 'ĠNikola', 'os', 'ĠD', 'rou', 'kas', '2', 'Ġ,', 'ĠRichard', 'ĠMcC', 'read', 'ie', '1', 'Ġ1', 'University', 'Ġof', 'ĠGlasgow', 'Ġ2', 'National', 'ĠBank', 'Ġof', 'ĠGreece', 'Ġj', 'avier', '.', 'san', 'z', '-', 'cru', 'z', 'ad', 'op', 'u', 'ig', '@', 'glas', 'gow', '.', 'ac', '.', 'uk', ',', 'Ġd', 'rou', 'kas', '.', 'nik', 'ola', 'os', '@', 'n', 'bg', '.', 'gr', …] |
修改对全文的token与标记文本的token进行匹配生成二分类 label的代码, 使匹配从第二个'<s>'开始:
#在每个滑窗中寻找正文开始处token位置 def find_main_body(tokenized_sentence): start_index = 0 s_token_count = 0 for idx, token in enumerate(tokenized_sentence): if token == '</s>': s_token_count += 1 # print("</s>:",s_token_count) if s_token_count == 2: start_index = idx + 1 return start_index |
for k in range(main_body_start,len(tokenized_sentence)-1):#从正文部分开始匹配 |
if tokenized_sentence[main_body_start:label_length-label_start] == tokenized_label[label_start:]:#后半部分没有的情况 token_label[main_body_start:label_length-label_start] = [1] * (label_length-label_start) |
现在匹配的内容与label一一对应, 没有重复标题中的内容
['ĠIn', 'Ġthis', 'Ġpaper', ',', 'Ġwe', 'Ġaim', 'Ġto', 'Ġsolve', 'Ġthis', 'Ġproblem', 'Ġby', 'Ġintroducing', 'ĠFAR', '-', 'Trans', ',', 'Ġthe', 'Ġfirst', 'Ġpublic', 'Ġdataset', 'Ġfor', 'ĠFAR', ',', 'Ġcontaining', 'Ġpricing', 'Ġinformation', 'Ġand', 'Ġretail', 'Ġinvestor', 'Ġtransactions', 'Ġacquired', 'Ġfrom', 'Ġa', 'Ġlarge', 'ĠEuropean', 'Ġfinancial', 'Ġinstitution', '.'] ['ĠThis', 'Ġwork', 'Ġaims', 'Ġto', 'Ġsolve', 'Ġthis', 'Ġlimitation', 'Ġby', 'Ġproposing', 'Ġa', 'Ġnovel', 'Ġdataset', 'Ġfor', 'Ġthe', 'ĠFAR', 'Ġtask', ',', 'Ġprovided', 'Ġby', 'Ġa', 'Ġlarge', 'ĠEuropean', 'Ġfinancial', 'Ġinstitution', '.', 'ĠAs', 'Ġfar', 'Ġas', 'Ġwe', 'Ġare', 'Ġaware', ',', 'Ġthis', 'Ġdataset', 'Ġrepresents', 'Ġthe', 'Ġfirst', 'Ġdataset', 'Ġin', 'Ġthis', 'Ġdomain', 'Ġcontaining', 'Ġpricing', 'Ġtime', 'Ġseries', 'Ġfor', 'Ġmultiple', 'Ġasset', 'Ġtypes', 'Ġ(', 'stocks', ',', 'Ġbonds', 'Ġand', 'Ġmutual', 'Ġfunds', '),', 'Ġasset', 'Ġdescriptions', ',', 'Ġas', 'Ġwell', 'Ġas', 'Ġmost', 'Ġimportantly', 'Ġ(', 'an', 'onym', 'ised', ')', 'Ġcustomer', 'Ġinformation', 'Ġand', 'Ġinvestment', 'Ġtransactions', '.'] ['ĠWe', 'Ġintroduce', 'Ġin', 'Ġthis', 'Ġwork', 'Ġa', 'Ġnovel', 'Ġdataset', 'Ġfor', 'Ġfinancial', 'Ġrecommendation', ',', 'Ġwhich', 'Ġwe', 'Ġshall', 'Ġname', 'ĠFAR', '-', 'Trans', '.', 'ĠAs', 'Ġfar', 'Ġas', 'Ġwe', 'Ġare', 'Ġaware', 'Ġthis', 'Ġdataset', 'Ġrepresents', 'Ġthe', 'Ġfirst', 'Ġpublic', 'Ġdataset', 'Ġcontaining', 'Ġboth', 'Ġasset', 'Ġpricing', 'Ġinformation', 'Ġand', 'Ġinvestment', 'Ġtransactions', 'Ġfor', 'ĠFAR', '.', 'ĠThe', 'Ġdata', 'Ġhas', 'Ġbeen', 'Ġprovided', 'Ġby', 'Ġa', 'Ġlarge', 'ĠEuropean', 'Ġfinancial', 'Ġinstitution', ',', 'Ġrepresenting', 'Ġa', 'Ġsnapshot', 'Ġof', 'Ġthe', 'Ġmarket', 'Ġavailable', 'Ġto', 'ĠGreek', 'Ġinvestors', 'Ġbetween', 'ĠJanuary', 'Ġ2018', 'Ġand', 'ĠNovember', 'Ġ2022', '.', 'ĠFAR', '-', 'Trans', 'Ġcovers', 'Ġpricing', 'Ġdata', 'Ġfor', 'Ġstocks', ',', 'Ġbonds', 'Ġand', 'Ġmutual', 'Ġfunds', ',', 'Ġas', 'Ġwell', 'Ġas', 'Ġinvestment', 'Ġtransaction', 'Ġlogs', 'Ġ(', 'ass', 'et', 'Ġbuy', 'Ġand', 'Ġsell', 'Ġactions', ')', 'Ġhandled', 'Ġby', 'Ġthe', 'Ġinstitution', ',', 'Ġcustomer', ',', 'Ġmarket', 'Ġand', 'Ġasset', 'Ġinformation', '.', 'ĠThis', 'Ġsection', 'Ġprovides', 'Ġa', 'Ġdescription', 'Ġof', 'Ġthe', 'Ġdataset', 'Ġand', 'Ġthe', 'Ġacquisition', 'Ġand', 'Ġcleaning', 'Ġmethodology', '.', 'ĠThe', 'Ġdataset', 'Ġis', 'Ġavailable', 'Ġfrom', 'Ġhttps', '://', 'doi', '.', 'org', '/', '10', '.', '55', '25', '/', 'gl', 'a', '.', 'research', 'data', '.', '16', '58', '.', 'ĠTable', 'Ġ1', 'Ġsummarizes', 'Ġits', 'Ġglobal', 'Ġproperties', '.'] ['ĠFinally', ',', 'Ġwe', 'Ġsolve', 'Ġinconsistencies', 'Ġon', 'Ġasset', 'Ġprices', 'Ġby', 'Ġproviding', 'Ġan', 'Ġestimate', 'Ġof', 'Ġthe', 'Ġtotal', 'Ġvalue', 'Ġof', 'Ġthe', 'Ġtransaction', '.', 'ĠWe', 'Ġestimate', 'Ġthe', 'Ġvalue', 'Ġby', 'Ġmultiplying', 'Ġthe', 'Ġnumber', 'Ġof', 'Ġshares', 'Ġby', 'Ġthe', 'Ġclosing', 'Ġprice', 'Ġof', 'Ġthe', 'Ġasset', 'Ġon', 'Ġthe', 'Ġdate', 'Ġof', 'Ġthe', 'Ġtransaction', '.', 'ĠIn', 'Ġthe', 'Ġend', ',', 'Ġwe', 'Ġhave', 'Ġ388', ',', '0', '49', 'Ġtransactions', 'Ġin', 'Ġour', 'Ġdataset', ',', 'Ġcorresponding', 'Ġto', 'Ġ29', ',', '090', 'Ġcustomers', '.'] |
输出结果拼接
参考QA任务文档
对于同一篇文章被滑窗分成的不同段落, 输出时如何判断他们是连在一起的呢. 目前是直接输出了, 因为每个滑窗的label是对应的, 可能不影响训练, 但是应用的时候如何处理呢?
在输入参数return_overflowing_tokens=True, return_offsets_mapping=True,的情况下, sentence_tokens keys: dict_keys(['input_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])
其中有两个特殊的键
offset_mapping代表偏移, 每次切出来的单词对应着第几个字母
text: "80 % of Americans … offset_mapping: (0,2) (2,3) (4,6) (7,16)… |
overflow_to_sample_mapping代表这个窗口来自第几个样本
sentence_tokens['overflow_to_sample_mapping'] |
[0, 0, 0] |
如果要使其标识不同样本, 需要所有样本一起放入, 参考文档例子:
inputs = tokenizer( raw_datasets["train"][2:6]["question"], raw_datasets["train"][2:6]["context"], max_length=100, truncation="only_second", stride=50, return_overflowing_tokens=True, return_offsets_mapping=True, ) print(f"The 4 examples gave {len(inputs['input_ids'])} features.") print(f"Here is where each comes from: {inputs['overflow_to_sample_mapping']}.") |
'The 4 examples gave 19 features.' 'Here is where each comes from: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3].' |
这个标识可以用来拼接同一个样本的输出
但是还有一个问题, 如果确定了几个滑窗来自同一个样本, 重叠部分的输出该如何处理呢?
首先要知道输出的这个token位置, 针对滑窗重复部分256个token中取出的, 则去除.
offset mappings的清理:
它们将包含问题和上下文的偏移量,但是一旦进入后处理阶段,我们将无法知道输入 ID 的哪一部分对应于上下文以及哪一部分是问题(sequence_ids()方法仅适用于分词器的输出)。因此,我们将与问题相对应的偏移量设置为 None |
根据样本论文标号拼接预测结果
尝试在使用tokenizer划分滑窗的时候就加上标号, 否则后面很难区分
tokenized_inputs = {"input_ids": [], "attention_mask": [], "labels": [],"paper_idx":[]} |
tokenized_inputs["paper_idx"].append(torch.tensor(i)) |
tokenized_inputs["paper_idx"] = torch.stack(tokenized_inputs["paper_idx"]) |
得到tokenized_inputs["paper_idx"]
tensor([0, 0, 0]) |
取出时考虑paper_idx
内部token拼接时考虑的是当前与下一个样例的id, 外部段落拼接考虑的是当前与上一个段落idx比较
或者都使用当前与上一个段落idx比较.
外部段落拼接如果当前与上一个段落idx相同, 则把当前段落接在同一个字符串, 否则作为新元素.
token拼接时如果相同则前256个重复token不考虑. 此时可以使用刚才判断的bool值.
第一个滑窗比较特殊, 不判断不删除.
def get_extracted_description(predicted_token_class_id, predicted_inputs_id, predicted_paper_idxs): dataset_descriptions = [] for i, predicted_inputs_id in enumerate(predicted_inputs_id): prediction_class = predicted_token_class_id[i] #判断是否同一篇文章 is_same_paper=False if(i>0): if(predicted_paper_idxs[i]==predicted_paper_idxs[i-1]): is_same_paper=True dataset_description=get_prediction_string(prediction_class, predicted_inputs_id, is_same_paper) #拼接 if(is_same_paper): dataset_descriptions[-1].join(dataset_description) else: dataset_descriptions.append(dataset_description) return dataset_descriptions dataset_descriptions = get_extracted_description(predicted_token_class_ids, inputs["input_ids"], inputs["paper_idx"]) |
拼接后针对同一篇文章的输出在同一个字符串中了
dataset_descriptions: [''] |
滑窗重复token位置处理
加入每一个滑窗的最后重复位置不考虑的话, 最后一个滑窗会比较特殊, 可以判断如果下一个滑窗和当前的样本标号相同, 则忽略滑窗重复的位置.
token拼接时如果当前与上一个段落idx相同, 则前256个重复token不考虑. 此时可以使用刚才判断的bool值.
#同一篇文章则删除滑窗重复部分 if(is_same_paper): tokenized_sub_sentence=tokenized_sub_sentence[stride:] predicted_tokens_classes=predicted_tokens_classes[stride:] |
评估label
label不是窗口的label, 而是整段论文的所有描述拼接, 之前已经完成. ref列表中的每个元素是同一篇论文中的各段描述, join后为完整描述的字符串
" ".join(ref) |
因此输出的文本也全部拼接完成了