基于pytorch的神经网络与对比学习CL的训练示例实战和代码解析

news2025/1/12 20:38:04

目录

    • 对比学习原理解析
    • 构建一个对比学习模型(代码详解)
      • 导入库
      • 构建简单的神经网络
      • 构建对比学习的损失函数
      • 开始训练
    • 完整代码

对比学习原理解析

对比学习(Contrastive Learning)是一种无监督学习方法,用于从未标记的数据中学习表示。它的目标是通过将相似样本靠近并将不相似样本分开来学习有意义的表示。

对比学习的核心思想是通过最大化相似样本之间的相似性,并最小化不相似样本之间的相似性来训练模型。具体而言,对于每个样本,对比学习会构建一个正样本对和若干个负样本对。正样本对由同一类别或相似的样本组成,而负样本对则由不同类别或不相似的样本组成。然后,模型会被训练以使得正样本对的表示在嵌入空间中更接近,而负样本对的表示则更远离彼此。

对比学习的一个常见应用是自监督学习,其中模型通过利用输入数据的某种变换作为自动生成的标签进行训练。例如,在图像领域,可以通过对图像执行随机裁剪、旋转、遮挡等操作来生成正样本对和负样本对。模型被训练以使得经过变换的两个图像的表示更接近,而不同图像的表示更远离彼此。这样,模型可以学习到具有良好判别性能的图像表示,从而在后续的任务(如分类、检测等)中表现更好。

对比学习的优点之一是它不需要标记数据,因此可以广泛应用于大规模未标记数据集。它还具有较强的泛化能力和可解释性,使得它在许多领域如计算机视觉、自然语言处理和推荐系统等方面受到关注和应用。

图片识别一种对比学习:
在这里插入图片描述

加噪声自监督对比学习:
在这里插入图片描述

详情请见对比学习解释:https://blog.csdn.net/Exploer_TRY/article/details/116502372

构建一个对比学习模型(代码详解)

导入库

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

构建简单的神经网络

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * 5 * 5, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 128)
        )
        
    def forward(self, x1, x2):
        out1 = self.conv(x1)
        out1 = out1.view(out1.size(0), -1)
        out1 = self.fc(out1)
        
        out2 = self.conv(x2)
        out2 = out2.view(out2.size(0), -1)
        out2 = self.fc(out2)
        
        return out1, out2

.view() 的作用是将张量转换为指定的形状。这在深度学习中很常见,因为神经网络的输入和输出通常需要特定的形状。通过使用 torch.view(),可以方便地调整张量的维度,以匹配网络的要求。

构建对比学习的损失函数

对比学习的损失函数需要自己构建,基本思想事利用欧式距离等计算两个编码之间的相似度。

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        
    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive

解释:
1、margin 表示一个边界或阈值
用于指定相似性的门槛。它是 Contrastive Loss(对比损失)函数的一个超参数。

Contrastive Loss 旨在鼓励相似样本距离更近,不相似样本距离更远。通过引入 margin 值,我们可以控制样本之间被认为是相似的程度。

具体来说,在计算对比损失时,我们会考虑两个样本输出向量的欧氏距离(euclidean_distance)。如果 label 表示这两个样本是相同类别的,我们希望 euclidean_distance 较小,因为相似的样本应该更接近。反之,如果 label 表示这两个样本是不同类别的,我们希望 euclidean_distance 较大,因为不相似的样本应该相隔较远。

在计算损失时,我们使用以下公式:

loss_contrastive = (1 - label) * euclidean_distance^2 + label * max(margin - euclidean_distance, 0)^2

其中 (1 - label) * euclidean_distance^2 部分表示相似样本的损失,label * max(margin - euclidean_distance, 0)^2 部分表示不相似样本的损失。当 euclidean_distance 小于 margin 时,不相似样本的损失为 0。

因此,margin 控制了相似性的门槛,较小的 margin 值会鼓励更严格的相似性定义,而较大的 margin 值会放宽相似性的限制。通过调整 margin 的取值,可以根据任务需求和数据特点来优化对比学习模型的性能。

2、label 是一个表示样本对是否属于同一类别的标签
它用于计算对比损失(contrastive loss)。

在对比学习中,通常使用成对的样本来构建训练数据集。对每个样本对,label 可以取以下两个值之一:

如果两个样本属于相同的类别或具有相似性,则 label 为 0。
如果两个样本属于不同的类别或具有差异性,则 label 为 1。

通过这种方式,我们可以指定哪些样本是相似的(label=0),哪些样本是不相似的(label=1)。

