大模型笔记4 长文本滑动窗口

news2025/1/11 23:44:04

Extractive QA参考:

https://juejin.cn/post/7180925054559977533

https://huggingface.co/learn/nlp-course/en/chapter7/7

目录

滑动窗口例子(提取开始结束点任务)

滑动窗口代码实现

tokenize() 默认添加问题

每个滑窗添加标题和摘要

训练label跨滑窗情况token匹配

Token匹配忽略固定标题摘要

输出结果拼接

根据样本论文标号拼接预测结果

滑窗重复token位置处理


posing questions about a document and identifying the answers as spans of text in the document itself.

常见范式:做句子的二分类任务(该句是否属于摘要),将预测为“属于”的句子拼起来,组成摘要。

像BERT这样的纯编码器模型往往擅长提取诸如“谁发明了Transformer架构?”之类的事实问题的答案,但是当给出诸如“为什么天空是蓝色的?”之类的开放式问题时,则表现不佳。在这些更具挑战性的情况下,通常使用 T5 和 BART 等编码器-解码器模型以类似文本摘要的方式综合信息.

Bert在SQuAD 数据集训练的模型例子:

https://huggingface.co/huggingface-course/bert-finetuned-squad

squad 评估的输出格式包含两个字段, 每个都是列表:

Answer: {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}

评估时, 某些问题有几个可能的答案,并且此脚本会将预测答案与所有可接受的答案进行比较,并取得最高分数。

使用深度学习方法做抽取式摘要的经典论文:

Friendly Topic Assistant for Transformer Based Abstractive Summarization

Friendly Topic Assistant for Transformer Based Abstractive Summarization - ACL Anthology

SummaRuNNer: A Recurrent Neural Network Based Sequence Model for Extractive Summarization of Documents

Extractive Summarization using Deep Learning

Neural Extractive Summarization with Side Information

Ranking Sentences for Extractive Summarization with Reinforcement Learning

Fine-tune BERT for Extractive Summarization

Extractive Summarization of Long Documents by Combining Global and Local Context

Extractive Summarization as Text Matching

重要模型: Fine-tune BERT for Extractive Summarization(BertSum)

https://arxiv.org/abs/1903.10318

GitHub - nlpyang/BertSum: Code for paper Fine-tune BERT for Extractive Summarization

滑动窗口例子(提取开始结束点任务)

模型内部的注意力滑窗

全局注意力可以通过参数设置

用户可以通过设置张量来定义哪些代币“本地”参与,哪些代币“全局”参与 global_attention_mask 适当地在运行时。所有 Longformer 模型都采用以下逻辑 global_attention_mask:

0: the token attends “locally”,

1: the token attends “globally”.

Longformer 自注意力结合了局部(滑动窗口)和全局注意力

文档没有提到更多滑窗相关内容, 可能这种模型的创新就在于attention window, 而不是整个段落滑窗. 因此还是需要另外编写.

输入时候的滑动窗口例子

(参考上文提到的huggingface文档)

从数据集的一个样本创建多个训练特征来处理长上下文,并在它们之间有一个滑动窗口。

为了使用当前示例了解其工作原理,我们可以将长度限制为 100 个,并使用 50 个标记的滑动窗口。提醒一下,我们使用:

max_length 设置最大长度(此处为 100)

truncation="only_second" 当问题及其上下文过长时,截断上下文(位于第二个位置)

stride 设置两个连续块之间的重叠令牌数(此处为 50)

return_overflowing_tokens=True 为了让tokenizer知道我们想要溢出的token

inputs = tokenizer(

    question,

    context,

    max_length=100,

    truncation="only_second",

    stride=50,

    return_overflowing_tokens=True,

)

for ids in inputs["input_ids"]:

    print(tokenizer.decode(ids))

'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basi [SEP]'

'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin [SEP]'

'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 [SEP]'

'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP]. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]'

针对只有一半答案的情况, 数据集为我们提供了上下文中答案的开始字符,通过添加答案的长度,我们可以在上下文中找到结束字符。为了将它们映射到令牌索引,我们需要使用我们研究的偏移映射 第 6 章.我们可以让我们的标记器通过传递来返回这些 return_offsets_mapping=True:

inputs = tokenizer(

    question,

    context,

    max_length=100,

    truncation="only_second",

    stride=50,

    return_overflowing_tokens=True,

    return_offsets_mapping=True,

)

inputs.keys()

这样inputs里面就会包含偏移值

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])

得到列表inputs["overflow_to_sample_mapping"]

它代表滑窗来自第几个样本

如[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3]代表样本0~2每个样本分别有4个滑窗,样本3有7个滑窗

answers = raw_datasets["train"][2:6]["answers"]

start_positions = []

end_positions = []

for i, offset in enumerate(inputs["offset_mapping"]):

    sample_idx = inputs["overflow_to_sample_mapping"][i]

    answer = answers[sample_idx]

    start_char = answer["answer_start"][0]

    end_char = answer["answer_start"][0] + len(answer["text"][0])

    sequence_ids = inputs.sequence_ids(i)

    # Find the start and end of the context

    idx = 0

    while sequence_ids[idx] != 1:

        idx += 1

    context_start = idx

    while sequence_ids[idx] == 1:

        idx += 1

    context_end = idx - 1

    # If the answer is not fully inside the context, label is (0, 0)

    if offset[context_start][0] > start_char or offset[context_end][1] < end_char:

        start_positions.append(0)

        end_positions.append(0)

    else:

        # Otherwise it's the start and end token positions

        idx = context_start

        while idx <= context_end and offset[idx][0] <= start_char:

            idx += 1

        start_positions.append(idx - 1)

        idx = context_end

        while idx >= context_start and offset[idx][1] >= end_char:

            idx -= 1

        end_positions.append(idx + 1)

