深度学习之---迁移学习

news2024/11/30 10:35:41

目录

一、什么是迁移学习

二、为什么需要迁移学习?

1. 大数据与少标注的矛盾:

2. 大数据与弱计算的矛盾:

3. 普适化模型与个性化需求的矛盾:

4. 特定应用(如冷启动)的需求。

三、迁移学习的基本问题有哪些?

四、 迁移学习有哪些常用概念?

​编辑 五、迁移学习与传统机器学习有什么区别?

六、迁移学习的核心及度量准则? 


一、什么是迁移学习

        迁移学习(Transfer Learning)是一种机器学习方法,就是把为任务 A 开发 的模型作为初始点,重新使用在为任务 B 开发模型的过程中。迁移学习是通过 从已学习的相关任务中转移知识来改进学习的新任务,虽然大多数机器学习算 法都是为了解决单个任务而设计的,但是促进迁移学习的算法的开发是机器学 习社区持续关注的话题。 迁移学习对人类来说很常见,例如,我们可能会发现 学习识别苹果可能有助于识别梨,或者学习弹奏电子琴可能有助于学习钢琴。

        找到目标问题的相似性,迁移学习任务就是从相似性出发,将旧领域 (domain)学习过的模型应用在新领域上

二、为什么需要迁移学习?

1. 大数据与少标注的矛盾:

        虽然有大量的数据,但往往都是没有标注的, 无法训练机器学习模型。人工进行数据标定太耗时。

2. 大数据与弱计算的矛盾:

        普通人无法拥有庞大的数据量与计算资源。因 此需要借助于模型的迁移。

3. 普适化模型与个性化需求的矛盾:

        即使是在同一个任务上,一个模型也 往往难以满足每个人的个性化需求,比如特定的隐私设置。这就需要在 不同人之间做模型的适配。

4. 特定应用(如冷启动)的需求。

三、迁移学习的基本问题有哪些?

基本问题主要有3个:

  • How to transfer: 如何进行迁移学习?(设计迁移方法)
  • What to transfer: 给定一个目标领域,如何找到相对应的源领域, 然后进行迁移?(源领域选择)
  • When to transfer: 什么时候可以进行迁移,什么时候不可以?(避 免负迁移)

四、 迁移学习有哪些常用概念?

基本定义

域(Domain):数据特征和特征分布组成,是学习的主体

        源域 (Source domain):已有知识的域

        目标域 (Target domain):要进行学习的域

任务 (Task):由目标函数和学习结果组成,是学习的结果

按特征空间分类

 按迁移情景分类

        归纳式迁移学习(Inductive TL):源域和目标域的学习任务 不同

        直推式迁移学习(Transductive TL):源域和目标域不同,学 习任务相同

        无监督迁移学习(Unsupervised TL):源域和目标域均没有 标签 按迁移方法分类

        基于样本的迁移 (Instance based TL):通过权重重用源域和 目标域的样例进行迁移

        基于样本的迁移学习方法 (Instance based Transfer Learning) 根据一定的权重生成规则,对数据样本进行重用, 来进行迁移学习。下图形象地表示了基于样本迁移方法的思想 源域中存在不同种类的动物,如狗、鸟、猫等,目标域只有狗 这一种类别。在迁移时,为了最大限度地和目标域相似,我们 可以人为地提高源域中属于狗这个类别的样本权重。

        基于特征的迁移 (Feature based TL):将源域和目标域的特 征变换到相同空间  

        基于特征的迁移方法 (Feature based Transfer Learning) 是 指将通过特征变换的方式互相迁移,来减少源域和目标域之间的 差距;或者将源域和目标域的数据特征变换到统一特征空间中, 然后利用传统的机器学习方法进行分类识别。根据特征的同构 和异构性,又可以分为同构和异构迁移学习。下图很形象地表示 了两种基于特 征的迁移学习方法。

        基于模型的迁移 (Parameter based TL):利用源域和目标域的参数共享 模型

        基于模型的迁移方法 (Parameter/Model based Transfer Learning) 是指 从源域和目标域中找到他们之间共享的参数信息,以实现迁移的方法。这种迁移 方式要求的假设条件是: 源域中的数据与目标域中的数据可以共享一些模型的 参数。下图形象地表示了基于模型的迁移学习方法的基本思想。  

        基于关系的迁移 (Relation based TL):利用源域中的逻辑网络关系进行迁移

        基于关系的迁移学习方法 (Relation Based Transfer Learning) 与上述三种 方法具有截然不同的思路。这种方法比较关注源域和目标域的样本之间的关 系。下图形象地表示了不 同领域之间相似的关系。  

 五、迁移学习与传统机器学习有什么区别?