在损失函数的前向传播方法中,label 用于加权计算对比损失。具体地,当 label=0 时,我们希望 euclidean_distance 较小,因为相似的样本应该更接近;当 label=1 时,我们希望 euclidean_distance 较大,因为不相似的样本应该相隔较远。

因此,label 在对比损失函数中起到了区分相似性和差异性的作用,并用于根据样本对的真实标签调整损失的计算。

3、计算欧式距离pairwise_distance
F.pairwise_distance 是 PyTorch 中的一个函数,用于计算两个张量之间的欧氏距离。它可以用来衡量两个特征向量或样本之间的相似性或差异性。

开始训练

## 加载数据
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=64)

## 构建网络
model = SiameseNetwork()
criterion = ContrastiveLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

## 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    for batch_idx, (img1, img2, label) in enumerate(train_dataloader):
        img1, img2, label = img1.to(device), img2.to(device), label.to(device)
        
        optimizer.zero_grad()
        output1, output2 = model(img1, img2)
        loss = criterion(output1, output2, label)
        loss.backward()
        optimizer.step()
        
        if (batch_idx+1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                epoch+1, num_epochs, batch_idx+1, len(train_dataloader), loss.item()))
 
 # 保存模型
torch.save(model.state_dict(), './model.pt')

完整代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * 5 * 5, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 128)
        )
        
    def forward(self, x1, x2):
        out1 = self.conv(x1)
        out1 = out1.view(out1.size(0), -1)
        out1 = self.fc(out1)
        
        out2 = self.conv(x2)
        out2 = out2.view(out2.size(0), -1)
        out2 = self.fc(out2)
        
        return out1, out2
        
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        
    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive

## 加载数据
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=64)

## 构建网络
model = SiameseNetwork()
criterion = ContrastiveLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

## 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    for batch_idx, (img1, img2, label) in enumerate(train_dataloader):
        img1, img2, label = img1.to(device), img2.to(device), label.to(device)
        
        optimizer.zero_grad()
        output1, output2 = model(img1, img2)
        loss = criterion(output1, output2, label)
        loss.backward()
        optimizer.step()
        
        if (batch_idx+1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                epoch+1, num_epochs, batch_idx+1, len(train_dataloader), loss.item()))
 
 # 保存模型
torch.save(model.state_dict(), './model.pt')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/732772.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

3 STM32标准库函数 之 窗口看门狗(WWDG)所有函数的介绍及使用

3 STM32标准库函数 之 窗口看门狗(WWDG)所有函数的介绍及使用 1. 图片有格式2 文字无格式三 库函数之窗口看门狗(WWDG)所有函数的介绍及使用前言一、IWDG库函数固件库函数预览1.1 函 数 IWDG_WriteAccessCmd1.1.1 IWDG_WriteAcces…

string模拟实现

文章目录 1.回顾库函数strcpymemcpystrcmpstrstr 2.回顾类和对象哪些函数里会有this指针?this指针调用方法结论:只要是不修改this指针指向的对象内容的成员函数,都可以加上const自己写了构造函数,编译器不会自动生成默认构造2.1构…

代码随想录第21天 | 回溯理论基础 77. 组合

回溯理论基础 回溯法解决的问题都可以抽象为树形结构,是的,我指的是所有回溯法的问题都可以抽象为树形结构! 因为回溯法解决的都是在集合中递归查找子集,集合的大小就构成了树的宽度,递归的深度,都构成的…

MySQL面试题总结(部分)

一.介绍MySQL为什么在面试中会提及 1.为什么要在面试时MySQL会被提及? 在面试中问MySQL问题有几个主要原因: 1. 数据库管理系统的重要性:MySQL作为一种常用的关系型数据库管理系统(RDBMS),在互联网和企业应用中得到广泛使用。对数…

Conda安装及使用方法(常用命令)

系列文章目录 文章目录 系列文章目录前言一、Conda下载安装1.下载2.安装3.配置国内源 二、Conda安装Python环境1.创建虚拟环境2.激活虚拟环境3.虚拟环境安装Python库 三、Conda环境环境执行脚本四、PyCharm配置Conda环境五、Conda迁移环境1.方式一:拷贝环境2.方式二…

Modbus通信从入门到精通_1_Modbus通信基础

关于Modbus通信的相关知识比较零碎,此处对查找到的知识点从理论,通信协议、使用方法方面进行整理。 值得学习的博文:Modbus及调试用软件介绍;Modbus协议和上位机应用开发介绍 文章目录 1. Modbus通信理论1.1 Modbus通信特点1.2 存…

多线程(1): 线程的创建、回收、分离

