【机器学习】--- 自监督学习

news2024/11/15 12:09:21

在这里插入图片描述

1. 引言

机器学习近年来的发展迅猛,许多领域都在不断产生新的突破。在监督学习和无监督学习之外,自监督学习(Self-Supervised Learning, SSL)作为一种新兴的学习范式,逐渐成为机器学习研究的热门话题之一。自监督学习通过从数据中自动生成标签,避免了手工标注的代价高昂,进而使得模型能够更好地学习到有用的表示。

自监督学习的应用领域广泛,涵盖了图像处理、自然语言处理、音频分析等多个方向。本篇博客将详细介绍自监督学习的核心思想、常见的自监督学习方法及其在实际任务中的应用。我们还将通过具体的代码示例来加深对自监督学习的理解。

2. 自监督学习的核心思想

自监督学习的基本理念是让模型通过从数据本身生成监督信号进行训练,而无需人工标注。常见的方法包括生成对比任务、预测数据中的某些属性或部分等。自监督学习的关键在于设计出有效的预训练任务,使模型在完成这些任务的过程中能够学习到数据的有效表示。

2.1 自监督学习与监督学习的区别

在监督学习中,模型的训练需要依赖大量的人工标注数据,而无监督学习则没有明确的标签。自监督学习介于两者之间,它通过从未标注的数据中创建监督信号,完成预训练任务。通常,自监督学习的流程可以分为两步:

  1. 预训练:利用自监督任务对模型进行预训练,使模型学习到数据的有效表示。
  2. 微调:将预训练的模型应用到具体任务中,通常需要进行一些监督学习的微调。
2.2 常见的自监督学习任务

常见的自监督任务包括:

  • 对比学习(Contrastive Learning):从数据中生成正样本和负样本对,模型需要学会区分正负样本。
  • 预文本任务(Pretext Tasks):如图像块预测、顺序预测、旋转预测等任务。
2.3 自监督学习的优点

自监督学习具备以下优势:

  • 减少对人工标注的依赖:通过生成任务标签,大大降低了数据标注的成本。
  • 更强的泛化能力:在大量未标注的数据上进行预训练,使模型能够学习到通用的数据表示,提升模型在多个任务上的泛化能力。

3. 自监督学习的常见方法

在自监督学习中,研究者设计了多种预训练任务来提升模型的学习效果。以下是几种常见的自监督学习方法。

3.1 对比学习(Contrastive Learning)

对比学习是目前自监督学习中最受关注的一个方向。其基本思想是通过构造正样本对(相似样本)和负样本对(不同样本),让模型学习区分样本之间的相似性。典型的方法包括SimCLR、MoCo等。

SimCLR 的实现
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
import numpy as np

# SimCLR数据增强
class SimCLRTransform:
    def __init__(self, size):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(size=size),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=(3, 3)),
            transforms.ToTensor()
        ])

    def __call__(self, x):
        return self.transform(x), self.transform(x)

# 定义对比损失
class NTXentLoss(nn.Module):
    def __init__(self, temperature):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature

    def forward(self, z_i, z_j):
        batch_size = z_i.size(0)
        z = torch.cat([z_i, z_j], dim=0)
        sim_matrix = torch.mm(z, z.t()) / self.temperature
        mask = torch.eye(2 * batch_size, dtype=torch.bool).to(sim_matrix.device)
        sim_matrix.masked_fill_(mask, -float('inf'))
        
        positives = torch.cat([torch.diag(sim_matrix, batch_size), torch.diag(sim_matrix, -batch_size)], dim=0)
        negatives = sim_matrix[~mask].view(2 * batch_size, -1)
        
        logits = torch.cat([positives.unsqueeze(1), negatives], dim=1)
        labels = torch.zeros(2 * batch_size).long().to(logits.device)
        
        loss = nn.CrossEntropyLoss()(logits, labels)
        return loss