六、迁移学习的核心及度量准则? 

        迁移学习的总体思路可以概括为:开发算法来最大限度地利用有标注的领 域的知识,来辅助目标领域的知识获取和学习。

        迁移学习的核心是:找到源领域和目标领域之间的相似性,并加以合理利 用。这种相似性非常普遍。比如,不同人的身体构造是相似的;自行车和摩托 车的骑行方式是相似的;国际象棋和中国象棋是相似的;羽毛球和网球的打球 方式是相似的。这种相似性也可以理解为不变量。以不变应万变,才能立于不 败之地。

        有了这种相似性后,下一步工作就是, 如何度量和利用这种相似性。度量 工作的目标有两点:一是很好地度量两个领域的相似性,不仅定性地告诉我们 它们是否相似,更定量地给出相似程度。二是以度量为准则,通过我们所要采 用的学习手段,增大两个领域之间的相似性,从而完成迁移学习。

        一句话总结: 相似性是核心,度量准则是重要手段。

 七、迁移学习三步走

        1加载预训练模型(inceptionnet-v3)(主干网络,backbone),提取所 有图片数据集的特征(特征向量2048维度)。(调用别人训练好的模型,因为 他们的模型泛化能力强,不用自己创建训练模型)

        2用特征向量训练自己的后端网络模型,(后端用自己创建dense后端模 型,保存dense后端6个模型)

        3调用最后一个模型来显示测试集16张图片预测结果

 第一步

import os.path
import numpy as np
# # import tensorflow.compat.v1 as tf
# import tensorflow._api.v2.compat.v1 as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.python.platform import gfile
MODEL_FILE = 'model/tensorflow_inception_graph.pb'
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
INPUT_IMAGE = 'data/agriculture'
OUTPUT_VEC = 'data/bottleneck'
def load_google_model(path):
   with gfile.FastGFile(path, "rb") as f:
       graph_def = tf.GraphDef()
       graph_def.ParseFromString(f.read())
       jpeg_data_tensor, bottleneck_tensor = \
           tf.import_graph_def(graph_def, return_elements=
[JPEG_DATA_TENSOR_NAME, BOTTLENECK_TENSOR_NAME])
   return jpeg_data_tensor, bottleneck_tensor
def get_random_cached_bottlenecks(sess, path, 
jpeg_data_tensor, bottleneck_tensor):
   for _, class_name in enumerate(os.listdir(path)):
       sub_path = os.path.join(path, class_name)
       for img in os.listdir(sub_path):
           img_path=os.path.join(sub_path,img)
           image_data = gfile.FastGFile(img_path, 
'rb').read()
           bottleneck_values = sess.run(bottleneck_tensor, 
feed_dict={jpeg_data_tensor: image_data})第二步骤:
           bottleneck_values = np.squeeze(bottleneck_values)
           sub_dir_path = os.path.join(OUTPUT_VEC, 
class_name)
           if not os.path.exists(sub_dir_path):
               os.makedirs(sub_dir_path)
           new_image_path=os.path.join(sub_dir_path, 
img)+".txt"
           if not os.path.exists(new_image_path):
               bottleneck_string = ','.join(str(x) for x in 
bottleneck_values)
               with open(new_image_path, 'w') as 
bottleneck_file:
                   bottleneck_file.write(bottleneck_string)
           else:
               break
if __name__ == '__main__':
   jpeg_data_tensor, bottleneck_tensor = 
load_google_model(MODEL_FILE)
   with tf.Session() as sess:
       tf.global_variables_initializer().run()
       get_random_cached_bottlenecks(sess, INPUT_IMAGE, 
jpeg_data_tensor, bottleneck_tensor)

第二步 

import os
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
# import tensorflow._api.v2.compat.v1 as tf
from sklearn.model_selection import train_test_split
IN_DIR = 'data/bottleneck'
OUT_DIR = 'runs'
checkpoint_every = 100 #every 每,
def get_data(path):   x_vecs=[]
   y_labels=[]
   for i, j in enumerate(os.listdir(path)): #enumerate代表枚
