【域适应】基于域分离网络的MNIST数据10分类典型方法实现

news2024/10/7 14:25:17

关于

大规模数据收集和注释的成本通常使得将机器学习算法应用于新任务或数据集变得异常昂贵。规避这一成本的一种方法是在合成数据上训练模型,其中自动提供注释。尽管它们很有吸引力,但此类模型通常无法从合成图像推广到真实图像,因此需要域适应算法来操纵这些模型,然后才能成功应用。现有的方法要么侧重于将表示从一个域映射到另一个域,要么侧重于学习提取对于提取它们的域而言不变的特征。然而,通过只关注在两个域之间创建映射或共享表示,他们忽略了每个域的单独特征。域分离网络可以实现对每个域的独特之处进行特征建模,,同时进行模型域不变特征的提取

参考文章: https://arxiv.org/abs/1608.06019

工具

 

方法实现

数据集定义
import torch.utils.data as data
from PIL import Image
import os


class GetLoader(data.Dataset):
    def __init__(self, data_root, data_list, transform=None):
        self.root = data_root
        self.transform = transform

        f = open(data_list, 'r')
        data_list = f.readlines()
        f.close()

        self.n_data = len(data_list)

        self.img_paths = []
        self.img_labels = []

        for data in data_list:
            self.img_paths.append(data[:-3])
            self.img_labels.append(data[-2])

    def __getitem__(self, item):
        img_paths, labels = self.img_paths[item], self.img_labels[item]
        imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB')

        if self.transform is not None:
            imgs = self.transform(imgs)
            labels = int(labels)

        return imgs, labels

    def __len__(self):
        return self.n_data
模型搭建
import torch.nn as nn
from functions import ReverseLayerF


