孪生神经网络MATLAB实战[含源码]

news2024/7/11 9:13:57

​一、算法原理

      孪生神经网络( Siamese neural network)是一种深度学习网络,它使用两个或多个具有相同架构、共享相同参数和权重的相同子网。孪生网络通常用于寻找两个可比较事物之间的关系的任务。孪生网络的一些常见应用包括面部识别、签名验证或释义识别。孪生神经网络是基于两个人工神经网络建立的耦合架构,孪生神经网络以两个样本为输入,输出其嵌入高维空间的表征,以比较两个样本的相似程度,狭义的孪生神经网络由两个结构相同,且权重共享的神经网络拼接而成,网络框架如下图所示。

        广义的孪生神经网络(又称pseudo-siamese network,伪孪生神经网络),可由两个任意的神经网络拼接而成,可由卷积神经网络、循环神经网络等组成,网络框架如下图所示。

    简单来说,孪生神经网络就是衡量两个输入的相似程度。孪生神经网络有两个输入,将两个输入输入到两个神经网络,这两个神经网络分别将输入映射到新的空间,形成输入在新的空间中的表示。通过Loss的计算,评价两个输入的相似度。孪生神经网络在这些任务中表现良好,因为它们的共享权重意味着在训练过程中需要学习的参数更少,并且它们可以用相对较少的训练数据产生良好的结果。

二、代码实战

%% matlab学习之家
%% 孪生神经网络
clc
clear
%% 读取训练集
dataFolderTrain = "D:\S\孪生神经网络\images_background"; %% 更换路径
imdsTrain = imageDatastore(dataFolderTrain, ...
    IncludeSubfolders=true, ...
    LabelSource="none");
​
files = imdsTrain.Files;
parts = split(files,filesep);
labels = join(parts(:,(end-2):(end-1)),"-");
imdsTrain.Labels = categorical(labels);
​
%% 显示图片
idx = randperm(numel(imdsTrain.Files),8);
​
for i = 1:numel(idx)
    subplot(4,2,i)
    imshow(readimage(imdsTrain,idx(i)))
    title(imdsTrain.Labels(idx(i)),Interpreter="none");
end
​
batchSize = 10;
[pairImage1,pairImage2,pairLabel] = getTwinBatch(imdsTrain,batchSize);
​
for i = 1:batchSize
    if pairLabel(i) == 1
        s = "similar";
    else
        s = "dissimilar";
    end
    subplot(2,5,i)
    imshow([pairImage1(:,:,:,i) pairImage2(:,:,:,i)]);
    title(s)
