使用 JAX 进行 LLM 分布式监督微调

news2025/1/15 23:30:48

LLM distributed supervised fine-tuning with JAX — ROCm Blogs (amd.com)

24年1月25日,Douglas Jia 发布在AMD ROCm 博客上的文章。

在这篇文章中,我们回顾了使用 JAX 对基于双向编码器表示(BERT)的大型语言模型(LLM)进行文本分类任务微调的过程。我们探讨了在多个 AMD GPU 上并行化这一微调过程的技术,然后评估模型在测试数据集上的性能。为此,我们使用了一个基于 BERT的 cased transformer 模型和 General Language Understanding Evaluation(GLUE)基准数据集在多个 AMD GPU 上进行实验。

我们重点关注 JAX 中两个单程序多数据(SPMD)并行化方法。这两个方法是:
- 使用 pmap 函数在单个领先轴上进行简单的数据分发。
- 使用 jit、`Mesh` 和 mesh_utils 函数在设备之间分片数据,提供更大的并行化控制。

我们主要强调第一个方法,并在文章的最后部分提供了第二个方法的详细说明。
在撰写本文时,我们参考了这个教程,我们强烈推荐阅读。

什么是监督微调?

在人工智能(AI)时代,基于Transformer架构的模型(如 BERT、GPT-3 及其后续版本)为实现各种自然语言处理(NLP)任务(如文本分类、文本生成和情感分析)的尖端性能提供了坚实的基础。然而,当这些大型预训练模型单独应用于这些特定任务时,常常表现出一定的局限性。监督微调(SFT)为解决这些局限性提供了方案。

与在大规模、多样化数据集上进行广泛无监督训练的预训练模型不同,SFT采用了一种专注且资源高效的方法。通常,这需要一个相对紧凑、高质量的数据集,该数据集精确地针对特定任务量身定制。SFT可以在不需要长时间训练的情况下,将模型性能提升到最先进的水平,因为它能够利用预训练模型所获得的广泛知识。

SFT过程包括微调模型的现有权重或添加额外参数,以确保与指定任务的复杂性保持一致。通常,这种适应会结合任务特定的层,例如为分类添加一个 softmax 层,从而增强模型解决监督任务的能力。

什么是 JAX?

JAX 是一个高性能的 Python 数值计算库。与传统的机器学习框架(如 TensorFlow 和 PyTorch)相比,JAX 的速度和效率都非常出色。JAX 利用即时编译(JIT),无缝的自动微分,以及高效向量化和并行化代码的能力,使其能简单地适配 AI 加速器(如 GPU 和 TPU)。

为什么使用 AMD GPU?

AMD GPU 因其强大的开源支持而脱颖而出,工具如 ROCm 和 HIP 使其易于适配 AI 工作流程。AMD 具有竞争力的性价比,非常适合寻求成本效益的 AI 和深度学习任务解决方案的用户。随着 AMD 在市场上的影响力不断增长,越来越多的机器学习库和框架正在添加对 AMD GPU 的支持。

硬件要求和运行环境

为了利用完成此任务所需的计算能力,我们使用AMD加速器云平台 (AAC)。AAC 是一个按需提供云计算资源和API的付费平台。具体来说,我们使用一个JAX Docker容器,其在AAC上拥有8个GPU,以充分利用先进的GPU并行计算能力。

本文是硬件无关的,这意味着要成功运行提供的代码示例,不需要访问AAC。只要您有加速器设备(如GPU或TPU),您应该能够以最小的代码修改来运行这些代码示例。如果您使用的是AMD GPU,请确保正确安装了ROCm及其兼容版本的JAX和Jaxlib。参考以下教程进行安装:

  • ROCm 安装

  • JAX and Jaxlib 安装: 您也可以直接通过链接拉取一个JAX Docker镜像。

代码示例:对Transformer模型进行SFT

为了演示,我们使用一个通用语言理解评估(GLUE)基准数据集Quora Question Pairs(QQP)微调一个基于transformer的LLM(如:bert-base-cased)。该数据集包含超过40万对问题,每对问题都有一个二进制注释,指示这两个问题是否是相互的复述。输入变量是两个问题的句子,而输出变量是一个二进制指标,表示这两个问题是否具有相同的含义。

安装

首先,安装所需的软件包 (%%capture 是一个 _cell magic_,它将抑制单元格的输出)。

