预训练中文GPT2(包括重新训练tokenizer)

news2025/1/10 12:15:58

训练数据

1.json后缀的文件

2.数据是json line格式,一行一条json

3. json结构如下

{
  
  "content": "①北京和上海户籍的游客可获得韩国多次签证;②“整容客”可以不经由韩国使领馆、直接在网上申请签证;③中泰免签的实施日期尚未敲定;④越南已向中国持通行证旅游的公民全面开放。"
}

tokenizer训练(BPE)

from transformers import AutoTokenizer
from datasets import load_dataset

path = r'/tmp/pycharm_project_806/LCSTS_new/train.json'  # a chinese text dataset
raw_data = load_dataset("json", data_files=path, split='train')

training_corpus = (
    raw_data[i : i + 1000]["content"]
    for i in range(0, len(raw_data), 1000)
)

old_tokenizer = AutoTokenizer.from_pretrained("/home/chenjq/model/gpt2")
tokenizer = old_tokenizer.train_new_from_iterator(training_corpus, 52000)

example = '就是去美国大使馆的官方网站,它有中文版,去把每一条仔细研究透了,把每一个表格和材料都准备好了'  # chinese text
old_tokens = old_tokenizer.tokenize(example)
print('old_tokens:',old_tokens)

new_tokens = tokenizer.tokenize(example)
print('new_tokens',new_tokens)
tokenizer.save_pretrained("./my-tok")

tokenizer训练(sentencePiece)

from tokenizers import (
    decoders,
    models,
    normalizers,
    pre_tokenizers,
    processors,
    trainers,
    Tokenizer,
)
from datasets import load_dataset
from tokenizers import Regex

path = r'all_train.json'  # a chinese text dataset
# path = r'/tmp/pycharm_project_806/cluener.json'  # a chinese text dataset
raw_data = load_dataset("json", data_files=path, split='train')

training_corpus = (
    raw_data[i : i + 1000]["content"]
    for i in range(0, len(raw_data), 1000)
)


tokenizer = Tokenizer(models.Unigram())

# NLG不应当加入 normalizers.Lowercase(),因为在decode的时候,就无法生成大写的了
# 在bert等NLU模型中,可以加入 normalizers.Lowercase(),因为NLU一般不用于文本生成,而是用于文本理解(如文本分类,实体抽取),
# 这种情况下其实大写小写无所谓
tokenizer.normalizer = normalizers.Sequence(
    [
        normalizers.Replace("``", '"'),
        normalizers.Replace("''", '"'),
        normalizers.NFKD(),
        normalizers.StripAccents(),
        normalizers.Replace(Regex(" {2,}"), " "),
    ]
)

tokenizer.pre_tokenizer = pre_tokenizers.Metaspace()

print(tokenizer.pre_tokenizer.pre_tokenize_str("北京是中国的首都,今天天气真好。"))
print(1)

special_tokens = ["<|endoftext|>"]
trainer = trainers.UnigramTrainer(
    vocab_size=52000, special_tokens=special_tokens, unk_token="<unk>",max_piece_length=4,
)
tokenizer.train_from_iterator(training_corpus, trainer=trainer)

tokenizer.decoder = decoders.Metaspace()

from transformers import PreTrainedTokenizerFast

wrapped_tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=tokenizer,
    bos_token="<|endoftext|>",
    eos_token="<|endoftext|>",
)
wrapped_tokenizer.save_pretrained('./sp-tok')

print(wrapped_tokenizer.tokenize("北京是中国的首都,今天天气真好。"))

模型训练