# 定义模型架构
class SimCLR(nn.Module):
    def __init__(self, base_model, projection_dim=128):
        super(SimCLR, self).__init__()
        self.backbone = base_model
        self.projector = nn.Sequential(
            nn.Linear(self.backbone.fc.in_features, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )

    def forward(self, x):
        h = self.backbone(x)
        z = self.projector(h)
        return z

# 模型训练
def train_simclr(model, train_loader, epochs=100, lr=1e-3, temperature=0.5):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = NTXentLoss(temperature)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for x_i, x_j in train_loader:
            optimizer.zero_grad()
            z_i = model(x_i)
            z_j = model(x_j)
            loss = criterion(z_i, z_j)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader)}')

# 示例:在CIFAR-10上进行SimCLR训练
from torchvision.datasets import CIFAR10

train_dataset = CIFAR10(root='./data', train=True, transform=SimCLRTransform(32), download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)

resnet_model = models.resnet18(pretrained=False)
simclr_model = SimCLR(base_model=resnet_model)

train_simclr(simclr_model, train_loader)

以上代码展示了如何实现SimCLR对比学习模型。通过数据增强生成正样本对,使用NT-Xent损失函数来区分正负样本对,进而让模型学习到有效的数据表示。

3.2 预文本任务(Pretext Tasks)

除了对比学习,预文本任务也是自监督学习中的一种重要方法。常见的预文本任务包括图像块预测、旋转预测、Jigsaw拼图任务等。我们以Jigsaw拼图任务为例,展示如何通过打乱图像块顺序,让模型进行重新排序来学习图像表示。

Jigsaw任务的实现
import random

# 定义Jigsaw数据预处理
class JigsawTransform:
    def __init__(self, size, grid_size=3):
        self.size = size
        self.grid_size = grid_size
        self.transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor()
        ])

    def __call__(self, x):
        x = self.transform(x)
        blocks = self.split_into_blocks(x)
        random.shuffle(blocks)
        return torch.cat(blocks, dim=1), torch.tensor([i for i in range(self.grid_size ** 2)])

    def split_into_blocks(self, img):
        c, h, w = img.size()
        block_h, block_w = h // self.grid_size, w // self.grid_size
        blocks = []
        for i in range(self.grid_size):
            for j in range(self.grid_size):
                block = img[:, i*block_h:(i+1)*block_h, j*block_w:(j+1)*block_w]
                blocks.append(block.unsqueeze(0))
        return blocks

# 定义Jigsaw任务模型
class JigsawModel(nn.Module):
    def __init__(self, base_model):
        super(JigsawModel, self).__init__()
        self.backbone = base_model
        self.classifier = nn.Linear(base_model.fc.in_features, 9)

    def forward(self, x):
        features = self.backbone(x)
        out = self.classifier(features)
        return out

# 示例:在CIFAR-10上进行Jigsaw任务训练
train_dataset = CIFAR10(root='./data', train=True, transform=JigsawTransform(32), download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)

jigsaw_model = JigsawModel(base_model=resnet_model)

# 训练过程同样可以采用类似SimCLR的方式进行

Jigsaw任务通过打乱图像块并要求模型恢复原始顺序来学习图像的表示,训练方式与

普通的监督学习任务相似,核心是构建预训练任务并生成标签。

4. 自监督学习的应用场景

自监督学习目前在多个领域得到了成功的应用,包括但不限于:

  • 图像处理:通过预训练任务学习到丰富的图像表示,进而提升在图像分类、目标检测等任务上的表现。
  • 自然语言处理:BERT等模型的成功应用展示了自监督学习在文本任务中的巨大潜力。
  • 时序数据分析:例如在视频处理、音频分析等领域,自监督学习也展示出了强大的能力。

5. 结论

自监督学习作为机器学习中的一个新兴热点,极大地推动了无标注数据的利用效率。通过设计合理的预训练任务,模型能够学习到更加通用的数据表示,进而提升下游任务的性能。在未来,自监督学习有望在更多实际应用中发挥重要作用,帮助解决数据标注昂贵、难以获取的难题。

