人工智能(pytorch)搭建模型14-pytorch搭建Siamese Network模型(孪生网络),实现模型的训练与预测

news2024/9/28 7:20:01

大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型14-pytorch搭建Siamese Network模型(孪生网络),实现模型的训练与预测。孪生网络是一种用于度量学习(Metric Learning)和比较学习(Comparison Learning)的深度神经网络模型。它主要用于学习将两个输入样本映射到一个共享的嵌入空间,并衡量它们之间的相似性。
孪生网络通常由两个相同的子网络组成,这两个子网络共享参数和权重。每个子网络将输入样本分别映射到嵌入空间中的特征向量。这些特征向量可以被用来度量两个输入样本之间的相似性或距离。

文章目录:

  1. 引言
  2. Siamese Network模型原理
  3. 使用PyTorch搭建Siamese Network模型
    3.1 数据预处理
    3.2 模型架构设计
    3.3 损失函数选择
    3.4 模型训练与评估
  4. 实现代码
  5. 数据样例
  6. 结果与分析
  7. 总结

1. 引言

在计算机视觉领域,Siamese Network(孪生网络)被广泛应用于人脸识别、图像检索和目标跟踪等任务。Siamese Network模型通过将两个相似或不相似的输入序列映射到同一个特征空间中,并计算它们的相似度来实现任务目标。本文将介绍如何使用PyTorch搭建Siamese Network模型,并提供完整的代码示例。

2. Siamese Network模型原理

Siamese Network模型是一种基于孪生网络结构设计的深度学习模型。该模型的核心思想是通过共享相同的权重参数来处理两个输入序列,使得同类样本的特征表示更加接近,异类样本的特征表示更加远离。

模型的基本原理如下:

  1. 输入层:接受输入的两个序列数据(如图像、文本等)。
  2. 共享层:采用相同的权重参数处理两个输入序列数据,将它们映射到同一个特征空间中。
  3. 相似度计算层:计算两个输入序列在特征空间中的相似度得分。
  4. 损失函数:根据相似度得分和真实标签之间的差异,计算模型的损失值。
  5. 反向传播与优化:利用梯度下降算法,通过反向传播方法来优化模型的权重参数。

Siamese Network模型的数学原理可以通过以下方式表示:

假设我们有两个输入样本 x 1 x_1 x1 x 2 x_2 x2,它们分别通过共享的子网络 θ \theta θ映射到嵌入空间中的特征向量 h 1 h_1 h1 h 2 h_2 h2,即:

h 1 = θ ( x 1 ) , h 2 = θ ( x 2 ) h_1 = \theta(x_1),h_2 = \theta(x_2) h1=θ(x1),h2=θ(x2)

接下来,我们可以使用一种相似度度量函数 d ( h 1 , h 2 ) d(h_1, h_2) d(h1,h2)来计算 h 1 h_1 h1 h 2 h_2 h2之间的相似度或距离。常见的相似度度量函数包括欧氏距离、余弦相似度等。

在训练过程中,我们希望正样本对 ( x 1 , x 2 + ) (x_1, x_2^+) (x1,x2+)的特征向量在嵌入空间中更加接近,而负样本对 ( x 1 , x 2 − ) (x_1, x_2^-) (x1,x2)的特征向量在嵌入空间中更加远离。因此,我们可以定义一个对比损失函数 L \mathcal{L} L来衡量样本对的相似度或差异度,例如:

L ( x 1 , x 2 + , x 2 − ) = [ d ( h 1 , h 2 + ) − d ( h 1 , h 2 − ) + m ] + \mathcal{L}(x_1, x_2^+, x_2^-) = [d(h_1, h_2^+) - d(h_1, h_2^-) + m]_+ L(x1,x2+,x2)=[d(h1,h2+)d(h1,h2)+m]+

其中, [ ⋅ ] + [\cdot]_+ []+表示取正值操作, m m m是一个预先定义的边界值,用于控制正样本对和负样本对之间的距离间隔。

通过最小化损失函数 L \mathcal{L} L来更新网络的参数,我们可以使得正样本对在嵌入空间中更加接近,负样本对在嵌入空间中更加远离。

整个Siamese Network模型的训练过程可以使用梯度下降等优化算法进行。在前向传播过程中,输入样本经过子网络映射得到特征向量。然后计算损失函数并进行反向传播,根据梯度更新网络参数,以逐渐优化特征表示和相似性度量。

这就是Siamese Network模型的数学原理,其中通过共享子网络和对比损失函数,可以学习到适应度量学习任务的特征表示,并在嵌入空间中度量样本之间的相似性。
在这里插入图片描述

