偏标记学习+图像分类(论文复现)

news2024/10/4 8:41:30

偏标记学习+图像分类(论文复现)

本文所涉及所有资源均在传知代码平台可获取

文章目录

    • 偏标记学习+图像分类(论文复现)
        • 概述
        • 算法原理
        • 核心逻辑
        • 效果演示
        • 使用方式

概述

本文复现论文提出的偏标记学习方法,随着深度神经网络的发展,机器学习任务对标注数据的需求不断增加。然而,大量的标注数据十分依赖人力资源与标注者的专业知识。弱监督学习可以有效缓解这一问题,因其不需要完全且准确的标注数据。该论文关注一个重要的弱监督学习问题——偏标记学习(Partial Label Learning),其中每个训练实例与一组候选标签相关联,但仅有一个标签是真实的

在这里插入图片描述

该论文提出了一种渐进式真实标签识别方法,旨在训练过程中逐渐确定样本的真实标签。该论文所提出的方法获得了接近监督学习的性能,且与具体的网络结构、损失函数、随机优化算法无关

算法原理

传统的监督学习常用交叉熵损失和随机梯度下降来优化深度神经网络。交叉熵损失定义如下

在这里插入图片描述

在这里插入图片描述

核心逻辑

具体的核心逻辑如下所示:

import models
import datasets
import torch
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdm

def CE_loss(probs, targets):
    """交叉熵损失函数"""
    loss = -torch.sum(targets * torch.log(probs), dim = -1)
    loss_avg = torch.sum(loss)/probs.shape[0]
    return loss_avg

class Proden:
    def __init__(self, configs):
        self.configs = configs
    
    def train(self, save = False):
        configs = self.configs
        # 读取数据集
        dataset_path = configs['dataset path']
        if configs['dataset'] == 'CIFAR-10':
            train_data, train_labels, test_data, test_labels = datasets.cifar10_read(dataset_path)
            train_dataset = datasets.Cifar(train_data, train_labels)
            test_dataset = datasets.Cifar(test_data, test_labels)
            output_dimension = 10
        elif configs['dataset'] == 'CIFAR-100':
            train_data, train_labels, test_data, test_labels = datasets.cifar100_read(dataset_path)
            train_dataset = datasets.Cifar(train_data, train_labels)
            test_dataset = datasets.Cifar(test_data, test_labels)
            output_dimension = 100
        # 生成偏标记
        partial_labels = datasets.generate_partial_labels(train_labels, configs['partial rate'])
        train_dataset.load_partial_labels(partial_labels)
        # 计算数据的均值和方差,用于模型输入的标准化
        mean = [np.mean(train_data[:, i, :, :]) for i in range(3)]
        std = [np.std(train_data[:, i, :, :]) for i in range(3)]
        normalize = transforms.Normalize(mean, std)
        # 设备:GPU或CPU
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # 加载模型
        if configs['model'] == 'ResNet18':
            model = models.ResNet18(output_dimension = output_dimension).to(device)
        elif configs['model'] == 'ConvNet':
            model = models.ConvNet(output_dimension = output_dimension).to(device)
        # 设置学习率等超参数
        lr = configs['learning rate']
        weight_decay = configs['weight decay']
        momentum = configs['momentum']
        optimizer = optim.SGD(model.parameters(), lr = lr, weight_decay = weight_decay, momentum = momentum)
        lr_step = configs['learning rate decay step']
        lr_decay = configs['learning rate decay rate']
        lr_scheduler = StepLR(optimizer, step_size=lr_step, gamma=lr_decay)
        for epoch_id in range(configs['epoch count']):
            # 训练模型
            train_dataloader = DataLoader(train_dataset, batch_size = configs['batch size'], shuffle = True)
            model.train()
            for batch in tqdm(train_dataloader, desc='Training(Epoch %d)' % epoch_id, ascii=' 123456789#'):
                ids = batch['ids']
                # 标准化输入
                data = normalize(batch['data'].to(device))
                partial_labels = batch['partial_labels'].to(device)
                targets = batch['targets'].to(device)
                optimizer.zero_grad()
                # 计算预测概率
                logits = model(data)
                probs = F.softmax(logits, dim=-1)
                # 更新软标签
                with torch.no_grad():
                    new_targets = F.normalize(probs * partial_labels, p=1, dim=-1)
                    train_dataset.targets[ids] = new_targets.cpu().numpy()
                # 计算交叉熵损失
                loss = CE_loss(probs, targets)
                loss.backward()
                # 更新模型参数
                optimizer.step()
            # 调整学习率
            lr_scheduler.step()