class DSN(nn.Module):
    def __init__(self, code_size=100, n_class=10):
        super(DSN, self).__init__()
        self.code_size = code_size

        ##########################################
        # private source encoder
        ##########################################

        self.source_encoder_conv = nn.Sequential()
        self.source_encoder_conv.add_module('conv_pse1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5,
                                                                padding=2))
        self.source_encoder_conv.add_module('ac_pse1', nn.ReLU(True))
        self.source_encoder_conv.add_module('pool_pse1', nn.MaxPool2d(kernel_size=2, stride=2))

        self.source_encoder_conv.add_module('conv_pse2', nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5,
                                                                padding=2))
        self.source_encoder_conv.add_module('ac_pse2', nn.ReLU(True))
        self.source_encoder_conv.add_module('pool_pse2', nn.MaxPool2d(kernel_size=2, stride=2))

        self.source_encoder_fc = nn.Sequential()
        self.source_encoder_fc.add_module('fc_pse3', nn.Linear(in_features=7 * 7 * 64, out_features=code_size))
        self.source_encoder_fc.add_module('ac_pse3', nn.ReLU(True))

        #########################################
        # private target encoder
        #########################################

        self.target_encoder_conv = nn.Sequential()
        self.target_encoder_conv.add_module('conv_pte1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5,
                                                                padding=2))
        self.target_encoder_conv.add_module('ac_pte1', nn.ReLU(True))
        self.target_encoder_conv.add_module('pool_pte1', nn.MaxPool2d(kernel_size=2, stride=2))

        self.target_encoder_conv.add_module('conv_pte2', nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5,
                                                                padding=2))
        self.target_encoder_conv.add_module('ac_pte2', nn.ReLU(True))
        self.target_encoder_conv.add_module('pool_pte2', nn.MaxPool2d(kernel_size=2, stride=2))

        self.target_encoder_fc = nn.Sequential()
        self.target_encoder_fc.add_module('fc_pte3', nn.Linear(in_features=7 * 7 * 64, out_features=code_size))
        self.target_encoder_fc.add_module('ac_pte3', nn.ReLU(True))

        ################################
        # shared encoder (dann_mnist)
        ################################

        self.shared_encoder_conv = nn.Sequential()
        self.shared_encoder_conv.add_module('conv_se1', nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5,
                                                                  padding=2))
        self.shared_encoder_conv.add_module('ac_se1', nn.ReLU(True))
        self.shared_encoder_conv.add_module('pool_se1', nn.MaxPool2d(kernel_size=2, stride=2))

        self.shared_encoder_conv.add_module('conv_se2', nn.Conv2d(in_channels=32, out_channels=48, kernel_size=5,
                                                                  padding=2))
        self.shared_encoder_conv.add_module('ac_se2', nn.ReLU(True))
        self.shared_encoder_conv.add_module('pool_se2', nn.MaxPool2d(kernel_size=2, stride=2))

        self.shared_encoder_fc = nn.Sequential()
        self.shared_encoder_fc.add_module('fc_se3', nn.Linear(in_features=7 * 7 * 48, out_features=code_size))
        self.shared_encoder_fc.add_module('ac_se3', nn.ReLU(True))

        # classify 10 numbers
        self.shared_encoder_pred_class = nn.Sequential()
        self.shared_encoder_pred_class.add_module('fc_se4', nn.Linear(in_features=code_size, out_features=100))
        self.shared_encoder_pred_class.add_module('relu_se4', nn.ReLU(True))
        self.shared_encoder_pred_class.add_module('fc_se5', nn.Linear(in_features=100, out_features=n_class))

        self.shared_encoder_pred_domain = nn.Sequential()
        self.shared_encoder_pred_domain.add_module('fc_se6', nn.Linear(in_features=100, out_features=100))
        self.shared_encoder_pred_domain.add_module('relu_se6', nn.ReLU(True))

        # classify two domain
        self.shared_encoder_pred_domain.add_module('fc_se7', nn.Linear(in_features=100, out_features=2))

        ######################################
        # shared decoder (small decoder)
        ######################################

        self.shared_decoder_fc = nn.Sequential()
        self.shared_decoder_fc.add_module('fc_sd1', nn.Linear(in_features=code_size, out_features=588))
        self.shared_decoder_fc.add_module('relu_sd1', nn.ReLU(True))

        self.shared_decoder_conv = nn.Sequential()
        self.shared_decoder_conv.add_module('conv_sd2', nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5,
                                                                  padding=2))
        self.shared_decoder_conv.add_module('relu_sd2', nn.ReLU())

        self.shared_decoder_conv.add_module('conv_sd3', nn.Conv2d(in_channels=16, out_channels=16, kernel_size=5,
                                                                  padding=2))
        self.shared_decoder_conv.add_module('relu_sd3', nn.ReLU())

        self.shared_decoder_conv.add_module('us_sd4', nn.Upsample(scale_factor=2))

        self.shared_decoder_conv.add_module('conv_sd5', nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3,
                                                                  padding=1))
        self.shared_decoder_conv.add_module('relu_sd5', nn.ReLU(True))

        self.shared_decoder_conv.add_module('conv_sd6', nn.Conv2d(in_channels=16, out_channels=3, kernel_size=3,
                                                                  padding=1))

    def forward(self, input_data, mode, rec_scheme, p=0.0):

        result = []

        if mode == 'source':

            # source private encoder
            private_feat = self.source_encoder_conv(input_data)
            private_feat = private_feat.view(-1, 64 * 7 * 7)
            private_code = self.source_encoder_fc(private_feat)

        elif mode == 'target':

            # target private encoder
            private_feat = self.target_encoder_conv(input_data)
            private_feat = private_feat.view(-1, 64 * 7 * 7)
            private_code = self.target_encoder_fc(private_feat)

        result.append(private_code)

        # shared encoder
        shared_feat = self.shared_encoder_conv(input_data)
        shared_feat = shared_feat.view(-1, 48 * 7 * 7)
        shared_code = self.shared_encoder_fc(shared_feat)
        result.append(shared_code)

        reversed_shared_code = ReverseLayerF.apply(shared_code, p)
        domain_label = self.shared_encoder_pred_domain(reversed_shared_code)
        result.append(domain_label)

        if mode == 'source':
            class_label = self.shared_encoder_pred_class(shared_code)
            result.append(class_label)

        # shared decoder

        if rec_scheme == 'share':
            union_code = shared_code
        elif rec_scheme == 'all':
            union_code = private_code + shared_code
        elif rec_scheme == 'private':
            union_code = private_code

        rec_vec = self.shared_decoder_fc(union_code)
        rec_vec = rec_vec.view(-1, 3, 14, 14)

        rec_code = self.shared_decoder_conv(rec_vec)
        result.append(rec_code)

        return result




 模型训练
