【机器学习】机器学习的基本分类-自监督学习-对比学习(Contrastive Learning)

news2025/1/4 9:22:32

对比学习是一种自监督学习方法,其目标是学习数据的表征(representation),使得在表征空间中,相似的样本距离更近,不相似的样本距离更远。通过设计对比损失函数(Contrastive Loss),模型能够有效捕捉数据的语义结构。


核心思想

对比学习的关键在于:

  1. 正样本(Positive Pair):具有相似语义或来源的样本对,例如同一图像的不同增强版本。
  2. 负样本(Negative Pair):语义不同或来源不同的样本对,例如不同图像。

通过对比正负样本对,模型能够学习区分不同数据点的特征。


方法流程

  1. 数据增强:对一个样本 x 应用两种不同的增强方法,生成 x_1, x_2​,作为正样本对。
  2. 特征提取:通过编码器(如卷积神经网络)将数据映射到潜在特征空间,得到表征 z_1, z_2
  3. 对比损失:设计损失函数,使正样本对的表征距离最小化,负样本对的表征距离最大化。

对比学习的损失函数

1. 对比损失(Contrastive Loss)

对比损失鼓励正样本对的距离更小,负样本对的距离更大。

L = \frac{1}{N} \sum_{i=1}^N \left[ y_i \cdot d(z_i, z_j)^2 + (1 - y_i) \cdot \max(0, m - d(z_i, z_j))^2 \right]

  • y_i:样本对是否为正样本(1 表示正样本,0 表示负样本)。
  • d(z_i, z_j):样本对在表征空间中的距离(通常使用欧氏距离)。
  • m:负样本对的最小距离(margin)。
2. InfoNCE 损失

用于最大化正样本对的相似性,同时将负样本对的相似性最小化。

L = - \log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k=1}^{N} \exp(\text{sim}(z_i, z_k) / \tau)}

  • \text{sim}(z_i, z_j) = \frac{z_i \cdot z_j}{\|z_i\| \|z_j\|}:余弦相似度。
  • \tau:温度参数,用于控制分布的平滑程度。
  • N:批量中样本数量。

典型方法

1. SimCLR

SimCLR 是对比学习的经典方法之一:

  • 核心思想:通过数据增强生成正样本对,并利用 InfoNCE 损失函数进行优化。
  • 数据增强:随机裁剪、颜色抖动、模糊等。
2. MoCo(Momentum Contrast)

通过维护一个动态更新的“字典”,解决负样本数量不足的问题。

  • 核心思想:使用动量编码器(momentum encoder)生成更多的负样本。
3. BYOL(Bootstrap Your Own Latent)

无需显式的负样本,通过自回归(self-prediction)学习特征表征。

  • 核心思想:一个在线网络(Online Network)和一个目标网络(Target Network)协同训练。
4. SWAV(Swapping Assignments Between Views)

结合聚类和对比学习,利用图像的多视图表征。

  • 核心思想:通过在线分配伪标签,避免显式使用负样本。

示例代码:SimCLR

以下是一个实现 SimCLR 的示例代码:

import tensorflow as tf
from tensorflow.keras import layers, models


# 图像增强函数
def augment_image(image):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_crop(image, size=(32, 32, 3))
    image = tf.image.random_brightness(image, max_delta=0.5)
    return image


# 定义编码器
def create_encoder():
    base_model = tf.keras.applications.ResNet50(include_top=False, pooling='avg', input_shape=(32, 32, 3))
    return models.Model(inputs=base_model.input, outputs=base_model.output)


# SimCLR 模型
class SimCLRModel(tf.keras.Model):
    def __init__(self, encoder, projection_dim):
        super(SimCLRModel, self).__init__()
        self.encoder = encoder
        self.projection_head = tf.keras.Sequential([
            layers.Dense(256, activation='relu'),
            layers.Dense(projection_dim)
        ])

    def call(self, x):
        features = self.encoder(x)
        projections = self.projection_head(features)
        return tf.math.l2_normalize(projections, axis=1)


# 构建模型
encoder = create_encoder()
simclr_model = SimCLRModel(encoder, projection_dim=128)


# InfoNCE 损失
def info_nce_loss(features, temperature=0.5):
    batch_size = tf.shape(features)[0]
    labels = tf.range(batch_size)
    similarity_matrix = tf.matmul(features, features, transpose_b=True)
    logits = similarity_matrix / temperature
    return tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True))


# 训练
(X_train, _), _ = tf.keras.datasets.cifar10.load_data()
X_train = tf.image.resize(X_train, (32, 32)) / 255.0


def preprocess_data(image):
    return augment_image(image), augment_image(image)


train_data = tf.data.Dataset.from_tensor_slices(X_train)
train_data = train_data.map(preprocess_data).batch(32)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