#!/usr/bin/env python
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.

Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=text-generation
"""
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.

import logging
import math
import os
import sys
import warnings
from dataclasses import dataclass, field
from itertools import chain
from typing import Optional

import datasets
import evaluate
import torch
from datasets import load_dataset

import transformers
from transformers import (
    CONFIG_MAPPING,
    MODEL_FOR_CAUSAL_LM_MAPPING,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    default_data_collator,
    is_torch_tpu_available,
    set_seed,
)
from transformers.testing_utils import CaptureLogger
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.37.0.dev0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

logger = logging.getLogger(__name__)


MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
            )
        },
    )
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
    )
    config_overrides: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override some existing default config settings when a model is trained from scratch. Example: "
                "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
            )
        },
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    token: str = field(
        default=None,
        metadata={
            "help": (
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
            )
        },
    )
    use_auth_token: bool = field(
        default=None,
        metadata={
            "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
        },
    )
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
                "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
                "execute code present on the Hub on your local machine."
            )
        },
    )
    torch_dtype: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
                "dtype will be automatically derived from the model's weights."
            ),
            "choices": ["auto", "bfloat16", "float16", "float32"],
        },
    )
    low_cpu_mem_usage: bool = field(
        default=False,
        metadata={
            "help": (
                "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
                "set True will benefit LLM loading time and RAM consumption."
            )
        },
    )

    def __post_init__(self):
        if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
            raise ValueError(
                "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
            )


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
        },
    )
    streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
    block_size: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "Optional input sequence length after tokenization. "
                "The training dataset will be truncated in block of this size for training. "
                "Default to the model max input length for single sentence inputs (take into account special tokens)."
            )
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    validation_split_percentage: Optional[int] = field(
        default=5,
        metadata={
            "help": "The percentage of the train set used as validation set in case there's no validation split"
        },
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    keep_linebreaks: bool = field(
        default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
    )

    def __post_init__(self):
        if self.streaming:
            require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")

        if self.dataset_name is None and self.train_file is None and self.validation_file is None:
            raise ValueError("Need either a dataset name or a training/validation file.")
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
                assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
            if self.validation_file is not None:
                extension = self.validation_file.split(".")[-1]
                assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."


def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if model_args.use_auth_token is not None:
        warnings.warn(
            "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
            FutureWarning,
        )
        if model_args.token is not None:
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
        model_args.token = model_args.use_auth_token

    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_clm", model_args, data_args)

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    if training_args.should_log:
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
        transformers.utils.logging.set_verbosity_info()

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
        + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    # Detecting last checkpoint.
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        raw_datasets = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
            token=model_args.token,
            streaming=data_args.streaming,
        )
        if "validation" not in raw_datasets.keys():
            raw_datasets["validation"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                streaming=data_args.streaming,
            )
            raw_datasets["train"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                streaming=data_args.streaming,
            )
    else:
        data_files = {}
        dataset_args = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = (
            data_args.train_file.split(".")[-1]
            if data_args.train_file is not None
            else data_args.validation_file.split(".")[-1]
        )
        if extension == "txt":
            extension = "text"
            dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
        raw_datasets = load_dataset(
            extension,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
            token=model_args.token,
            **dataset_args,
        )
        # If no validation data is there, validation_split_percentage will be used to divide the dataset.
        if "validation" not in raw_datasets.keys():
            raw_datasets["validation"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                **dataset_args,
            )
            raw_datasets["train"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
                token=model_args.token,
                **dataset_args,
            )

    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    config_kwargs = {
        "cache_dir": model_args.cache_dir,
        "revision": model_args.model_revision,
        "token": model_args.token,
        "trust_remote_code": model_args.trust_remote_code,
    }
    if model_args.config_name:
        config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
    elif model_args.model_name_or_path:
        config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")
        if model_args.config_overrides is not None:
            logger.info(f"Overriding config: {model_args.config_overrides}")
            config.update_from_string(model_args.config_overrides)
            logger.info(f"New config: {config}")

    tokenizer_kwargs = {
        "cache_dir": model_args.cache_dir,
        "use_fast": model_args.use_fast_tokenizer,
        "revision": model_args.model_revision,
        "token": model_args.token,
        "trust_remote_code": model_args.trust_remote_code,
    }
    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
    else:
        raise ValueError(
            "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.model_name_or_path:
        torch_dtype = (
            model_args.torch_dtype
            if model_args.torch_dtype in ["auto", None]
            else getattr(torch, model_args.torch_dtype)
        )
        model = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            token=model_args.token,
            trust_remote_code=model_args.trust_remote_code,
            torch_dtype=torch_dtype,
            low_cpu_mem_usage=model_args.low_cpu_mem_usage,
        )
    else:
        model = AutoModelForCausalLM.from_config(config, trust_remote_code=model_args.trust_remote_code)
        n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
        logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")

    # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
    # on a small vocab and want a smaller embedding size, remove this test.
    embedding_size = model.get_input_embeddings().weight.shape[0]
    if len(tokenizer) > embedding_size:
        model.resize_token_embeddings(len(tokenizer))

    # Preprocessing the datasets.
    # First we tokenize all the texts.
    if training_args.do_train:
        column_names = list(raw_datasets["train"].features)
    else:
        column_names = list(raw_datasets["validation"].features)
    # text_column_name = "text" if "text" in column_names else column_names[0]
    text_column_name = "content"

    # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
    tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")

    def tokenize_function(examples):
        with CaptureLogger(tok_logger) as cl:
            output = tokenizer(examples[text_column_name])
        # clm input could be much much longer than block_size
        if "Token indices sequence length is longer than the" in cl.out:
            tok_logger.warning(
                "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
                " before being passed to the model."
            )
        return output

    with training_args.main_process_first(desc="dataset map tokenization"):
        if not data_args.streaming:
            tokenized_datasets = raw_datasets.map(
                tokenize_function,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on dataset",
            )
        else:
            tokenized_datasets = raw_datasets.map(
                tokenize_function,
                batched=True,
                remove_columns=column_names,
            )
    if hasattr(config, "max_position_embeddings"):
        max_pos_embeddings = config.max_position_embeddings
    else:
        # Define a default value if the attribute is missing in the config.
        max_pos_embeddings = 1024

    if data_args.block_size is None:
        block_size = tokenizer.model_max_length
        if block_size > max_pos_embeddings:
            logger.warning(
                f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
                f"Using block_size={min(1024, max_pos_embeddings)} instead. You can change that default value by passing --block_size xxx."
            )
            if max_pos_embeddings > 0:
                block_size = min(1024, max_pos_embeddings)
            else:
                block_size = 1024
    else:
        if data_args.block_size > tokenizer.model_max_length:
            logger.warning(
                f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model "
                f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
            )
        block_size = min(data_args.block_size, tokenizer.model_max_length)

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, and if the total_length < block_size  we exclude this batch and return an empty dict.
        # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
        total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
    # to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/process#map

    with training_args.main_process_first(desc="grouping texts together"):
        if not data_args.streaming:
            lm_datasets = tokenized_datasets.map(
                group_texts,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                load_from_cache_file=not data_args.overwrite_cache,
                desc=f"Grouping texts in chunks of {block_size}",
            )
        else:
            lm_datasets = tokenized_datasets.map(
                group_texts,
                batched=True,
            )

    if training_args.do_train:
        if "train" not in tokenized_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = lm_datasets["train"]
        if data_args.max_train_samples is not None:
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))

    if training_args.do_eval:
        if "validation" not in tokenized_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = lm_datasets["validation"]
        if data_args.max_eval_samples is not None:
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))

        def preprocess_logits_for_metrics(logits, labels):
            if isinstance(logits, tuple):
                # Depending on the model and config, logits may contain extra tensors,
                # like past_key_values, but logits always come first
                logits = logits[0]
            return logits.argmax(dim=-1)

        metric = evaluate.load("accuracy")

        def compute_metrics(eval_preds):
            preds, labels = eval_preds
            # preds have the same shape as the labels, after the argmax(-1) has been calculated
            # by preprocess_logits_for_metrics but we need to shift the labels
            labels = labels[:, 1:].reshape(-1)
            preds = preds[:, :-1].reshape(-1)
            return metric.compute(predictions=preds, references=labels)

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        # Data collator will default to DataCollatorWithPadding, so we change it.
        data_collator=default_data_collator,
        compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics
        if training_args.do_eval and not is_torch_tpu_available()
        else None,
    )

    # Training
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        elif last_checkpoint is not None:
            checkpoint = last_checkpoint
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        trainer.save_model()  # Saves the tokenizer too for easy upload

        metrics = train_result.metrics

        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluation
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

        metrics = trainer.evaluate()

        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
        try:
            perplexity = math.exp(metrics["eval_loss"])
        except OverflowError:
            perplexity = float("inf")
        metrics["perplexity"] = perplexity

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
    if data_args.dataset_name is not None:
        kwargs["dataset_tags"] = data_args.dataset_name
        if data_args.dataset_config_name is not None:
            kwargs["dataset_args"] = data_args.dataset_config_name
            kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
        else:
            kwargs["dataset"] = data_args.dataset_name

    if training_args.push_to_hub:
        trainer.push_to_hub(**kwargs)
    else:
        trainer.create_model_card(**kwargs)


def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()

"""
python run_clm.py \
    --train_file /tmp/pycharm_project_806/LCSTS_new/train.json \
    --tokenizer_name /home/chenjq/pythonWork/nlp/train_new_gpt2/my-tok \
    --model_type gpt2 \
    --num_train_epochs 2 \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 8 \
    --do_train \
    --output_dir ./tmp/test-clm
    
