【多任务学习】Multi-task Learning 手把手编码带数据集, 一文吃透多任务学习

news2025/1/16 5:17:44

文章目录

  • 前言
  • 1.多任务学习
    • 1.1 定义
    • 1.2 原理
  • 2. 多任务学习code
    • 2.1 数据集初探
    • 2.2 预处理
    • 2.3 网络结构
    • 2.4 训练
  • 3. 总结


前言

我们之前讲过的模型通常聚焦单个任务,比如预测图片的类别等,在训练的时候,我们会关注某一个特定指标的优化.
但是有时候,我们需要知道一个图片,从它身上知道新闻的类型(政治/体育/娱乐)和是男性的新闻还是女性的.
我们关注某一个特定指标的优化,可能忽略了对有关注的指标的有用信息.具体来说就是训练相关任务所带来的额外信息,通过在多个相关任务中共享表示,我们可以使得模型在我们原本任务上获得更好的泛化能力.这种方法就叫做多任务学习.


1.多任务学习

1.1 定义

同时完成多个预测,共享表示,共享特征提取.使得模型关注到一些特有的特征.其实一套提取特征的网络,配合多个损失函数,就是多任务损失.
图像定位是单任务,若还需要知道类别,就变成了多任务学习.
在这里插入图片描述

1.2 原理

多任务学习的模型通常通过所有任务重共用隐藏层(特征提取层),而针对不同任务使用多个输出层来实现.自动学习到的任务越多,模型就能获得捕捉所有任务的表示,而原本任务上过拟合的风险更小.
多任务学习中,针对一个任务的特征提取,由于其它任务也能对提取的特征做出筛选,所以可以帮助模型将注意力集中到那些真正起作用的特征上.
模型会学习那些尽量表达多个任务的特征,而这些特征泛化能力会很好.

2. 多任务学习code

同时预测一个物品的颜色和类别.

2.1 数据集初探

一个分支用于分类给定输入图像的服装种类(比如衬衫、裙子、牛仔裤、鞋子等);
另一个分支负责分类该服装的颜色(黑色、红色、蓝色等)。
总体而言,我们的数据集由 2525 张图像构成,分为 7 种「颜色+类别」组合,包括:

黑色牛仔裤(344 张图像)
黑色鞋子(358 张图像)
蓝色裙子(386 张图像)
蓝色牛仔裤(356 张图像)
蓝色衬衫(369 张图像)
红色裙子(380 张图像)
红色衬衫(332 张图像)
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
数据集下载链接:https://pan.baidu.com/s/1JtKt7KCR2lEqAirjIXzvgg 提取码:2kbc

2.2 预处理

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torchvision
import glob
from torchvision import transforms
from torch.utils import data
from PIL import Image

img_paths = glob.glob(r"F:\multi-output-classification\dataset\*\*.jpg")
img_paths[:5]

在这里插入图片描述
路径文件夹就表示了标签,所以要获取其标签:

label_names = [img_path.split("\\")[-2] for img_path in img_paths]
label_names[:5]

在这里插入图片描述

label_array = np.array([la.split("_") for la in label_names])
label_array

在这里插入图片描述

label_color = label_array[:,0]
label_color

在这里插入图片描述

label_item = label_array[:,1]
label_item


吧他们转成index,因为torch中只认数字

unique_color = np.unique(label_color)
unique_color
unique_item = np.unique(label_item)
unique_item
item_to_idx = dict((v,k) for k, v in enumerate(unique_item))
item_to_idx
color_to_idx = dict((v,k) for k, v in enumerate(unique_color))
color_to_idx
label_item = [item_to_idx.get(k) for k in label_item]
label_color = [color_to_idx.get(k) for k in label_color ]
transform = transforms.Compose([
    transforms.Resize((96,96)),
    transforms.ToTensor(),
])

自定义数据集