for epoch in range(10):
    for x1, x2 in train_data:
        with tf.GradientTape() as tape:
            z1 = simclr_model(x1)
            z2 = simclr_model(x2)
            loss = info_nce_loss(tf.concat([z1, z2], axis=0))
        gradients = tape.gradient(loss, simclr_model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, simclr_model.trainable_variables))
    print(f"Epoch {epoch + 1}, Loss: {loss.numpy()}")

输出结果

Epoch 1, Loss: 3.465735912322998
Epoch 2, Loss: 3.465735912322998
Epoch 3, Loss: 3.465735912322998
Epoch 4, Loss: 3.465735912322998
Epoch 5, Loss: 3.465735912322998

对比学习的优势与挑战

优势
  1. 无需标签数据:适用于大规模无标签数据集。
  2. 高质量特征:学习的表征具有很强的迁移能力。
  3. 通用性强:适用于图像、文本、语音等多种模态。
挑战
  1. 负样本选择:负样本数量和质量对性能影响大。
  2. 计算成本:对比学习需要大量计算资源,尤其是在大规模数据上训练。
  3. 超参数调整:温度参数等对模型表现至关重要。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2269120.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

xterm + vue3 + websocket 终端界面

xterm.js 下载插件 // xterm npm install --save xterm// xterm-addon-fit 使终端适应包含元素 npm install --save xterm-addon-fit// xterm-addon-attach 通过websocket附加到运行中的服务器进程 npm install --save xterm-addon-attach <template><div :…

记一次护网通过外网弱口令一路到内网

视频教程在我主页简介或专栏里 目录&#xff1a; 资产收集 前期打点 突破 完结 又是年底护网季&#xff0c;地市护网有玄机&#xff0c;一路磕磕又绊绊&#xff0c;终是不负领导盼。 扯远了-_-!!&#xff0c;年底来了一个地市级护网&#xff0c;开头挺顺利的&#xff0c…

XIAO ESP32 S3网络摄像头——2视频获取

本文主要是使用XIAO Esp32 S3制作网络摄像头的第2步,获取摄像头图像。 1、效果如下: 2、所需硬件 3、代码实现 3.1硬件代码: #include "WiFi.h" #include "WiFiClient.h" #include "esp_camera.h" #include "camera_pins.h"// 设…

uniapp:微信小程序文本长按无法出现复制菜单

一、问题描述 在集成腾讯TUI后&#xff0c;为了能让聊天文本可以复制&#xff0c;对消息组件的样式进行修改&#xff0c;主要是移除下面的user-select属性限制&#xff1a; user-select: none;-webkit-user-select: none;-khtml-user-select: none;-moz-user-select: none;-ms…

2025:OpenAI的“七十二变”?

朋友们&#xff0c;准备好迎接AI的狂欢了吗&#xff1f;&#x1f680; 是不是跟我一样&#xff0c;每天醒来的第一件事就是看看AI领域又有什么新动向&#xff1f; 尤其是那个名字如雷贯耳的 OpenAI&#xff0c;简直就是AI界的弄潮儿&#xff0c;一举一动都牵动着我们这些“AI发…

无人机频射信号检测数据集,平均正确识别率在94.3%,支持yolo,coco json,pasical voc xml格式的标注,364张原始图片

无人机频射信号检测数据集&#xff0c;平均正确识别率在94.3&#xff05;&#xff0c;支持yolo&#xff0c;coco json&#xff0c;pasical voc xml格式的标注&#xff0c;364张原始图片 可识别下面的信号&#xff1a; 图像传输信号LFST &#xff08;Image_Transmission_sign…

柱状图中最大的矩形 - 困难

************* c topic: 84. 柱状图中最大的矩形 - 力扣&#xff08;LeetCode&#xff09; ************* chenck the topic first: Think about the topics I have done before. the rains project comes:盛最多水的容器 - 中等难度-CSDN博客https://blog.csdn.net/ElseWhe…

第17篇 使用数码管实现计数器___ARM汇编语言程序<四>

Q&#xff1a;如何使用定时器实现数码管循环计数器&#xff1f; A&#xff1a;DE1-SoC_Computer系统有许多硬件定时器&#xff0c;本次实验使用A9 Private Timer定时器实现延时&#xff1a;定时器首先向Load寄存器写入计数值&#xff0c;然后向Control寄存器中的使能位E写1来启…

SSM 进销存系统

&#x1f942;(❁◡❁)您的点赞&#x1f44d;➕评论&#x1f4dd;➕收藏⭐是作者创作的最大动力&#x1f91e; &#x1f496;&#x1f4d5;&#x1f389;&#x1f525; 支持我&#xff1a;点赞&#x1f44d;收藏⭐️留言&#x1f4dd;欢迎留言讨论 &#x1f525;&#x1f525;&…