start_positions, end_positions

([83, 51, 19, 0, 0, 64, 27, 0, 34, 0, 0, 0, 67, 34, 0, 0, 0, 0, 0],

 [85, 53, 21, 0, 0, 70, 33, 0, 40, 0, 0, 0, 68, 35, 0, 0, 0, 0, 0])

第一个特征标签 (83, 85)-> 从 83 到 85(含)的解码token跨度(decoded span of tokens)

idx = 0

sample_idx = inputs["overflow_to_sample_mapping"][idx]

answer = answers[sample_idx]["text"][0]

start = start_positions[idx]

end = end_positions[idx]

labeled_answer = tokenizer.decode(inputs["input_ids"][idx][start : end + 1])

print(f"Theoretical answer: {answer}, labels give: {labeled_answer}")

'Theoretical answer: the Main Building, labels give: the Main Building'

其中的(0, 0)意味着答案不在该特征的上下文块中, 其中特征是其中一个input_ids组

滑动窗口代码实现

tokenize_and_align_labels中处理sentences时由于longformer模型及其tokenizer的最大输入长度为4096, tokenizer目前的设置将会对句子进行截断.

修改代码, 在使用tokenizer时, 用滑动窗口功能处理sentences, 及其对应的labels.

使用参数 max_length(默认值为 4096)和 stride(默认值为 256)的滑动窗口方法。这样,该函数就可以处理长度超过最大长度的输入,方法是将它们分成具有重叠的块。

训练循环:主训练循环保持不变,但它现在处理滑动窗口方法生成的标记化输入和标签。

在获得模型的输出后对输出进行拼接

注意不能在输入模型之前进行不同滑窗的拼接, 因为longformer模型的最大输入长度为4096.

修改前token_labels和inputid一样的形状 shape: torch.Size([1, 4096])

检查模型输入-还原成token

tokenizer.convert_ids_to_tokens(input_ids[0])

得到训练token全部是开头这一段

所以之前token划分就已经切割, 因此没有在输入模型时候超出限制

接着检查使用dataloader之前的input长度

inputs, token_labels = tokenize_and_align_labels(sentences, labels_texts, tokenizer)

inputs 中的id shape也是: torch.Size([4096])

最后对比tokenize_and_align_labels中第一次从sentences转成token的维度

    print("sentences lenth:",len(sentences[0]))

    print("tokenized_inputs shape:",tokenized_inputs["input_ids"][0].shape)

sentences lenth: 42723

tokenized_inputs shape: torch.Size([4096])

说明需要从此处开始加滑窗

修改tokenizer增加滑窗

        sentence_tokens = tokenizer(

            title,

            sentence,

            max_length=max_length,

            # truncation=True,

            truncation="only_second",

            stride=stride,

            return_overflowing_tokens=True,

            return_offsets_mapping=True,

            padding='max_length'

        )

NotImplementedError: return_offset_mapping is not available when using Python tokenizers.To use this feature, change your tokenizer to one deriving from transformers.PreTrainedTokenizerFast.More information on available tokenizers at https://github.com/huggingface/transformers/pull/2674

得到两个特殊的键, 之后可以看看怎么使用

sentence_tokens['offset_mapping'][0]

[(0, 0), (0, 3), (3, 4), (4, 9), (9, 10), (11, 13), (14, 24), (25, 28), (28, 30), (30, 32), (33, 36), (37, 46), (47, 52), (53, 62), (62, 67), (68, 76), (77, 86), (87, 92), (93, 107), (108, 109), (109, 110), (110, 112), (112, 113), (114, 116), (117, 118), (119, 122), (122, 128), (129, 131), (132, 138), (138, 143), (144, 151), (152, 157), (158, 168), (169, 175), (176, 185), (186, 196), (197, 200), (201, 210), (210, 211), (212, 216), (217, 220), (221, 232), (233, 237), (238, 242), (243, 247), (248, 254), (255, 262), (263, 265), (266, 269), (270, 281), (282, 288), (288, 289), (290, 293), (294, 303), (304, 311), (312, 315), (316, 321), (322, 326), (327, 335), ...]

sentence_tokens['overflow_to_sample_mapping']

[0, 0, 0]

sentence lenth: 42723

sentence_tokens len: 4096

第一段还是和原来一样的开头, 但是有三段了

print("sentence_tokens len:",len(sentence_tokens["input_ids"]))

sentence_tokens len: 3

可以看到第二段不再是开头, 而是接下去的内容了