效果演示

我提前在 CIFAR-10[2] 数据集和 12 层的 ConvNet[3] 网络上训练了一份模型参数。为了测试其准确率,需要配置环境并运行main.py脚本,得到结果如下

在这里插入图片描述

由图可见,该算法在测试集上获得了 89.8% 的准确率。

进一步地,测试训练出的模型在真实图片上的预测结果。在线部署模型后,将一张轮船的图片输入,可以得到输出的预测类型为 “Ship”:

在这里插入图片描述

在这里插入图片描述

使用方式

解压附件压缩包并进入工作目录。如果是Linux系统,请使用如下命令:

unzip Proden-implemention.zip
cd Proden-implemention

代码的运行环境可通过如下命令进行配置

pip install -r requirements.txt

运行如下命令以下载并解压数据集

bash download.sh

如果希望在本地训练模型,请运行如下命令

python main.py -c [你的配置文件路径] -r [选择下者之一:"train""test""infer"]

如果希望在线部署,请运行如下命令

python main-flask.py

文章代码资源点击附件获取

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2187706.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

异常场景分析

优质博文:IT-BLOG-CN 为了防止黑客从前台异常信息,对系统进行攻击。同时,为了提高用户体验,我们都会都抛出的异常进行拦截处理。 一、异常处理类 Java把异常当做是破坏正常流程的一个事件,当事件发生后,…

CMU 10423 Generative AI:lec16(Mixture of Experts 混合专家模型)

关于MoE推荐博客: https://huggingface.co/blog/zh/moehttps://www.paddlepaddle.org.cn/documentation/docs/zh/guides/06_distributed_training/moe_cn.html 1 概述 这个文档是关于Mixture of Experts (MoE) 的介绍和实现,主要内容如下:…

virtualbox配置为NAT模式后物理机和虚拟机互通

virtualbox配置为 NAT模式后,虚拟机分配到的 IP地址一般是 10.xx网段的,虚拟机可以通过网络地址转换访问物理机所在的网络,但若不做任何配置,则物理机无法直接访问虚拟机。 virtualbox在提供 NAT配置模式时,也提供了端…

深度学习:CycleGAN图像风格迁移转换

基础概念 CycleGAN是一种GAN的变体,它被设计用来在没有成对训练数据的情况下学习两种不同域之间的图像到图像的转换,不需要同一场景或物体在两个不同域中的对应图像。 CycleGAN由Jun-Yan Zhu等人在2017年提出。 CycleGAN的模型架构主要由两组生成器和…

mac配置python出现DataDirError: Valid PROJ data directory not found错误的解决

最近在利用python下载SWOT数据时出现以下的问题: import xarray as xr import s3fs import cartopy.crs as ccrs from matplotlib import pyplot as plt import earthaccess from earthaccess import Auth, DataCollections, DataGranules, Store import os os.env…

CSS3--美开二度

免责声明:本文仅做分享! 目录 定位 相对定位 绝对定位 定位居中 固定定位 堆叠层级 z-index 定位-小结 CSS 精灵 京东案例 字体图标 下载字体 使用字体 上传矢量图 CSS 修饰属性 垂直对齐方式 vertical-align 过渡 transition 透明度 opa…

【西门子V20变频器】 变频器运行时报A922报警

报警说明 原因: 1.变频器未接负载 2.变频器设定的电机参数与实际电机不匹配 3.查看P2179查看 无负载监控 设定的电流极限值,出厂默认为“3.0”

mysql事务 -- 事务的隔离性(测试实验+介绍,脏读,不可重复读,可重复度读,幻读)

目录 事务的隔离性 引入 测试 读未提交 脏读 读提交 不可重复读 属于问题吗? 例子 可重复读 幻读 串行化 原理 总结 事务的隔离性 引入 当我们让两个客户端共同执行begin语句时,就开始了两个事务并发访问 在这个过程中,可能会出现sql交叉的问题 但我们不希望因为…

项目定位与服务器(SERVER)模块划分