class Multi_dataset(data.Dataset):
    def __init__(self,imgs_path, label_color, label_item) -> None:
        super().__init__()
        self.imgs_path = imgs_path
        self.label_color = label_color
        self.label_item = label_item
    
    def __getitem__(self, index):
        img_path = self.imgs_path[index]
        pil_img = Image.open(img_path)
        # 防止有图片有黑白图
        pil_img = pil_img.convert('RGB')
        pil_img = transform(pil_img)
        label_c = self.label_color[index]
        label_i = self.label_item[index]
        return pil_img, (label_c,label_i)
    def __len__(self):
        return len(self.imgs_path)

划分训练集

count = len(multi_dataset)
count
# 划分训练集 测试集
train_count = int(count*0.8)
test_count =  count - train_count
train_ds, test_ds = data.random_split(multi_dataset,[train_count, test_count])
len(train_ds),len(test_ds)
BATCHSIZE = 32
train_dl = data.DataLoader(train_ds,batch_size=BATCHSIZE,shuffle=True)
test_dl = data.DataLoader(test_ds,batch_size=BATCHSIZE)

在这里插入图片描述
在这里插入图片描述

2.3 网络结构

2.4 训练

3. 总结

未完待续

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/501358.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Windows远程执行进程工具psexec和wmiexec介绍

在自动化测试或者自动化工具开发中,通常需要向其它电脑或者服务器发送指令,比如Windows发送命令到Linux服务器开启某个服务进程,或者读取状态信息,我们可以使用ssh协议实现。 如果Windows主机需要发送命令到局域网内的其它Window…

第四十五章 Unity 滚动视图 (Scroll View) UI

我们介绍一下滚动条 (Scrollbar),它允许用户滚动由于太大而无法完全看到的图像或其他视图。这种效果在我们网页中经常看到,尤其是网页内容太长的时候,就会在垂直方向出现滚动条。当然,有时候也会在水平方向出现滚动条。我们拖动滚…

WebRTC引用计数和线程

1. 什么是引用计数: 引用计数是计算机编程语言中的一种内存管理技术,是指将资源(可以是对象、内存、或磁盘空间等等)的被引用次数保存起来,当被引用次数变为零时就将其释放的过程。 使用引用计数技术可以实现自动资源…

纯HTML+CSS网页设计期末作业(海贼王主题网站——图片可换,主题可换)

纯HTMLCSS网页设计期末作业(海贼王主题网站——图片可换,主题可换) 博主:命运之光 专栏:web网页制作大作业 网页最底下有视频可以观看网页效果,喜欢的可以在博主主页资源里免费下载(●’◡’●)让我们一起加…

[数据库系统] 一、添加常用约束(educoder)

1.任务:给表添加常用的约束。 2.相关知识 在数据库的使用过程中我们经常要限制属性的取值,比如有些属性不能为空,就需要添加非空约束,本关要求完成常用约束的添加和定义。 目录 (1)唯一约束 (2)添加非空约束 (3)使用默认约束…

大数据|Hive和数据仓库

