BERT-CRF 微调中文 NER 模型

news2024/12/30 1:37:21

文章目录

  • 数据集
  • 模型定义
  • 数据集预处理
    • BIO 标签转换
    • 自定义Dataset
    • 拆分训练、测试集
  • 训练
  • 验证、测试
  • 指标计算
  • 推理
  • 其它
    • 相关参数
    • CRF 模块

数据集

  • CLUE-NER数据集:https://github.com/CLUEbenchmark/CLUENER2020/blob/master/pytorch_version/README.md
    在这里插入图片描述

模型定义

import torch
import torch.nn as nn
from pytorch_crf import CRF
from transformers import BertPreTrainedModel, BertModel

class BertCrfForNer(BertPreTrainedModel):
    def __init__(self, config):
        super(BertCrfForNer, self).__init__(config)
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.crf = CRF(num_tags=config.num_labels, batch_first=True)
        self.num_labels = config.num_labels
        self.init_weights()

    def forward(self, input_ids, token_type_ids=None, attention_mask=None,labels=None,input_lens=None):
        outputs =self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        outputs = (logits,)
        if labels is not None:
            loss = self.crf(emissions = logits, tags=labels, mask=attention_mask)
            outputs =(-1*loss,)+outputs
        return outputs # (loss), scores

其中 CRF 模块 pytorch_crf.py 见后文。

数据集预处理

BIO 标签转换

ALLOW_LABEL = ["name", "organization", "address","company","government"]

def generate_bio_tags(tokenizer, text_json, allowed_type = ALLOW_LABEL):
    def tokenize_with_location(tokenizer, input_data):
        encoded_input = tokenizer.encode_plus(input_data, return_offsets_mapping=True)
        return list(zip([tokenizer.decode(i) for i in  encoded_input.input_ids],encoded_input.offset_mapping))

    def get_bio_tag(labels, token_start, token_end):
        if token_start >= token_end:
            return "O"
        for entity_type, entities in labels.items():
            if entity_type in allowed_type:
                for entity_name, positions in entities.items():
                    for position in positions:
                        start, end = position
                        if token_start >= start and token_end <= end+1:
                            if token_start == start:
                                return f"B-{entity_type}"
                            else:
                                return f"I-{entity_type}"
        return "O"
                    
    text = text_json["text"]
    labels = text_json["label"]

    # 使用BERT分词器进行分词
    tokenized_text = tokenize_with_location(tokenizer, text)
    tokens, bio_tags = [], []
    for token, loc in tokenized_text:
        loc_s, loc_e = loc
        bio_tag = get_bio_tag(labels, loc_s, loc_e)
        bio_tags.append(bio_tag)
        tokens.append(token)
    return tokens, bio_tags

# 输入JSON数据
input_json = {"text": "你们是最棒的!#英雄联盟d学sanchez创作的原声王", "label": {"game": {"英雄联盟": [[8, 11]]}}}
generate_bio_tags(tokenizer, input_json)
"""
(['[CLS]',
  '你',
  '们',
  '是',
  '最',
  '棒',
  '的',
  '!',
  '#',
  '英',
  '雄',
  '联',
  '盟',
  'd',
  '学',
  'san',
  '##che',
  '##z',
  '创',
  '作',
  '的',
  '原',
  '声',
  '王',
  '[SEP]'],
 ['O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O'])"""

自定义Dataset

from tqdm.notebook import tqdm
import json
import pickle
import os

cached_dataset = 'train.dataset.pkl'
train_file = 'train.json'
if not os.path.exists(cached_dataset):
    
    dataset = []
    with open(train_file, 'r') as file:
        for line in tqdm(file.readlines()):
            data = json.loads(line.strip())
            tokens, bio_tags = generate_bio_tags(tokenizer, data)
            if len(set(bio_tags)) > 1:
                dataset.append({"text": data["text"], "tokens": tokens, "tags": bio_tags})
    with open(cached_dataset, 'wb') as f:
        pickle.dump(dataset, f)
        
else:
    with open(cached_dataset, 'rb') as f:
        dataset = pickle.load(f)

先把原始数据 {“text”: …, “label”: … } 转换成 {“text”: … , “tokens”: …, “tags”: …}

from itertools import product
from torch.utils.data import Dataset, DataLoader

labels = ["O"] + [f"{i}-{j}" for i,j in product(['B','I'], ALLOW_LABEL)]
label2id = {k: v for v, k in enumerate(labels)}
id2label = {v: k for v, k in enumerate(labels)}

class BertDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_len):
        self.len = len(dataset)
        self.data = dataset
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __getitem__(self, index):
        # step 1: tokenize (and adapt corresponding labels)
        item = self.data[index]
        
        # step 2: add special tokens (and corresponding labels)
        tokenized_sentence = item["tokens"]
        labels = item["tags"] # add outside label for [CLS] token


        # step 3: truncating/padding
        maxlen = self.max_len

        if (len(tokenized_sentence) > maxlen):
            # truncate
            tokenized_sentence = tokenized_sentence[:maxlen]
            labels = labels[:maxlen]
        else:
            # pad
            tokenized_sentence = tokenized_sentence + ['[PAD]'for _ in range(maxlen - len(tokenized_sentence))]
            labels = labels + ["O" for _ in range(maxlen - len(labels))]

        # step 4: obtain the attention mask
        attn_mask = [1 if tok != '[PAD]' else 0 for tok in tokenized_sentence]
        
        # step 5: convert tokens to input ids
        ids = self.tokenizer.convert_tokens_to_ids(tokenized_sentence)

        label_ids = [label2id[label] for label in labels]
        # the following line is deprecated
        #label_ids = [label if label != 0 else -100 for label in label_ids]
        
        return {
              'ids': torch.tensor(ids, dtype=torch.long),
              'mask': torch.tensor(attn_mask, dtype=torch.long),
              #'token_type_ids': torch.tensor(token_ids, dtype=torch.long),
              'targets': torch.tensor(label_ids, dtype=torch.long)
        } 
    
    def __len__(self):
        return self.len

拆分训练、测试集

import numpy as np
import random
def split_train_test_valid(dataset, train_size=0.9, test_size=0.1):
    dataset = np.array(dataset)
    total_size = len(dataset)
    
    # define the ratios
    train_len = int(total_size * train_size)
    test_len = int(total_size * test_size)

    # split the dataframe
    idx = list(range(total_size))
    random.shuffle(idx)  # 将index列表打乱
    data_train = dataset[idx[:train_len]]
    data_test = dataset[idx[train_len:train_len+test_len]]
    data_valid = dataset[idx[train_len+test_len:]]  # 剩下的就是valid
 
    return data_train, data_test, data_valid


data_train, data_test, data_valid = split_train_test_valid(dataset)
print("FULL Dataset: {}".format(len(dataset)))
print("TRAIN Dataset: {}".format(data_train.shape))
print("TEST Dataset: {}".format(data_test.shape))

training_set = BertDataset(data_train, tokenizer, MAX_LEN)
testing_set = BertDataset(data_test, tokenizer, MAX_LEN)
train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }

test_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0
                }
training_loader = DataLoader(training_set, **train_params)
testing_loader = DataLoader(testing_set, **test_params)

训练

model = BertCrfForNer.from_pretrained('models/bert-base-chinese',
# model = AutoModelForTokenClassification.from_pretrained('save_model',
                                                   num_labels=len(id2label),
                                                   id2label=id2label,
                                                  label2id=label2id)
if MULTI_GPU:
    model = torch.nn.DataParallel(model, )
model.to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)
from sklearn.metrics import accuracy_score
import warnings
warnings.filterwarnings('ignore')

def train(epoch):
    tr_loss, tr_accuracy = 0, 0
    nb_tr_examples, nb_tr_steps = 0, 0
    tr_preds, tr_labels = [], []
    # put model in training mode
    model.train()
    
    for idx, batch in enumerate(training_loader):
        
        ids = batch['ids'].to(device, dtype = torch.long)
        mask = batch['mask'].to(device, dtype = torch.long)
        targets = batch['targets'].to(device, dtype = torch.long)

        outputs = model(input_ids=ids, attention_mask=mask, labels=targets)
#         loss, tr_logits = outputs.loss, outputs.logits
        loss, tr_logits = outputs[0], outputs[1]
        if MULTI_GPU:
            loss = loss.mean()
        tr_loss += loss.item()

        nb_tr_steps += 1
        nb_tr_examples += targets.size(0)
        
        if idx % 100==0:
            loss_step = tr_loss/nb_tr_steps
            print(f"Training loss per 100 training steps: {loss_step}")
           
        # compute training accuracy
        flattened_targets = targets.view(-1) # shape (batch_size * seq_len,)
        num_labels = model.module.num_labels if MULTI_GPU else model.num_labels
        active_logits = tr_logits.view(-1, num_labels) # shape (batch_size * seq_len, num_labels)
        flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size * seq_len,)
        # now, use mask to determine where we should compare predictions with targets (includes [CLS] and [SEP] token predictions)
        active_accuracy = mask.view(-1) == 1 # active accuracy is also of shape (batch_size * seq_len,)
        targets = torch.masked_select(flattened_targets, active_accuracy)
        predictions = torch.masked_select(flattened_predictions, active_accuracy)
        
        tr_preds.extend(predictions)
        tr_labels.extend(targets)
        
        tmp_tr_accuracy = accuracy_score(targets.cpu().numpy(), predictions.cpu().numpy())
        tr_accuracy += tmp_tr_accuracy
    
        # gradient clipping
        torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(), max_norm=MAX_GRAD_NORM
        )
        
        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    epoch_loss = tr_loss / nb_tr_steps
    tr_accuracy = tr_accuracy / nb_tr_steps
    print(f"Training loss epoch: {epoch_loss}")
    print(f"Training accuracy epoch: {tr_accuracy}")