1. 多线程概述 多线程在项目开发过程中使用频率非常高,因为使用多线程可以提高程序的并发性。提高程序并发性有两种方式:(1)多线程 (2)多进程。但是多线程对系统资源的消耗会更加少一些,并且线程和进程执行效率差不多。 在执行系统应用程序时&#xff…

2023/7/8总结

Tomcat 启动:双击bin目录下的startup.bat文件停止:双击bin目录下的shutdown.bat 文件访问 :http://localhost:8080(默认是8080,可以修改) git的使用 打开git bash git config --global user.name "名…

Vue3---什么是路由缓存问题

使用带有参数的路由时需要注意的是,当用户从 /users/johnny 导航到 /users/jolyne 时,相同的组件实例将被重复使用。因为两个路由都渲染同个组件,比起销毁再创建,复用则显得更加高效。不过,这也意味着组件的生命周期钩…

500万PV的网站需要多少台服务器?

1. 衡量业务量的指标 衡量业务量的指标项有很多,比如,常见Web类应用中的PV、UV、IP。而比较贴近业务的指标项就是大家通常所说的业务用户数。但这个用户数比较笼统,其实和真实访问量有比较大的差距,所以为了更贴近实际业务量及压力…

什么是提示工程?

原文链接:芝士AI吃鱼 理解大规模人工智能模型为何如此行事是一门艺术。即使是最有成就的技术专家也会对大型语言模型 (LLM) 的意想不到的能力感到困惑,大型语言模型是ChatGPT等人工智能聊天机器人的基本构建模块。 因此,提示工程成为生成式 …

特征选择算法 | Matlab 基于最大互信息系数特征选择算法(MIC)的分类数据特征选择

文章目录 效果一览文章概述部分源码参考资料效果一览 文章概述 特征选择算法 | Matlab 基于最大互信息系数特征选择算法(MIC)的分类数据特征选择 部分源码 %--------------------

python 常用数据结构-列表

list 列表 列表定义与使用列表常用方法列表嵌套列表推导式 列表定义 列表是有序的可变元素的集合,使用中括号[]包围,元素之间用逗号分隔 列表是动态的,可以随时扩展和收缩 列表是异构的,可以同时存放不同类型的对象 列表中允…

阶乘后的零(力扣)数学 JAVA

给定一个整数 n ,返回 n! 结果中尾随零的数量。 提示 n! n * (n - 1) * (n - 2) * … * 3 * 2 * 1 示例 1: 输入:n 3 输出:0 解释:3! 6 ,不含尾随 0 示例 2: 输入:n 5 输出&…

WSL2 及 docker开发环境搭建

WSL2 及 docker开发环境搭建 1.使能WSL 控制面板->程序->程序和功能->启动或关闭Windows功能->勾选红框中选项->确认后重启电脑  2.下载Linux Kernel Update安装包 下载地址如下, 附件已将下载的安装包作为附件形式上传,…

ITIL 4服务连续性管理实践

一、目的和描述 关键信息 服务连续性管理实践的目的是确保灾难发生时,服务的可用性和性能能够保持在足够的水平。本实践提供了一个框架机制,利用产生有效响应的能力来构建组织的弹性,以保障关键利益相关者的利益,还有组织的声誉…

element 封装dialog弹窗组件鼠标移动到弹窗出现title

问题&#xff1a; element 封装dialog弹窗组件鼠标移动到弹窗出现title 封装的组件 <template><el-dialog title"111"v-bind"$attrs" v-on"$listeners" :visible.sync"show" ></el-dialog> </template><s…

02-webpack的热更新是如何做的,以及原理

一、是什么 HMR 可以理解为模块热替换&#xff0c;指在应用程序运行过程中&#xff0c;替换、添加、删除模块&#xff0c;而无需重新刷新整个应用. 如&#xff0c;我们在应用运行过程中修改了某个模块&#xff0c;通过自动刷新会导致整个应用的整体刷新&#xff0c;那页面中的…

pygame伪3d 实现地面效果

教程来自What is Mode 7? Let’s code it! 油管镜像 import cv2 import pygame import sys from pygame import gfxdraw import numpy as np(width, height) (800, 600) pygame.init() screen pygame.display.set_mode((width, height)) image pygame.image.load("11…

ElasticSearch基础学习(SpringBoot集成ES)

一、概述 什么是ElasticSearch&#xff1f; ElasticSearch&#xff0c;简称为ES&#xff0c; ES是一个开源的高扩展的分布式全文搜索引擎。 它可以近乎实时的存储、检索数据&#xff1b;本身扩展性很好&#xff0c;可以扩展到上百台服务器&#xff0c;处理PB级别的数据。 ES也…