import random
import os
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import numpy as np
from torch.autograd import Variable
from torchvision import datasets
from torchvision import transforms
from model_compat import DSN
from data_loader import GetLoader
from functions import SIMSE, DiffLoss, MSE
from test import test

######################
# params             #
######################

source_image_root = os.path.join('.', 'dataset', 'mnist')
target_image_root = os.path.join('.', 'dataset', 'mnist_m')
model_root = 'model'
cuda = True
cudnn.benchmark = True
lr = 1e-2
batch_size = 32
image_size = 28
n_epoch = 100
step_decay_weight = 0.95
lr_decay_step = 20000
active_domain_loss_step = 10000
weight_decay = 1e-6
alpha_weight = 0.01
beta_weight = 0.075
gamma_weight = 0.25
momentum = 0.9

manual_seed = random.randint(1, 10000)
random.seed(manual_seed)
torch.manual_seed(manual_seed)

#######################
# load data           #
#######################

img_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

dataset_source = datasets.MNIST(
    root=source_image_root,
    train=True,
    transform=img_transform
)

dataloader_source = torch.utils.data.DataLoader(
    dataset=dataset_source,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8
)

train_list = os.path.join(target_image_root, 'mnist_m_train_labels.txt')

dataset_target = GetLoader(
    data_root=os.path.join(target_image_root, 'mnist_m_train'),
    data_list=train_list,
    transform=img_transform
)

dataloader_target = torch.utils.data.DataLoader(
    dataset=dataset_target,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8
)

#####################
#  load model       #
#####################

my_net = DSN()

#####################
# setup optimizer   #
#####################


def exp_lr_scheduler(optimizer, step, init_lr=lr, lr_decay_step=lr_decay_step, step_decay_weight=step_decay_weight):

    # Decay learning rate by a factor of step_decay_weight every lr_decay_step
    current_lr = init_lr * (step_decay_weight ** (step / lr_decay_step))

    if step % lr_decay_step == 0:
        print 'learning rate is set to %f' % current_lr

    for param_group in optimizer.param_groups:
        param_group['lr'] = current_lr

    return optimizer