for epoch in range(EPOCHS):
    print(f"Training epoch: {epoch + 1}")
    train(epoch)
"""
Training epoch: 1
Training loss per 100 training steps: 76.82186126708984
Training loss per 100 training steps: 26.512494955912675
Training loss per 100 training steps: 18.23713019356799
Training loss per 100 training steps: 14.71561597431221
Training loss per 100 training steps: 12.793566083075698
Training loss epoch: 12.138352865534845
Training accuracy epoch: 0.9093487211512798
"""

验证、测试

def valid(model, testing_loader):
    # put model in evaluation mode
    model.eval()
    
    eval_loss, eval_accuracy = 0, 0
    nb_eval_examples, nb_eval_steps = 0, 0
    eval_preds, eval_labels = [], []
    
    with torch.no_grad():
        for idx, batch in enumerate(testing_loader):
            
            ids = batch['ids'].to(device, dtype = torch.long)
            mask = batch['mask'].to(device, dtype = torch.long)
            targets = batch['targets'].to(device, dtype = torch.long)
            
            outputs = model(input_ids=ids, attention_mask=mask, labels=targets)
            loss, eval_logits = outputs[0], outputs[1]
            if MULTI_GPU:
                loss = loss.mean()
            
            eval_loss += loss.item()

            nb_eval_steps += 1
            nb_eval_examples += targets.size(0)
        
            if idx % 100==0:
                loss_step = eval_loss/nb_eval_steps
                print(f"Validation loss per 100 evaluation steps: {loss_step}")
              
            # compute evaluation accuracy
            flattened_targets = targets.view(-1) # shape (batch_size * seq_len,)
            num_labels = model.module.num_labels if MULTI_GPU else model.num_labels
            active_logits = eval_logits.view(-1, num_labels) # shape (batch_size * seq_len, num_labels)
            flattened_predictions = torch.argmax(active_logits, axis=1) # shape (batch_size * seq_len,)
            # now, use mask to determine where we should compare predictions with targets (includes [CLS] and [SEP] token predictions)
            active_accuracy = mask.view(-1) == 1 # active accuracy is also of shape (batch_size * seq_len,)
            targets = torch.masked_select(flattened_targets, active_accuracy)
            predictions = torch.masked_select(flattened_predictions, active_accuracy)
            
            eval_labels.extend(targets)
            eval_preds.extend(predictions)
            
            tmp_eval_accuracy = accuracy_score(targets.cpu().numpy(), predictions.cpu().numpy())
            eval_accuracy += tmp_eval_accuracy
    
    #print(eval_labels)
    #print(eval_preds)

    labels = [id2label[id.item()] for id in eval_labels]
    predictions = [id2label[id.item()] for id in eval_preds]

    #print(labels)
    #print(predictions)
    
    eval_loss = eval_loss / nb_eval_steps
    eval_accuracy = eval_accuracy / nb_eval_steps
    print(f"Validation Loss: {eval_loss}")
    print(f"Validation Accuracy: {eval_accuracy}")

    return labels, predictions

labels, predictions = valid(model, testing_loader)
"""
Validation loss per 100 evaluation steps: 5.371463775634766
Validation Loss: 5.623965330123902
Validation Accuracy: 0.9547014622783095
"""

指标计算

from seqeval.metrics import classification_report

print(classification_report([labels], [predictions]))
"""
              precision    recall  f1-score   support

     address       0.50      0.62      0.55       316
     company       0.65      0.77      0.70       270
  government       0.69      0.85      0.76       208
        name       0.87      0.87      0.87       374
organization       0.76      0.82      0.79       343

   micro avg       0.69      0.79      0.74      1511
   macro avg       0.69      0.79      0.73      1511
weighted avg       0.70      0.79      0.74      1511
"""

推理

from transformers import pipeline