3. 使用PyTorch搭建Siamese Network模型

3.1 数据预处理

在使用Siamese Network模型前,需要对数据进行预处理,包括数据加载、数据划分和数据增强等操作。以人脸识别为例,可以使用FaceNet数据集,其中包含多个人的人脸图像样本。

3.2 模型架构设计

在PyTorch中搭建Siamese Network模型的关键是定义模型的网络结构。可以使用卷积神经网络(CNN)作为共享层,并添加一些全连接层和激活函数。具体的模型架构可参考以下示例代码:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        # Shared layers (convolutional layers)
        self.conv1 = nn.Conv2d(1, 64, 10)
        self.conv2 = nn.Conv2d(64, 128, 7)
        self.conv3 = nn.Conv2d(128, 128, 4)
        self.conv4 = nn.Conv2d(128, 256, 4)
        # Fully connected layers
        self.fc1 = nn.Linear(9216, 4096)
        self.fc2 = nn.Linear(4096, 1024)
        self.fc3 = nn.Linear(1024, 128)

    def forward(self, x1, x2):
        x1 = F.relu(self.conv1(x1))
        x1 = F.max_pool2d(x1, 2)
        x1 = F.relu(self.conv2(x1))
        x1 = F.max_pool2d(x1, 2)
        x1 = F.relu(self.conv3(x1))
        x1 = F.max_pool2d(x1, 2)
        x1 = F.relu(self.conv4(x1))
        x1 = F.max_pool2d(x1, 2)
        x1 = x1.view(x1.size()[0], -1)
        x1 = F.relu(self.fc1(x1))
        x1 = F.relu(self.fc2(x1))
        x1 = self.fc3(x1)

        x2 = F.relu(self.conv1(x2))
        x2 = F.max_pool2d(x2, 2)
        x2 = F.relu(self.conv2(x2))
        x2 = F.max_pool2d(x2, 2)
        x2 = F.relu(self.conv3(x2))
        x2 = F.max_pool2d(x2, 2)
        x2 = F.relu(self.conv4(x2))
        x2 = F.max_pool2d(x2, 2)
        x2 = x2.view(x2.size()[0], -1)
        x2 = F.relu(self.fc1(x2))
        x2 = F.relu(self.fc2(x2))
        x2 = self.fc3(x2)

        return x1, x2

3.3 损失函数选择

在Siamese Network模型中,常用的损失函数是对比损失函数(Contrastive Loss),用于度量两个输入序列之间的相似度。可以通过定义一个自定义的损失函数来实现对比损失函数的计算。

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive

3.4 模型训练与评估

在训练Siamese Network模型前,需要加载数据并将其划分为训练集和测试集。然后,使用梯度下降算法来优化参数,并在每个epoch结束时计算模型的损失值和准确率。下面是训练与评估的代码示例:

def train(model, train_loader, optimizer, criterion):
    model.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (data1, data2, label) in enumerate(train_loader):
        optimizer.zero_grad()
        output1, output2 = model(data1, data2)
        loss = criterion(output1, output2, label)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(output1.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()

    acc = 100 * correct / total
    avg_loss = train_loss / len(train_loader)

    return avg_loss, acc

def test(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (data1, data2, label) in enumerate(test_loader):
            output1, output2 = model(data1, data2)
            loss = criterion(output1, output2, label)

            test_loss += loss.item()
            _, predicted = torch.max(output1.data, 1)
            total += label.size(0)
            correct += (predicted == label).sum().item()

    acc = 100 * correct / total
    avg_loss = test_loss / len(test_loader)

    return avg_loss, acc

4. 数据样例

为了方便演示,这里给出几条数据样例,用于训练和测试Siamese Network模型。数据样例应包含两个输入序列(如图像对)以及它们的标签。

# 加载数据集
import torch
from torch.utils.data import Dataset, DataLoader
class SiameseDataset(Dataset):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        img1 = torch.randn(1, 28, 28)  # 假设图像维度为 3x224x224
        img2 = torch.randn(1, 28, 28)
        label = torch.randint(0, 2, (1,)).item()  # 随机生成标签

        return img1, img2, label

def split_dataset(dataset, train_ratio=0.8):
    train_size = int(train_ratio * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

    return train_dataset, test_dataset

# 设置随机种子,以保证可复现性
torch.manual_seed(2023)

# 创建自定义数据集对象
dataset = SiameseDataset(num_samples=1000)

# 划分数据集
train_dataset, test_dataset = split_dataset(dataset, train_ratio=0.8)

# 创建数据加载器
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

6. 训练结果与分析

# 配置模型及优化器
model = SiameseNetwork()
criterion = ContrastiveLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 模型训练与测试
for epoch in range(10):
    train_loss, train_acc = train(model, train_loader, optimizer, criterion)
    test_loss, test_acc = test(model, test_loader, criterion)
    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Test Loss={test_loss:.4f}")

运行结果:

Epoch 1: Train Loss=139617133.5882, Test Loss=4168544.1429
Epoch 2: Train Loss=18824583.2325, Test Loss=351236.0737
Epoch 3: Train Loss=129070.3893, Test Loss=0.1328
Epoch 4: Train Loss=0.1287, Test Loss=0.1228
Epoch 5: Train Loss=0.1291, Test Loss=0.1306
Epoch 6: Train Loss=0.1219, Test Loss=0.1373
Epoch 7: Train Loss=0.1259, Test Loss=0.1183
Epoch 8: Train Loss=0.1219, Test Loss=0.1127
Epoch 9: Train Loss=0.1278, Test Loss=0.1194
Epoch 10: Train Loss=0.1231, Test Loss=0.1116

7. 总结

本文主要介绍了Siamese Network模型的原理和应用项目,并使用PyTorch实现了该模型。通过搭建Siamese Network模型,可以实现诸如人脸识别、图像检索等任务。最后,通过完整的代码示例和实验结果分析,验证了Siamese Network模型的有效性和可行性。

这篇文章基于PyTorch框架和Siamese Network模型详细介绍了该模型的原理、实现方法以及训练测试流程,提供了完整的代码和数据样例,并进行了实验结果与分析。相信读者可以通过本文了解到Siamese Network模型的基本概念和应用,为进一步研究和实践提供参考。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/688293.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

教学实训模块升级,助力应用型数据科学人才培养|ModelWhale 版本更新

初夏梅雨季,ModelWhale 迎来新一轮版本更新,多角度优化各领域用户的使用体验。 本次更新中,ModelWhale 主要进行了以下功能迭代: • 优化 课程作业布置(团队版✓ ) • 新增 课程作业关联至课件&#xff…

MySQL的服务层和存储引擎层

1. 服务层(Server Layer): 服务层是MySQL的顶层组件,负责处理客户端与MySQL服务器之间的交互。它提供了一组API和协议,使应用程序能够连接到MySQL服务器,并发送查询、事务管理、用户权限控制等请求。服务层…

6-js基础-1

JavaScript 基础 - 1 了解变量、数据类型、运算符等基础概念,能够实现数据类型的转换,结合四则运算体会如何编程。 JavaScript介绍变量常量数据类型运算符实战案例 重点单词: js介绍 能说出JavaScript 是什么? 怎么写? 能写出JavaScript 输…

【C++详解】——红黑树

目录 红黑树的概念 红黑树的性质 红黑树节点的定义 红黑树的结构 红黑树的插入操作 情况一 情况二 情况三 红黑树的验证 红黑树的查找 红黑树与AVL树的比较 红黑树的概念 红黑树,是一种二叉搜索树,但在每个结点上增加一个存储位表示…

基于SpringBoot的在线拍卖系统【附ppt和万字文档(Lun文)和搭建文档】

主要功能 主要功能 前台登录: ①首页:轮播图、竞拍公告、拍卖商品展示 ②拍卖商品:分类:手机、数码、电器等,可以点击商品竞拍 ③竞拍公告:可以查看竞拍的信息 ④留言反馈:用户可以提交留言 ⑤…

如何办理跨境电商营业执照?加速度jsudo

如今电商行业的发展持续火热,跨境电商亦是如此,随着疫情的好转,各行各业也逐渐好转起来,此时也是一个做跨境电商的好时机,那么做跨境电商的前提需要什么呢?当然是营业执照了,那么如何办理跨境电商营业执照…

Flutter Ping 检查服务器通讯信号强度

Flutter Ping 检查服务器通讯信号强度 前言 对通讯敏感的程序中,我们除了检查当前网络通道外,还要检查与服务器实际的型号强度。 一般我们采用 ping 的方式返回型号的强度和稳定程度。 dart_ping 包 https://pub-web.flutter-io.cn/packages/dart_ping …

【Java】Java 链表类详记

本文仅供学习参考! 相关文章链接: https://www.runoob.com/java/java-linkedlist.html https://www.developer.com/java/java-linkedlist-class/ https://www.w3schools.com/java/java_linkedlist.asp Java 中链表的类型 从最基本的角度来说&#xff0c…

EBO绘制矩形

数据: float vertices[] { 0.5f, 0.5f, 0.0f, // top right 0.5f, -0.5f, 0.0f, // bottom right -0.5f, -0.5f, 0.0f, // bottom left -0.5f, 0.5f, 0.0f // top left }; unsigned int indices[] { // note that we start from 0! 0, 1, 3, // first triangle 1,…

UE4自定义资产类型编辑器实现

在虚幻引擎中,资产是具有持久属性的对象,可以在编辑器中进行操作。 Unreal 附带多种资源类型,从 UStaticMesh 到 UMetasoundSources 等等。 自定义资源类型是实现专门对象的好方法,这些对象需要专门构建的编辑器来进行高效操作。 …

SpringBoot3 快速入门及原理分析

1. 环境要求 环境&工具版本SpringBoot3.0.5IDEA2021.2.1Java17Maven3.5Tomcat10.0 2. SpringBoot是什么 SpringBoot 能帮我们简单、快速地创建一个独立的、生产级别的 Spring 应用(说明:SpringBoot底层是Spring) SpringBoot 应用只需…

CentOS7安装使用Nginx

CentOS7安装使用Nginx CentOS7安装使用Nginx1.安装1.1下载1.2 检验服务器上是否有nginx1.3 解压安装1.4 验证 2.部署2.1基本知识2.1.1常用命令2.1.2配置文件 2.2 配置效果前端后端 CentOS7安装使用Nginx 本文使用的nginx版本为1.22.1 Nginx发布版本分为主线版本和稳定版本&…

如何解决多线程卡死问题?四招教你轻松应对!

多线程大家都用过,可以让一个程序同时执行多个任务,提高效率和性能,一个人干的慢,三个人干。但是,多线程也带来了一些问题和挑战,比如线程同步、线程安全、线程死锁等问题,三个人抢一碗米饭&…

操作系统OS(一)磁盘与文件系统

计算机存储 计算机只能看懂1和0组成的语言,所以计算机存储数据的大小就是存储了多少个1和0. 比特位bit(位) 是计算机世界中最小的存储单位,每个1或者0占据1bit,表示二进制位 字节byte 由8个二进制位构成,1…

OpenGL 几何着色器

1.效果展示 爆破物体。 2.简介 在顶点和片段着色器之间有一个可选的几何着色器,几何着色器的输入是一个图元(如点或三角形)的一组顶点。几何着色器可以在顶点发送到下一着色器阶段之前对它们随意变换。然而,几何着色器最有趣的…

RabbitMQ 2023面试5题(四)

一、RabbitMQ有哪些作用 RabbitMQ是一个消息队列中间件,它的作用是利用高效可靠的消息传递机制进行与平台无关的数据交流,并基于数据通信来进行的分布式系统的集成,主要作用有以下方面: 实现应用程序之间的异步和解耦&#xff1a…

[Africa battleCTF 2023 prequal] CPR部分

非州的比赛,说是总体简单,但也有几个难题0解,估计依然是等不到WP。 这个界面还挺好,除了慢以外没大问题。 Rev SEYI 题目很简单,程序报病毒,win11上的defender关上不容易呀。我的电脑怎么就不能听我的呢…

【Java高级语法】(十八)Optional类:解锁Java的Optional魔法:消灭那些隐匿的空指针,还程序世界一个安稳!~

Java高级语法详解之Optional类 1️⃣ 概念2️⃣ 优势和缺点3️⃣ 使用3.1 常用操作API3.2 案例3.3 使用技巧 4️⃣ 应用场景5️⃣ 实现原理🌾 总结 1️⃣ 概念 Optional类是Java 8引入的新特性,旨在解决空值(null)的处理问题。它…

ProtoBuf介绍与使用

文章目录 1、ProtoBuf概述2、下载和安装3、简单使用 1、ProtoBuf概述 Protobuf(Protocol Buffers)是由Google开发的一种语言无关的数据序列化格式。它旨在将结构化数据(如结构化消息或文档)高效地序列化为紧凑的二进制表示&#…

python GUI工具之PyQt5模块,pyCharm 配置PyQt5可视化窗口

https://doc.qt.io/qt-5/qtwidgets-module.html https://doc.qt.io/qt-5/qt.html#AlignmentFlag-enum 一、简介 PyQt是Qt框架的Python语言实现,由Riverbank Computing开发,是最强大的GUI库之一。PyQt提供了一个设计良好的窗口控件集合,每一…