/tmp/pycharm_project_806/LCSTS_new/train.json
/tmp/pycharm_project_806/cluener.json
    --gradient_accumulation_steps 8 \

--max_train_samples 1000
"""

训练代码参考:

https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/README.md

 结果对比

推理代码

from transformers import GPT2Tokenizer,GPT2LMHeadModel, set_seed
set_seed(42)

# model_path = '/tmp/pycharm_project_806/tmp/test-clm/checkpoint-5500'
model_path = "/home/chenjq/model/gpt2"

tokenizer = GPT2Tokenizer.from_pretrained(model_path)

# add the EOS token as PAD token to avoid warnings
model = GPT2LMHeadModel.from_pretrained(model_path,pad_token_id=tokenizer.eos_token_id)

# encode context the generation is conditioned on
input_ids = tokenizer.encode('美国', return_tensors='pt')

# generate text until the output length (which includes the context length) reaches 50
greedy_output = model.generate(input_ids, max_length=50)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))


# activate beam search and early_stopping
beam_output = model.generate(
    input_ids,
    max_length=50,
    num_beams=5,
    early_stopping=True
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output[0], skip_special_tokens=True))



# set no_repeat_ngram_size to 2
beam_output = model.generate(
    input_ids,
    max_length=50,
    num_beams=5,
    no_repeat_ngram_size=2,
    early_stopping=True
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output[0], skip_special_tokens=True))



# set return_num_sequences > 1
beam_outputs = model.generate(
    input_ids,
    max_length=50,
    num_beams=5,
    no_repeat_ngram_size=2,
    num_return_sequences=5,
    early_stopping=True
)

# now we have 3 output sequences
print("Output:\n" + 100 * '-')
for i, beam_output in enumerate(beam_outputs):
  print("{}: {}".format(i, tokenizer.decode(beam_output, skip_special_tokens=True)))




# activate sampling and deactivate top_k by setting top_k sampling to 0
sample_output = model.generate(
    input_ids,
    do_sample=True,
    max_length=50,
    top_k=0
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))




# use temperature to decrease the sensitivity to low probability candidates
sample_output = model.generate(
    input_ids,
    do_sample=True,
    max_length=50,
    top_k=0,
    temperature=0.7
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))





# set top_k to 50
sample_output = model.generate(
    input_ids,
    do_sample=True,
    max_length=50,
    top_k=50
)

print("top_k Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))





# deactivate top_k sampling and sample only from 92% most likely words
sample_output = model.generate(
    input_ids,
    do_sample=True,
    max_length=50,
    top_p=0.92,
    top_k=0
)

print("top_p Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))

原始GPT2

自己训练的GPT2 (BPE tokenizer)

自己训练的GPT2  (sentencePiece tokenizer)

结论

1.训练数据采用了LCSTS数据集,LCSTS_new是中文短摘要最常用的LCSTS短摘要数据集的升级版本,在数据量、质量方面均有显著提升,在信息摘要与提炼的过程中,与原文的事实一致性需要得到重点关注。

{
  "id": 6,
  "summary": "中国游客大增多国放宽签证",
  "content": "①北京和上海户籍的游客可获得韩国多次签证;②“整容客”可以不经由韩国使领馆、直接在网上申请签证;③中泰免签的实施日期尚未敲定;④越南已向中国持通行证旅游的公民全面开放。"
}

2.从生成结果上看,自己训练的比原始的更好。

3.训练数据大约500M,都是短文本,新闻数据,缺乏多样性。可以尝试增加数据多样性,增加文本长度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1377867.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

屏幕截图编辑工具Snagit中文

Snagit是一款优秀的屏幕、文本和视频捕获与转换程序。它能够捕获屏幕、窗口、客户区窗口、最后一个激活的窗口或用鼠标定义的区域&#xff0c;并支持BMP、PCX、TIF、GIF或JPEG格式的保存。Snagit还具有自动缩放、颜色减少、单色转换、抖动等功能&#xff0c;并能将捕获的图像转…

【Linux】:Linux中的Git分支管理

本章开始介绍 Git 的杀⼿级功能之⼀&#xff08;注意是之⼀&#xff0c;也就是后⾯还有之⼆&#xff0c;之三……&#xff09;&#xff1a;分⽀。分⽀就是科幻电影⾥⾯的平⾏宇宙&#xff0c;当你正在电脑前努⼒学习 C 的时候&#xff0c;另⼀个你正在另⼀个平⾏宇宙⾥努⼒学习…

x3daudio1_7.dll如何恢复,这6个方法都能修复x3daudio1_7.dll丢失问题

x3daudio1_7.dll文件缺失”。那么&#xff0c;什么是x3daudio17.dll文件&#xff1f;它的作用和影响又是什么呢&#xff1f;本文将详细介绍x3daudio17.dll文件的定义、作用和影响&#xff0c;并提供6个修复方法来解决这个问题。 一、x3daudio1_7.dll是什么&#xff1f; x3dau…

高级分布式系统-第6讲 分布式系统的容错性--可靠的组通信

可靠的组通信 组内通信最好是每个进程之间都建立点到点的通信&#xff0c; 但实际中这样的组织结构不是有效的&#xff0c; 因为会浪费很大的通信带宽。 在平等组中&#xff0c; 多播是主要的组织结构。 但多播是具有同步性质的容错结构&#xff0c; 并不适用拜占庭模型。 多…

OpenGl L6坐标系统

一.标准化设备坐标 我们在L5谈到了对顶点着色器中的点进行变换&#xff0c;而变换的范围必须在 -1.0到1.0 之间&#xff0c;否者将不可见。只有将所有的点转换为标准化设备坐标后&#xff0c;才能全部传入光栅器&#xff0c;再转换为屏幕上的像素。 将坐标变换为标准化设备坐标…

【MySQL】C语言连接MySQL

文章目录 一、引入库下载库文件验证是否引入成功 二、MySQL C API相关接口三、总结 一、引入库 mysql的基础&#xff0c;我们之前已经学过&#xff0c;后面我们只关心使用要 使用C语言连接mysql &#xff0c;需要使用mysql官网提供的库&#xff0c;大家可以去MySQL官网下载。 …

随机漫步【scatter的使用】

去掉scatter的坐标轴&#xff08;未成功版&#xff09; import matplotlib.pyplot as plt from random import choice class RandomWalk():def __init__(self,num_points 5000):self.num_points num_pointsself.x_values [0]self.y_values [0]def fill_walk(self):while l…

父组件中 arr.push改变数组,但是子组件监听不到 arr 的变化

目录 一、问题 二、解决方法 三、总结 tiips:如嫌繁琐&#xff0c;直接移步总结即可&#xff01; 一、问题 1.真是奇怪呀&#xff0c;一般来说通过 push方法改变 数组&#xff0c;是一定会有响应式的&#xff0c;那就可以监听到变化。但是我今天却遇到了一件奇怪的事情。在…

多模态推荐系统综述:四、模型优化

四、模型优化 由于多模态信息的存在&#xff0c;当多模态编码器和推荐模型一起训练时&#xff0c;模型训练的计算要求大大增加。因此&#xff0c;多模态推荐模型在训练过程中可以分为两类&#xff1a;端到端训练和两步训练。 端到端训练可以利用反向传播获得的每个梯度来更新模…

2024.1.11 关于 Jedis 库操作 Redis 基本演示

目录 引言 通用命令 SET & GET EXISTS & DEL KEYS EXPIRE & TTL TYPE String 类型命令 MGET & MSET GETRANGE & SETRANGE APPEND INCR & DECR List 类型命令 LPUSH & LRANG LPOP & LPOP BLPOP & BRPOP LLEN Set 类型命…

Shutter Encoder多媒体转换v17.8

软件介绍 多媒体包含种类繁多的各种文件格式&#xff0c;每种格式都有其不同的特征和所谓的“怪癖”。 因此&#xff0c;如果使用多种图像、视频或音频格式&#xff0c;找到一个集中的软件来从一个地方处理所有这些格式可能会非常棘手。 这就是 Shutter Encoder 基本上允许做的…

科研绘图(二)气泡图

气泡矩阵图&#xff08;Bubble Matrix Plot&#xff09;&#xff0c;通常用于显示三个变量之间的关系。这种图表类型将数据点表示为气泡的形式&#xff0c;其中气泡的大小通常表示第三个数值变量的大小。图表的X轴和Y轴代表两个分类或定量变量。颜色可能代表另一个分类变量或是…

计算机缺失msvcp120.dll的最新解决方法,实测可以完美修复

在计算机使用过程中&#xff0c;我们经常会遇到一些错误提示&#xff0c;其中之一就是“msvcp120.dll丢失”。msvcp120.dll是Microsoft Visual C Redistributable Package的一部分&#xff0c;它是运行许多基于Windows操作系统的应用程序所必需的动态链接库文件之一。如果计算机…

矿山无人驾驶方案

矿山无人驾驶运输系统&#xff0c;可实现露天矿采煤装载运输的无人化&#xff0c;满足智能矿山安全、高效、绿色、环保等目标。 无人驾驶应用的总体技术架构包括“车端、场端、云端”三个层面以及相应的安全保障体系&#xff0c;其中车端的智能矿卡具备车辆感知、通信、决策和执…

数字信号处理教程学习笔记1-第2章时域中的离散信号和系统

信号处理的任务示意方框图 模拟信号和数字信号分别是啥样的,有啥区别

【AI视野·今日CV 计算机视觉论文速览 第286期】Tue, 9 Jan 2024

AI视野今日CS.CV 计算机视觉论文速览 Tue, 9 Jan 2024 Totally 121 papers &#x1f449;上期速览✈更多精彩请移步主页 Daily Computer Vision Papers Dr$^2$Net: Dynamic Reversible Dual-Residual Networks for Memory-Efficient Finetuning Authors Chen Zhao, Shuming Li…

富唯智能新研发的复合机器人,轻松破解汽车底盘零配件生产中的难题

随着汽车工业的快速发展&#xff0c;对于底盘零配件的需求也日益增长。为了满足市场需求&#xff0c;智能物流解决方案在汽车底盘零配件生产中扮演着越来越重要的角色。如何实现高效、准确的生产和物流管理&#xff0c;以满足市场快速变化的需求&#xff0c;成为了汽车生产商亟…

日期类的实现|运算符重载的复用

前言 通过前面C入门与类与对象的学习&#xff0c;今天我们将运用所学的知识点完成一个Date类。 本节目标 运用所学知识完成Date类。详细讲解运算符各种重载。理解运算符重载的复用。 一、Date类的六个默认成员函数 六个成员函数&#xff0c;Date类只需要自己实现构造函数即可…

比尔盖茨:如果只能解决一个问题,我的答案总是营养不良

谷禾健康 当地时间12月19日&#xff0c;微软联合创始人、亿万富翁比尔盖茨发布了对来年的年度预测&#xff0c;称 2024 年将是一个“转折点”。 在这封长达 10 页的信中他展示了对人工智能领域的更多创新、婴儿营养不良问题的突破、气候变化谈判的进展等多方面的期待。 人工智能…

vue-virtual-scroll-list(可单选、多选、搜索查询、创建条目)

element-ui-解决下拉框数据量过多问题&#xff08;vue-virtual-scroll-list&#xff09;_element-ui下拉框数据太多如何优化-CSDN博客 的升级版 参考链接&#xff1a;封装el-select&#xff0c;实现虚拟滚动,可单选、多选、搜索查询、创建条目-CSDN博客 1.封装组件 select.v…