%%capture
!pip install datasets
!pip install git+https://github.com/huggingface/transformers.git
!pip install flax
!pip install git+https://github.com/deepmind/optax.git
!pip install evaluate
!pip install ipywidgets
!pip install black isort # 单元格中的格式化器;可选项

导入剩余的软件包和功能。

import os
from itertools import chain
from typing import Callable

import evaluate
import flax
import jax
import jax.numpy as jnp
import optax
import pandas as pd
from datasets import load_dataset
from flax import traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from ipywidgets import IntProgress as IProgress
from tqdm.notebook import tqdm
from transformers import (
    AutoConfig,
    AutoTokenizer,
    FlaxAutoModelForSequenceClassification,
)

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

JAX 预先分配75%的GPU内存以减少首次运行JAX操作时的开销和碎片,但可能会触发内存不足(OOM)错误。为了避免OOM问题,可通过将 XLA_PYTHON_CLIENT_PREALLOCATE 标志设置为 false 来抑制默认行为。

检查是否可以通过JAX检测到GPU设备。如果不能,可能需要重新安装ROCm、JAX和Jaxlib。如果JAX安装正确,你可以看到所有请求的GPU设备,在我们的例子中是8个GPU。

jax.local_devices()
[gpu(id=0),
 gpu(id=1),
 gpu(id=2),
 gpu(id=3),
 gpu(id=4),
 gpu(id=5),
 gpu(id=6),
 gpu(id=7)]

获取微调数据集和预训练模型检查点

指定你的微调过程的设置:数据集、预训练模型以及每个设备每批次要处理的数据量。

task = "qqp"
model_checkpoint = "bert-base-cased"
per_device_batch_size = 64

加载数据集和评估指标模块。

raw_dataset = load_dataset("glue", task)
metric = evaluate.load("glue", task)

接下来的几段代码展示了如何使用模型特定的分词器对文本数据进行分词,并加载分词后的训练和验证数据。使用与预训练模型相同的分词器确保在微调过程中相同的词会被转换为相同的嵌入向量。

重要的是,我们在原始训练数据中对训练和评估数据集进行了10%的抽样。尽管如此,QQP数据集仍然提供了足够的数据来实现令人满意的性能,并且可以在每个epoch后观察到指标的改进。这种抽样方法还加快了我们的训练过程,便于说明。

使用数据预处理函数和map包装器的批处理和并行处理功能处理训练和评估数据集。你可以在以下输出中查看分词后的数据集。

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
def preprocess_function(examples):
    texts = (examples["question1"], examples["question2"])
    processed = tokenizer(*texts, padding="max_length", max_length=128, truncation=True)
    processed["labels"] = examples["label"]
    return processed
# 关于如何处理和操作 huggingface 数据集的详细信息:
# https://huggingface.co/docs/datasets/process
data = raw_dataset["train"].shuffle(seed=0)
train_data = data.select(list(range(int(data.shape[0] * 0.1))))
eval_data = data.select(list(range(int(data.shape[0] * 0.1), int(data.shape[0] * 0.2))))
print(f"原始训练数据集的形状为: {data.shape}")
print(f"当前训练数据集的形状为: {train_data.shape}")
print(f"当前验证数据集的形状为: {eval_data.shape}")
原始训练数据集的形状为: (363846, 4)
当前训练数据集的形状为: (36384, 4)
当前验证数据集的形状为: (36385, 4)
train_dataset = train_data.map(
    preprocess_function, batched=True, remove_columns=train_data.column_names
)
eval_dataset = eval_data.map(
    preprocess_function, batched=True, remove_columns=eval_data.column_names
)
# 你可以在以下单元格的输出中查看已分词的数据集
pd.DataFrame(train_dataset[:3])

从Hugging Face下载预训练模型配置和检查点。注意,你会看到一个警告信息,指出某些模型权重未使用。这是预期的,因为BERT模型检查点是一个PreTraining模型类,而你正在初始化一个
SequenceClassification模型。警告信息指出:你可能需要在下游任务上训练该模型,以便能够将其用于预测和推理。 这就是我们接下来要关注的内容。