在这篇文章中,我们不仅阐述了自监督学习的基本原理,还通过代码示例展示了如何实现对比学习和Jigsaw任务等具体方法。通过深入理解这些技术,读者可以尝试将其应用到实际任务中,从而提高模型的表现。

参考文献

  1. Chen, Ting, et al. “A simple framework for contrastive learning of visual representations.” International conference on machine learning. PMLR, 2020.
  2. Gidaris, Spyros, and Nikos Komodakis. “Unsupervised representation learning by predicting image rotations.” International Conference on Learning Representations. 2018.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2141884.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C++题解】1996. 每个小组的最大年龄

欢迎关注本专栏《C从零基础到信奥赛入门级(CSP-J)》 问题:1996. 每个小组的最大年龄 类型:二维数组 题目描述: 同学们在操场上排成了一个 n 行 m 列的队形,每行的同学属于一个小组,请问每个小…

PCIe进阶之TL:Completion Rules TLP Prefix Rules

1 Completion Rules & TLP Prefix Rules 1.1 Completion Rules 所有的 Read、Non-Posted Write 和 AtomicOp Request 都需要返回一个 Completion。Completion 有两种类型:一种带数据负载的,一种不带数据负载的。以下各节定义了 Completion header 中每个字段的规则。 C…

【磨皮美白】基于Matlab的人像磨皮美白处理算法,Matlab处理

博主简介:matlab图像代码项目合作(扣扣:3249726188) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 本次案例是基于Matlab的图像磨皮美白处理,用matlab实现。 一、案例背景和算法介绍 …

【图像匹配】基于SURF算法的图像匹配,matlab实现

博主简介:matlab图像代码项目合作(扣扣:3249726188) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 本次案例是基于基于SURF算法的图像匹配,用matlab实现。 一、案例背景和算法介绍 前…

7天速成前端 ------学习日志 (继苍穹外卖之后)

前端速成计划总结: 全26h课程,包含html,css,js,vue3,预计7天内学完。 起始日期:9.16 预计截止:9.22 每日更新,学完为止。 学前计划 课…

文字loading加载

效果 1. 导入库 import sys from PyQt5.QtCore import QTimer, Qt, QThread, pyqtSignal from PyQt5.QtGui import QPainter, QFont, QColor, QBrush from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QPushButton, QProgressBar, QLabel 代码首先导入了P…

编辑器拓展(入门与实践)

学习目标:入门编辑器并实现几个简单的工具 菜单编辑器 MenuItem [MenuItem("编辑器拓展/MenuItem")]static void MenuItem(){Debug.Log("这是编辑器拓展");} } 案例 1:在场景中的 GameObject 设置 1. 设置面板2. 直接创建 GameObject 结构…

2-96 基于matlab的SMOTE数据扩充算法

基于matlab的SMOTE数据扩充算法,主动设置数据扩充百分比,并考虑最近邻居数进行扩充,计算样本到他所在类样本集中所有样本距离,从样本的K近邻中随机选择若干样本添加到扩充样本集。程序已调通,可直接运行。 下载源程序…

c++中引用是通过指针的方式实现

其实在汇编层面上&#xff0c;引用的代码和指针的代码是一致的。 先看指针情况下的代码分析&#xff0c;如下所示&#xff1a; #include <iostream>using namespace std;void fuzhi(int *x)//引用传参 {*x 10; }int main(int argc, char** argv) {int a 0;int b;a …

LeetCode[简单] 283.移动零

给定一个数组 nums&#xff0c;编写一个函数将所有 0 移动到数组的末尾&#xff0c;同时保持非零元素的相对顺序。 请注意 &#xff0c;必须在不复制数组的情况下原地对数组进行操作。 思路&#xff1a;利用快慢指针&#xff0c;快指针遍历数组&#xff0c;慢指针是非零元素索…