optimizer = optim.SGD(my_net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

loss_classification = torch.nn.CrossEntropyLoss()
loss_recon1 = MSE()
loss_recon2 = SIMSE()
loss_diff = DiffLoss()
loss_similarity = torch.nn.CrossEntropyLoss()

if cuda:
    my_net = my_net.cuda()
    loss_classification = loss_classification.cuda()
    loss_recon1 = loss_recon1.cuda()
    loss_recon2 = loss_recon2.cuda()
    loss_diff = loss_diff.cuda()
    loss_similarity = loss_similarity.cuda()

for p in my_net.parameters():
    p.requires_grad = True

#############################
# training network          #
#############################
 MNIST数据重建/共有部分特征/私有数据特征可视化

 MNIST_m数据重建/共有部分特征/私有数据特征可视化

代码获取

相关问题和项目开发,欢迎私信交流和沟通。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1594483.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Pandas相比Excel的优势是哪些?

熟悉Pandas的同学会知道,Pandas相当于Python中的Excel,都是基于二维表的进行数据处理分析,不同的是,Pandas基于代码操作数据,Excel是图形化的分析工具。 不少人会问Excel比Pandas更简单,为什么还要学习Pan…

SSL Pinning之双向认证

双向认证处理流程 概述获取证书逆向app 获取证书的KeyStore的 key通过jadx 反编译 app 获取证书:frida hook 证书转换命令行转换portecle 工具使用 charles 配置 p12 格式证书 概述 本篇只介绍怎么解决ssl pinning, 不讲ssl/tls 原理。 为了解决ssl pinn…

【Java】内存可见性问题是什么?

文章目录 内存模型内存可见性解决方案volatile 内存模型 什么是JAVA 内存模型? Java Memory Model (JAVA 内存模型)是描述线程之间如何通过内存(memory)来进行交互。 具体说来, JVM中存在一个主存区(Main Memory或Java Heap Mem…

MySQL优化慢SQL的6种方式

⛰️个人主页: 蒾酒 🔥系列专栏:《mysql经验总结》 🌊山高路远,行路漫漫,终有归途 目录 写在前面 优化思路 优化方法 1.避免查询不必要的列 2.分页优化 3.索引优化 4.JOIN优化 5.排序优化 6.UNION 优化…

基于51单片机多功能洗衣机控制(强洗弱洗漂洗)设计( proteus仿真+程序+设计报告+原理图+讲解视频)

基于51单片机多功能洗衣机控制(强洗弱洗漂洗)设计( proteus仿真程序设计报告原理图讲解视频) 多功能洗衣机控制-强洗弱洗漂洗 1. 主要功能:2. 讲解视频:3. 仿真设计4. 程序代码5. 设计报告6. 原理图7. 设计资料内容清单资料下载链接&#xf…

Element-UI 自定义-下拉框选择年份

1.实现效果 场景表达&#xff1a; 默认展示当年的年份&#xff0c;默认展示前7年的年份 2.实现思路 创建一个新的Vue组件。 使用<select>元素和v-for指令来渲染年份下拉列表。 使用v-model来绑定选中的年份值。 3.实现代码展示 <template><div><el-…

数学:人工智能学习之路上的“拦路虎”及其背后的奥秘

在人工智能的浪潮席卷全球的今天&#xff0c;越来越多的人开始涉足这一领域&#xff0c;以期掌握其核心技术&#xff0c;为未来的科技发展贡献力量。然而&#xff0c;在学习的道路上&#xff0c;许多人却遇到了一个不小的挑战——数学。为何数学会成为学习人工智能的“拦路虎”…

多波长,无限可能:探索波分复用的前沿 ✨

&#x1f331;当我们谈论光纤通信时&#xff0c;波分复用是一个不可忽视的关键词。它不仅改变了通信的速度和容量&#xff0c;更是连接数字世界的一座隐形桥梁。让我们深入了解这项技术的精髓。 &#x1f4da; 目录 ❓什么是WDM波分复用&#xff1f;以及WDM工作原理&#x1f5…

有真的副业推荐吗?

#有真的副业推荐吗# 我做副业项目的时候&#xff0c;认识了一位带娃宝妈&#xff0c;讲一下她空闲时间做副业赚钱的故事吧。在一个温馨的小家庭里&#xff0c;李婷是一位全职宝妈&#xff0c;她的主要任务是照顾和陪伴自己可爱的宝宝。然而&#xff0c;随着宝宝逐渐长大&#x…

Linux应用开发笔记(八)SPI应用层开发及其框架

文章目录 前言一、扩展SPI协议&#xff08;Single/Dual/Quad/Octal SPI&#xff09;二、SPi驱动框架三、SPI应用编程1. SPI相关数据结构与ioctl函数2. 基本函数 前言 与IIC类似&#xff0c;SPI协议也是我们的老朋友了&#xff0c;这里依然不多作赘述&#xff0c;本文将介绍SPI的…

Flutter第八弹 构建拥有不同项的列表

目标&#xff1a;1&#xff09;项目中&#xff0c;数据源可能涉及不同的模版&#xff0c;显示不同类型的子项&#xff0c;类似RecycleView的itemType, 有多种类型&#xff0c;列表怎么显示&#xff1f; 2&#xff09;不同的数据源构建列表 一、创建不同的数据源 采用类似Rec…

韩顺平Java | C25 JDBC和连接池(上)

概述 JDBC概述&#xff1a;JDBC为访问不同数据库提供统一接口&#xff0c;为使用者屏蔽细节问题。Java程序员使用JDBC可以连接任何提供了JDBC驱动程序的数据库系统&#xff0c;从而完成对数据库的各种操作。 // 模拟代码 //JdbcInterface.java --Java规定的JDBC接口(方法) p…

Linux-文件系统理解(磁盘的物理与逻辑结构、什么是inode、OS如何管理磁盘)

一、磁盘 磁盘的物理结构 磁盘的本质是一个机械设备&#xff0c;可以存储大量的二进制信息&#xff0c;是实现数据存储的基础硬件设施&#xff0c;磁盘的盘片类似于光盘&#xff0c;不过盘片的两面都是可读可写可擦除的&#xff0c;每个盘面都有一个磁头&#xff0c;马达可以使…

keil无法查看外设寄存器(生成SFR文件)

1.前言 自从更新了keil&#xff0c;用的是越来越不顺手了&#xff0c;一会是cannot evaluate&#xff0c;一会是与强制与cubemx强制联系了&#xff0c;这次也是的&#xff08;地铁&#xff0c;老人&#xff0c;手机&#xff09;折腾了一下总是搞好了&#xff08;网上的解法只能…

安卓刷机fastboot分段传输

win10 fastboot 无法识别&#xff0c;驱动下载地址GitHub - xushuan/google_latest_usb_driver_windows 把inf文件更新到设备管理器驱动更新即可 问题 archive does not contain super_empty.img Sending vbmeta_a (4 KB) OKAY [ 0.117s] Writing …

Open CASCADE学习|实现Extrude功能

首先定义了一些基本的几何元素&#xff0c;如线、圆和平面&#xff0c;然后使用makeExtrudebydir函数来对一个面进行挤出操作。下面是详细过程&#xff1a; 定义Extrude函数&#xff1a;makeExtrudebydir函数接受一个TopoDS_Shape对象和一个gp_Vec对象作为参数。TopoDS_Shape是…

【vs2019】window10环境变量设置

【vs2019】window10环境变量设置 【先赞后看养成习惯】求关注点赞收藏&#x1f60a; 安装VS2019时建议默认安装地址&#xff0c;最好不要改动&#xff0c;不然容易出问题 以下是安装完VS2019后环境变量的设置情况&#xff0c;C:\Program Files (x86)\Microsoft Visual Studi…

【Unity添加远程桌面】使用Unity账号远程控制N台电脑

设置地址&#xff1a; URDP终极远程桌面&#xff1b;功能强大&#xff0c;足以让开发人员、设计师、建筑师、工程师等等随时随地完成工作或协助别人https://cloud-desktop.u3dcloud.cn/在网站登录自己的Unity 账号上去 下载安装被控端安装 保持登录 3.代码添加当前主机 "…

初探vercel托管项目

文章目录 第一步、注册与登录第二步、本地部署 在个人网站部署的助手vercel&#xff0c;支持 Github部署&#xff0c;只需简单操作&#xff0c;即可发布&#xff0c;方便快捷&#xff01; 第一步、注册与登录 进入vercel【官网】&#xff0c;在右上角 login on&#xff0c;可登…

【小迪安全2023】第23天:WEB攻防-Python考点CTF与CMS-SSTI模版注入PYC反编译

&#x1f36c; 博主介绍&#x1f468;‍&#x1f393; 博主介绍&#xff1a;大家好&#xff0c;我是 hacker-routing &#xff0c;很高兴认识大家~ ✨主攻领域&#xff1a;【渗透领域】【应急响应】 【Java、PHP】 【VulnHub靶场复现】【面试分析】 &#x1f389;点赞➕评论➕收…