sentence_tokens: ['<s>', 'Ġ(', '64', ',', '219', ')', 'ĠAverage', 'Ġreturn', 'Ġ(', 'by', 'Ġassets', ',', 'Ġwhole', 'Ġperiod', ')', 'Ġ37', '.', '16', '%', 'ĠAverage', 'Ġreturn', 'Ġ(', 'by', 'Ġcustomers', ',', 'Ġwhole', 'Ġperiod', ')', …', 'Ġwith', 'Ġover', 'Ġ5', ',', '000', 'Ġinteractions', 'Ġhappening', 'Ġev', '-', 'Ġ', 'ery', 'Ġmonth', '.', 'ĠThe', 'Ġprevious', 'Ġperiod', ',', 'Ġbetween', 'ĠJanuary', 'Ġ2018', 'Ġand', 'ĠDecember', 'Ġ2019', 'Ġonly', 'Ġreceives', 'Ġbetwe

现在输入token维度是没有问题了,

每个元素append, 再整个torch.stack的方法得到的是列表中新的元素, 不是拼接成同一个字符串, 因此不会超过长度限制

现在改后的预测处predicted_token_class_ids更加复杂, 需要再处理一层

Traceback (most recent call last): File "d:/Projects/longformer/tests/tkn_clsfy_slide.py", line 154, in <module> prediction_strings = [get_prediction_string(prediction, predicted_inputs_id) for prediction, predicted_inputs_id in zip(predicted_token_class_ids, inputs["input_ids"])] TypeError: zip argument #1 must support iteration

取出来了, 虽然最后一段预测的token包含padding.

tokenize() 默认添加<pad>问题

检查label, 发现tokenize的时候被加了padding.

一般来说tokenize()是不添加<pad>的. 可能是tokenizer默认参数设置的问题. 但是这个函数又没有padding参数可以传入, 直接改tokenizer又可能对后续产生影响, 目前 删除这些padding

tokenized_label = [token for token in tokenized_label if token != '<pad>']#删除pad

现在可以正确匹配了

matched tokenized_label: ['ĠAs', 'Ġfuture', 'Ġwork', ',', 'Ġwe', 'Ġshall', 'Ġexplore', 'Ġthe', 'Ġuse', 'Ġof', 'Ġthis', 'Ġdataset', 'Ġfor', 'Ġmultiple', 'Ġtasks', ',', 'Ġnot', 'Ġonly', 'Ġincluding', 'Ġfinancial', 'Ġasset', 'Ġrecommendation', ',', 'Ġbut', 'Ġalso', 'Ġportfolio', 'Ġconstruction', 'Ġand', 'Ġoptim', 'isation', 'Ġor', 'Ġinvestor', 'Ġand', 'Ġasset', 'Ġmodelling', '.']

例子中每个输入都包含问题和上下文的某些部分,

可以改成每个输入都包含标题和摘要, 但这需要数据集的预处理

预测匹配了目标句子, 但是输出的预测依然有许多padding

打印出来可以看到超出的pad部分label是0, 所以label是没有问题的

matched clsfy_label: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, …, 0, 0, 0]

查看评估中的tokenize, 去掉所有tokenize处的pad

#评估时

        ref_tokens = [token for token in ref_tokens if token != '<pad>']#删除pad

        pred_tokens = [token for token in pred_tokens if token != '<pad>']#删除pad

#预测token拼接时

    for token, pred_class in zip(tokenized_sub_sentence, predicted_tokens_classes):

        if token.startswith("Ġ"):

            if (len(current_word) != 0) & current_word_pred:

                dataset_description_words.append(current_word)

            current_word = token[1:]

            current_word_pred = (pred_class == 'Dataset description')

        elif token != '<pad>':#删除pad

            current_word += token

            current_word_pred = current_word_pred or (pred_class == 'Dataset description')

    if (len(current_word) != 0)& (current_word != '<pad>') & current_word_pred:

        dataset_description_words.append(current_word)

到这里所有的<pad>都处理完了

每个滑窗添加标题和摘要

例子

将问题和上下文一起传递给我们的标记器,它将正确插入特殊标记以形成一个句子

[CLS] question [SEP] context [SEP]

context = raw_datasets["train"][0]["context"]

question = raw_datasets["train"][0]["question"]

inputs = tokenizer(question, context)

tokenizer.decode(inputs["input_ids"])

'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, '

'the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin '

'Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues '

'and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]'

得到[CLS] question [SEP] context [SEP]形式

在longformer中支持的分隔形式是</s >

例子2添加更多参数

inputs = tokenizer(

    question,

    context,

    max_length=100,

    truncation="only_second",

    stride=50,

    return_overflowing_tokens=True,

)

for ids in inputs["input_ids"]:

    print(tokenizer.decode(ids))

'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basi [SEP]'

'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP]. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]'

在输入参数中加入title, 并且截断参数设置truncation="only_second"只针对第二个滑窗

    for i, sentence in enumerate(sentences):

        labels_text = labels_texts[i]

        title = titles[i]

        sentence_tokens = tokenizer(

            title,

            sentence,

            max_length=max_length,

            # truncation=True,

            truncation="only_second",

            stride=stride,

            return_overflowing_tokens=True,

            return_offsets_mapping=True,

            padding='max_length'

        )

        print("sentence_tokens keys:",sentence_tokens.keys())

        input_ids = sentence_tokens['input_ids']

        attention_mask = sentence_tokens['attention_mask']

        offset_mapping = sentence_tokens['offset_mapping']

        print("input_token:",tokenizer.convert_ids_to_tokens(input_ids[0]))