通过Cephadm工具搭建Ceph分布式存储以及通过文件系统形式进行挂载的步骤

1、什么是Ceph Ceph是一种开源、分布式存储系统&#xff0c;旨在提供卓越的性能、可靠性和可伸缩性。它是为了解决大规模数据存储问题而设计的&#xff0c;使得用户可以在无需特定硬件支持的前提下&#xff0c;通过普通的硬件设备来部署和管理存储解决方案。Ceph的灵活性和设计…

【Rust自学】8.4. String类型 Pt.2:字节、标量值、字形簇以及字符串的各类操作

8.4.0. 本章内容 第八章主要讲的是Rust中常见的集合。Rust中提供了很多集合类型的数据结构&#xff0c;这些集合可以包含很多值。但是第八章所讲的集合与数组和元组有所不同。 第八章中的集合是存储在堆内存上而非栈内存上的&#xff0c;这也意味着这些集合的数据大小无需在编…

svn分支相关操作(小乌龟操作版)

在开发工作中进行分支开发&#xff0c;涉及新建分支&#xff0c;分支切换&#xff0c;合并分支等 新建远程分支 右键选择branch/tagert按钮 命名分支的路径名称 点击确定后远程分支就会生成一个当时命名的文件夹&#xff08;开发分支&#xff09; 分支切换 一般在开发阶段&a…

24年收尾之作------动态规划<六> 子序列问题(含对应LeetcodeOJ题)

目录 引例 经典LeetCode OJ题 1.第一题 2.第二题 3.第三题 4.第四题 5.第五题 6.第六题 7.第七题 引例 OJ传送门 LeetCode<300>最长递增子序列 画图分析: 使用动态规划解决 1.状态表示 dp[i]表示以i位置元素为结尾的子序列中&#xff0c;最长递增子序列的长度 2.…

蓝牙|软件 Qualcomm S7 Sound Platform开发系列之初级入门指南

本文适用范围 ADK24.2~ 问题/功能描述 S7开发环境搭建与编译介绍 实现方案 本文介绍适用于windows平台Application部分,audio ss的说明会在下一篇文章在做说明,Linux平台如果不进行AI算法的开发,个人认知是没有必要配置,若是做服务器倒是不错的选择.因为编译完成后烧录调试还…

Redis - 4 ( 9000 字 Redis 入门级教程 )

一&#xff1a; Zset 有序集合 1.1 常用命令 有序集合在 Redis 数据结构中相较于字符串、列表、哈希和集合稍显陌生。它继承了集合中元素不允许重复的特点&#xff0c;但与集合不同的是&#xff0c;有序集合的每个元素都关联一个唯一的浮点分数&#xff08;score&#xff09;…

ubuntu 使用samba与windows共享文件[注意权限配置]

在Ubuntu上使用Samba服务与Windows系统共享文件&#xff0c;需要正确配置Samba服务以及相应的权限。以下是详细的步骤&#xff1a; 安装Samba 首先&#xff0c;确保你的Ubuntu系统上安装了Samba服务。 sudo apt update sudo apt install samba配置Samba 安装完成后&#xff0c…

打印进度条

文章目录 1.Python语言实现(1)黑白色(2)彩色&#xff1a;蓝色 2.C语言实现(1)黑白颜色(2)彩色版&#xff1a;红绿色 1.Python语言实现 (1)黑白色 import sys import timedef progress_bar(percentage, width50):"""打印进度条:param percentage: 当前进度百分比…

深度解析 LDA 与聚类结合的文本主题分析实战

🌟作者简介:热爱数据分析,学习Python、Stata、SPSS等统计语言的小高同学~🍊个人主页:小高要坚强的博客🍓当前专栏:《Python之文本分析》🍎本文内容:深度解析 LDA 与聚类结合的文本主题分析实战🌸作者“三要”格言:要坚强、要努力、要学习 目录 引言 技术框架…

点跟踪基准最早的论文学习解读:TAP-Vid: A Benchmark for Tracking Any Point in a Video—前置基础

TAP-Vid: A Benchmark for Tracking Any Point in a Video— TAP-Vid&#xff1a;跟踪视频中任意点的基准、 学习这一篇文章的本来的目的是为了学习一下TAP-NET便于理解后面用到的TAPIR方法的使用。 文章目录 TAP-Vid: A Benchmark for Tracking Any Point in a Video— TAP-V…

C进阶-字符串与内存函数介绍(另加2道典型面试题)

满意的话&#xff0c;记得一键三连哦&#xff01; 我们先看2道面试题 第一道&#xff1a; 我们画图理解&#xff1a; pa&#xff0c;先使用再&#xff0c;pa开始指向a【0】&#xff0c;之后pa向下移动一位&#xff0c;再解引用&#xff0c;指向a【1】&#xff0c;a【1】又指向…