举,把元素一个个列举出来。
       sub_path = os.path.join(path, j)
       for vec in os.listdir(sub_path):
           vec_path = os.path.join(sub_path, vec)
           with open(vec_path, 'r') as f:
               vec_str = f.read()
           vec_values = [float(x) for x in 
vec_str.split(',')]
           x_vecs.append(vec_values)
           y_labels.append(np.eye(5)[i])
   return np.array(x_vecs), np.array(y_labels)
image_data,labels=get_data(IN_DIR)
train_data,test_data,train_label,test_label=train_test_split(
image_data,labels,train_size=0.8,shuffle=True)
test_data,val_data,test_label,val_label=train_test_split(test
_data,test_label,train_size=0.5)
if __name__ == '__main__':#入口
   X = tf.placeholder(tf.float32, [None, 2048])
   Y = tf.placeholder(tf.float32, [None, 5])
   with tf.name_scope('final_training_ops'):
       logits = tf.layers.dense(X, 5)
   with tf.name_scope('loss'):
       cross_entropy_mean = 
tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits
=logits, labels=Y))
   with tf.name_scope('Optimizer'):
       train_step = 
tf.train.GradientDescentOptimizer(0.001).minimize(cross_entro
py_mean)
   with tf.name_scope('evaluation'):       correct_prediction = tf.equal(tf.argmax(logits, 1), 
tf.argmax(Y, 1))
       evaluation_step = 
tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
   with tf.Session() as sess:
       sess.run(tf.global_variables_initializer())
       # 保存检查点
       checkpoint_dir = 
os.path.abspath(os.path.join(OUT_DIR, 'checkpoints'))
       checkpoint_prefix = os.path.join(checkpoint_dir, 
'model')
       if not os.path.exists(checkpoint_dir):
           os.makedirs(checkpoint_dir)
       saver = tf.train.Saver(tf.global_variables(), 
max_to_keep=6)
       for epoch in range(1001):
           batch_size = 64
           start = 0
           num_step = len(train_data) // batch_size
           for i in range(num_step):
               xb = train_data[start : start + batch_size]
               yb = train_label[start : start + batch_size]
               start = start + batch_size
               _ = sess.run([train_step], feed_dict={X: xb, 
Y: yb})
           if epoch % 100 == 0:
               validation_accuracy = 
