BERT模型核心组件详解及其实现

news2024/11/16 18:09:39
摘要

BERT(Bidirectional Encoder Representations from Transformers)是一种基于Transformer架构的预训练模型,在自然语言处理领域取得了显著的成果。本文详细介绍了BERT模型中的几个关键组件及其实现,包括激活函数、变量初始化、嵌入查找、层归一化等。通过深入理解这些组件,读者可以更好地掌握BERT模型的工作原理,并在实际应用中进行优化和调整。

1. 引言

BERT模型由Google研究人员于2018年提出,通过大规模的无监督预训练和任务特定的微调,显著提升了多个自然语言处理任务的性能。本文将重点介绍BERT模型中的几个核心组件,包括激活函数、变量初始化、嵌入查找、层归一化等,并提供相应的代码实现。

2. 激活函数
2.1 Gaussian Error Linear Unit (GELU)

GELU是一种平滑的ReLU变体,其数学表达式如下:

GELU(x)=x⋅Φ(x)GELU(x)=x⋅Φ(x)

其中,Φ(x)Φ(x)是标准正态分布的累积分布函数(CDF)。在TensorFlow中,GELU可以通过以下方式实现:

def gelu(x):
  """Gaussian Error Linear Unit.

  This is a smoother version of the RELU.
  Original paper: https://arxiv.org/abs/1606.08415
  Args:
    x: float Tensor to perform activation.

  Returns:
    `x` with the GELU activation applied.
  """
  cdf = 0.5 * (1.0 + tf.tanh(
      (np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
  return x * cdf
2.2 激活函数映射

为了方便使用不同的激活函数,我们定义了一个映射函数get_activation,该函数根据传入的字符串返回相应的激活函数:

def get_activation(activation_string):
  """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.

  Args:
    activation_string: String name of the activation function.

  Returns:
    A Python function corresponding to the activation function. If
    `activation_string` is None, empty, or "linear", this will return None.
    If `activation_string` is not a string, it will return `activation_string`.

  Raises:
    ValueError: The `activation_string` does not correspond to a known
      activation.
  """
  if not isinstance(activation_string, six.string_types):
    return activation_string

  if not activation_string:
    return None

  act = activation_string.lower()
  if act == "linear":
    return None
  elif act == "relu":
    return tf.nn.relu
  elif act == "gelu":
    return gelu
  elif act == "tanh":
    return tf.tanh
  else:
    raise ValueError("Unsupported activation: %s" % act)
3. 变量初始化

在深度学习中,合理的变量初始化对于模型的收敛速度和最终性能至关重要。BERT模型中使用了截断正态分布初始化器(truncated_normal_initializer),其标准差为0.02:

def create_initializer(initializer_range=0.02):
  """Creates a `truncated_normal_initializer` with the given range."""
  return tf.truncated_normal_initializer(stddev=initializer_range)
4. 嵌入查找

嵌入查找是将输入的token id转换为向量表示的过程。BERT模型中使用了两种方法:一种是使用tf.gather(),另一种是使用one-hot编码:

def embedding_lookup(input_ids,
                     vocab_size,
                     embedding_size=128,
                     initializer_range=0.02,
                     word_embedding_name="word_embeddings",
                     use_one_hot_embeddings=False):
  """Looks up words embeddings for id tensor.

  Args:
    input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
      ids.
    vocab_size: int. Size of the embedding vocabulary.
    embedding_size: int. Width of the word embeddings.
    initializer_range: float. Embedding initialization range.
    word_embedding_name: string. Name of the embedding table.
    use_one_hot_embeddings: bool. If True, use one-hot method for word
      embeddings. If False, use `tf.gather()`.

  Returns:
    float Tensor of shape [batch_size, seq_length, embedding_size].
  """
  if input_ids.shape.ndims == 2:
    input_ids = tf.expand_dims(input_ids, axis=[-1])

  embedding_table = tf.get_variable(
      name=word_embedding_name,
      shape=[vocab_size, embedding_size],
      initializer=create_initializer(initializer_range))

  flat_input_ids = tf.reshape(input_ids, [-1])
  if use_one_hot_embeddings:
    one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
    output = tf.matmul(one_hot_input_ids, embedding_table)
  else:
    output = tf.gather(embedding_table, flat_input_ids)

  input_shape = get_shape_list(input_ids)
  output = tf.reshape(output,
                      input_shape[0:-1] + [input_shape[-1] * embedding_size])
  return (output, embedding_table)
5. 层归一化

层归一化(Layer Normalization)是一种常用的归一化技术,用于加速训练过程并提高模型的泛化能力。BERT模型中使用了tf.contrib.layers.layer_norm来进行层归一化:

def layer_norm(input_tensor, name=None):
  """Run layer normalization on the last dimension of the tensor."""
  return tf.contrib.layers.layer_norm(
      inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name)

为了方便使用,我们还定义了一个组合函数layer_norm_and_dropout,该函数先进行层归一化,再进行dropout操作:

def layer_norm_and_dropout(input_tensor, dropout_prob, name=None):
  """Runs layer normalization followed by dropout."""
  output_tensor = layer_norm(input_tensor, name)
  output_tensor = dropout(output_tensor, dropout_prob)
  return output_tensor
6. Dropout

Dropout是一种常用的正则化技术,用于防止模型过拟合。在BERT模型中,dropout的概率可以通过配置参数进行设置:

def dropout(input_tensor, dropout_prob):
  """Perform dropout.

  Args:
    input_tensor: float Tensor.
    dropout_prob: Python float. The probability of dropping out a value (NOT of
      *keeping* a dimension as in `tf.nn.dropout`).

  Returns:
    A version of `input_tensor` with dropout applied.
  """
  if dropout_prob is None or dropout_prob == 0.0:
    return input_tensor

  output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob)
  return output
7. 从检查点加载变量

在微调过程中,通常需要从预训练的模型检查点中加载变量。get_assignment_map_from_checkpoint函数用于计算当前变量与检查点变量的映射关系:

def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
  """Compute the union of the current variables and checkpoint variables."""
  assignment_map = {}
  initialized_variable_names = {}

  name_to_variable = collections.OrderedDict()
  for var in tvars:
    name = var.name
    m = re.match("^(.*):\\d+$", name)
    if m is not None:
      name = m.group(1)
    name_to_variable[name] = var

  init_vars = tf.train.list_variables(init_checkpoint)

  assignment_map = collections.OrderedDict()
  for x in init_vars:
    (name, var) = (x[0], x[1])
    if name not in name_to_variable:
      continue
    assignment_map[name] = name
    initialized_variable_names[name] = 1
    initialized_variable_names[name + ":0"] = 1

  return (assignment_map, initialized_variable_names)
8. 结论

本文详细介绍了BERT模型中的几个核心组件,包括激活函数、变量初始化、嵌入查找、层归一化等。通过深入理解这些组件,读者可以更好地掌握BERT模型的工作原理,并在实际应用中进行优化和调整。希望本文能为读者在自然语言处理领域的研究和开发提供有益的参考。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2241660.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

jmeter常用配置元件介绍总结之逻辑控制器

系列文章目录 安装jmeter jmeter常用配置元件介绍总结之逻辑控制器 逻辑控制器1.IF控制器2.事务控制器3.循环控制器4.While控制器5.ForEach控制器6.Include控制器7.Runtime控制器8.临界部分控制器9.交替控制器10.仅一次控制器11.简单控制器12.随机控制器13.随机顺序控制器14.吞…

探索 HTML 和 CSS 实现的蜡烛火焰

效果演示 这段代码是一个模拟蜡烛火焰的HTML和CSS代码。它创建了一个具有动态效果的蜡烛火焰动画&#xff0c;包括火焰的摆动、伸缩和光晕的闪烁。 HTML <div class"holder"><div class"candle"><div class"blinking-glow"&g…

机器学习 - 为 Jupyter Notebook 安装新的 Kernel

https://ipython.readthedocs.io/en/latest/install/kernel_install.html 当使用jupyter-notebook --no-browser 启动一个 notebook 时&#xff0c;默认使用了该 jupyter module 所在的 Python 环境作为 kernel&#xff0c;比如 C:\devel\Python\Python311。 如果&#xff0c…

SwiftUI-基础入门

开发OS X 图形应用界面时有三种实现方式&#xff1a;XIB、Storyboard、SwiftUI。Storyboard基于XIB做了优化&#xff0c;但XIB基本被放弃了&#xff0c;而SwiftUI是苹果公司后来开发的一套编程语言&#xff0c;用来平替Objective-C。虽然现在Swift 6 还是有些不完善的地方&…

androidstudio入门到放弃配置

b站视频讲解传送门 android_studio安装包&#xff1a;https://developer.android.google.cn/studio?hlzh-cn 下载安装 开始创建hello-world 1.删除缓存 文件 下载gradle文件压缩&#xff1a;gradle-8.9用自己创建项目时自动生成的版本即可&#xff0c;不用和我一样 https://…

如何在pycharm中 判断是否成功安装pytorch环境

1、在电脑开始端&#xff0c;找到 2、打开后 在base环境下 输入conda env list 目前我的环境中没有pytorch 学习视频&#xff1a;【Anaconda、Pytorch、Pycharm到底是什么关系?什么是环境?什么是包?】https://www.bilibili.com/video/BV1CN411s7Ue?vd_sourcefad0750b8c6…

昆明华厦眼科医院举办中外专家眼科技术研讨会

9月13日&#xff0c;“睿智迭代&#xff0c;增效赋能”Menicon Z Night中外专家研讨会在昆明华厦眼科医院成功举办。此次会议由目立康公司与昆明华厦眼科医院携手共筑&#xff0c;标志着双方合作迈向新的高度。 昆明华厦眼科医院总经理王若镜首先发表了热情洋溢的致辞&#xff…

HarmonyOS ArkUI(基于ArkTS) 开发布局 (上)

一 ArkUI(基于ArkTS)概述 基于ArkTS的声明式开发范式的方舟开发框架是一套开发极简、高性能、支持跨设备的UI开发框架&#xff0c;提供了构建应用UI所必需的能力 点击详情 特点 开发效率高&#xff0c;开发体验好 代码简洁&#xff1a;通过接近自然语义的方式描述UI&#x…

【STM32】项目实战——OV7725/OV2604摄像头颜色识别检测(开源)

本篇文章分享关于如何使用STM32单片机对彩色摄像头&#xff08;OV7725/OV2604&#xff09;采集的图像数据进行分析处理&#xff0c;最后实现颜色的识别和检测。 目录 一、什么是颜色识别 1、图像采集识别的一些基本概念 1. 像素&#xff08;Pixel&#xff09; 2. 分辨率&am…

SpringCloud-使用FFmpeg对视频压缩处理

在现代的视频处理系统中&#xff0c;压缩视频以减小存储空间、加快传输速度是一项非常重要的任务。FFmpeg作为一个强大的开源工具&#xff0c;广泛应用于音视频的处理&#xff0c;包括视频的压缩和格式转换等。本文将通过Java代码示例&#xff0c;向您展示如何使用FFmpeg进行视…

大数据学习14之Scala面向对象--至简原则

1.类和对象 1.1基本概念 面向对象&#xff08;Object Oriented&#xff09;是一种编程思想&#xff0c;面向对象主要是把事物给对象化&#xff0c;包括其属性和行为。面向对象编程更贴近实际生活的思想&#xff0c;总体来说面向对象的底层还是面向过程&#xff0c;面向过程抽象…

pipx安装提示找不到包

执行&#xff1a; pipx install --include-deps --force "ansible6.*"WARNING: Retrying (Retry(total4, connectNone, readNone, redirectNone, statusNone)) after connection broken by NewConnectionError(<pip._vendor.urllib3.connection.HTTPSConnection …

‘conda‘ 不是内部或外部命令,也不是可运行的程序或批处理文件,Miniconda

下载了conda&#xff0c;但是在cmd里执行conda --version会显示’conda’ 不是内部或外部命令&#xff0c;也不是可运行的程序或批处理文件。 原因是环境变量里没有添加conda&#xff0c;无法识别路径。 需要在系统环境变量里添加如下路径&#xff1a; 保存之后重新打开cmd&am…

【Qt】使用QString的toLocal8Bit()导致的问题

问题 使用Qt发送一个Http post请求的时候&#xff0c;服务一直返回错误和失败信息。同样的url以及post参数&#xff0c;复制黏贴到postman里就可以发送成功。就感觉很神奇。 原因 最后排查出原因是因为参数中含有汉字而导致的编码问题。 在拼接post参数时&#xff0c;使用了…

设计一致性的关键:掌握 Axure 母版使用技巧

设计一致性的关键&#xff1a;掌握 Axure 母版使用技巧 前言 在快节奏的产品开发周期中&#xff0c;设计师们一直在寻找能够提升工作效率和保持设计一致性的方法。 Axure RP&#xff0c;作为一款强大的原型设计工具&#xff0c;其母版功能为设计师们提供了一个强大的解决方案…

鸿蒙next ui安全区域适配(刘海屏、摄像头挖空等)

目录 相关api 团结引擎对于鸿蒙的适配已经做了安全区域的适配&#xff0c;也考虑到了刘海屏和摄像机挖孔的情况&#xff0c;在团结引擎内可以直接使用Screen.safeArea 相关api 团结引擎对于鸿蒙的适配已经做了安全区域的适配&#xff0c;也考虑到了刘海屏和摄像机挖孔的情况&am…

Android OpenGL ES详解——实例化

目录 一、实例化 1、背景 2、概念 实例化、实例数量 gl_InstanceID 应用举例 二、实例化数组 1、概念 2、应用举例 三、应用举例——小行星带 1、不使用实例化 2、使用实例化 四、总结 一、实例化 1、背景 假如你有一个有许多模型的场景&#xff0c;而这些模型的…

前端传数组 数据库存Json : [1,2,3]格式

一、前端正常传数组&#xff0c;但是value.toString() 即可 const empIds ref([1,2,3]) empIds.value empIds.value.toString() await updateApiRules(empIds.value) // 接口传参 二、后端用String类型接收后转换 String[] empIds updateDO.getEmpId().split("&#x…

《Java核心技术 卷I》用户图形界面鼠标事件

鼠标事件 如果只希望用户能够点击按钮或菜单&#xff0c;那么就不需要显式地处理鼠标事件&#xff0c;鼠标操作将由用户界面中的各种组件内部处理&#xff0c;不过&#xff0c;如果希望用户能使用鼠标画图&#xff0c;就需要捕获鼠标移动&#xff0c;点击和拖动事件。 本节&am…

贪心算法入门(三)

相关文章 贪心算法入门&#xff08;一&#xff09;-CSDN博客 贪心算法入门&#xff08;二&#xff09;-CSDN博客 1.什么是贪心算法&#xff1f; 贪心算法是一种解决问题的策略&#xff0c;它将复杂的问题分解为若干个步骤&#xff0c;并在每一步都选择当前最优的解决方案&am…