前文回顾:HBase基本工作原理 目录 📚数据仓库和OLAP 🐇数据仓库 🥕面向主题 🥕集成的 🥕时变的 🥕非易失的 🐇OLTP(联机事务处理)vs OLAP(…

织梦城市分站怎么安装织梦二级域名站群织梦制作企业分站教程

1、安装说明 一、下载织梦多城市二级域名源码; 二、上传源码到服务器;(必须支持泛解析) 三、在浏览器输入http://你的域名/install进入安装页面; 可以参考http://www.hlzcb.com/zhimengxueyuan/zhimenganzhuangshiyong/25830.html 四、输入数据库用户名密码和数据库…

公司股权转让,变更股东要了解哪些?

什么是公司股权? 股权即股票持有者所具有的与其拥有的股票比例相应的权益及承担一定责任的权力。股权转让是一种物权变动行为,股权转让后,股东基于股东地位而对公司所发生的权利义务关系全部同时移转于受让人,受让人因此成为公司…

SpringBoot的yml多环境配置3种方法

目录 方式一:多个yml文件 步骤一、创建多个配置文件 步骤二、applicaiton.yml中指定配置 方式二: 单个yml文件 方式三:在pom.xml中指定环境配置 步骤一、创建多个配置文件 步骤二、在application.yml中添加多环境配置属性 步骤三、在po…

创客匠人直播怎么样?

快速发展的互联网时代,直播成为各个行业的引流利器。如今,众多教培机构转型线上,直播无疑是最好的线上教学方式,不仅可以让老师和学员异地教学还能最好的还原线下教学场景,利用科技提高学习质量。 做好一场直播不仅仅…

JColorChooser和JFileChooser

Swing提供了JColorChooser和JFileChooser这两种对话框,可以很方便的完成演示的选择和本地文件的选择。 1.JColorChooser JColorChooser用于创建颜色选择器对话框,该类的用法非常简单,只需要调用它的静态方法就可以快速生成一个颜色选择对话框…

蓝牙耳机什么牌子好?500内好用的蓝牙耳机推荐

随着蓝牙耳机的受欢迎程度越来越高,近几年来,无蓝牙耳机市场呈爆发式增长,蓝牙耳机品牌也越来越多。那么蓝牙耳机什么牌子好?接下来,我来给大家推荐几款500内好用的蓝牙耳机,一起来看看吧。 一、南卡小音舱…

GDD471A001 PLC / DCS维护日志

​ GDD471A001 PLC / DCS维护日志 PLC维护日志 PLC/DCS 维护日志将帮助您跟踪过去的故障、解决方案、零件更换。如果以后再次出现同样的问题,跟踪日志将帮助您立即解决。 您的控制系统的可靠性可以通过参考维护日志来确定。 使用 PLC/DCS 维护日志可以识别频繁出…

React的生命周期及Redux状态管理器等

生命周期 一个应用或页面从创建到消亡过程中某一时刻自动调用的回调函数称为生命周期钩子函数 挂载 constructor :来初始化函数内部 state,为 事件处理函数 绑定实例render:渲染 DOMcomponentDidMount:组件挂载、DOM 渲染完后&…

V8 JavaScript引擎

简介 V8 (v8.dev)是 Google 的开源高性能 JavaScript 和 WebAssembly 引擎,用 C 编写。它用于 Chrome 和 Node.js 等。它实现了 ECMAScript 和 WebAssembly,并运行在 Windows 7 或更高版本、macOS 10.12 以及使用 x64、IA-32、ARM 或 MIPS 处理器的 Lin…

FFmpeg HEVC 解码 YUV

1. 概要与流程图 1.1 FFmpeg 支持 h264,hevc 等解码,由于分离视频文件为 hevc 格式,为了方便起见,当前解码的格式为 hevc,代码支持各种视频格式解码,需要修改参数和适配 1.2 HEVC 解码 YUV 流程图如下: 2. 封装读写文件操作 2.1 读写头文件,FileTool.h #import <Fou…

MAC常用操作

1. 添加环境变量 vi ~/.bash_profile export PATHselfdefine_path:$PATH source ~/.bash_profile适用于安装Application之后将该Application的Contents/bin下的可执行程序添加到环境变量&#xff0c;使得在终端能够启用。 例如使用cmake-3.25.0-macos-universal.dmg安装好cmak…

基于微信小程序网上书城系统

开发工具&#xff1a;IDEA、微信小程序 服务器&#xff1a;Tomcat8.0&#xff0c; jdk1.8 项目构建&#xff1a;maven 数据库&#xff1a;mysql5.7 前端技术&#xff1a;vue、uniapp 服务端技术&#xff1a;springspringmvcmybatis(ssm框架) 本系统分微信小程序和管理后…

双向链表(数据结构)(C语言)

目录 概念 带头双向循环链表的实现 前情提示 双向链表的结构体定义 双向链表的初始化 关于无头单向非循环链表无需初始化函数&#xff0c;顺序表、带头双向循环链表需要的思考 双向链表在pos位置之前插入x 双向链表的打印 双链表删除pos位置的结点 双向链表的尾插 关…

Windows命令提示行使用指南一

命令提示行使用指南 前言一、起源和发展二、和DOS的关系三、常用命令 前言 cmd 是 Windows 操作系统中的命令行界面&#xff08;CLI&#xff09;&#xff0c;也称为命令提示符&#xff08;CMD&#xff09;或批处理文件。它是 Windows 命令行界面的主要组成部分&#xff0c;用于…