num_labels = 2
seed = 0
config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels)
model = FlaxAutoModelForSequenceClassification.from_pretrained(
    model_checkpoint, config=config, seed=seed
)
某些在bert-base-cased模型检查点中的权重在初始化FlaxBertForSequenceClassification时未被使用: {('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'transform', 'dense', 'bias')}
- 如果您正在从另一个任务或架构的模型检查点初始化FlaxBertForSequenceClassification(例如,从BertForPreTraining模型初始化BertForSequenceClassification模型),这是预期的。
- 如果您正在从您期望完全相同的模型检查点初始化FlaxBertForSequenceClassification(从BertForSequenceClassification模型初始化BertForSequenceClassification模型),这不是预期的。
某些在bert-base-cased模型检查点中的权重未被初始化到FlaxBertForSequenceClassification并被重新初始化: {('classifier', 'kernel'), ('classifier', 'bias'), ('bert', 'pooler', 'dense', 'kernel'), ('bert', 'pooler', 'dense', 'bias')}
您可能需要在下游任务中训练此模型,以便能够使用它进行预测和推理。

定义微调模型的状态

以下代码块展示了如何设置训练参数,比如训练周期数和初始学习率。学习率调度是为了使学习率在训练过程中线性衰减,以确保学习的效率和稳定性。

num_train_epochs = 6
learning_rate = 2e-5
total_batch_size = per_device_batch_size * jax.local_device_count()
print("The overall batch size (both for training and eval) is", total_batch_size)
The overall batch size (both for training and eval) is 512
num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs

learning_rate_function = optax.linear_schedule(
    init_value=learning_rate, end_value=0, transition_steps=num_train_steps
)

接下来,需要建立训练状态,包括优化器和损失函数的职责,并监督模型参数在训练过程中的更新。

使用状态对象,初始化和更新模型。当调用模型时,将状态作为输入,模型会返回通过新数据批次更新后的状态,同时保留模型实例。

Flax 提供了一个用户友好的类(`flax.training.train_state.TrainState`),它将模型参数、损失函数和优化器封装在一起。当提供数据时,它可以使用 apply_gradients 函数更新模型参数。

下面的代码块展示了如何定义和建立训练状态、优化器和损失函数。

class TrainState(train_state.TrainState):
    logits_function: Callable = flax.struct.field(pytree_node=False)
    loss_function: Callable = flax.struct.field(pytree_node=False)
# 创建一个 decay_mask_fn 函数,以确保对任何偏置项或 LayerNorm 权重不应用权重衰减,因为这可能不会提高模型性能甚至会有害。

def decay_mask_fn(params):
    flat_params = traverse_util.flatten_dict(params)
    flat_mask = {
        path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale"))
        for path in flat_params
    }
    return traverse_util.unflatten_dict(flat_mask)
# 标准的带权重衰减的 Adam 优化器
def adamw(weight_decay):
    return optax.adamw(
        learning_rate=learning_rate_function,
        b1=0.9,
        b2=0.999,
        eps=1e-6,
        weight_decay=weight_decay,
        mask=decay_mask_fn,
    )
def loss_function(logits, labels):
    xentropy = optax.softmax_cross_entropy(
        logits, onehot(labels, num_classes=num_labels)
    )
    return jnp.mean(xentropy)


def eval_function(logits):
    return logits.argmax(-1)
# 实例化 TrainState
state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=adamw(weight_decay=0.01),
    logits_function=eval_function,
    loss_function=loss_function,
)

定义如何训练、评估模型并启用并行化

train_step 和 eval_step 参数定义了如何训练和评估模型。训练步骤遵循标准的训练过程:

  1. 使用当前的权重计算损失。

  2. 计算损失函数相对于权重的梯度。

  3. 使用梯度和学习率更新权重。

  4. 使用梯度和学习率更新权重。

需要强调的是,`lax.pmean` 函数计算跨所有 8 个 GPU 设备的数据批次梯度的均值。这个关键步骤保证了所有 GPU 设备上的模型参数同步。

def train_step(state, batch, dropout_rng):
    targets = batch.pop("labels")
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    def loss_function(params):
        logits = state.apply_fn(
            **batch, params=params, dropout_rng=dropout_rng, train=True
        )[0]
        loss = state.loss_function(logits, targets)
        return loss

    grad_function = jax.value_and_grad(loss_function)
    loss, grad = grad_function(state.params)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)
    metrics = jax.lax.pmean(
        {"loss": loss, "learning_rate": learning_rate_function(state.step)},
        axis_name="batch",
    )
    return new_state, metrics, new_dropout_rng
def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.logits_function(logits)

