kaggle竞赛 | 计算机视觉 | Doodle Recognition Challenge

news2024/10/7 6:39:57

目录

  • 赛题链接
  • 赛题背景
  • 数据集探索
    • 合并多个类别CSV数据集
  • 数据建模 (pytorch)

赛题链接

https://www.kaggle.com/competitions/quickdraw-doodle-recognition/overview/evaluation
数据集从上述链接中找

赛题背景

'Quick,Draw!'作为实验性游戏发布,以有趣的方式向公众宣传 AI 的工作原理。游戏提示用户绘制描绘特定类别的图像,例如“香蕉”、“桌子”等。游戏生成了超过 1B 幅图画,其中的一个子集被公开发布,作为本次比赛训练集的基础。该子集包含 5000 万张图纸,涵盖 340 个标签类别。

听起来很有趣,对吧?挑战在于:由于训练数据来自游戏本身,绘图可能不完整或可能与标签不匹配。您需要构建一个识别器,它可以有效地从这些嘈杂的数据中学习,并在来自不同分布的手动标记的测试集上表现良好。

您的任务是为现有的 Quick, Draw! 构建一个更好的分类器。数据集。通过在此数据集上推进模型,Kagglers 可以更广泛地改进模式识别解决方案。这将对手写识别及其在 OCR(光学字符识别)、ASR(自动语音识别)和 NLP(自然语言处理)等领域的稳健应用产生直接影响。
在这里插入图片描述
属于多分类问题

数据集探索

字段解释

KeyTypeDescription
key_id64 位无符号整数所有图纸的唯一标识符。
wordstring玩家绘制的类别。
recognizedboolean该词是否被游戏识别。
timestampdatetime创建绘图时间
countrycodestring玩家所在位置的两个字母国家代码
drawingstring表示矢量绘图的 JSON 数组

example:
在这里插入图片描述
根据矢量绘图的JSON数组画图

