迁移学习案例-python代码

news2024/10/1 5:06:20

在这里插入图片描述

大白话

迁移学习就是用不太相同但又有一些联系的A和B数据,训练同一个网络。比如,先用A数据训练一下网络,然后再用B数据训练一下网络,那么就说最后的模型是从A迁移到B的。

迁移学习的具体形式是多种多样的,比如先用A训练好一个网络,然后复制这个网络的某几个层的参数到一个新的网络作为初始化的参数,然后用B数据去训练这个新网络。又或者,面对中文翻译的问题,中文翻译成英文和中文翻译成火星文,前几层在提取特征,可以共享参数层,后面几层由于任务不同就可以各自私有训练。

案例来源:李宏毅课程-机器学习-迁移学习

A数据:是源数据,量大效果好,并且有标签。
在这里插入图片描述
B数据:量少,没标签。
在这里插入图片描述

目的:希望用A数据先训练网络提取到关键特征,然后预测B数据的标签。但是把他们当作两个任务效果不佳,于是以一种迁移的方法解决--域对抗(先用A训练好模型,再直接用B测试,这样效果不佳;而是希望以一种“迁移”的方法,把A数据的知识拿到B上面用)

直接上代码

import matplotlib.pyplot as plt

def no_axis_show(img, title='', cmap=None):
    # imshow, 缩放模式为nearest。
    fig = plt.imshow(img, interpolation='nearest', cmap=cmap)
    # 不要显示axis。
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.title(title)

titles = ['horse', 'bed', 'clock', 'apple', 'cat', 'plane', 'television', 'dog', 'dolphin', 'spider']
plt.figure(figsize=(18, 18))
for i in range(10):
    plt.subplot(1, 10, i+1)
    fig = no_axis_show(plt.imread(f'work/real_or_drawing/train_data/{i}/{500*i}.bmp'), title=titles[i])
plt.figure(figsize=(18, 18))
for i in range(10):
    plt.subplot(1, 10, i+1)
    fig = no_axis_show(plt.imread(f'work/real_or_drawing/test_data/0/' + str(i).rjust(5, '0') + '.bmp'))
import cv2
import matplotlib.pyplot as plt
titles = ['horse', 'bed', 'clock', 'apple', 'cat', 'plane', 'television', 'dog', 'dolphin', 'spider']
plt.figure(figsize=(18, 18))

original_img = plt.imread(f'work/real_or_drawing/train_data/0/0.bmp')
plt.subplot(1, 5, 1)
no_axis_show(original_img, title='original')

gray_img = cv2.cvtColor(original_img, cv2.COLOR_RGB2GRAY)
plt.subplot(1, 5, 2)
no_axis_show(gray_img, title='gray scale', cmap='gray')

gray_img = cv2.cvtColor(original_img, cv2.COLOR_RGB2GRAY)
plt.subplot(1, 5, 2)
no_axis_show(gray_img, title='gray scale', cmap='gray')

canny_50100 = cv2.Canny(gray_img, 50, 100)
plt.subplot(1, 5, 3)
no_axis_show(canny_50100, title='Canny(50, 100)', cmap='gray')

canny_150200 = cv2.Canny(gray_img, 150, 200)
plt.subplot(1, 5, 4)
no_axis_show(canny_150200, title='Canny(150, 200)', cmap='gray')

canny_250300 = cv2.Canny(gray_img, 250, 300)
plt.subplot(1, 5, 5)
no_axis_show(canny_250300, title='Canny(250, 300)', cmap='gray')
import cv2
import numpy as np
import paddle

import paddle.optimizer as optim
from paddle.io import DataLoader
from paddle.vision.datasets import DatasetFolder
from paddle.nn import Sequential, Conv2D, BatchNorm1D, BatchNorm2D, ReLU, MaxPool2D, Linear
from paddle.vision.transforms import Compose, Grayscale, Transpose, RandomHorizontalFlip, RandomRotation, Resize, ToTensor
class Canny(paddle.vision.transforms.transforms.BaseTransform):
    def __init__(self, low, high, keys=None):
        super(Canny, self).__init__(keys)
        self.low = low
        self.high = high

    def _apply_image(self, img):
        Canny = lambda img: cv2.Canny(np.array(img), self.low, self.high)
        return Canny(img)
source_transform = Compose([
    RandomHorizontalFlip(),
    RandomRotation(15),
    Grayscale(),
    Canny(low=170, high=300),
    # Transpose(),
    ToTensor()
    ])
target_transform = Compose([
    Grayscale(),
    Resize((32, 32)),
    RandomHorizontalFlip(),
    RandomRotation(15, fill=(0,)),
    ToTensor()
    ])

source_dataset = DatasetFolder('work/real_or_drawing/train_data', transform=source_transform)
target_dataset = DatasetFolder('work/real_or_drawing/test_data', transform=target_transform)