end
%% 定义孪生神经网络架构
layers = [
    imageInputLayer([105 105 1],Normalization="none")
    convolution2dLayer(10,64,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal")
    reluLayer
    maxPooling2dLayer(2,Stride=2)
    convolution2dLayer(7,128,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal")
    reluLayer
    maxPooling2dLayer(2,Stride=2)
    convolution2dLayer(4,128,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal")
    reluLayer
    maxPooling2dLayer(2,Stride=2)
    convolution2dLayer(5,256,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal")
    reluLayer
    fullyConnectedLayer(4096,WeightsInitializer="narrow-normal",BiasInitializer="narrow-normal")];
​
net = dlnetwork(layers);
​
fcWeights = dlarray(0.01*randn(1,4096));
fcBias = dlarray(0.01*randn(1,1));
​
fcParams = struct(...
    "FcWeights",fcWeights,...
    "FcBias",fcBias);
%% 定义损失函数
numIterations = 10000;
miniBatchSize = 180;
learningRate = 6e-5;
gradDecay = 0.9;
gradDecaySq = 0.99;
executionEnvironment = "auto";
​
if canUseGPU
    gpu = gpuDevice;
    disp(gpu.Name + " GPU detected and available for training.")
end
​
trailingAvgSubnet = [];
trailingAvgSqSubnet = [];
trailingAvgParams = [];
trailingAvgSqParams = [];
​
monitor = trainingProgressMonitor(Metrics="Loss",XLabel="Iteration",Info="ExecutionEnvironment");
if canUseGPU
    updateInfo(monitor,ExecutionEnvironment=gpu.Name + " GPU")
else
    updateInfo(monitor,ExecutionEnvironment="CPU")
end
​
start = tic;
iteration = 0;
%% 使用自定义训练循环训练模型。循环遍历训练数据并在每次迭代时更新网络参数。
while iteration < numIterations && ~monitor.Stop
​
    iteration = iteration + 1;
​
    
    [X1,X2,pairLabels] = getTwinBatch(imdsTrain,miniBatchSize);
​
    
    X1 = dlarray(X1,"SSCB");
    X2 = dlarray(X2,"SSCB");
​
  
    if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
        X1 = gpuArray(X1);
        X2 = gpuArray(X2);
    end
​
    % 使用dlfeval和modelLoss评估模型损失和梯度
    [loss,gradientsSubnet,gradientsParams] = dlfeval(@modelLoss,net,fcParams,X1,X2,pairLabels);
​
    % 更新孪生神经网络参数
    [net,trailingAvgSubnet,trailingAvgSqSubnet] = adamupdate(net,gradientsSubnet, ...
        trailingAvgSubnet,trailingAvgSqSubnet,iteration,learningRate,gradDecay,gradDecaySq);
​
    % 更新全连接层参数.
    [fcParams,trailingAvgParams,trailingAvgSqParams] = adamupdate(fcParams,gradientsParams, ...
        trailingAvgParams,trailingAvgSqParams,iteration,learningRate,gradDecay,gradDecaySq);
​
    recordMetrics(monitor,iteration,Loss=loss);
    monitor.Progress = 100 * iteration/numIterations;
​
end
%% 读取测试集
dataFolderTest = "D:\S\孪生神经网络\images_evaluation" %% 更换路径
imdsTest = imageDatastore(dataFolderTest, ...
    IncludeSubfolders=true, ...
    LabelSource="none");
​
files = imdsTest.Files;
parts = split(files,filesep);
labels = join(parts(:,(end-2):(end-1)),"_");
imdsTest.Labels = categorical(labels);
​
numClasses = numel(unique(imdsTest.Labels));
accuracy = zeros(1,5);
accuracyBatchSize = 150;
​
for i = 1:5
    
    [X1,X2,pairLabelsAcc] = getTwinBatch(imdsTest,accuracyBatchSize);
​
   
    X1 = dlarray(X1,"SSCB");
    X2 = dlarray(X2,"SSCB");
​
   
    if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
        X1 = gpuArray(X1);
        X2 = gpuArray(X2);
    end
​
  
    Y = predictTwin(net,fcParams,X1,X2);
​
    Y = gather(extractdata(Y));
    Y = round(Y);
​
​
    accuracy(i) = sum(Y == pairLabelsAcc)/accuracyBatchSize;
end
​
averageAccuracy = mean(accuracy)*100;
​
testBatchSize = 10;
​
[XTest1,XTest2,pairLabelsTest] = getTwinBatch(imdsTest,testBatchSize);
​
XTest1 = dlarray(XTest1,"SSCB");
XTest2 = dlarray(XTest2,"SSCB");
​
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    XTest1 = gpuArray(XTest1);
    XTest2 = gpuArray(XTest2);
end
YScore = predictTwin(net,fcParams,XTest1,XTest2);
YScore = gather(extractdata(YScore));
YPred = round(YScore);
XTest1 = extractdata(XTest1);
XTest2 = extractdata(XTest2);
f = figure;
tiledlayout(2,5);
f.Position(3) = 2*f.Position(3);
​
predLabels = categorical(YPred,[0 1],["dissimilar" "similar"]);
targetLabels = categorical(pairLabelsTest,[0 1],["dissimilar","similar"]);
%% 用预测的标签和预测的分数绘制图像
for i = 1:numel(pairLabelsTest)
    nexttile
    imshow([XTest1(:,:,:,i) XTest2(:,:,:,i)]);
​
    title( ...
        "Target: " + string(targetLabels(i)) + newline + ...
        "Predicted: " + string(predLabels(i)) + newline + ...
        "Score: " + YScore(i))
end

仿真结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1366540.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

3.9 EXERCISES

矩阵加法需要两个输入矩阵A和B&#xff0c;并产生一个输出矩阵C。输出矩阵C的每个元素都是输入矩阵A和B的相应元素的总和&#xff0c;即C[i][j] A[i][j] B[i][j]。为了简单起见&#xff0c;我们将只处理元素为单精度浮点数的平方矩阵。编写一个矩阵加法内核和主机stub函数&am…

用红葡萄酿造的白葡萄酒是怎样的?

“由黑变白”这是“黑葡萄”的直译&#xff0c;代表一种由深蓝到黑葡萄制成的白葡萄酒&#xff0c;这种酿酒方式起源于法国&#xff0c;黑皮诺和莫尼尔的红葡萄一直被加工成白葡萄酒&#xff0c;作为香槟的基础。这是可能的&#xff0c;因为红色浆果通常果肉较轻。红色素&#…

TypeScript 从入门到进阶之基础篇(八)函数篇

系列文章目录 TypeScript 从入门到进阶系列 TypeScript 从入门到进阶之基础篇(一) ts基础类型篇TypeScript 从入门到进阶之基础篇(二) ts进阶类型篇TypeScript 从入门到进阶之基础篇(三) 元组类型篇TypeScript 从入门到进阶之基础篇(四) symbol类型篇TypeScript 从入门到进阶…

Java课程设计团队博客 —— 基于网页的时间管理系统

博客目录 1.项目简介2.项目采用的技术3.功能需求分析4.项目亮点5.主要功能截图6.Git地址7.总结 Java团队博客分工 姓名职务负责模块孙岚组长 资源文件路径和tomcat服务器的相关配置。 前端的页面设计与逻辑实现的代码编写。 Servlet前后端数据交互的编写。 用户登录和断开连接…

独立式键盘控制的4级变速流水灯

#include<reg51.h> // 包含51单片机寄存器定义的头文件 unsigned char speed; //储存流水灯的流动速度 sbit S1P1^4; //位定义S1为P1.4 sbit S2P1^5; //位定义S2为P1.5 sbit S3P1^6; //位定义S3为P1.6 sbit S4P1^7; //位…

熟悉HBase常用操作

1. 用Hadoop提供的HBase Shell命令完成以下任务 (1)列出HBase所有表的相关信息,如表名、创建时间等。 启动HBase: cd /usr/local/hbase bin/start-hbase.sh bin/hbase shell列出HBase所有表的信息: hbase(main):001:0> list(2)在终端输出指定表的所有记录数据。 …

鸿蒙原生应用/元服务开发-消息通知整体说明

应用/元服务可以通过通知接口发送通知消息&#xff0c;终端用户可以通过通知栏查看通知内容&#xff0c;也可以点击通知来打开应用。 通知常见的使用场景&#xff1a;显示接收到的短消息、即时消息等。显示应用的推送消息&#xff0c;如广告、版本更新等。显示当前正在进行的事…

CCF模拟题 202309-1 坐标变换(其一)

问题描述 试题编号&#xff1a; 202309-1 试题名称&#xff1a; 坐标变换&#xff08;其一&#xff09; 时间限制&#xff1a; 1.0s 内存限制&#xff1a; 512.0MB 问题描述&#xff1a; 对于平面直角坐标系上的坐标&#xff08;x,y&#xff09;&#xff0c;小P定义了一个包含…

代码随想Day60 | 84.柱状图中最大的矩形

84.柱状图中最大的矩形 这道题和接雨水遥相呼应&#xff0c;接雨水是求外部凹槽&#xff0c;这道题是求内部面积&#xff0c;这道题的整体思路是某一个元素&#xff0c;找到其左边的第一个小于该数的位置&#xff0c;右边的第一个小于该数的位置&#xff0c;然后以当前索引的高…

CodeWave智能开发平台--03--目标:应用创建--07供应商数据表格02

摘要 本文是网易数帆CodeWave智能开发平台系列的第10篇&#xff0c;主要介绍了基于CodeWave平台文档的新手入门进行学习&#xff0c;实现一个完整的应用&#xff0c;本文主要完成07供应商数据表格下 CodeWave智能开发平台的10次接触 CodeWave参考资源 网易数帆CodeWave开发…

202312 青少年软件编程等级考试Scratch一级真题(电子学会)

2023年12月 青少年软件编程等级考试Scratch一级真题&#xff08;电子学会&#xff09; 试卷总分数&#xff1a;100分 试卷及格分&#xff1a;60 分 考试时长&#xff1a;60 分钟 第 1 题 单选题 观察下列每个圆形中的四个数&#xff0c;找出规律&#xff0c;在括…

鸿蒙原生应用/元服务开发-短时任务

概述 应用退至后台一小段时间后&#xff0c;应用进程会被挂起&#xff0c;无法执行对应的任务。如果应用在后台仍需要执行耗时不长的任务&#xff0c;如状态保存等&#xff0c;可以通过本文申请短时任务&#xff0c;扩展应用在后台的运行时间。 约束与限制 申请时机&#xf…

使用Vite创建vue3工程

介绍 使用Vite构建工具&#xff0c;创建Vue3工程 示例 第一步&#xff1a;执行创建项目的命令&#xff0c;study-front-vue3是项目名称 npm init vite-app study-front-vue3第二步&#xff1a;进入项目文件夹&#xff0c;执行命令&#xff0c;安装模块 cd study-front-vue…

如何将ElementUI组件库中的时间控件迁移到帆软报表中

需求:需要将ElementUI组件库中的时间控件迁移到帆软报表中,具体为普通报表的参数面板中,填报报表的组件中,决策报表的组件与参数面板中。 这三个场景中分别需要用到帆软报表二开平台的ParameterWidgetOptionProvider,FormWidgetOptionProvider,CellWidgetOptionProvider开…

文献阅读:Sparse Low-rank Adaptation of Pre-trained Language Models

文献阅读&#xff1a;Sparse Low-rank Adaptation of Pre-trained Language Models 1. 文章简介2. 具体方法介绍 1. SoRA具体结构2. 阈值选取考察 3. 实验 & 结论 1. 基础实验 1. 实验设置2. 结果分析 2. 细节讨论 1. 稀疏度分析2. rank分析3. 参数位置分析4. 效率考察 4.…

什么是检索增强生成 (RAG)

什么是 RAG RAG&#xff0c;即检索增强生成&#xff0c;是一种将预训练的大型语言模型的功能与外部数据源相结合的技术。这种方法将 GPT-3 或 GPT-4 等 LLM 的生成能力与专用数据搜索机制的精确性相结合&#xff0c;从而形成一个可以提供细微响应的系统。 本文更详细地探讨了…

JavaWeb——Spring事务管理

六、Spring事务管理 1. 注解 注解&#xff1a;Transactional 位置&#xff1a;业务&#xff08;service&#xff09;层的方法上、类上、接口上——一般在执行多条增删改方法上加 作用&#xff1a;将当前方法交给spring进行事务管理&#xff0c;方法执行前&#xff0c;开启事…

编程语言的语法糖,你了解多少?

什么是语法糖 语法糖是一种编程语言的特性&#xff0c;通常是一些简单的语法结构或函数调用&#xff0c;它可以通过隐藏底层的复杂性&#xff0c;并提供更高级别的抽象&#xff0c;从而使代码更加简洁、易读和易于理解&#xff0c;但它并不会改变代码的执行方式。 为什么需要语…

(aiohttp-asyncio-FFmpeg-Docker-SRS)实现异步摄像头转码服务器

1. 背景介绍 在先前的博客文章中&#xff0c;我们已经搭建了一个基于SRS的流媒体服务器。现在&#xff0c;我们希望通过Web接口来控制这个服务器的行为&#xff0c;特别是对于正在进行的 RTSP 转码任务的管理。这将使我们能够在不停止整个服务器的情况下&#xff0c;动态地启动…

OPPO Find X7 Ultra 发布,搭载双潜望四主摄摄影技术

2024年1月8日&#xff0c;深圳——OPPO发布旗舰Find X7 Ultra&#xff0c;定义移动影像的终极形态。Find X7 Ultra 首创的双潜望四主摄构成哈苏大师镜头群&#xff0c;以六个光学品质焦段提供目前手机最强大、品质最高的多摄变焦能力。首次搭载专为超光影图像引擎定制的一英寸传…