model_to_test = (
    model.module if hasattr(model, "module") else model
)
pipe = pipeline(task="token-classification", model=model_to_test.to("cpu"), tokenizer=tokenizer, aggregation_strategy="simple")

pipe("我的名字是michal johnson,我的手机号是13425456344,我家住在东北松花江上8幢7单元6楼5号房")
"""
[{'entity_group': 'name',
  'score': 0.83746755,
  'word': 'michal johnson',
  'start': 5,
  'end': 19},
 {'entity_group': 'address',
  'score': 0.924768,
  'word': '东 北 松 花 江 上 8 幢 7 单 元 6 楼 5 号 房',
  'start': 42,
  'end': 58}]
"""

其它

相关参数

import torch
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1,3'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

MAX_LEN = 128
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 32
EPOCHS = 1
LEARNING_RATE = 1e-05
MAX_GRAD_NORM = 10
MULTI_GPU = False
ALLOW_LABEL = ["name", "organization", "address","company","government"]

CRF 模块

参考:https://github.com/CLUEbenchmark/CLUENER2020/blob/master/pytorch_version/models/crf.py

import torch
import torch.nn as nn
from typing import List, Optional

class CRF(nn.Module):
    """Conditional random field.
    This module implements a conditional random field [LMP01]_. The forward computation
    of this class computes the log likelihood of the given sequence of tags and
    emission score tensor. This class also has `~CRF.decode` method which finds
    the best tag sequence given an emission score tensor using `Viterbi algorithm`_.
    Args:
        num_tags: Number of tags.
        batch_first: Whether the first dimension corresponds to the size of a minibatch.
    Attributes:
        start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
            ``(num_tags,)``.
        end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
            ``(num_tags,)``.
        transitions (`~torch.nn.Parameter`): Transition score tensor of size
            ``(num_tags, num_tags)``.
    .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
       "Conditional random fields: Probabilistic models for segmenting and
       labeling sequence data". *Proc. 18th International Conf. on Machine
       Learning*. Morgan Kaufmann. pp. 282–289.
    .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
    """

    def __init__(self, num_tags: int, batch_first: bool = False) -> None:
        if num_tags <= 0:
            raise ValueError(f'invalid number of tags: {num_tags}')
        super().__init__()
        self.num_tags = num_tags
        self.batch_first = batch_first
        self.start_transitions = nn.Parameter(torch.empty(num_tags))
        self.end_transitions = nn.Parameter(torch.empty(num_tags))
        self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))

        self.reset_parameters()

    def reset_parameters(self) -> None:
        """Initialize the transition parameters.
        The parameters will be initialized randomly from a uniform distribution
        between -0.1 and 0.1.
        """
        nn.init.uniform_(self.start_transitions, -0.1, 0.1)
        nn.init.uniform_(self.end_transitions, -0.1, 0.1)
        nn.init.uniform_(self.transitions, -0.1, 0.1)

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(num_tags={self.num_tags})'

    def forward(self, emissions: torch.Tensor,
                tags: torch.LongTensor,
                mask: Optional[torch.ByteTensor] = None,
                reduction: str = 'mean') -> torch.Tensor:
        """Compute the conditional log likelihood of a sequence of tags given emission scores.
        Args:
            emissions (`~torch.Tensor`): Emission score tensor of size
                ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
                ``(batch_size, seq_length, num_tags)`` otherwise.
            tags (`~torch.LongTensor`): Sequence of tags tensor of size
                ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
                ``(batch_size, seq_length)`` otherwise.
            mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
                if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
            reduction: Specifies  the reduction to apply to the output:
                ``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
                ``sum``: the output will be summed over batches. ``mean``: the output will be
                averaged over batches. ``token_mean``: the output will be averaged over tokens.
        Returns:
            `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
            reduction is ``none``, ``()`` otherwise.
        """
        if reduction not in ('none', 'sum', 'mean', 'token_mean'):
            raise ValueError(f'invalid reduction: {reduction}')
        if mask is None:
            mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device)
        if mask.dtype != torch.uint8:
            mask = mask.byte()
        self._validate(emissions, tags=tags, mask=mask)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            tags = tags.transpose(0, 1)
            mask = mask.transpose(0, 1)

        # shape: (batch_size,)
        numerator = self._compute_score(emissions, tags, mask)
        # shape: (batch_size,)
        denominator = self._compute_normalizer(emissions, mask)
        # shape: (batch_size,)
        llh = numerator - denominator

        if reduction == 'none':
            return llh
        if reduction == 'sum':
            return llh.sum()
        if reduction == 'mean':
            return llh.mean()
        return llh.sum() / mask.float().sum()

    def decode(self, emissions: torch.Tensor,
               mask: Optional[torch.ByteTensor] = None,
               nbest: Optional[int] = None,
               pad_tag: Optional[int] = None) -> List[List[List[int]]]:
        """Find the most likely tag sequence using Viterbi algorithm.
        Args:
            emissions (`~torch.Tensor`): Emission score tensor of size
                ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
                ``(batch_size, seq_length, num_tags)`` otherwise.
            mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
                if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
            nbest (`int`): Number of most probable paths for each sequence
            pad_tag (`int`): Tag at padded positions. Often input varies in length and
                the length will be padded to the maximum length in the batch. Tags at
                the padded positions will be assigned with a padding tag, i.e. `pad_tag`
        Returns:
            A PyTorch tensor of the best tag sequence for each batch of shape
            (nbest, batch_size, seq_length)
        """
        if nbest is None:
            nbest = 1
        if mask is None:
            mask = torch.ones(emissions.shape[:2], dtype=torch.uint8,
                              device=emissions.device)
        if mask.dtype != torch.uint8:
            mask = mask.byte()
        self._validate(emissions, mask=mask)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            mask = mask.transpose(0, 1)

        if nbest == 1:
            return self._viterbi_decode(emissions, mask, pad_tag).unsqueeze(0)
        return self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag)

    def _validate(self, emissions: torch.Tensor,
                  tags: Optional[torch.LongTensor] = None,
                  mask: Optional[torch.ByteTensor] = None) -> None:
        if emissions.dim() != 3:
            raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
        if emissions.size(2) != self.num_tags:
            raise ValueError(
                f'expected last dimension of emissions is {self.num_tags}, '
                f'got {emissions.size(2)}')

        if tags is not None:
            if emissions.shape[:2] != tags.shape:
                raise ValueError(
                    'the first two dimensions of emissions and tags must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')

        if mask is not None:
            if emissions.shape[:2] != mask.shape:
                raise ValueError(
                    'the first two dimensions of emissions and mask must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
            no_empty_seq = not self.batch_first and mask[0].all()
            no_empty_seq_bf = self.batch_first and mask[:, 0].all()
            if not no_empty_seq and not no_empty_seq_bf:
                raise ValueError('mask of the first timestep must all be on')

    def _compute_score(self, emissions: torch.Tensor,
                       tags: torch.LongTensor,
                       mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # tags: (seq_length, batch_size)
        # mask: (seq_length, batch_size)
        seq_length, batch_size = tags.shape
        mask = mask.float()

        # Start transition score and first emission
        # shape: (batch_size,)
        score = self.start_transitions[tags[0]]
        score += emissions[0, torch.arange(batch_size), tags[0]]

        for i in range(1, seq_length):
            # Transition score to next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += self.transitions[tags[i - 1], tags[i]] * mask[i]

            # Emission score for next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]

        # End transition score
        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        # shape: (batch_size,)
        last_tags = tags[seq_ends, torch.arange(batch_size)]
        # shape: (batch_size,)
        score += self.end_transitions[last_tags]

        return score

    def _compute_normalizer(self, emissions: torch.Tensor,
                            mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        seq_length = emissions.size(0)

        # Start transition score and first emission; score has size of
        # (batch_size, num_tags) where for each batch, the j-th column stores
        # the score that the first timestep has tag j
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]

        for i in range(1, seq_length):
            # Broadcast score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emissions = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the sum of scores of all
            # possible tag sequences so far that end with transitioning from tag i to tag j
            # and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emissions

            # Sum over all possible current tags, but we're in score space, so a sum
            # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
            # all possible tag sequences so far, that end in tag i
            # shape: (batch_size, num_tags)
            next_score = torch.logsumexp(next_score, dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Sum (log-sum-exp) over all possible tags
        # shape: (batch_size,)
        return torch.logsumexp(score, dim=1)

    def _viterbi_decode(self, emissions: torch.FloatTensor,
                        mask: torch.ByteTensor,
                        pad_tag: Optional[int] = None) -> List[List[int]]:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        # return: (batch_size, seq_length)
        if pad_tag is None:
            pad_tag = 0

        device = emissions.device
        seq_length, batch_size = mask.shape

        # Start transition and first emission
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]
        history_idx = torch.zeros((seq_length, batch_size, self.num_tags),
                                  dtype=torch.long, device=device)
        oor_idx = torch.zeros((batch_size, self.num_tags),
                              dtype=torch.long, device=device)
        oor_tag = torch.full((seq_length, batch_size), pad_tag,
                             dtype=torch.long, device=device)

        # - score is a tensor of size (batch_size, num_tags) where for every batch,
        #   value at column j stores the score of the best tag sequence so far that ends
        #   with tag j
        # - history_idx saves where the best tags candidate transitioned from; this is used
        #   when we trace back the best tag sequence
        # - oor_idx saves the best tags candidate transitioned from at the positions
        #   where mask is 0, i.e. out of range (oor)

        # Viterbi algorithm recursive case: we compute the score of the best tag sequence
        # for every possible next tag
        for i in range(1, seq_length):
            # Broadcast viterbi score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emission = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the score of the best
            # tag sequence so far that ends with transitioning from tag i to tag j and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emission

            # Find the maximum score over all possible current tag
            # shape: (batch_size, num_tags)
            next_score, indices = next_score.max(dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # and save the index that produces the next score
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(-1), next_score, score)
            indices = torch.where(mask[i].unsqueeze(-1), indices, oor_idx)
            history_idx[i - 1] = indices

        # End transition score
        # shape: (batch_size, num_tags)
        end_score = score + self.end_transitions
        _, end_tag = end_score.max(dim=1)

        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1

        # insert the best tag at each sequence end (last position with mask == 1)
        history_idx = history_idx.transpose(1, 0).contiguous()
        history_idx.scatter_(1, seq_ends.view(-1, 1, 1).expand(-1, 1, self.num_tags),
                             end_tag.view(-1, 1, 1).expand(-1, 1, self.num_tags))
        history_idx = history_idx.transpose(1, 0).contiguous()

        # The most probable path for each sequence
        best_tags_arr = torch.zeros((seq_length, batch_size),
                                    dtype=torch.long, device=device)
        best_tags = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
        for idx in range(seq_length - 1, -1, -1):
            best_tags = torch.gather(history_idx[idx], 1, best_tags)
            best_tags_arr[idx] = best_tags.data.view(batch_size)

        return torch.where(mask, best_tags_arr, oor_tag).transpose(0, 1)

    def _viterbi_decode_nbest(self, emissions: torch.FloatTensor,
                              mask: torch.ByteTensor,
                              nbest: int,
                              pad_tag: Optional[int] = None) -> List[List[List[int]]]:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        # return: (nbest, batch_size, seq_length)
        if pad_tag is None:
            pad_tag = 0

        device = emissions.device
        seq_length, batch_size = mask.shape

        # Start transition and first emission
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]
        history_idx = torch.zeros((seq_length, batch_size, self.num_tags, nbest),
                                  dtype=torch.long, device=device)
        oor_idx = torch.zeros((batch_size, self.num_tags, nbest),
                              dtype=torch.long, device=device)
        oor_tag = torch.full((seq_length, batch_size, nbest), pad_tag,
                             dtype=torch.long, device=device)

        # + score is a tensor of size (batch_size, num_tags) where for every batch,
        #   value at column j stores the score of the best tag sequence so far that ends
        #   with tag j
        # + history_idx saves where the best tags candidate transitioned from; this is used
        #   when we trace back the best tag sequence
        # - oor_idx saves the best tags candidate transitioned from at the positions
        #   where mask is 0, i.e. out of range (oor)

        # Viterbi algorithm recursive case: we compute the score of the best tag sequence
        # for every possible next tag
        for i in range(1, seq_length):
            if i == 1:
                broadcast_score = score.unsqueeze(-1)
                broadcast_emission = emissions[i].unsqueeze(1)
                # shape: (batch_size, num_tags, num_tags)
                next_score = broadcast_score + self.transitions + broadcast_emission
            else:
                broadcast_score = score.unsqueeze(-1)
                broadcast_emission = emissions[i].unsqueeze(1).unsqueeze(2)
                # shape: (batch_size, num_tags, nbest, num_tags)
                next_score = broadcast_score + self.transitions.unsqueeze(1) + broadcast_emission

            # Find the top `nbest` maximum score over all possible current tag
            # shape: (batch_size, nbest, num_tags)
            next_score, indices = next_score.view(batch_size, -1, self.num_tags).topk(nbest, dim=1)

            if i == 1:
                score = score.unsqueeze(-1).expand(-1, -1, nbest)
                indices = indices * nbest

            # convert to shape: (batch_size, num_tags, nbest)
            next_score = next_score.transpose(2, 1)
            indices = indices.transpose(2, 1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # and save the index that produces the next score
            # shape: (batch_size, num_tags, nbest)
            score = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), next_score, score)
            indices = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), indices, oor_idx)
            history_idx[i - 1] = indices

        # End transition score shape: (batch_size, num_tags, nbest)
        end_score = score + self.end_transitions.unsqueeze(-1)
        _, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1)

        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1

        # insert the best tag at each sequence end (last position with mask == 1)
        history_idx = history_idx.transpose(1, 0).contiguous()
        history_idx.scatter_(1, seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest),
                             end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest))
        history_idx = history_idx.transpose(1, 0).contiguous()

        # The most probable path for each sequence
        best_tags_arr = torch.zeros((seq_length, batch_size, nbest),
                                    dtype=torch.long, device=device)
        best_tags = torch.arange(nbest, dtype=torch.long, device=device) \
                         .view(1, -1).expand(batch_size, -1)
        for idx in range(seq_length - 1, -1, -1):
            best_tags = torch.gather(history_idx[idx].view(batch_size, -1), 1, best_tags)
            best_tags_arr[idx] = best_tags.data.view(batch_size, -1) // nbest

        return torch.where(mask.unsqueeze(-1), best_tags_arr, oor_tag).permute(2, 1, 0)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1627464.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何安全进行速卖通自养号测评操作?