【D3.js in Action 3 精译_023】3.3 使用 D3 将数据绑定到 DOM 元素

当前内容所在位置&#xff1a; 第一部分 D3.js 基础知识 第一章 D3.js 简介&#xff08;已完结&#xff09; 1.1 何为 D3.js&#xff1f;1.2 D3 生态系统——入门须知1.3 数据可视化最佳实践&#xff08;上&#xff09;1.3 数据可视化最佳实践&#xff08;下&#xff09;1.4 本…

销管系统 —— P14 菜单项悬停高亮显示遇到的问题

悬停在子菜单背景颜色并没有显示&#xff0c;为什么&#xff1a; 什么是后代选择器 —— 选中父元素 后代中 满足条件的元素&#xff1b;这个子菜单menu—item它既满足上面的也满足下面的&#xff0c;按这个顺序的话&#xff0c;下面的就被覆盖了&#xff08;CSS优先级规则&…

Nginx实用篇:实现负载均衡、限流与动静分离

Nginx实用篇&#xff1a;实现负载均衡、限流与动静分离 | 原创作者/编辑&#xff1a;凯哥Java | 分类&#xff1a;Nginx学习系列教程 Nginx 作为一款高性能的 HTTP 服务器及反向代理解决方案&#xff0c;在互联网架构中扮演着至关重要的角色。它…

可视化深度网络的强大工具:Grad-CAM介绍与使用步骤

《博主简介》 小伙伴们好&#xff0c;我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 &#x1f44d;感谢小伙伴们点赞、关注&#xff01; 《------往期经典推荐------》 一、AI应用软件开发实战专栏【链接】 项目名称项目名称1.【人脸识别与管理系统开发…

第一次安装Pytorch

1、新版本的Anaconda内置的python版本是3.12&#xff0c; 目前 Windows 上的 PyTorch 仅支持 Python 3.8-3.11;不支持 Python 2.x。 1、创建运行环境 在不创建虚拟环境的情况下&#xff0c;不建议使用最新的Python和Anaconda。 在几次失败后&#xff0c;我使用的是Anaconda3-2…

单相可控整流电路(单相半波整流电路、单相桥式全控整流电路)

目录 1. 单相半波整流电路 2. 单相桥式全控整流电路 单相可控整流电路是利用可控硅&#xff08;晶闸管&#xff09;将交流电转换为直流电的电路&#xff0c;主要有两种常见类型&#xff1a;单相半波整流电路和单相桥式全控整流电路。 1. 单相半波整流电路 单相半波整流电路是…

python实现多个pdf文件合并

打印发票时&#xff0c;需要将pdf合并成一个&#xff0c;单页两张打印。网上一些pdf合并逐渐收费&#xff0c;这玩意儿都能收费&#xff1f;自己写一个脚本使用。 实现代码&#xff1a; 输入pdf文件夹路径data_dir&#xff0c;统计目录下的“合并后的PDF”文件夹下&#xff0c;…

十六,Spring Boot 整合 Druid 以及使用 Druid 监控功能

十六&#xff0c;Spring Boot 整合 Druid 以及使用 Druid 监控功能 文章目录 十六&#xff0c;Spring Boot 整合 Druid 以及使用 Druid 监控功能1. Druid 的基本介绍2. 准备工作&#xff1a;3. Druid 监控功能3.1 Druid 监控功能 —— Web 关联监控3.2 Druid 监控功能 —— SQL…

数组学习内容

动态初始化 只给长度&#xff0c;数据类型【】 数组名new 数据类型【数组长度】 内存图

MySQL篇(数值函数/)(持续更新迭代)

目录 常见函数一&#xff1a;数值函数 一、常见数值函数 1. 基本函数 2. 角度与弧度互换函数 3. 三角函数 4. 指数与对数 5. 进制间的转换 常见函数二&#xff1a;日期函数 一、常见日期函数 二、SQL演示 1. curdate&#xff1a;当前日期 2. curtime&#xff1a;当前…