sentence_tokens keys: dict_keys(['input_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])

input_token: ['<s>', 'ĠFAR', '-', 'Trans', ':', 'ĠAn', 'ĠInvestment', 'ĠDat', 'as', 'et', 'Ġfor', 'ĠFinancial', 'ĠAsset', 'ĠRecommend', 'ation', …'ĠWe', 'Ġalso', 'Ġprovide', 'Ġa', 'Ġbench', '-', 'mark', 'ing', 'Ġcomparison', 'Ġbetween', 'Ġeleven', 'ĠFAR', 'Ġalgorithms', 'Ġover', 'Ġthe', 'Ġdata', 'Ġfor', 'Ġuse', 'Ġas', 'Ġfuture', 'Ġbas', 'elines', '.', 'ĠThe', 'Ġdataset', 'Ġcan', 'Ġbe', 'Ġdownloaded', 'Ġfrom', 'Ġhttps', '://', 'doi', '.', 'org', '/', '10', '.', '55', '25', '/', 'gl', 'a', '.', 'research', 'data', '.', '16', '58', '.', '</s>', '</s>', 'ĠFAR', '-', 'Trans', ':', 'ĠAn', 'ĠInvestment', 'ĠDat', 'as', 'et', 'Ġfor', 'ĠFinancial', 'ĠAsset', 'ĠRecommend', 'ation', 'ĠJavier', 'ĠSan', 'z', '-', 'Cruz', 'ado', '1', 'Ġ,', 'ĠNikola', 'os', 'ĠD', 'rou', 'kas', '2', 'Ġ,', 'ĠRichard', 'ĠMcC', 'read', 'ie', '1', 'Ġ1', 'University', 'Ġof', 'ĠGlasgow', 'Ġ2', 'National', 'ĠBank', 'Ġof', 'ĠGreece', 'Ġj', 'avier', '.', 'san', 'z', '-', 'cru', 'z', 'ad', 'op', 'u', 'ig', '@', 'glas', 'gow', '.', 'ac', '.', 'uk', ',', 'Ġd', 'rou', 'kas', '.', 'nik', 'ola', 'os', '@', 'n', 'bg', '.', 'gr', …

得到结果:

Epoch 1/3, Loss: 0.35347285121679306

Epoch 2/3, Loss: 0.04643934965133667

Epoch 3/3, Loss: 0.026251467876136303

Model and tokenizer saved to 'trained_model'

Prediction Strings: ['problem by', '', 'In']

d:\Users\laugo\anaconda3\envs\longformer\lib\site-packages\sklearn\metrics\_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.

  _warn_prf(average, modifier, msg_start, len(result))

Precision: 0.0000, Recall: 0.0000, F1 Score: 0.0000

训练label跨滑窗情况token匹配

目前打token label时候是完整匹配

for k in range(len(tokenized_sentence) - label_length + 1):

tokenized_sentence[k:k + label_length] == tokenized_label

从遍历开始就保证了label预留长度

可以考虑改成按照tokenized_sentence的长度将tokenized_label截断后比较

分别对后半部分缺少和前半部分缺少的情况进行处理

                # 处理label跨滑窗

                for k in range(len(tokenized_sentence)-1):

                    # print("tokenized_sentence:",tokenized_sentence[k:k + label_length])

                    end_position=min(len(tokenized_sentence)-1, k + label_length)

                    if tokenized_sentence[k:end_position] == tokenized_label[0:end_position]:#后半部分没有的情况

                        print("matched tokenized_label:",tokenized_label,"\n",tokenized_sentence[k:end_position])

                        # end_position=min(len(tokenized_sentence)-1, k + label_length)

                        # end_position=k + label_length

                        token_label[k:end_position] = [1] * (end_position-k)

                        # print("matched clsfy_label:",token_label[k:len(tokenized_sentence)-1])

                for start in range(label_length-1):#前半部分没有的情况

                    # print("tokenized_sentence:",tokenized_sentence[k:k + label_length])

                    if tokenized_sentence[0:label_length-start] == tokenized_label[start:]:#后半部分没有的情况

                        print("matched tokenized_label:",tokenized_sentence[0:label_length-start],"\n",tokenized_label[start:])

                        # end_position=min(len(tokenized_sentence)-1, k + label_length)

                        # end_position=k + label_length

                        token_label[0:label_length-start] = [1] * (label_length-start)

得到输出

Prediction Strings: ['FAR-Trans, the first public dataset for FAR, containing pricing information and retail investor transactions acquired from a large European financial institution. aim solve problem by introducing FAR-Trans, the first public dataset for FAR, containing pricing information and retail investor transactions acquired from a large European financial institution.', 'introducing FAR-Trans, the first public dataset for FAR, containing pricing information and retail investor transactions acquired from a large European financial institution.', 'FAR-Trans, the first public dataset for FAR, containing pricing information and retail investor transactions acquired from a large European financial institution.']

Precision: 1.0000, Recall: 0.9091, F1 Score: 0.9524

Token匹配忽略固定标题摘要

加入滑窗后, 每一个滑窗的第一段Q部分是题目摘要, 第一个滑窗的第二段A部分也是题目摘要, 出现提取出token的重复, 因此token匹配的起点改为第二段A开头.

(提取式QA的Q在每一个滑窗中, 但是答案一般不会出现在Q中, 因此不造成影响.

但是题目摘要中如果有描述, 将每次被Token匹配.

目前Token匹配的时候默认是全部匹配, 改成所有匹配从第二个</s>段开始.)

滑窗后的输入样本token(tokenized_sentence)的格式如下, 其中用special tokens 即<s>标记区分不同文本, 第一对<s></s>标记包裹的是标题摘要, 第二对</s></s>标记包裹的是全部文本, 注意第二个的开头是'</s>'

A Longformer sequence has the following format:

single sequence: <s> X </s>

pair of sequences: <s> A </s></s> B </s>

例子

['<s>', 'ĠFAR', '-', 'Trans', ':', 'ĠAn', 'ĠInvestment', 'ĠDat', 'as', 'et', 'Ġfor', 'ĠFinancial', 'ĠAsset', 'ĠRecommend', 'ation', …'ĠWe', 'Ġalso', 'Ġprovide', 'Ġa', 'Ġbench', '-', 'mark', 'ing', 'Ġcomparison', 'Ġbetween', 'Ġeleven', 'ĠFAR', 'Ġalgorithms', 'Ġover', 'Ġthe', 'Ġdata', 'Ġfor', 'Ġuse', 'Ġas', 'Ġfuture', 'Ġbas', 'elines', '.', 'ĠThe', 'Ġdataset', 'Ġcan', 'Ġbe', 'Ġdownloaded', 'Ġfrom', 'Ġhttps', '://', 'doi', '.', 'org', '/', '10', '.', '55', '25', '/', 'gl', 'a', '.', 'research', 'data', '.', '16', '58', '.', '</s>', '</s>', 'ĠFAR', '-', 'Trans', ':', 'ĠAn', 'ĠInvestment', 'ĠDat', 'as', 'et', 'Ġfor', 'ĠFinancial', 'ĠAsset', 'ĠRecommend', 'ation', 'ĠJavier', 'ĠSan', 'z', '-', 'Cruz', 'ado', '1', 'Ġ,', 'ĠNikola', 'os', 'ĠD', 'rou', 'kas', '2', 'Ġ,', 'ĠRichard', 'ĠMcC', 'read', 'ie', '1', 'Ġ1', 'University', 'Ġof', 'ĠGlasgow', 'Ġ2', 'National', 'ĠBank', 'Ġof', 'ĠGreece', 'Ġj', 'avier', '.', 'san', 'z', '-', 'cru', 'z', 'ad', 'op', 'u', 'ig', '@', 'glas', 'gow', '.', 'ac', '.', 'uk', ',', 'Ġd', 'rou', 'kas', '.', 'nik', 'ola', 'os', '@', 'n', 'bg', '.', 'gr', …]

修改对全文的token与标记文本的token进行匹配生成二分类 label的代码, 使匹配从第二个'<s>'开始:

#在每个滑窗中寻找正文开始处token位置

def find_main_body(tokenized_sentence):

    start_index = 0

    s_token_count = 0

    for idx, token in enumerate(tokenized_sentence):

        if token == '</s>':

            s_token_count += 1

            # print("</s>:",s_token_count)

            if s_token_count == 2:

                start_index = idx + 1

                return start_index

for k in range(main_body_start,len(tokenized_sentence)-1):#从正文部分开始匹配

                    if tokenized_sentence[main_body_start:label_length-label_start] == tokenized_label[label_start:]:#后半部分没有的情况

                        token_label[main_body_start:label_length-label_start] = [1] * (label_length-label_start)

现在匹配的内容与label一一对应, 没有重复标题中的内容

 ['ĠIn', 'Ġthis', 'Ġpaper', ',', 'Ġwe', 'Ġaim', 'Ġto', 'Ġsolve', 'Ġthis', 'Ġproblem', 'Ġby', 'Ġintroducing', 'ĠFAR', '-', 'Trans', ',', 'Ġthe', 'Ġfirst', 'Ġpublic', 'Ġdataset', 'Ġfor', 'ĠFAR', ',', 'Ġcontaining', 'Ġpricing', 'Ġinformation', 'Ġand', 'Ġretail', 'Ġinvestor', 'Ġtransactions', 'Ġacquired', 'Ġfrom', 'Ġa', 'Ġlarge', 'ĠEuropean', 'Ġfinancial', 'Ġinstitution', '.']

 ['ĠThis', 'Ġwork', 'Ġaims', 'Ġto', 'Ġsolve', 'Ġthis', 'Ġlimitation', 'Ġby', 'Ġproposing', 'Ġa', 'Ġnovel', 'Ġdataset', 'Ġfor', 'Ġthe', 'ĠFAR', 'Ġtask', ',', 'Ġprovided', 'Ġby', 'Ġa', 'Ġlarge', 'ĠEuropean', 'Ġfinancial', 'Ġinstitution', '.', 'ĠAs', 'Ġfar', 'Ġas', 'Ġwe', 'Ġare', 'Ġaware', ',', 'Ġthis', 'Ġdataset', 'Ġrepresents', 'Ġthe', 'Ġfirst', 'Ġdataset', 'Ġin', 'Ġthis', 'Ġdomain', 'Ġcontaining', 'Ġpricing', 'Ġtime', 'Ġseries', 'Ġfor', 'Ġmultiple', 'Ġasset', 'Ġtypes', 'Ġ(', 'stocks', ',', 'Ġbonds', 'Ġand', 'Ġmutual', 'Ġfunds', '),', 'Ġasset', 'Ġdescriptions', ',', 'Ġas', 'Ġwell', 'Ġas', 'Ġmost', 'Ġimportantly', 'Ġ(', 'an', 'onym', 'ised', ')', 'Ġcustomer', 'Ġinformation', 'Ġand', 'Ġinvestment', 'Ġtransactions', '.']

 ['ĠWe', 'Ġintroduce', 'Ġin', 'Ġthis', 'Ġwork', 'Ġa', 'Ġnovel', 'Ġdataset', 'Ġfor', 'Ġfinancial', 'Ġrecommendation', ',', 'Ġwhich', 'Ġwe', 'Ġshall', 'Ġname', 'ĠFAR', '-', 'Trans', '.', 'ĠAs', 'Ġfar', 'Ġas', 'Ġwe', 'Ġare', 'Ġaware', 'Ġthis', 'Ġdataset', 'Ġrepresents', 'Ġthe', 'Ġfirst', 'Ġpublic', 'Ġdataset', 'Ġcontaining', 'Ġboth', 'Ġasset', 'Ġpricing', 'Ġinformation', 'Ġand', 'Ġinvestment', 'Ġtransactions', 'Ġfor', 'ĠFAR', '.', 'ĠThe', 'Ġdata', 'Ġhas', 'Ġbeen', 'Ġprovided', 'Ġby', 'Ġa', 'Ġlarge', 'ĠEuropean', 'Ġfinancial', 'Ġinstitution', ',', 'Ġrepresenting', 'Ġa', 'Ġsnapshot', 'Ġof', 'Ġthe', 'Ġmarket', 'Ġavailable', 'Ġto', 'ĠGreek', 'Ġinvestors', 'Ġbetween', 'ĠJanuary', 'Ġ2018', 'Ġand', 'ĠNovember', 'Ġ2022', '.', 'ĠFAR', '-', 'Trans', 'Ġcovers', 'Ġpricing', 'Ġdata', 'Ġfor', 'Ġstocks', ',', 'Ġbonds', 'Ġand', 'Ġmutual', 'Ġfunds', ',', 'Ġas', 'Ġwell', 'Ġas', 'Ġinvestment', 'Ġtransaction', 'Ġlogs', 'Ġ(', 'ass', 'et', 'Ġbuy', 'Ġand', 'Ġsell', 'Ġactions', ')', 'Ġhandled', 'Ġby', 'Ġthe', 'Ġinstitution', ',', 'Ġcustomer', ',', 'Ġmarket', 'Ġand', 'Ġasset', 'Ġinformation', '.', 'ĠThis', 'Ġsection', 'Ġprovides', 'Ġa', 'Ġdescription', 'Ġof', 'Ġthe', 'Ġdataset', 'Ġand', 'Ġthe', 'Ġacquisition', 'Ġand', 'Ġcleaning', 'Ġmethodology', '.', 'ĠThe', 'Ġdataset', 'Ġis', 'Ġavailable', 'Ġfrom', 'Ġhttps', '://', 'doi', '.', 'org', '/', '10', '.', '55', '25', '/', 'gl', 'a', '.', 'research', 'data', '.', '16', '58', '.', 'ĠTable', 'Ġ1', 'Ġsummarizes', 'Ġits', 'Ġglobal', 'Ġproperties', '.']

 ['ĠFinally', ',', 'Ġwe', 'Ġsolve', 'Ġinconsistencies', 'Ġon', 'Ġasset', 'Ġprices', 'Ġby', 'Ġproviding', 'Ġan', 'Ġestimate', 'Ġof', 'Ġthe', 'Ġtotal', 'Ġvalue', 'Ġof', 'Ġthe', 'Ġtransaction', '.', 'ĠWe', 'Ġestimate', 'Ġthe', 'Ġvalue', 'Ġby', 'Ġmultiplying', 'Ġthe', 'Ġnumber', 'Ġof', 'Ġshares', 'Ġby', 'Ġthe', 'Ġclosing', 'Ġprice', 'Ġof', 'Ġthe', 'Ġasset', 'Ġon', 'Ġthe', 'Ġdate', 'Ġof', 'Ġthe', 'Ġtransaction', '.', 'ĠIn', 'Ġthe', 'Ġend', ',', 'Ġwe', 'Ġhave', 'Ġ388', ',', '0', '49', 'Ġtransactions', 'Ġin', 'Ġour', 'Ġdataset', ',', 'Ġcorresponding', 'Ġto', 'Ġ29', ',', '090', 'Ġcustomers', '.']

输出结果拼接

参考QA任务文档

对于同一篇文章被滑窗分成的不同段落, 输出时如何判断他们是连在一起的呢. 目前是直接输出了, 因为每个滑窗的label是对应的, 可能不影响训练, 但是应用的时候如何处理呢?

在输入参数return_overflowing_tokens=True, return_offsets_mapping=True,的情况下, sentence_tokens keys: dict_keys(['input_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])

其中有两个特殊的键

offset_mapping代表偏移, 每次切出来的单词对应着第几个字母

text:          "80      %      of     Americans …

offset_mapping: (0,2)  (2,3)   (4,6)   (7,16)…

overflow_to_sample_mapping代表这个窗口来自第几个样本

sentence_tokens['overflow_to_sample_mapping']

[0, 0, 0]

如果要使其标识不同样本, 需要所有样本一起放入, 参考文档例子:

inputs = tokenizer(

    raw_datasets["train"][2:6]["question"],

    raw_datasets["train"][2:6]["context"],

    max_length=100,

    truncation="only_second",

    stride=50,

    return_overflowing_tokens=True,

    return_offsets_mapping=True,

)

print(f"The 4 examples gave {len(inputs['input_ids'])} features.")

print(f"Here is where each comes from: {inputs['overflow_to_sample_mapping']}.")

'The 4 examples gave 19 features.'

'Here is where each comes from: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3].'

这个标识可以用来拼接同一个样本的输出

但是还有一个问题, 如果确定了几个滑窗来自同一个样本, 重叠部分的输出该如何处理呢?

首先要知道输出的这个token位置, 针对滑窗重复部分256个token中取出的, 则去除.

offset mappings的清理:

它们将包含问题和上下文的偏移量,但是一旦进入后处理阶段,我们将无法知道输入 ID 的哪一部分对应于上下文以及哪一部分是问题(sequence_ids()方法仅适用于分词器的输出)。因此,我们将与问题相对应的偏移量设置为 None

根据样本论文标号拼接预测结果

尝试在使用tokenizer划分滑窗的时候就加上标号, 否则后面很难区分

tokenized_inputs = {"input_ids": [], "attention_mask": [], "labels": [],"paper_idx":[]}

tokenized_inputs["paper_idx"].append(torch.tensor(i))

tokenized_inputs["paper_idx"] = torch.stack(tokenized_inputs["paper_idx"])

得到tokenized_inputs["paper_idx"]

tensor([0, 0, 0])

取出时考虑paper_idx

内部token拼接时考虑的是当前与下一个样例的id, 外部段落拼接考虑的是当前与上一个段落idx比较

或者都使用当前与上一个段落idx比较.

外部段落拼接如果当前与上一个段落idx相同, 则把当前段落接在同一个字符串, 否则作为新元素.

token拼接时如果相同则前256个重复token不考虑. 此时可以使用刚才判断的bool值.

第一个滑窗比较特殊, 不判断不删除.

def get_extracted_description(predicted_token_class_id, predicted_inputs_id, predicted_paper_idxs):

    dataset_descriptions = []

    for i, predicted_inputs_id in enumerate(predicted_inputs_id):

        prediction_class = predicted_token_class_id[i]

        #判断是否同一篇文章

        is_same_paper=False

        if(i>0):

            if(predicted_paper_idxs[i]==predicted_paper_idxs[i-1]):

                is_same_paper=True

        dataset_description=get_prediction_string(prediction_class, predicted_inputs_id, is_same_paper)

        #拼接

        if(is_same_paper):

            dataset_descriptions[-1].join(dataset_description)

        else:

            dataset_descriptions.append(dataset_description)

    return dataset_descriptions

dataset_descriptions = get_extracted_description(predicted_token_class_ids, inputs["input_ids"], inputs["paper_idx"])

拼接后针对同一篇文章的输出在同一个字符串中了

dataset_descriptions: ['']

滑窗重复token位置处理

加入每一个滑窗的最后重复位置不考虑的话, 最后一个滑窗会比较特殊, 可以判断如果下一个滑窗和当前的样本标号相同, 则忽略滑窗重复的位置.

token拼接时如果当前与上一个段落idx相同, 则前256个重复token不考虑. 此时可以使用刚才判断的bool值.

#同一篇文章则删除滑窗重复部分

    if(is_same_paper):

        tokenized_sub_sentence=tokenized_sub_sentence[stride:]

        predicted_tokens_classes=predicted_tokens_classes[stride:]

评估label

label不是窗口的label, 而是整段论文的所有描述拼接, 之前已经完成. ref列表中的每个元素是同一篇论文中的各段描述, join后为完整描述的字符串

" ".join(ref)

因此输出的文本也全部拼接完成了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1968948.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MyBatis开发: XML配置⽂件

前言 在IDEA的yml文件注释发现乱码 1.配置文件注释中文显示乱码 退出重进&#xff0c;发现不是乱码就成功 一.MyBatis XML配置⽂件 学习了注解的⽅式, 接下来我们学习XML的⽅式. 使⽤Mybatis的注解⽅式&#xff0c;主要是来完成⼀些简单的增删改查功能. 如果需要实现复杂的SQL…

使用Langchain构建简单的数据库Agent

这篇文章我们介绍一个使用LangChain实现SQLagent的方法&#xff0c;LangChain直接内置了自己的SQLagent实现-—SQLDatabaseChain。这个方法使用 SQL Alchemy 与数据库交互。感兴趣的可以考虑一下这两个方案是否可以融合&#xff0c;这样保证SQL的准确性从而提升最终结果的准确率…

修改依赖库

修改依赖库 在开发时&#xff0c;当我们发现使用的依赖库有 bug&#xff0c;需要修改&#xff0c;一般都有这几种处理方式&#xff1a; fork 源码&#xff0c;修复 bug 然后提交 pr&#xff0c;等待作者合并&#xff0c;发布新版本提 issue 等待作者修复&#xff08;跟方式1类…

从物理学到电气工程:如何自学PLC进入工厂担任助理工程师?

本科物理专业自学 PLC 方向&#xff0c;有机会进厂担任助理电气工程师&#xff0c;但可能会面临一些挑战。在开始前刚好我有一些资料&#xff0c;是我根据网友给的问题精心整理了一份「PLC的资料从专业入门到高级教程」&#xff0c; 点个关注在评论区回复“888”之后私信回复“…

reactive函数

承上启下 在上一节 ref() 函数中&#xff0c;我们大致理解了 ref() 函数的作用是用来将数据转化为响应式的。但是对于基本类型和引用类型&#xff0c;Vue3底层做的转换不一致&#xff1a;对于基本类型&#xff0c;Vue3 通过 ref() 函数将变量转化为了 RefImpl引用对象从而实现响…

NVIDIA把Llama-3的上下文长度扩展16倍,长上下文理解能力超越GPT-4

在 Llama-3.1 模型发布之前&#xff0c;开源模型与闭源模型的性能之间一直存在较大的差距&#xff0c;尤其是在长上下文理解能力上。 大模型的上下文处理能力是指模型能够处理的输入和输出 Tokens 的总数。这个长度有一个限制&#xff0c;超过这个限制的内容会被模型忽略。一般…

Power功效分析之均值差原理及案例实操分析

Power功效分析常用于实验研究时样本量的计算&#xff08;或功效值计算&#xff09;&#xff0c;实验研究中均值差的使用较多&#xff0c;具体包括单样本t检验、独立样本t检验、配对t检验、单样本z检验、Mann-whitey检验和配对符号秩和检验等&#xff0c;具体如下表格所述&#…

最新版的,SpringBoot整合Sharding-Jdbc实现读写分离

Sharding-Jdbc实现读写分离 Hello&#xff0c;兄弟们好&#xff0c;我是Feri&#xff0c;最近整理了最新的基于Seata-Server2.0实现分布式事务的demo&#xff0c;希望对你有所帮助&#xff0c;有任何问题&#xff0c;可以随时沟通交流&#xff0c;在成为技术大牛的路上&#xf…

校园点餐系统

1 项目介绍 1.1 摘要 在这个被海量信息淹没的数字化时代&#xff0c;互联网技术以惊人的速度迭代&#xff0c;信息的触角无处不在&#xff0c;社会的脉动随之加速。每一天&#xff0c;我们都被汹涌而至的数据浪潮包裹&#xff0c;生活在一个全方位的数字信息矩阵中。互联网的…

vue3解析markdown文件为html并且高亮显示代码块

前言&#xff1a; 很多时候我们程序员写的文档都是以markdown为主&#xff0c;但是我们每次找相关资料极为不便&#xff0c;如果能直接把markdown文档引进vue项目里&#xff0c;解析成html并且展示出来&#xff0c;然后部署在服务器上&#xff0c;查看是不是极为方便呢。&…

3D打印随形透气钢:模具困气终结者

困气是模具经常遇到的问题&#xff0c;是制约生产效率与产品质量的关键因素之一。传统透气钢材料虽有所助益&#xff0c;但其在加工复杂度、形状适应性及性能均衡性上的局限性明显。在此背景下&#xff0c;3D打印技术的革新性应用——随形透气钢应运而生&#xff0c;为困气、排…

view 和 reshape的区别 及 测试对一个数据执行view 和 reshape之后得到的数据还一样吗

一、测试对一个数据执行view 和 reshape之后得到的数据还一样吗 问题&#xff1a; x torch.randn(2, 3, 4) y_view x.view(12&#xff0c; 2) y_reshape y_view.reshape(2&#xff0c;3, 4)得到的结果一样吗 import torch# 创建一个张量 x torch.randn(2, 3, 4)# 使用 …

Datawhale夏令营AI for Science(AI+气象)学习笔记1

如何针对降水预测问题搭建模型 回顾baseline, 我们可以大致将搭建模型并解决问题分为以下几个步骤: 定义数据集, 建立起训练数据和标签之间的关系&#xff1b;定义数据加载器(DataLoader)&#xff0c; 方便取数据进行训练 定义模型, 利用PyTorch搭建网络&#xff0c;根据输入…

关于DynamoRIO处理多线程程序时候的问题

&#x1f3c6;本文收录于《CSDN问答解惑-专业版》专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收…

【天机学堂】面试总结

写在前面&#xff0c;首先要将天机学堂包装一下&#xff0c;智慧教育平台》&#xff0c;暂时就想到这个。天机学堂文档 1.包装简历 待更新。。。

持续集成09--Jenkins配置Sonar代码漏洞扫描工具

专栏内容 持续集成01--Git版本管理及基础应用实践_持续集成下的git分支-CSDN博客 持续集成02--Linux环境更新/安装Java新版本-CSDN博客 持续集成03--Jenkins的安装与配置-CSDN博客 持续集成04--Jenkins结合Gitee创建项目_jenkins集成gitee-CSDN博客 持续集成05--Gogs的安装与使…

Ubuntu运行深度学习代码,代码随机epoch中断没有任何报错

深度学习运行代码直接中断 文章目录 深度学习运行代码直接中断问题描述设备信息问题补充解决思路问题发现及正确解决思路新问题出现最终问题&#xff1a;ubuntu系统&#xff0c;4090显卡安装英伟达驱动535.x外的驱动会导致开机无法进入桌面问题记录 问题描述 运行深度学习代码…

MySQL--表完整性约束

前言&#xff1a;本博客仅作记录学习使用&#xff0c;部分图片出自网络&#xff0c;如有侵犯您的权益&#xff0c;请联系删除 作用&#xff1a;用于保证数据的完整性和一致性 约束条件说明PRIMARY KEY (PK)该字段为该表的主键&#xff0c;可以唯一的标识记录&#xff0c;不可以…

【Python 逆向滑块】(实战三)逆向滑块,并实现用Python+Node.js 生成滑块、识别滑块、验证滑块、发送短信

逆向日期&#xff1a;2024.08.01 使用工具&#xff1a;Node.js 本章知识&#xff1a;逆向网易易盾【cb】参数 文章难度&#xff1a;中等&#xff08;没耐心的请离开&#xff09; 文章全程已做去敏处理&#xff01;&#xff01;&#xff01; 【需要做的可联系我】 AES解密处理…