def show_imale(n,owls,drawing):
    fig,axs = plt.subplots(nrows=n,ncols=n,sharex=True,sharey=True,figsize = (16.10))
    for i , drawing in enumerate(owls,drawing):
        ax = axs[i//n,i%n]
        for x,y in drawing:
            ax.plot(x,-np.array(y),lw=3)
    fig.savefig('owls.png',dpi=200)
    plt.show();

在这里插入图片描述

赛题建模思路

  1. 读取数据并转化为图像
  2. 构建分类模型
  3. 确定训练细节和数据扩增方法;
  4. 对测试集完成预测并完成模型集成

数据集的文件结构:
在这里插入图片描述
每一种类型的数据图片,都放在一个单独的csv中,下面要对整个数据集进行处理。

合并多个类别CSV数据集

import os, sys, codecs, glob
import numpy as np
import pandas as pd
import cv2
import timm

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

# 读取单个csv文件
def read_df(path, nrows):
    print('Reading...', path)
    if nrows.isdigit():
        return pd.read_csv(path, nrows=int(nrows), parse_dates=['timestamp'])
    else:
        return pd.read_csv(path, parse_dates=['timestamp'])

# 读取多个csv文件
def contcat_df(paths, nrows):
    dfs = []
    for path in paths:
        dfs.append(read_df(path, nrows))
    return pd.concat(dfs, axis=0, ignore_index=True)

def main():
    if not os.path.exists('./data'):
        os.mkdir('./data')
    
    CLASSES_CSV = glob.glob('../input/train_simplified/*.csv')
    CLASSES = [x.split('/')[-1][:-4] for x in CLASSES_CSV]

    print('Reading data...')
    # 读取指定行数的csv文本,并进行拼接
    df = contcat_df(CLASSES_CSV, number)
    
    # 数据打乱
    df = df.reindex(np.random.permutation(df.index))
    
    lbl = LabelEncoder().fit(df['word'])
    df['word'] = lbl.transform(df['word'])
    
    if df.shape[0] * 0.05 < 120000:
        df_train, df_val = train_test_split(df, test_size=0.05)
    else:
        df_train, df_val = df.iloc[:-500000], df.iloc[-500000:]
    
    print('Train:', df_train.shape[0], 'Val', df_val.shape[0])
    print('Save data...')
    df_train.to_pickle(os.path.join('./data', 'train_' + str(number) + '.pkl'))
    df_val.to_pickle(os.path.join('./data', 'val_' + str(number) + '.pkl'))

# python 1_save2df.py 50000
# python 1_save2df.py all
if __name__ == "__main__":
    number = str(sys.argv[1])
    main()

其中glob的作用如下注释所示

import glob

#获取指定目录下的所有图片
print (glob.glob(r"/home/qiaoyunhao/*/*.png"),"\n")#加上r让字符串不转义

#获取上级目录的所有.py文件
print (glob.glob(r'../*.py')) #相对路径

得到的结果如下所示:
32300个训练集,1700个测试集
这里我们是先采用少量数据集训练,试一下数据是否拟合,若拟合
在这里插入图片描述

数据建模 (pytorch)

导入所需库

import os, sys, codecs, glob
from PIL import Image, ImageDraw

import numpy as np
import pandas as pd
import cv2

import torch
torch.backends.cudnn.benchmark = False
import timm

import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset

import logging
logging.basicConfig(level=logging.DEBUG, filename='example.log',
                    format='%(asctime)s - %(filename)s[line:%(lineno)d]: %(message)s')  

将绘图的轨迹转变为图片
这里用的是opencv,cv的处理速度大于pillow

def draw_cv2(raw_strokes, size=256, lw=6, time_color=True):
    BASE_SIZE = 299
    img = np.zeros((BASE_SIZE, BASE_SIZE), np.uint8)
    for t, stroke in enumerate(eval(raw_strokes)):
        str_len = len(stroke[0])
        for i in range(len(stroke[0]) - 1):
            
            # 数据集随机丢弃一些像素,属于数据集的drop out,防止过拟合
            if np.random.uniform() > 0.95:
                continue
            
            color = 255 - min(t, 10) * 13 if time_color else 255
            _ = cv2.line(img, (stroke[0][i] + 22, stroke[1][i]  + 22),
                         (stroke[0][i + 1] + 22, stroke[1][i + 1] + 22), color, lw)
    
    if size != BASE_SIZE:
        return cv2.resize(img, (size, size))
    else:
        return img

计算topk准确率

def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        # print(correct.shape)
        res = []
        for k in topk:
            # print(correct[:k].shape)
            correct_k = correct[:k].float().sum()
            res.append(correct_k.mul_(100.0 / batch_size))
            
        # print(res)
        return res

数据扩展

class QRDataset(Dataset):
    def __init__(self, img_drawing, img_label, img_size, transform=None):
        self.img_drawing = img_drawing
        self.img_label = img_label
        self.img_size = img_size
        self.transform = transform

    def __getitem__(self, index):
        img = np.zeros((self.img_size, self.img_size, 3))
        img[:, :, 0] = draw_cv2(self.img_drawing[index], self.img_size)
        img[:, :, 1] = img[:, :, 0]
        img[:, :, 2] = img[:, :, 0]
        img = Image.fromarray(np.uint8(img))
        
        if self.transform is not None:
            img = self.transform(img)
        
        label = torch.from_numpy(np.array([self.img_label[index]]))
        return img, label

    def __len__(self):
        return len(self.img_drawing)

载入模型

def get_resnet18():
    model = models.resnet18(True)
    model.avgpool = nn.AdaptiveAvgPool2d(1) # 匹配不固定的输入尺寸
    model.fc = nn.Linear(512, 340)
    return model

def get_resnet34():
    model = models.resnet34(True)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    model.fc = nn.Linear(512, 340)
    return model

def get_resnet50():
    model = models.resnet50(True)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    model.fc = nn.Linear(2048, 340)
    return model

def get_resnet101():
    model = models.resnet101(True)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    model.fc = nn.Linear(2048, 340)

图片mixup操作

def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    # x 是一个batch 一批的输入
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

主函数

  1. 数据扩展
def main():
	df_train = pd.read_pickle(os.path.join('./data', 'train_' + dataset + '.pkl'))
	    # df_train = df_train.reindex(np.random.permutation(df_train.index))
	    df_val = pd.read_pickle(os.path.join('./data', 'val_' + dataset + '.pkl'))
	    
	    train_loader = torch.utils.data.DataLoader(
	        QRDataset(df_train['drawing'].values, df_train['word'].values, imgsize,
	                         transforms.Compose([
	                            transforms.RandomHorizontalFlip(),
	                            transforms.RandomVerticalFlip(),
	                            # transforms.RandomAffine(5, scale=[0.95, 1.05]),
	                            transforms.ToTensor(),
	                            # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
	            ])
	        ),
	        batch_size=200, shuffle=True, num_workers=5,
	    )
	
	    val_loader = torch.utils.data.DataLoader(
	        QRDataset(df_val['drawing'].values, df_val['word'].values, imgsize,
	                         transforms.Compose([
	                            transforms.RandomHorizontalFlip(),
	                            transforms.RandomVerticalFlip(),
	                            transforms.ToTensor(),
	                            # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
	            ])
	        ),
	        batch_size=200, shuffle=False, num_workers=5,
	    )

载入模型

if modelname == 'resnet18':
        model = get_resnet18()
    elif modelname == 'resnet34':
        model = get_resnet34()
    elif modelname == 'resnet50':
        model = get_resnet50()
    elif modelname == 'resnet101':
        model = get_resnet101()
    else:
        model = timm.create_model(modelname, num_classes=340, pretrained=True, in_chans=3)

设置优化器等损失函数

# model = nn.DataParallel(model).cuda()
    # nvismodel.load_state_dict(torch.load('./resnet50_64_7_0.pt'))
    # model.load_state_dict(torch.load('./data/resnet18_64_16_110.pt'))
    
    model = model.cuda()
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    # scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2, 3, 5, 7, 8], gamma=0.1)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=len(train_loader) / 10, gamma=0.95)
    
    print('Train:', df_train.shape[0], 'Val', df_val.shape[0])
    print('Epoch/Batch\t\tTrain: loss/Top1/Top3\t\tTest: loss/Top1/Top3')

训练50次

for epoch in range(50):
        train_losss, train_acc1s, train_acc5s = [], [], []
        for i, data in enumerate(train_loader):
            scheduler.step()
            model = model.train()
            train_img, train_label = data
            optimizer.zero_grad()
            
            # TODO: data paraell
            # train_img = Variable(train_img).cuda(async=True)
            # train_label = Variable(train_label.view(-1)).cuda()
            
            train_img = Variable(train_img).cuda()
            train_label = Variable(train_label.view(-1)).cuda()
            
            # 加入mixup
            if np.random.randint(1, 10) >= 5:
                mixed_x, y_a, y_b, lam = mixup_data(train_img, train_label)
                output = model(mixed_x)
                train_loss = mixup_criterion(loss_fn, output, y_a, y_b, lam)
            else:
                output = model(train_img)
                train_loss = loss_fn(output, train_label)
            
            # output = model(train_img)
            # train_loss = loss_fn(output, train_label)
            
            train_loss.backward()
            optimizer.step()
            
            train_losss.append(train_loss.item())
            if i % 5 == 0:
                logging.info('{0}/{1}:\t{2}\t{3}.'.format(epoch, i, optimizer.param_groups[0]['lr'], train_losss[-1]))
            
            if i % int(10) == 0:
                val_losss, val_acc1s, val_acc5s = [], [], []
                
                with torch.no_grad():
                    train_acc1, train_acc3 = accuracy(output, train_label, topk=(1, 3))
                    train_acc1s.append(train_acc1.data.item())
                    train_acc5s.append(train_acc3.item())
                
                    for data in val_loader:
                        val_images, val_labels = data
                        
                        # val_images = Variable(val_images).cuda(async=True)
                        # val_labels = Variable(val_labels.view(-1)).cuda()

                        val_images = Variable(val_images).cuda()
                        val_labels = Variable(val_labels.view(-1)).cuda() 
                       
                        output = model(val_images)
                        val_loss = loss_fn(output, val_labels)
                        val_acc1, val_acc3 = accuracy(output, val_labels, topk=(1, 3))
                        
                        val_losss.append(val_loss.item())
                        val_acc1s.append(val_acc1.item())
                        val_acc5s.append(val_acc3.item())
                        
                
                logstr = '{0:2s}/{1:6s}\t\t{2:.4f}/{3:.4f}/{4:.4f}\t\t{5:.4f}/{6:.4f}/{7:.4f}'.format(
                    str(epoch), str(i),
                    np.mean(train_losss, 0), np.mean(train_acc1s, 0), np.mean(train_acc5s, 0),
                    np.mean(val_losss, 0), np.mean(val_acc1s, 0), np.mean(val_acc5s, 0),
                )
                torch.save(model.state_dict(), './data/{0}_{1}_{2}_{3}.pt'.format(modelname, imgsize, epoch, i))
                print(logstr)

运行

# python 2_train.py 模型 数量 图片尺寸
# python 2_train.py resnet18 5000 64
if __name__ == "__main__":
    modelname = str(sys.argv[1]) # 模型名字
    dataset = str(sys.argv[2]) # 数据集规模
    imgsize = int(sys.argv[3]) # 图片的尺寸
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/179441.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python元组

python元组 文章目录python元组一、实验目的二、实验原理三、实验环境四、实验内容五、实验步骤1.创建元组2.访问元组3.修改元组4.删除元组5.索引及截取6.元组运算符7.内置函数总结一、实验目的 掌握元组的用法 二、实验原理 Python 的元组与列表类似&#xff0c;不同之处在…

2. MySQL之mysql-connector-python的安装使用

MySQL 是最流行的关系型数据库管理系统&#xff0c;关于数据库以及MySQL相关知识&#xff0c;此处不再赘述。本篇介绍使用 mysql-connector-python 来连接使用 MySQL。 1. 安装mysql-connector-python 执行以下代码&#xff0c;没有报错&#xff0c;证明安装成功。 import my…

旗舰版:Stimulsoft Ultimate 2023.1.5 Crack

Stimulsoft Ultimate 是一套用于创建报告和仪表板的通用工具。该产品包括一整套适用于 WinForms、ASP.NET、.NET Core、JavaScript、WPF、PHP、Java 和其他环境的工具。 无需比较产品功能。Stimulsoft Ultimate 包括一切&#xff01; 报表设计器的一切 我们提供易于使用且功能齐…

Android深入系统完全讲解(41)

我们要学习的是整体逻辑&#xff0c;我们 C 找 Java 的依据是类和对象&#xff0c;参数中 JNIEnv *env, jobject obj 。 env 代表当前环境上下文&#xff0c;这个当我们多个线程调用的时候&#xff0c;需要 AttachCurrentThread 进行设定&#xff0c;让 env 关联到当前线程&…

Linux常见命令 24 - RPM命名管理-包命名与依赖性

目录 1. RPM包命名规则 2. RPM包依赖性 1. RPM包命名规则 如包全名&#xff1a;httpd-2.2.15-15.e16.centos.1.i686.rpm httpd&#xff1a;软件包名2.2.15&#xff1a;软件版本15&#xff1a;软件发布的次数el6.centos&#xff1a;适合的Linux平台&#xff1a;CentOS 6.xi6…

springboot和nacos整合mybatis-plus实现多数据源管理

文章目录1.依赖2.配置文件3.redis测试3.1redis配置文件3.2controller3.3测试4.mysql测试4.1数据库表和结构4.2实体类和枚举4.3DogMapper.xml4.4DogMapper4.5service和serviceImpl4.6controller4.7测试写了一个小demo&#xff0c;通过mybatis-plus实现多数据源管理使用了mysql和…

【笔记】A simple yet effective baseline for 3d human pose estimation

【论文】https://arxiv.org/abs/1705.03098v2 【pytorch】(本文代码参考)weigq/3d_pose_baseline_pytorch: A simple baseline for 3d human pose estimation in PyTorch. (github.com) 【tensorflow】https://github.com/una-dinosauria/3d-pose-baseline 基本上算作是2d人体…

Python压缩JS文件,PythonWeb程序员必看系列,重点是 slimit

Python 压缩文件系列文章&#xff0c;我们已经完成了 2 篇&#xff0c;具体如下&#xff1a; Python Flask 实现 HTML 文件压缩&#xff0c;9 级压缩 Python 压缩 css 文件&#xff0c;第三方模块推荐 压缩JS学习目录&#x1f6a9; jsmin 库&#x1f3a8; 库的安装&#x1f3a8…

HackTheBox Stocker API滥用,CVE-2020-24815获取用户shell,目录遍历提权

靶机地址&#xff1a; https://app.hackthebox.com/machines/Stocker枚举 使用nmap枚举靶机 nmap -sC -sV 10.10.11.196机子开放了22&#xff0c;80端口&#xff0c;我们本地解析一下这个域名 echo "10.10.11.196 stocker.htb" >> /etc/hosts 去浏览器访问…

操作系统真相还原_第5章第4节:特权级

文章目录特权级TSS简介CPL和DPL入门处理器提供的从低特权级到高特权级的方法门、调用门和RPL序特权级 保护模式下特权级按照权力大小分为0、1、2、3级 0特权级是操作系统内核所在的的特权级 TSS简介 TSS&#xff0c;即Task State Segment&#xff0c;意为任务状态段&#x…

Modbus协议完整版

第一部分&#xff1a;Modbus协议1 引言1.1 范围MODBUS是OSI模型第7层上的应用层报文传输协议&#xff0c;它在连接至不同类型总线或网络的设备之间提供客户机/服务器通信。自从1979年出现工业串行链路的事实标准以来&#xff0c;MODBUS使成千上万的自动化设备能够通信。目前&am…

【图卷积网络】03-空域卷积介绍

注&#xff1a;本文为3.1-3.2 空域卷积视频笔记&#xff0c;仅供个人学习使用 1、谱域图卷积 1.1 回顾 上篇博客【图卷积神经网络】02-谱域图卷积介绍讲到了三个经典的谱域图卷积&#xff1a; SCNN用可学习的对角矩阵来代替谱域的卷积核。 ChebNet采用Chebyshev多项式代替谱…

TIA博途中计算多个数据的算术平均值的具体方法示例

TIA博途中计算多个数据的算术平均值的具体方法示例 我们这里采用官方提供的Floating Average功能块来实现多个数据的算术平均值的计算。 此功能块计算最新输入的100个数值的均值(浮动平均值)。采集的数据队列达到100个之后,队列每入栈一个新数值,将去掉一个队列里最早进来的…

高通平台开发系列讲解(GPS篇)gpsONE 系统架构

文章目录 一、系统架构图二、gpsONE系统组成三、gpsONE交互流程沉淀、分享、成长,让自己和他人都能有所收获!😄 📢高通的定位系统模块,名称叫gpsONE。 一、系统架构图 二、gpsONE系统组成 GPS系统架构可以分为六个部分: APP层Framework Client端(LocationManager API…

网站被挂马植入webshell导致网站瘫痪案例

一、问题现象 下午两点&#xff0c;刚刚睡醒&#xff0c;就接到了客户打来的电话&#xff0c;说他们的网站挂&#xff08;这个用词很不准确&#xff0c;但是感觉到问题的严重性&#xff09;了&#xff0c;询问是怎么发生的&#xff0c;之前做了什么操作&#xff0c;客户的回答…

Bash 脚本实例:获取符号链接的目标位置

我们都熟悉 Linux 中的符号链接&#xff0c;通常称为符号链接或软链接&#xff0c;符号链接是指向任何文件系统中的另一个文件或目录的特定文件。本文将介绍 Linux 中符号链接的基础知识&#xff0c;并创建一个简单的 bash 脚本来获取符号链接的目标位置。符号链接的类型主要有…

【栈和队列】java实现栈和队列以及集合中的栈和队列

前言&#xff1a; 大家好&#xff0c;我是良辰丫&#x1f3cd;&#x1f3cd;&#x1f3cd;&#xff0c;今天我带领大家去学习栈和队列的相关知识&#xff0c;&#x1f49e;&#x1f49e;&#x1f49e;栈和队列在数据结构中是相对简单的&#xff0c;但是应用还是蛮多的&#xff…

分享142个ASP源码,总有一款适合您

ASP源码 分享142个ASP源码&#xff0c;总有一款适合您 下面是文件的名字&#xff0c;我放了一些图片&#xff0c;文章里不是所有的图主要是放不下...&#xff0c; 142个ASP源码下载链接&#xff1a;https://pan.baidu.com/s/1TxdTrCJpO08rKLCUzIh0hQ?pwdyhka 提取码&#x…

微信小程序+云函数+腾讯云对话机器人API(ChatBot)

文章目录 前言 一、小程序云开发是什么&#xff1f; 二、步骤 1. 在app.js中绑定好云环境id&#xff0c;并且选好当前环境以及选好云文件夹 2. 去到腾讯云API Explorer中选好Region地区和Query这个必填参数&#xff0c;然后进行代码生成 3. 在上面的API Explorer网站点击前往获…

Python局部函数及用法

Python 函数内部可以定义变量&#xff0c;这样就产生了局部变量&#xff0c;有读者可能会问&#xff0c;Python 函数内部能定义函数吗&#xff1f;答案是肯定的。Python 支持在函数内部定义函数&#xff0c;此类函数又称为局部函数。那么&#xff0c;局部函数有哪些特征&#x…