source_dataloader = DataLoader(source_dataset, batch_size=32, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(target_dataset, batch_size=128, shuffle=False)
class FeatureExtractor(paddle.nn.Layer):

    def __init__(self):
        super(FeatureExtractor, self).__init__()

        self.conv = Sequential(
            Conv2D(1, 64, 3, 1, 1),
            BatchNorm2D(64),
            ReLU(),
            MaxPool2D(2),

            Conv2D(64, 128, 3, 1, 1),
            BatchNorm2D(128),
            ReLU(),
            MaxPool2D(2),

            Conv2D(128, 256, 3, 1, 1),
            BatchNorm2D(256),
            ReLU(),
            MaxPool2D(2),

            Conv2D(256, 256, 3, 1, 1),
            BatchNorm2D(256),
            ReLU(),
            MaxPool2D(2),

            Conv2D(256, 512, 3, 1, 1),
            BatchNorm2D(512),
            ReLU(),
            MaxPool2D(2)
        )
        
    def forward(self, x):
        x = self.conv(x).squeeze()
        return x

class LabelPredictor(paddle.nn.Layer):

    def __init__(self):
        super(LabelPredictor, self).__init__()

        self.layer = Sequential(
            Linear(512, 512),
            ReLU(),

            Linear(512, 512),
            ReLU(),

            Linear(512, 10),
        )

    def forward(self, h):
        c = self.layer(h)
        return c

class DomainClassifier(paddle.nn.Layer):

    def __init__(self):
        super(DomainClassifier, self).__init__()

        self.layer = Sequential(
            Linear(512, 512),
            BatchNorm1D(512),
            ReLU(),

            Linear(512, 512),
            BatchNorm1D(512),
            ReLU(),

            Linear(512, 512),
            BatchNorm1D(512),
            ReLU(),

            Linear(512, 512),
            BatchNorm1D(512),
            ReLU(),

            Linear(512, 1),
        )

    def forward(self, h):
        y = self.layer(h)
        return y
feature_extractor = FeatureExtractor()
label_predictor = LabelPredictor()
domain_classifier = DomainClassifier()

class_criterion = paddle.nn.loss.CrossEntropyLoss()
domain_criterion = paddle.nn.BCEWithLogitsLoss()

optimizer_F = optim.Adam(parameters=feature_extractor.parameters())
optimizer_C = optim.Adam(parameters=label_predictor.parameters())
optimizer_D = optim.Adam(parameters=domain_classifier.parameters())
def train_epoch(source_dataloader, target_dataloader, lamb):
    '''
      Args:
        source_dataloader: source data的dataloader
        target_dataloader: target data的dataloader
        lamb: 调控adversarial的loss系数。
    '''

    # D loss: Domain Classifier的loss
    # F loss: Feature Extrator & Label Predictor的loss
    # total_hit: 计算目前对了几笔 total_num: 目前经过了几笔
    running_D_loss, running_F_loss = 0.0, 0.0
    total_hit, total_num = 0.0, 0.0

    for i, ((source_data, source_label), (target_data, _)) in enumerate(zip(source_dataloader, target_dataloader)):

        # source_data = source_data.cuda()
        # source_label = source_label.cuda()
        # target_data = target_data.cuda()
        
        # 我们把source data和target data混在一起,否则batch_norm可能会算错 (两边的data的mean/var不太一样)
        mixed_data = paddle.concat([source_data, target_data], axis=0)
        domain_label = paddle.zeros([source_data.shape[0] + target_data.shape[0], 1])
        # 设定source data的label为1
        domain_label[:source_data.shape[0]] = 1

        # Step 1 : 训练Domain Classifier
        feature = feature_extractor(mixed_data)
        # 因为我们在Step 1不需要训练Feature Extractor,所以把feature detach避免loss backprop上去。
        domain_logits = domain_classifier(feature.detach())
        # print('domain_logits.shape:', domain_logits.shape, 'domain_label.shape:', domain_label.shape)
        loss = domain_criterion(domain_logits, domain_label)
        # running_D_loss+= loss.numpy()[0]
        running_D_loss+= loss.numpy()
        # print('loss:', loss)
        loss.backward()
        optimizer_D.step()

        # Step 2 : 训练Feature Extractor和Domain Classifier
        class_logits = label_predictor(feature[:source_data.shape[0]])
        domain_logits = domain_classifier(feature)
        # loss为原本的class CE - lamb * domain BCE,相减的原因同GAN中的Discriminator中的G loss。
        loss = class_criterion(class_logits, source_label) - lamb * domain_criterion(domain_logits, domain_label)
        # running_F_loss+= loss.numpy()[0]
        running_F_loss+= loss.numpy()
        loss.backward()
        optimizer_F.step()
        optimizer_C.step()

        optimizer_D.clear_grad()
        optimizer_F.clear_grad()
        optimizer_C.clear_grad()
        # print('class_logits.shape:', class_logits.shape, 'source_label.shape:', source_label.shape)
        # print('class_logits[0]:', class_logits[0], 'source_label[0]:', source_label[0])
        total_hit += np.sum((paddle.argmax(class_logits, axis=1) == source_label).numpy())
        total_num += source_data.shape[0]
        print(i, end='\r')

    return running_D_loss / (i+1), running_F_loss / (i+1), total_hit / total_num

# 训练200 epochs
for epoch in range(200):
    train_D_loss, train_F_loss, train_acc = train_epoch(source_dataloader, target_dataloader, lamb=0.1)

    paddle.save(feature_extractor.state_dict(), f'extractor_model.pdparams')
    paddle.save(label_predictor.state_dict(), f'predictor_model.pdparams')

    print('epoch {:>3d}: train D loss: {:6.4f}, train F loss: {:6.4f}, acc {:6.4f}'.format(epoch, train_D_loss, train_F_loss, train_acc))

训练结束,预测一波

result = []
label_predictor.eval()
feature_extractor.eval()
for i, (test_data, _) in enumerate(test_dataloader):
    test_data = test_data

    class_logits = label_predictor(feature_extractor(test_data))

    x = paddle.argmax(class_logits, axis=1).detach().numpy()
    result.append(x)

import pandas as pd
result = np.concatenate(result)

# Generate your submission
df = pd.DataFrame({'id': np.arange(0,len(result)), 'label': result})
df.to_csv('work/DaNN_submission.csv',index=False)

训练比较慢,还得是把代码转到cuda上才行,demo可以把epoch减小一点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2181524.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

仿真设计|基于51单片机的智能防火GSM上报仿真

目录 具体实现功能 设计介绍 51单片机简介 资料内容 仿真实现(protues8.7) 程序(Keil5) 全部内容 资料获取 具体实现功能 (1)LCD1602显示实时温度(DS18B20)值和烟雾&#x…

避免学术欺诈!在ChatGPT帮助下实现严格引用并避免抄袭

学境思源,一键生成论文初稿: AcademicIdeas - 学境思源AI论文写作 当今的学术环境中,保持学术诚信至关重要。随着ChatGPT等技术的发展,写作变得更加高效,但也增加了不当使用的风险。严格的引用和避免抄袭不仅是学术道…

一个服务器可以搭建几个网站

一个服务器可以搭建几个网站 该省就得省,一台服务器可以搭建多种不同的网站或应用#服务器#服务器租用 多个站点放在同一个服务器上有什么影响吗? 服务器里面会涉及到就是内存和带宽,如果说你一个服务器只放一个网站肯定更好一点&#xff0…

java将word转pdf

总结 建议使用aspose-words转pdf,poi的容易出问题还丑… poi的(多行的下边框就不对了) aspose-words的(基本和word一样) poi工具转换 <!-- 处理PDF --><dependency><groupId>fr.opensagres.xdocreport</groupId><artifactId>fr.opensagres…

计算机网络实验2——利用Wireshark对上网操作抓包并进行相关协议分析(实验部分)

五、实验过程 1.安装并启动Wireshark。 选择菜单栏上捕获->选项&#xff0c;勾选WLAN网卡。点击Start&#xff0c;进行抓包 Wireshark处于抓包状态中 2.打开浏览器&#xff0c;在地址栏中输入教师指定的web服务器地址。&#xff08;http://202.113.78.39&#xff09; 为了…

基于微信小程序的健康管理系统(源码+定制+文档)

博主介绍&#xff1a; ✌我是阿龙&#xff0c;一名专注于Java技术领域的程序员&#xff0c;全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师&#xff0c;我在计算机毕业设计开发方面积累了丰富的经验。同时&#xff0c;我也是掘金、华为云、阿里云、InfoQ等平台…

【Verilog学习日常】—牛客网刷题—Verilog企业真题—VL69

脉冲同步器&#xff08;快到慢&#xff09; 描述 sig_a 是 clka&#xff08;300M&#xff09;时钟域的一个单时钟脉冲信号&#xff08;高电平持续一个时钟clka周期&#xff09;&#xff0c;请设计脉冲同步电路&#xff0c;将sig_a信号同步到时钟域 clkb&#xff08;100M&…

VMware ESXi 6.7U3u macOS Unlocker 集成驱动版更新 OEM BIOS 2.7 支持 Windows Server 2025

VMware ESXi 6.7U3u macOS Unlocker & OEM BIOS 2.7 集成 Realtek 网卡驱动和 NVMe 驱动 (集成驱动版) UI fix 此版本解决的问题&#xff1a;VMware Host Client 无法将现有虚拟磁盘 (VMDK) 附加到虚拟机 请访问原文链接&#xff1a;https://sysin.org/blog/vmware-esxi-…

基于springboot的数据库原理教学案例案例库管理系统

目录 毕设制作流程功能和技术介绍系统实现截图开发核心技术介绍&#xff1a;使用说明开发步骤编译运行代码执行流程核心代码部分展示可行性分析软件测试详细视频演示源码获取 毕设制作流程 &#xff08;1&#xff09;与指导老师确定系统主要功能&#xff1b; &#xff08;2&am…

tftp传文件被服务器拒绝进入tftp: server error: (768) Access to staonline.pcap denied

环境&#xff1a;测试一个ac下挂ap&#xff0c;ap下的抓包文件传出时&#xff0c;出现问题&#xff1a; ac的wan口ip是192.168.186.167/24&#xff0c;gw是192.168.186.1&#xff0c;下挂ap的ip是192.168.202.199/24&#xff0c;ac上开子接口192.168.202.1/24&#xff0c;ac上开…

【MySQL】数据目录迁移

一、使用场景 使用该方法一般是数据目录所在磁盘不支持扩展&#xff0c;只能通过新加磁盘来扩展数据目录磁盘空间。通常是Windows服务器&#xff0c;或者是Linux服务器的mysql数据目录的磁盘没有使用lvm。 二、准备工作 1. 新磁盘初始化&#xff0c;达到可使用状态 2. 需要自己…

鸿蒙NEXT开发-自定义构建函数(基于最新api12稳定版)

注意&#xff1a;博主有个鸿蒙专栏&#xff0c;里面从上到下有关于鸿蒙next的教学文档&#xff0c;大家感兴趣可以学习下 如果大家觉得博主文章写的好的话&#xff0c;可以点下关注&#xff0c;博主会一直更新鸿蒙next相关知识 专栏地址: https://blog.csdn.net/qq_56760790/…

lambda表达式底层实现

一、lambda 代码 & 反编译 原始Java代码 假设我们有以下简单的Java程序&#xff0c;它使用Lambda表达式来遍历并打印一个字符串列表&#xff1a; import java.util.Arrays; import java.util.List;public class LambdaExample {public static void main(String[] args) {…

解决磁盘负载不均——ElasticSearch 分片分配和路由设置

ES 分片分配&#xff08;Shard Allocation&#xff09;时间点&#xff1a; 初始恢复&#xff08;Initial Recovery&#xff09;副本分配&#xff08;Replica Allocation&#xff09;重平衡&#xff08;Rebalance&#xff09;节点添加或移除 小结&#xff1a; 准备移除节点时&a…

Flask-2

文章目录 请求全局钩子[hook]异常抛出和捕获异常abort 主动抛出HTTP异常errorhandler 捕获错误 context请求上下文(request context)应用上下文(application context)current_appg变量 两者区别&#xff1a; 终端脚本命令flask1.0的终端命令使用自定义终端命令 flask2.0的终端命…

25中国烟草校园招聘面试问题总结 烟草面试全流程及面试攻略

开头附上工作招聘面试必备问题噢~~包括综合面试题、无领导小组面试题资源文件免费&#xff01;全文干货。 工作招聘无领导小组面试全攻略最常见面试题&#xff08;第一部分&#xff09;共有17章可用于国企私企合资企业工作招聘面试面试必备心得面试总结资源-CSDN文库https://d…

.NET CORE程序发布IIS后报错误 500.19

发布IIS后浏览时报错误500.19&#xff0c;同时配置文件web.config的路径中也存在问号“?”。 可能原因&#xff1a;没有安装运行时

树和二叉树知识点大全及相关题目练习【数据结构】

树和二叉树 要注意树和二叉树是两个完全不同的结构、概念&#xff0c;它们之间不存在包含之类的关系 树的定义 树&#xff08;Tree&#xff09;是n&#xff08;n≥0&#xff09;个结点的有限集&#xff0c;它或为空树&#xff08;n 0&#xff09;&#xff1b;或为非空树&a…

WPF 设计属性 设计页面时实时显示 页面涉及集合时不显示处理 设计页面时显示集合样式 显示ItemSource TabControl等集合样式

WPF 设计属性 设计页面时实时显示 页面涉及集合时不显示处理 设计页面时显示集合样式 显示ItemSource TabControl等集合样式 1、设计显示属性 1、设计时显示属性依赖以下属性 xmlns:d"http://schemas.microsoft.com/expression/blend/2008"2、在运行时不显示设计属性…

健康生活,从日常细节开始

健康生活&#xff0c;从日常细节开始 在快节奏的现代生活中&#xff0c;健康养生似乎成了一种奢侈的追求。但殊不知&#xff0c;真正的养生之道&#xff0c;往往就蕴含在我们日常的点点滴滴之中。今天&#xff0c;就让我们一起探讨几个简单却极易被忽视的健康生活小贴士&#…