对于新加入的卖家而言&#xff0c;进行销量测评显得尤为关键。速卖通平台上的新店往往难以获得活动的扶持&#xff0c;且初始流量相当有限。因此&#xff0c;开店的首要任务便是积极展开测评工作&#xff0c;努力积累初始的评论和销售记录。测评的益处颇为显著&#xff0c;它不…

Android Dalvik虚拟机JNI方法的注册过程分析

Dalvik虚拟机在调用一个成员函数的时候&#xff0c;如果发现该成员函数是一个JNI方法&#xff0c;那么就会直接跳到它的地址去执行。也就是说&#xff0c;JNI方法是直接在本地操作系统上执行的&#xff0c;而不是由Dalvik虚拟机解释器执行。由此也可看出&#xff0c;JNI方法是A…

数据结构(九)---并查集

目录 1.集合 2.集合的相关操作 (1)查(Find)&#xff1a; •Find操作的优化 (2)并(Union)&#xff1a; •Union操作的优化 1.集合 数据元素之间的逻辑关系可以为集合&#xff0c;树形关系&#xff0c;线性关系&#xff0c;图关系。对于集合而言&#xff0c;一个集合可以划…

微信小程序:8.WXSS

WXSS和CSS的关系 WXSS具有CSS大部分特性&#xff0c;同时&#xff0c;WXSS还对CSS进行扩充以及修改&#xff0c;适应微信小程序的开发。 与CSS相比&#xff0c;WXSS扩展的特性有&#xff1a; rpx尺寸单位imprt样式导入 rpx尺寸单位 rpx是微信小程序中独有的&#xff0c;用来…