接下来,应用 jax.pmap 函数到定义的 train_step 和 eval_step 函数。将 pmap() 应用于函数时,该函数会使用 XLA 编译(类似于 jit()),然后在 XLA 设备上并行运行,例如多 GPU 设备或多 TPU 核。简单来说,这一步将训练和评估函数发送到所有 GPU 设备。你还需要通过 flax.jax_utils.replicate 将训练状态发送到所有 GPU 设备,这些步骤确保你通过分布式训练在所有 GPU 设备上更新模型状态。

parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")
state = flax.jax_utils.replicate(state)

定义数据加载函数,这些函数返回数据批次生成器。在最终的训练和评估循环中,每一步都会输入一个新的数据批次。

def glue_train_data_loader(rng, dataset, batch_size):
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # 跳过不完整的批次。
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        batch = dataset[perm]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch
def glue_eval_data_loader(dataset, batch_size):
    for i in range(len(dataset) // batch_size):
        batch = dataset[i * batch_size : (i + 1) * batch_size]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch

基于整数种子生成伪随机数生成器(PRNG)密钥,然后将其拆分为 8 个新的密钥,以确保每个 GPU 设备都得到不同的密钥。然后运行训练步骤,以根据预定义的训练参数(如训练轮次和总批次大小)更新 state。在每个轮次结束时,运行评估步骤,以查看评估数据集上的准确率和 F1 指标。由于使用的训练数据集比基准中的原始训练数据集要小,可以看到在前几轮训练中,评估指标(训练损失和评估准确率)稳定提升。

rng = jax.random.PRNGKey(seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
for i, epoch in enumerate(
    tqdm(range(1, num_train_epochs + 1), desc=f"Epoch ...", position=0, leave=True)
):
    rng, input_rng = jax.random.split(rng)

    # train
    with tqdm(
        total=len(train_dataset) // total_batch_size, desc="Training...", leave=True
    ) as progress_bar_train:
        for batch in glue_train_data_loader(input_rng, train_dataset, total_batch_size):
            state, train_metrics, dropout_rngs = parallel_train_step(
                state, batch, dropout_rngs
            )
            progress_bar_train.update(1)

    # 评估
    with tqdm(
        total=len(eval_dataset) // total_batch_size, desc="Evaluating...", leave=False
    ) as progress_bar_eval:
        for batch in glue_eval_data_loader(eval_dataset, total_batch_size):
            labels = batch.pop("labels")
            predictions = parallel_eval_step(state, batch)
            metric.add_batch(
                predictions=list(chain(*predictions)), references=list(chain(*labels))
            )
            progress_bar_eval.update(1)

    eval_metric = metric.compute()

    loss = round(flax.jax_utils.unreplicate(train_metrics)["loss"].item(), 3)
    eval_score1 = round(list(eval_metric.values())[0], 3)
    metric_name1 = list(eval_metric.keys())[0]
    eval_score2 = round(list(eval_metric.values())[1], 3)
    metric_name2 = list(eval_metric.keys())[1]
    print(
        f"{i+1}/{num_train_epochs} | Train loss: {loss} | Eval {metric_name1}: {eval_score1}, {metric_name2}: {eval_score2}"
    )
Epoch ...:   0%|          | 0/6 [00:00<?, ?it/s]
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
1/6 | Train loss: 0.475 | Eval accuracy: 0.799, f1: 0.762
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
2/6 | Train loss: 0.369 | Eval accuracy: 0.834, f1: 0.789
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
3/6 | Train loss: 0.299 | Eval accuracy: 0.846, f1: 0.797
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
4/6 | Train loss: 0.239 | Eval accuracy: 0.846, f1: 0.806
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
5/6 | Train loss: 0.252 | Eval accuracy: 0.849, f1: 0.802
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
6/6 | Train loss: 0.212 | Eval accuracy: 0.849, f1: 0.805

使用JAX设备网格来实现并行化

from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels)
model = FlaxAutoModelForSequenceClassification.from_pretrained(
    model_checkpoint, config=config, seed=seed
)
state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=adamw(weight_decay=0.01),
    logits_function=eval_function,
    loss_function=loss_function,
)
一些来自 bert-base-cased 模型检查点的权重在初始化 FlaxBertForSequenceClassification 时未被使用: {('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'transform', 'dense', 'bias')}
- 当你用模型训练其他任务或用另一种架构初始化 FlaxBertForSequenceClassification 时,这是预期中的情况(例如从 BertForPreTraining 模型初始化 BertForSequenceClassification 模型)。
- 当你期望从与 FlaxBertForSequenceClassification 模型完全相同的检查点初始化时(从 BertForSequenceClassification 模型初始化 BertForSequenceClassification 模型),这不是预期情况。
FlaxBertForSequenceClassification 中一些权重没有从 bert-base-cased 模型检查点初始化,是新初始化的: {('classifier', 'kernel'), ('classifier', 'bias'), ('bert', 'pooler', 'dense', 'kernel'), ('bert', 'pooler', 'dense', 'bias')}
应该将这个模型训练到下游任务上以便用于预测和推断。
@jax.jit
def train_step(state, batch, dropout_rng):
    targets = batch.pop("labels")
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    def loss_function(params):
        logits = state.apply_fn(
            **batch, params=params, dropout_rng=dropout_rng, train=True
        )[0]
        loss = state.loss_function(logits, targets)
        return loss

    grad_function = jax.value_and_grad(loss_function)
    loss, grad = grad_function(state.params)
    new_state = state.apply_gradients(grads=grad)
    metrics = {"loss": loss, "learning_rate": learning_rate_function(state.step)}
    return new_state, metrics, new_dropout_rng
@jax.jit
def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.logits_function(logits)
num_devices = len(jax.local_devices())
devices = mesh_utils.create_device_mesh((num_devices,))

# 数据将沿批处理轴进行分割
data_mesh = Mesh(devices, axis_names=("batch",))  # naming axes of the mesh
data_sharding = NamedSharding(
    data_mesh,
    P(
        "batch",
    ),
)  # 命名网格的轴


def glue_train_data_loader(rng, dataset, batch_size):
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # 略过不完整的批处理。
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        batch = dataset[perm]
        batch = {
            k: jax.device_put(jnp.array(v), data_sharding) for k, v in batch.items()
        }

        yield batch


def glue_eval_data_loader(dataset, batch_size):
    for i in range(len(dataset) // batch_size):
        batch = dataset[i * batch_size : (i + 1) * batch_size]
        batch = {
            k: jax.device_put(jnp.array(v), data_sharding) for k, v in batch.items()
        }

        yield batch
# 在所有设备上复制模型和优化器变量
def get_replicated_train_state(devices, state):
    # 所有变量将在所有设备上复制
    var_mesh = Mesh(devices, axis_names=("_"))
    # 在 NamedSharding 中,未提到的轴将被复制(此处为所有轴)
    var_replication = NamedSharding(var_mesh, P())

    # 应用分布设置到模型变量
    state = jax.device_put(state, var_replication)

    return state


state = get_replicated_train_state(devices, state)
rng = jax.random.PRNGKey(seed)
dropout_rng = jax.random.PRNGKey(seed)
for i, epoch in enumerate(
    tqdm(range(1, num_train_epochs + 1), desc=f"Epoch ...", position=0, leave=True)
):
    rng, input_rng = jax.random.split(rng)

    # 训练
    with tqdm(
        total=len(train_dataset) // total_batch_size, desc="Training...", leave=True
    ) as progress_bar_train:
        for batch in glue_train_data_loader(input_rng, train_dataset, total_batch_size):
            state, train_metrics, dropout_rng = train_step(state, batch, dropout_rng)
            progress_bar_train.update(1)

    # 评估
    with tqdm(
        total=len(eval_dataset) // total_batch_size, desc="Evaluating...", leave=False
    ) as progress_bar_eval:
        for batch in glue_eval_data_loader(eval_dataset, total_batch_size):
            labels = batch.pop("labels")
            predictions = eval_step(state, batch)
            metric.add_batch(predictions=list(predictions), references=list(labels))
            progress_bar_eval.update(1)

    eval_metric = metric.compute()

    loss = round(train_metrics["loss"].item(), 3)
    eval_score1 = round(list(eval_metric.values())[0], 3)
    metric_name1 = list(eval_metric.keys())[0]
    eval_score2 = round(list(eval_metric.values())[1], 3)
    metric_name2 = list(eval_metric.keys())[1]
    print(
        f"{i+1}/{num_train_epochs} | Train loss: {loss} | Eval {metric_name1}: {eval_score1}, {metric_name2}: {eval_score2}"
    )
Epoch ...:   0%|          | 0/6 [00:00<?, ?it/s]
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
1/6 | Train loss: 0.469 | Eval accuracy: 0.796, f1: 0.759
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
2/6 | Train loss: 0.376 | Eval accuracy: 0.833, f1: 0.788
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
3/6 | Train loss: 0.296 | Eval accuracy: 0.844, f1: 0.795
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
4/6 | Train loss: 0.267 | Eval accuracy: 0.846, f1: 0.805
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
5/6 | Train loss: 0.263 | Eval accuracy: 0.848, f1: 0.804
Training...:   0%|          | 0/71 [00:00<?, ?it/s]
Evaluating...:   0%|          | 0/71 [00:00<?, ?it/s]
6/6 | Train loss: 0.222 | Eval accuracy: 0.849, f1: 0.805

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2219224.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

bash之基本运算符

一.算术运算符 vim test.sh #!/bin/basha10 b20valexpr $a $b echo "a b : $val"valexpr $a - $b echo "a - b : $val"valexpr $a \* $b echo "a * b : $val"valexpr $b / $a echo "b / a : $val"valexpr $b % $a echo "b % a …

pikachu靶场SSRF-curl测试报告

目录 一、测试环境 1、系统环境 2、使用工具/软件 二、测试目的 三、操作过程 1、实现ssrf攻击 四、源代码分析 五、结论 一、测试环境 1、系统环境 渗透机&#xff1a;本机(127.0.0.1) 靶 机&#xff1a;本机(127.0.0.1) 2、使用工具/软件 测试网址&#xff1a;…

Redis 常用指令详解

Redis是一款开源的、高性能的键值对存储数据库&#xff0c;常用于缓存、会话存储以及其他需要快速访问的数据场景。本文将介绍Redis的一些常用指令&#xff0c;并通过代码示例进行说明。 一、连接操作指令 1. 连接 Redis 服务器 ./redis-cli -h 127.0.0.1 -p 63792. 认证&a…

1.QT概述及C++基础

QT概述及C基础 1.简介2.QT安装3.QT_Creator的基本使用4.C基础 1.简介 概述 Qt 是一个跨平台的应用程序和用户界面框架&#xff0c;用于开发图形用户界面&#xff08;GUI&#xff09;应用程序以及命令行工具。它最初由挪威的 Trolltech &#xff08;奇趣科技&#xff09;公司开发…

MySQL程序介绍<一>

目录 MySQL程序简介 mysqld - MySQL 服务器 ​编辑 mysql - MySQL 命令⾏客⼾端 MySQL程序简介 1.MySQL安装完成通常会包含如下程序&#xff1a; Linux系统程序⼀般在 /usr/bin⽬录下&#xff0c;可以通过命令查看 windows系统⽬录&#xff1a; 你的安装路径\MySQL Server…

Redis JSON介绍和命令大全

Redis JSON介绍和命令大全 Redis JSON先说说JSON是什么再说说JSON Path先推荐两个网站JSONPath JAVA clents Redis JSON 安装内存json命令语法命令url命令解释JSON.ARRAPPENDJSON.ARRINDEXJSON.ARRINSERTJSON.ARRLENJSON.ARRPOPJSON.ARRTRIMJSON.CLEARJSON.DEBUG MEMORYJSON.DE…

Java 入门基础篇15 - java构造方法以及认识新的关键字

一 今日目标 构造方法static关键字代码块math类package关键字import关键字 二 构造方法概述 2.1 构造方法描述 构造方法是一个特殊方法&#xff0c;作用是创建对象&#xff0c;对对象进行初始化。 ​ 如&#xff1a; 对对象中的成员进行初始化值 2.1 构造方法的特征 1、方…

C/C++每日一练:编写一个栈数据结构

通过编写栈&#xff08;Stack&#xff09;数据结构&#xff0c;提升对基本数据结构的理解和运用。这也是掌握更复杂数据结构与算法的基础。栈是计算机科学中的一个重要概念&#xff0c;经常出现在许多算法和应用中。 栈&#xff08;Stack&#xff09; 栈是一种后进先出&#x…

【初阶数据结构】计数排序 :感受非比较排序的魅力

文章目录 前言1. 什么是计数排序&#xff1f;2. 计数排序的算法思路2.1 绝对位置和相对位置2.2 根据计数数组的信息来确认 3. 计数排序的代码4. 算法分析5. 计数排序的优缺点6.计数排序的应用场景 前言 如果大家仔细思考的话&#xff0c;可能会发现这么一个问题。我们学的七大…

【C语言】原码 反码 补码

为什么要有原码 反码 补码的概念&#xff1f; 因为在计算机中最终只能识别机器码&#xff0c;是以 0000 0000 二进制作为表示形式&#xff0c;对于一个数&#xff0c;计算机要使用一定的编码方式进行存储&#xff0c;原码 反码 补码是机器存储一个数值的编码方式&#xff0c;最…

技术分享:A-23OH型树脂在汽车涂装废溶剂回收中的应用

在当今汽车制造业竞争激烈的环境下&#xff0c;提高生产效率、降低成本的同时&#xff0c;满足环保要求已成为各制造商追求的核心目标。水性涂料因其环保、节能等多重优势&#xff0c;在汽车涂装领域的应用日益广泛。然而&#xff0c;随之而来的喷涂废溶剂处理问题也日益凸显。…

2024年软件设计师中级(软考中级)详细笔记【7】面向对象技术(下)23种设计模式(分值10+)

目录 前言阅读前必看 第七章 面向对象技术&#xff08;下&#xff09;7.3 设计模式&#xff08;固定4分&#xff09;7.3.1 设计模式的要素7.3.2 创建型设计模式7.3.2.1 Abstract Factory&#xff08;抽象工厂&#xff09;7.3.2.2 Builder&#xff08;生成器&#xff09;7.3.2.3…

调整奇数偶数的顺序

//调整奇数偶数的顺序 //输入一个整数数组&#xff0c;实现一个函数 //使得数组中所有的奇数位于数组的前半部分&#xff0c;所有的偶数位于数组的后半部分 #include<stdio.h> void tz(int a[],int sz) {int i 0;int j 0;int q 0;int c[100] { 0 };int b[100] { 0 …

Qt第十三天:网络编程:TCP和UDP的使用

我发现了有些人喜欢静静看博客不聊天呐&#xff0c; 但是ta会点赞。 这样的人呢帅气低调有内涵&#xff0c; 美丽大方很优雅。 说的就是你&#xff0c; 不用再怀疑哦 ❤️TCP&#xff1a; 一、创建项目&#xff0c;命名为Server&#xff0c;继承QWidget 二、添加Qt设计师…

Axure重要元件三——中继器添加数据

亲爱的小伙伴&#xff0c;在您浏览之前&#xff0c;烦请关注一下&#xff0c;在此深表感谢&#xff01; 本节课&#xff1a;中继器添加数据 课程内容&#xff1a;添加数据项、自动添加序号、自动添加数据汇总 应用场景&#xff1a;表单数据的添加 案例展示&#xff1a; 步骤…

算法: 模拟题目练习

文章目录 模拟替换所有的问号提莫攻击Z 字形变换外观数列数青蛙 总结 模拟 替换所有的问号 按照题目的要求写代码即可~ public String modifyString(String ss) {int n ss.length();if (n 1) {return "a";}char[] s ss.toCharArray();for (int i 0; i < n; i…

【华为HCIP实战课程十三】OSPF网络中3类LSA及区域间负载均衡,网络工程师

一、ABR SW1查看OSPF ABR为R4而非R3,因为R4连接骨干区域0,R3没有连接到区域0 R6查看OSPF路由: 二、查看3类LSA,由于R6不是ABR因此自身不会产生3类LSA 但是有区域间路由就可以看到3类LSA

分布式介绍

CAP理论 CAP理论是分布式架构中提出来的一种设计思想模型&#xff0c;全称是由Consistency、Availability、Partition Tolerance三个词组成。 C(Consistency&#xff0c;一致性):总能读到最新的写操作的结果A(Availability&#xff0c;可用性):每个请求都要在合理的时间内给出…

Spring Boot知识管理:跨平台集成方案

4系统概要设计 4.1概述 本系统采用B/S结构(Browser/Server,浏览器/服务器结构)和基于Web服务两种模式&#xff0c;是一个适用于Internet环境下的模型结构。只要用户能连上Internet,便可以在任何时间、任何地点使用。系统工作原理图如图4-1所示&#xff1a; 图4-1系统工作原理…

后渗透利用之vcenter

目录 vcenter介绍环境搭建历史漏洞版本信息1、直接访问2、请求接⼝ 打点CVE_2021_21972漏洞描述&#xff1a;POC&#xff1a; 后渗透获取vCenter后台重置密码Cookie登录创建管理员 获取虚拟机Hash分析快照挂载磁盘 获取Esxi 后台获取解密key获取数据库账号密码查询Esxi加密密码…