sess.run(evaluation_step, feed_dict={X: val_data, Y: 
val_label})
               print("[epoch {}]验证集准确率
{:.3f}%".format(epoch, validation_accuracy * 100))
               path = saver.save(sess, checkpoint_prefix, 
global_step=epoch)
               print('Saved model checkpoint to 
{}\n'.format(path))
       test_accuracy = sess.run(evaluation_step, feed_dict=
{X: test_data, Y: test_label})第三步骤:
       print("测试集准确率{:.3f}%".format(test_accuracy * 
100))

 第三步

import numpy as np
import cv2
import os
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.python.platform import gfile
import matplotlib.pyplot as plt
MODEL_FILE = 'model/tensorflow_inception_graph.pb'
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
CHECKPOINT_DIR = 'runs/checkpoints'
test_dir = 'data/test/agriculture'
def load_google_model(path):
   with gfile.FastGFile(path, "rb") as f:
       graph_def = tf.GraphDef()
       graph_def.ParseFromString(f.read())
       jpeg_data_tensor, bottleneck_tensor = \
           tf.import_graph_def(graph_def, return_elements=
[JPEG_DATA_TENSOR_NAME, BOTTLENECK_TENSOR_NAME])
   return jpeg_data_tensor, bottleneck_tensor
def create_test_featrue(sess, test_dir, jpeg_data_tensor, 
bottleneck_tensor):
   test_data, test_feature, test_labels = [], [], []
   for i in os.listdir(test_dir):
       img = cv2.imread(os.path.join(test_dir, i))
       img = cv2.resize(img, (256, 256))
       img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
       test_data.append(img)
       img_data = gfile.FastGFile(os.path.join(test_dir, i), 
"rb").read()       feature = sess.run(bottleneck_tensor, feed_dict=
{jpeg_data_tensor: img_data})
       test_feature.append(feature)
       test_labels.append(i.split("_")[0])
   return test_data, np.reshape(test_feature, (-1, 2048)), 
np.array(test_labels)
def show_img(test_data, pre_labels, test_labels):
   _, axs = plt.subplots(4, 4)
   for i, axi in enumerate(axs.flat):
       axi.imshow(test_data[i])
       print(pre_labels[i], test_labels[i])
       axi.set_xlabel(xlabel=pre_labels[i], color="black" if 
pre_labels[i] == test_labels[i] else "red")
       axi.set(xticks=[], yticks=[])
   plt.savefig(os.path.join("data/test/", 'agriculture' + 
".jpg"))
   plt.show()
if __name__ == '__main__':
   jpeg_data_tensor, bottleneck_tensor = 
load_google_model(MODEL_FILE)
   class_names = os.listdir("data/agriculture")
   num_class= len(class_names)
   x_transfer = tf.placeholder(tf.float32, [None, 2048])
   y_transfer = tf.placeholder(tf.int64, [None, num_class]) 
# [None,5]
   logits = tf.layers.dense(x_transfer, num_class)
   saver = tf.train.Saver()
   with tf.Session() as sess:
       sess.run(tf.global_variables_initializer())
       print(CHECKPOINT_DIR)
       last_point = 
tf.train.latest_checkpoint(CHECKPOINT_DIR)
       print(last_point)
       saver.restore(sess, last_point)三个步骤代码组合起来实现迁移学习:
       test_data, test_feature, test_labels = \
           create_test_featrue(sess, test_dir, 
jpeg_data_tensor, bottleneck_tensor)
       pred = sess.run(tf.argmax(logits, 1), {x_transfer: 
test_feature})
       show_img(test_data, [class_names[i] for i in pred], 
test_labels)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1824657.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

utm投影

一 概述 UTM (Universal Transverse Mercator)坐标系是由美国军方在1947提出的。虽然我们仍然将其看作与“高斯-克吕格”相似的坐标系统,但实际上UTM采用了网格的分带(或分块)。除在美国本土采用Clarke 1866椭球体以外&#xff0c…

聚观早报 | 深蓝G318价格发布;比亚迪方程豹豹3官图发布

聚观早报每日整理最值得关注的行业重点事件,帮助大家及时了解最新行业动态,每日读报,就读聚观365资讯简报。 整理丨Cutie 6月15日消息 深蓝G318价格发布 比亚迪方程豹豹3官图发布 夸克App升级高考AI搜索 iOS 18卫星通信实测 Redmi K70…

AI模型部署:Triton+Marker部署PDF转markdown服务

前言 在知识库场景下往往需要对PDF文档进行解析,从而能够通过RAG完成知识检索,本文介绍开源的PDF转Markdown工具marker,并借助Triton Inference Server将其服务化。 内容摘要 知识库场景下pdf解析简述Marker简介和安装Marker快速开始使用Tr…

Rust 实战丨绘制曼德博集

曼德博集 曼德博集其实是一个“没什么用”的发现。 曼德博集(Mandelbrot Set)是一种在复平面上形成独特且复杂图案的点的集合。这个集合是以数学家本华曼德博(Benoit Mandelbrot)的名字命名的,他在研究复杂结构和混沌…

LED显示屏色差处理方法

LED显示屏以其高亮度、低功耗和长寿命等优点,在广告、信息发布和舞台背景等领域得到广泛应用。然而,由于生产批次的不同,LED显示屏在亮度和色度上可能存在差异,影响显示效果。本文将探讨如何通过逐点校正技术来解决这一问题。 逐点…

【智源大会2024】(一)智源技术专题

智源的全家桶: 微调数据相关: 1.千万级数据集: BAAI创建了首个千万级别的高质量开源指令微调数据集。 2.模型性能与数据质量: 强调了模型性能与数据质量之间的高度相关性。 3.技术亮点: 使用了高质量的指令数据筛选与合成技术。这些技术显著提升了模型…

【ARM Cache 及 MMU 系列文章 6.5 -- 如何进行 Cache miss 统计?】

请阅读【ARM Cache 及 MMU/MPU 系列文章专栏导读】 及【嵌入式开发学习必备专栏】 文章目录 ARM Cache Miss 统计Cache 多层架构简介Cache 未命中的类型Cache 未命中统计Cache miss 统计代码实现Cache Miss 统计意义ARM Cache Miss 统计 在ARMv8/v9架构中,缓存未命中(Cache …

IIC通信总线

文章目录 1. IIC总线协议1. IIC简介2. IIC时序1. 数据有效性2. 起始信号和终止信号3. 数据格式4. 应答和非应答信号5. 时钟同步6. 写数据和读数据 2. AT24C023. AT24C02读写时序4. AT24C02配置步骤5. 代码部分1. IIC基本信号2. AT24C02驱动代码3. 实验结果分析 1. IIC总线协议 …

MAC系统下安装VUE

下载node.js 点击链接 选择图片中的稳定版本 安装node.js 打开终端,输入 node -v 和 npm -v 显示如上信息表示安装成功 安装vue脚手架🔧 sudo npm install -g vue/cli查看vue版本 vue -V6. 启动项目 1 采用 图形页面方式 控制台输入&#xff…

2024最新D卷 华为OD统一考试题库清单(按算法分类),如果你时间紧迫,就按这个刷

目录 专栏导读华为OD机试算法题太多了,知识点繁杂,如何刷题更有效率呢? 一、逻辑分析二、数据结构1、线性表① 数组② 双指针 2、map与list3、队列4、链表5、栈6、滑动窗口7、二叉树8、并查集9、矩阵 三、算法1、基础算法① 贪心思维② 二分查…

【c++进阶(三)】STL之vector的介绍和使用

💓博主CSDN主页::Am心若依旧💓 ⏩专栏分类c从入门到精通⏪ 🚚代码仓库:青酒余成🚚 🌹关注我🫵带你学习更多c   🔝🔝 vector的介绍 1.vector表示的是可变序列大小的容器 2、vector…

MySQL 日志(一)

本篇主要介绍MySQL日志的相关内容。 目录 一、日志简介 常用日志 一般查询日志和慢查询日志的输出形式 日志表 二、一般查询日志 三、慢查询日志 四、错误日志 一、日志简介 常用日志 在MySQL中常用的日志主要有如下几种: 这些日志通常情况下都是关闭的&a…

一文读懂Java线程池之线程复用原理

什么是线程复用 在Java中,我们正常创建线程执行任务,一般都是一条线程绑定一个Runnable执行任务。而Runnable实际只是一个普通接口,真正要执行,则还是利用了Thread类的run方法。这个rurn方法由native本地方法start0进行调用。我们看Thread类的run方法实现 /* What will be…

Mysql8.0.31开启mysqlbinlog

1、查看mysqlbinlog是否已经开启 show variables like %log_bin%; log_bin: ON是OFF否已经开启binlog log_bin_basename: binlog所在路径的文件开头前缀名 lob_bin_index: binlog文件的索引文件所在路径 2、若log_binOFF,则开启log_bin -- 退出mysql client ex…

open-amv开发环境搭建

open-amv是基于rv1103主控芯片的视觉开发板子 1.板子使用 板子使用type c作为调试口,同时供电,请在电脑上下载adb,当板子通过tpye c与电脑连接后,执行命令adb shell就会进入到板子的linux系统命令行。 2.编译环境 2.1 搭建doc…

【网络编程】优雅断开套接字连接

Linux的close函数和Windows的closesocket函数意味着完全断开连接。完全断开不仅指无法传输数据,而且也不能接收数据。 2台主机正在进行双向通信,主机A发送完最后的数据后,调用close函数断开了连接,之后主机A无法再接收主机B传输的…

超全Midjourney自学教程,怒码1万3千字!这是我见过最良心的教程啦!

前段时间,后台有网友私信我,说想跟我一起学AI~当时一边开心一边惶恐,满足于被人看到自己的努力、又担心自己是不是教不好别人,毕竟我自己也是业余时间边学边发的那种~ 不过,我还是会继续搬运或整理一些我认为值得记录…

C++100行超简单系统

非常好用&#xff0c;小白也可以自己修改 先来看图片&#xff1a; 用法附在代码里了&#xff01; #include <bits/stdc.h> #include <windows.h>using namespace std;struct users {string name;string num; bool f; } u[10000];int now_users 0; /*当前用户数*…

MyBatis使用 PageHelper 分页查询插件的详细配置

1. MyBatis使用 PageHelper 分页查询插件的详细配置 文章目录 1. MyBatis使用 PageHelper 分页查询插件的详细配置2. 准备工作3. 使用传统的 limit 关键字进行分页4. PageHelper 插件&#xff08;配置步骤&#xff09;4.1 第一步&#xff1a;引入依赖4.2 第二步&#xff1a;在m…

LDR6020显示器应用:革新连接体验,引领未来显示技术

一、引言 随着科技的飞速发展&#xff0c;显示器作为信息展示的重要载体&#xff0c;其性能和应用场景不断得到拓展。特别是在办公、娱乐以及物联网等领域&#xff0c;用户对显示器的需求越来越多样化。在这一背景下&#xff0c;LDR6020显示器的出现&#xff0c;以其卓越的性能…