第三节课,后端登录【1】.2--本人

一、视频链接 网址&#xff1a; 后端用户脱敏和session-CSDN直播 二、代码开始 2.1 新建一个request参数。完成用户登录态键 快捷建&#xff0c; 全局变量 代码&#xff1a; // 3.记录用户的登录态/*** 这段代码是Java Web开发中的一部分&#xff0c;用于在会话&#xff08…

面试:finalize

一、概述 将资源释放和清理放在finalize方法中非常不好&#xff0c;非常影响性能&#xff0c;严重时甚至会引起OOM&#xff08;Out Of Memory&#xff09;&#xff0c;从Java9开始就被标注为Deprecated&#xff0c;不建议被使用了。 二、两个重要的队列 1、unfinalized 队列 当…

为什么 Facebook 不使用 Git?

在编程的世界里&#xff0c;Git 就像水一样常见&#xff0c;以至于我们认为它是创建和管理代码更改的唯一可行的工具。 前 Facebook 员工&#xff0c;2024 年 首先&#xff0c;我为什么关心&#xff1f; 我致力于构建 Graphite&#xff0c;它从根本上受到 Facebook 内部工具的…

FSMC读取FPGA的FIFO

一、硬件说明 FSMC配置 单片机的代码如下&#xff1a; #define VALUE_ADDRESS_AD1 (__IO uint16_t *)0x60400000while (1){if(!HAL_GPIO_ReadPin(GPIOF, GPIO_PIN_8)) //数据非空{data *(__IO uint16_t *)VALUE_ADDRESS_AD1;data2 *(__IO uint16_t *)VALUE_ADDRESS_AD1…