目录 定位 HTTP协议以及HTTP服务器 高并发服务器 单Reactor单线程 单Reactor多线程 多Reactor多线程 模块划分 SERVER模块划分 Buffer 模块 Socket模块 Channel 模块 Connection模块 Acceptor模块 TimerQueue模块 Poller模块 EventLoop模块 TcpServer模块 SE…

【ADC】噪声(1)噪声分类

概述 本文学习于TI 高精度实验室课程,总结 ADC 的噪声分类,并简要介绍量化噪声和热噪声。 文章目录 概述一、ADC 中的噪声类型二、量化噪声三、热噪声四、量化噪声与热噪声对比 一、ADC 中的噪声类型 ADC 固有噪声由两部分组成:第一部分是量…

【树莓派系列】树莓派wiringPi库详解,官方外设开发

树莓派wiringPi库详解,官方外设开发 文章目录 树莓派wiringPi库详解,官方外设开发一、安装wiringPi库二、wiringPi库API大全1.硬件初始化函数2.通用GPIO控制函数3.时间控制函数4.串口通信串口API串口通信配置多串口通信配置串口自发自收测试串口间通信测…

Django 后端数据传给前端

Step 1 创建一个数据库 Step 2 在Django中点击数据库连接 Step 3 连接成功 Step 4 settings中找DATABASES Step 5 将数据库挂上面 将数据库引擎和数据库名改成自己的 Step 6 在_init_.py中加上数据库的支持语句 import pymysql pymysql.install_as_MySQLdb() Step7 简单创建两…

以企业的视角进行大学生招聘

课程来源:中国计算机学会---朱颖韶(资深人力资源领域--HR) 一、招聘流程 1.简历->门槛 注重:专业学历、行业经验 2.笔试面试->专业知识与技能 3.简历面试-> 过往的成果 4.面试 沟通能力、学习力-----了解动机、价值观…

Pikachu-Sql Inject-insert/update/delete注入

insert 注入 插入语句 insert into tables values(value1,value2,value3); 如:插入用户表 insert into users (id,name,password) values (id,username,password); 当点击注册 先判断是否有SQL注入漏洞,经过判断之后发现存在SQL漏洞。构造insert的pa…

8644 堆排序

### 思路 堆排序是一种基于堆数据结构的排序算法。堆是一种完全二叉树,分为最大堆和最小堆。堆排序的基本思想是将待排序数组构造成一个最大堆,然后依次将堆顶元素与末尾元素交换,并调整堆结构,直到排序完成。 ### 伪代码 1. 读取…

自闭症干预寄宿学校:专业治疗帮助孩子发展

自闭症干预寄宿学校:星贝育园的专业治疗助力孩子全面发展 在自闭症儿童的教育与康复领域,寄宿学校以其独特的教育模式和全面的关怀体系,为众多家庭提供了重要的选择。广州星贝育园自闭症儿童寄宿制学校,作为这一领域的佼佼者&…

达梦core文件分析(学习笔记)

目录 1、core 文件生成 1.1 前置条件说明 1.2 关于 core 文件生成路径的说明 1.3查看 core 文件的前置条件 2、查看 core 文件堆栈信息 2.1 使用gdb 2.2 使用达梦dmrdc 3、core 分析过程 3.1 服务端主动 core 3.2因未知异常原因导致的 core 4、测试案例 4.1测试环境…

(十八)、登陆 k8s 的 kubernetes-dashboard 更多可视化工具

文章目录 1、回顾 k8s 的安装2、确认 k8s 运行状态3、通过 token 登陆3.1、使用现有的用户登陆3.2、新加用户登陆 4、k8s 可视化工具 1、回顾 k8s 的安装 Mac 安装k8s 2、确认 k8s 运行状态 kubectl proxy kubectl cluster-info kubectl get pods -n kubernetes-dashboard3、…

网页前端开发之Javascript入门篇(4/9):循环控制

Javascript循环控制 什么是循环控制? 答:其概念跟 Python教程 介绍的一样,只是语法上有所变化。 参考流程图如下: 其对应语法: var i 0; // 设置起始值 var minutes 15; // 设置结束值(15分钟…

Stream流的终结方法(一)

1.Stream流的终结方法 2.forEach 对于forEach方法,用来遍历stream流中的所有数据 package com.njau.d10_my_stream;import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.function.Consumer; import java.util…