网上订餐系统,基于 SpringBoot+Vue+MySQL 开发的前后端分离的网上订餐系统设计实现

目录 一. 前言 二. 功能模块 2.1. 用户功能模块的实现 2.2. 管理员功能模块的实现 三. 部分代码实现 四. 源码下载 一. 前言 随着我国经济的飞速发展&#xff0c;人们的生活速度明显加快&#xff0c;在餐厅吃饭排队的情况到处可见&#xff0c;近年来由于新兴IT行业的空前…

python的turtle库画直线

1.画一条直线 让画笔从(0,0)划到&#xff08;100,100&#xff09;&#xff0c;在turtle中画笔是一只小乌龟。 import turtle turtle.setup(800,800,0,0)#turtle.setup(width,height,startx,starty)来设置窗口初始位置及大小 turtle.goto(100,100)2.画一条折线 left和right使小…

安装crossover游戏提示容量不足怎么办 如何把游戏放到外置硬盘里 Mac电脑清理磁盘空间不足

CrossOver作为一款允许用户在非原生操作系统上运行游戏和应用程序的软件&#xff0c;为不同平台的用户提供了极大的便利。然而&#xff0c;随着游戏文件大小的不断增加&#xff0c;内置硬盘的容量往往无法满足安装需求。幸运的是&#xff0c;通过一些简单的步骤&#xff0c;我们…

怎么样解决web图片加载未更新问题|浏览器图片未更新问题

问题列举 为什么我本地资源改变但是我在用 tomcat 预览网页时图片仍然未之前的图片? 为什么我当前网页的图片是之前的我换浏览器就变了? tomcat启动后&#xff0c;为什么访问项目中的图片无效解决? 解决方案 问题一 为什么我本地资源改变但是我在用 tomcat 预览网页时…

关于使用SpringSecurity框架发起JSON请求,但因登陆失效导致响应403的问题。

这里记录一个生产中遇到的一个问题。 现有环境是基于SpringBoot 2.6.8&#xff0c;然后是前后台一体化的项目。 安全框架使用的是内置版本的SpringSecurity。 在实际使用过程中遇到一个问题。 就是当用户登陆失效后&#xff0c;前端操作JSON请求获取列表数据&#xff0c;但…

编程学习路线

Java最强学习路线 快来官网定制一套属于自己的学习路线吧 官方网址&#xff1a; Learn to become a modern Java developerCommunity driven, articles, resources, guides, interview questions, quizzes for java development. Learn to become a modern Java developer by…

架构师系列- 消息中间件(12)-kafka基础

1、应用场景 1.1 kafka场景 Kafka最初是由LinkedIn公司采用Scala语言开发&#xff0c;基于ZooKeeper&#xff0c;现在已经捐献给了Apache基金会。目前Kafka已经定位为一个分布式流式处理平台&#xff0c;它以 高吞吐、可持久化、可水平扩展、支持流处理等多种特性而被广泛应用…

go 测试和文件

go 测试和文件 需求传统测试单元测试牛刀小试总结练习文件介绍打开关闭文件读文件一次性读取文件写文件文件或文件夹是否存在文件拷贝 需求 有一个函数&#xff0c;怎样确认他运行结果是正确的&#xff1f; func addUpper(n int)int {res : 0for i : 1; i < n; i {res1}r…

牛客NC368 质数的计数【中等 基础数学,数论 C++/Java/Go/PHP】

题目 题目链接&#xff1a; https://www.nowcoder.com/practice/190167d1990442da9adb133980259a27 思路 判断x是否是质数&#xff1a;这是判断质数最好的代码了public boolean isPrime(int x){if(x 2 || x3) return true;if(x%6!1 && x%6!5) return false; //不在6倍…

微服务之分布式理论概述

一、分布式技术相关的理论 CAP理论 CAP定理(CAP theorem)&#xff0c;⼜被称作布鲁尔定理(Eric Brewer)&#xff0c;1998年第⼀次提出. 最初提出是指分布式数据存储不可能同时提供以下三种保证中的两种以上: (1) ⼀致性(Consistency): 每次读取收到的信息都是最新的; (2) …

7天入门Android开发之第1天——初识Android

一、Android系统 1.Linux内核层&#xff1a; 这是安卓系统的底层&#xff0c;它提供了基本的系统功能&#xff0c;如内存管理、进程管理、驱动程序模型等。安卓系统构建在Linux内核之上&#xff0c;借助于Linux的稳定性和安全性。 2.系统运行库层&#xff1a; 这一层包括了安卓…

Python用于高级异常检测和聚类的工具库之BanditPAM使用详解

概要 Python BanditPAM库是一个用于高级异常检测和聚类的工具,具有强大的特性和灵活的功能,可以发现数据中的异常点并进行有效的聚类分析。本文将详细介绍Python BanditPAM库的安装、特性、基本功能、高级功能以及总结。 安装 首先,需要安装Python BanditPAM库。 可以使用…