2023.05.14-微调ResNet参加kaggle上猫狗大战比赛打到99%的分类准确率_convert

news2024/11/23 20:53:05

文章目录

  • 1. 前言
  • 2. 下载数据集
  • 3. 比赛成绩排名
  • 4. baseline
  • 5. 尝试
    • 5.1. 数据归一化(98.994%)
    • 5.2. 使用AdamW优化器(98.63%)
    • 5.3. 使用AdamW优化器+SegNet模块(95.05%)
  • 6. 结语
  • 7. 感慨
  • 8. 代码
      • 8.1. ResNet+Normalize+AdamW完整代码
      • 8.1.1. 仓库

1. 前言

  • 一直想玩一下这个猫狗大战,但是总是没有下功夫调参。周末有时间,又租借了一个云服务器,万事俱备,只欠东风,开始搞起。

2. 下载数据集

  • 想要参加kaggle官网上面的这个猫狗大战比赛,首先需要注册一个kaggle账号用来下载对应的数据集。

打开下面的网站进行下载即可

  • Dogs vs. Cats | Kaggle

3. 比赛成绩排名

  • www.kaggle.com/competitions/dogs vs cats/leaderboard
  • 第一名的分数是0.98914

4. baseline

  • 自己最开始的时候使用的是ResNet 18的代码作为baseline,分类准确度可以轻轻松松达到98%

5. 尝试

  • 自己搜索了网上对于猫狗大战中可以涨点的策略,自己主要做了以下尝试

5.1. 数据归一化(98.994%)

添加这个归一化代码

transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

完整代码

transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

效果

  • 不得不说,对数据进行归一化之后,可以极大的提高这个网络收敛的速度。第一个周期的验证准确率就可以达到98.39%
  • 100个周期跑完,最好可以达到98.994%的效果

5.2. 使用AdamW优化器(98.63%)

AdamW是带有权重衰减(而不是L2正则化)的Adam,它在错误实现、训练时间都胜过Adam。
对应的数据

epoch	train loss	train acc	val loss	val acc
0	43.95111	97.75%	2.93358	98.51%
1	430.50297	64.70%	36.61037	77.67%
2	137.0172	91.71%	5.94341	96.82%
3	40.69821	97.84%	3.16171	98.71%
4	28.72242	98.44%	5.38266	97.71%
5	21.23378	98.85%	5.59306	97.02%
6	18.11441	99.04%	3.98322	98.03%
7	19.32834	99.00%	5.01681	98.07%
8	11.94442	99.44%	4.81179	97.91%
9	11.1338	99.45%	4.59616	97.83%
10	14.35451	99.27%	8.86029	95.98%
11	9.79262	99.46%	9.53059	97.43%
12	11.3338	99.40%	7.66958	97.43%
13	8.59158	99.63%	5.31387	98.59%
14	12.89642	99.31%	3.93019	98.19%
15	6.99155	99.71%	5.23799	98.47%
16	8.25213	99.57%	4.20161	98.03%
17	6.52411	99.68%	8.51102	97.63%
18	10.21184	99.52%	4.32666	98.51%
19	7.15083	99.69%	6.45723	98.19%
20	6.47147	99.68%	5.964	98.15%
21	6.40303	99.72%	8.30525	97.51%
22	4.46209	99.82%	8.23106	98.11%
23	7.30719	99.64%	4.91704	98.63%
24	7.41548	99.66%	4.51357	98.35%
25	4.41403	99.78%	7.23314	98.39%
26	8.96065	99.64%	5.85345	98.07%
27	5.97362	99.73%	4.949	98.39%
28	8.65173	99.58%	4.26699	98.43%
29	1.94975	99.92%	4.99152	98.55%
30	5.14563	99.74%	3.90554	98.63%
31	1.1131	99.96%	7.56679	98.35%
32	10.75336	99.48%	5.23759	97.87%
33	0.86672	99.97%	9.2502	98.31%
34	7.93448	99.64%	4.37685	98.03%
35	2.44822	99.87%	7.21055	97.87%
36	6.85281	99.75%	5.51565	97.91%
37	3.2463	99.85%	9.12831	97.79%
38	6.26243	99.69%	5.899	97.75%
39	3.29857	99.90%	7.2071	97.87%
40	0.5045	99.99%	7.05801	98.51%
41	0.0135	100.00%	7.54731	98.43%
42	0.0027	100.00%	8.59324	98.47%
43	0.00083	100.00%	8.99156	98.43%
44	0.00045	100.00%	9.55036	98.43%
45	0.00027	100.00%	10.0697	98.43%
46	0.00017	100.00%	10.39488	98.43%
47	0.0001	100.00%	10.98709	98.43%
48	0.00008	100.00%	11.46222	98.43%
49	0.00005	100.00%	11.51941	98.35%
50	0.00004	100.00%	11.73555	98.39%
51	0.00002	100.00%	12.03522	98.35%
52	0.00002	100.00%	12.54926	98.35%
53	0.00001	100.00%	12.42227	98.35%
54	0.00001	100.00%	13.2006	98.31%
55	0.00001	100.00%	13.64486	98.31%
56	0	100.00%	12.90368	98.35%
57	0	100.00%	13.13818	98.35%
58	0	100.00%	13.7345	98.31%
59	0	100.00%	13.65401	98.27%
60	0	100.00%	13.74176	98.31%
61	0	100.00%	13.78569	98.31%
62	0	100.00%	14.64054	98.27%
63	0	100.00%	14.17896	98.27%
64	0	100.00%	13.99432	98.31%
65	0	100.00%	14.73406	98.31%
66	0	100.00%	14.69667	98.31%
67	0	100.00%	14.58825	98.27%
68	0	100.00%	14.88915	98.31%
69	0	100.00%	14.95989	98.27%
70	0	100.00%	15.37874	98.27%
71	0	100.00%	15.86721	98.27%
72	0	100.00%	16.20822	98.23%
73	0	100.00%	16.20378	98.31%
74	0	100.00%	17.1774	98.31%
75	25.10347	98.93%	5.52769	97.91%
76	9.66224	99.53%	4.98326	98.11%
77	2.80008	99.88%	6.26822	98.43%
78	5.21812	99.79%	4.73304	98.31%
79	3.3407	99.85%	8.41819	98.11%
80	0.46344	99.98%	7.39496	98.47%
81	0.01035	100.00%	7.52614	98.51%
82	0.00332	100.00%	8.00924	98.51%
83	0.00135	100.00%	8.59734	98.47%
84	0.00056	100.00%	9.3975	98.55%
85	0.00024	100.00%	9.93917	98.43%
86	0.00008	100.00%	11.35343	98.43%
87	0.00003	100.00%	11.89728	98.43%
88	0.00002	100.00%	12.30812	98.43%
89	0.00001	100.00%	12.8423	98.47%
90	0.00001	100.00%	13.57241	98.35%
91	0	100.00%	13.41991	98.51%
92	0	100.00%	13.87756	98.43%
93	0	100.00%	14.49194	98.31%
94	0	100.00%	14.60349	98.47%
95	0	100.00%	15.24883	98.39%
96	0	100.00%	15.04266	98.43%
97	0	100.00%	16.21219	98.39%
98	0	100.00%	15.58381	98.51%
99	0	100.00%	16.35482	98.35%

效果

最高可以达到98.63%

98.51%

5.3. 使用AdamW优化器+SegNet模块(95.05%)

我是想在之前的基础上添加一个注意力机制模块,但是不知道为什么训练级的准确率很高,但是验证集上的效果却要差很多,可能是因为自己添加的这个注意力机制模块使得网络的泛化性变差了吧
对应的数据

	00	841.16123	66.698%	80.34021	74.849%	
	01	782.98219	70.309%	66.07160	80.322%	
	02	593.89293	80.817%	55.39222	83.702%	
	03	485.13791	84.672%	49.78145	86.398%	
	04	386.34337	88.332%	34.40874	90.744%	
	05	324.79488	90.300%	37.25761	89.537%	
	06	273.36514	92.112%	41.78502	88.531%	
	07	245.33996	92.756%	30.65071	91.549%	
	08	209.99650	93.893%	24.99330	93.280%	
	09	174.44573	94.946%	40.70310	90.865%	
	10	152.54020	95.590%	24.66959	93.642%	
	11	126.36934	96.429%	26.63028	92.958%	
	12	107.61617	96.962%	24.49496	93.843%	
	13	94.44031	97.433%	27.07281	93.320%	
	14	77.85434	97.926%	33.65216	92.998%	
	15	71.30835	98.055%	27.37954	94.044%	
	16	56.10977	98.534%	37.30386	93.119%	
	17	51.94865	98.583%	45.16884	92.596%	
	18	45.82673	98.863%	33.09134	93.682%	
	19	46.00949	98.748%	30.61986	93.763%	
	20	39.88356	98.965%	32.49509	94.245%	
	21	35.98075	99.076%	30.70699	94.728%	
	22	36.77068	99.072%	26.50579	94.487%	
	23	29.62899	99.272%	29.40019	94.487%	
	24	30.70629	99.232%	37.46327	93.843%	
	25	38.08304	99.054%	28.52988	94.366%	
	26	25.40524	99.400%	37.30047	94.044%	
	27	33.73834	99.174%	30.09059	94.889%	
	28	24.33486	99.449%	34.55807	94.447%	
	29	29.78610	99.325%	31.62320	94.809%	
	30	23.03223	99.427%	46.01729	94.205%	
	31	26.88877	99.312%	42.09933	94.809%	
	32	25.12524	99.409%	36.05506	94.044%	
	33	22.30487	99.436%	33.46056	94.326%	
	34	23.79032	99.365%	33.57563	94.406%	
	35	18.53882	99.569%	31.54106	95.050%	
	36	20.52793	99.511%	37.89401	94.487%	
	37	21.22465	99.467%	43.78654	93.763%	
	38	19.86762	99.467%	47.26076	94.165%	
	39	17.43618	99.591%	52.05411	93.078%	
	40	19.54660	99.498%	32.24883	94.567%	
	41	15.23968	99.645%	42.51051	94.205%	
	42	20.26523	99.529%	37.01770	94.366%	
	43	13.82244	99.614%	39.53712	94.648%	
	44	18.52900	99.507%	36.48620	94.728%	
	45	13.13430	99.671%	46.33306	94.527%	
	46	20.10074	99.525%	38.52874	95.493%	
	47	17.74225	99.574%	30.75011	94.648%	
	48	11.84078	99.698%	41.63479	94.567%	
	49	18.99130	99.520%	35.11506	94.245%	
	50	13.96501	99.654%	36.95696	94.326%	
	51	10.47367	99.747%	42.35815	94.567%	
	52	17.46265	99.614%	49.29176	94.245%	
	53	13.03071	99.658%	45.44298	94.849%	
	54	12.27281	99.658%	45.32041	95.010%	
	55	15.32756	99.685%	40.12351	94.447%	
	56	14.36285	99.671%	39.26911	94.809%	
	57	10.85270	99.729%	41.98047	94.366%	
	58	13.66196	99.667%	45.47937	94.648%	
	59	13.33846	99.689%	44.08331	93.964%	
	60	12.87245	99.680%	43.91811	94.286%	
	61	11.93796	99.738%	36.15065	93.239%	
	62	12.06105	99.760%	33.67126	94.085%	
	63	13.68432	99.725%	45.61084	94.406%	
	64	14.13714	99.694%	36.90194	94.648%	
	65	8.25917	99.800%	49.74482	94.406%	
	66	12.15086	99.707%	42.50143	94.930%	
	67	10.02019	99.751%	40.13083	94.567%	
	68	9.81753	99.813%	57.15547	94.648%	
	69	13.14721	99.676%	41.48277	94.608%	
	70	10.72047	99.725%	43.08352	94.849%	
	71	10.62724	99.698%	39.06533	94.406%	
	72	8.58425	99.791%	45.32018	93.763%	

6. 结语

  • 可以说目前这个精度可以达到99%,我觉得应该是比较高的一个精度了,测试集上没有必要达到100%,这是很难的,也是不可能的,毕竟有些猫和狗的图片长得实在是太像了,人眼都很难分出来到底谁是猫谁是狗,所以这个猫狗大战分类的调试尝试到这里应该就差不多了。

7. 感慨

  • 当年猫狗大战的时候,能上到98%都已经算出top 1了。但是现在我们采用预训练模型加微调的方法,可以轻轻搞上99%。不仅感慨现在深度学习越来越卷了,
  • 不过也不得不说,ResNet毕竟是2015年imaginet图像分类比赛中的冠军,效果真的是一级棒。

8. 代码

8.1. ResNet+Normalize+AdamW完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from torchvision import transforms
import torchvision

from torch.utils.data import DataLoader



transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def make_dir(path):
    import os
    dir = os.path.exists(path)
    if not dir:
        os.makedirs(path)
make_dir('models')

batch_size = 8

train_set = torchvision.datasets.ImageFolder(root='data/cat_vs_dog/train', transform=transform)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,
                          num_workers=0)  # Batch Size定义:一次训练所选取的样本数。 Batch Size的大小影响模型的优化程度和速度。

val_dataset = torchvision.datasets.ImageFolder(root='data/cat_vs_dog/val', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True,
                        num_workers=0)  # Batch Size定义:一次训练所选取的样本数。 Batch Size的大小影响模型的优化程度和速度。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

net = torchvision.models.resnet18(weights=True)
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, 2)  # 将输出维度修改为2

criterion = nn.CrossEntropyLoss()
net = net.to(device)
optimizer = torch.optim.AdamW(lr=0.0001, params=net.parameters())
eposhs = 100

for epoch in range(eposhs):
    print(f'--------------------{epoch}--------------------')
    correct_train = 0
    sum_loss_train = 0
    total_correct_train = 0
    for inputs, labels in tqdm(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        output = net(inputs)
        loss = criterion(output, labels)
        sum_loss_train = sum_loss_train + loss.item()
        total_correct_train = total_correct_train + labels.size(0)
        optimizer.zero_grad()
        _, predicted = torch.max(output.data, 1)
        loss.backward()
        optimizer.step()
        correct_train = correct_train + (predicted == labels).sum().item()

    acc_train = correct_train / total_correct_train
    print('训练准确率是{:.3f}%:'.format(acc_train*100) )

    net.eval()
    correct_val = 0
    sum_loss_val = 0
    total_correct_val = 0
    for inputs, labels in tqdm(val_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        output = net(inputs)
        loss = criterion(output, labels)
        sum_loss_val = sum_loss_val + loss.item()

        output = net(inputs)
        total_correct_val = total_correct_val + labels.size(0)
        optimizer.zero_grad()
        _, predicted = torch.max(output.data, 1)
        correct_val = correct_val + (predicted == labels).sum().item()

    acc_val = correct_val / total_correct_val
    print('验证准确率是{:.3f}%:'.format(acc_val*100) )

    torch.save(net,'models/{}-{:.5f}_{:.3f}%_{:.5f}_{:.3f}%.pth'.format(epoch,sum_loss_train,acc_train *100,sum_loss_val,acc_val*100))


8.1.1. 仓库

  • 然后我把所有的代码和权重全部上传到了Huggin Face上面,如果有兴趣的小伙伴可以在我代码的基础上做进一步的尝试
  • NewBreaker/classify-cat_vs_dog · Hugging Face

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/531827.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

贝尔曼福特算法——负权值单源最短路径

title: 贝尔曼福特算法——负权值单源最短路径 date: 2023-05-16 11:42:26 tags: 数据结构与算法 贝尔曼福特算法——负权值单源最短路径 **问题:**具有负权值非环图的单源最短路径算法 git地址:https://github.com/944613709/HIT-Data-Structures-and-A…

阿里云备案服务码怎么申请?

阿里云备案服务码是什么?ICP备案服务码怎么获取?阿里云备案服务码分为免费和付费两种,申请备案服务码是有限制条件的,需要你的阿里云账号下有可用于申请备案服务码的云产品,如云服务器、建站产品、虚拟主机等&#xff…

基于Sentinel自研组件的系统限流、降级、负载保护最佳实践探索 | 京东云技术团队

作者:京东物流 杨建民 一、Sentinel简介 Sentinel 以流量为切入点,从流量控制、熔断降级、系统负载保护等多个维度保护服务的稳定性。 Sentinel 具有以下特征: 丰富的应用场景:秒杀(即突发流量控制在系统容量可以承受的范围&a…

【Android Studio】win10 创建并运行第一个App Hello world 超详细

概述 一个好的文章能够帮助开发者完成更便捷、更快速的开发。书山有路勤为径,学海无涯苦作舟。我是秋知叶i、期望每一个阅读了我的文章的开发者都能够有所成长。 一、开发环境 开发环境:windows10Android Studio 版本 Android Studio Flamingo | 2022…

SwiftUI 布局协议 - Part1

文章目录 简介什么是布局协议视图层次结构的族动态我们的第一个布局实现ProposedViewSizeLayoutSubviewsizeThatFits 方法placeSubviews 方法 容器对齐优先布局LayoutValueKey默认间距布局属性和 Spacer()布局缓存高明的伪装者使用 AnyLayout 切换布局结语 简介 今年 SwiftUI …

「——全部文章专栏汇总——」

欢迎来到我的博客 天喜Studio 在这里&#xff0c;我会分享我在 c语言、操作系统、计算机网络等方面的学习和经验&#xff0c;希望能对读者有所帮助。以下是我写的所有专栏 如果帮助到了你&#xff0c;还请点赞 关注支持一下♡>&#x16966;<)!! 如有疑问欢迎大家指正讨论…

ML03 网页数据抓取 (note)

很多机器学习数据集是从网页上抓取过来的。 网页数据抓取与网页爬虫的区别&#xff1a;数据抓取&#xff1a;特定的数据&#xff0c; 网页爬虫&#xff1a;将整个网页获取 数据科学家主要进行网页数据抓取&#xff0c;对网页上的特定数据感兴趣。 网页数据获取工具 curl 通常…

【夜莺(Flashcat)V6监控】2.夜莺告警相关:初级使用

介绍 夜莺监控系统不仅提供了强大的数据采集和可视化功能&#xff0c;还提供了灵活的告警配置能力&#xff0c;帮助我们实时掌握系统的运行状况&#xff0c;快速响应和解决潜在问题。 本章主要给大家介绍邮件、微信、钉钉配置并告警&#xff1b;简单几台机器随时看就好了&…

指针穿梭,数据流转:探秘C语言实现单向不带头不循环链表

本篇博客会讲解链表的最简单的一种结构&#xff1a;单向不带头不循环链表&#xff0c;并使用C语言实现。 概述 链表是一种线性的数据结构&#xff0c;而本篇博客讲解的是链表中最简单的一种结构&#xff0c;它的一个结点的声明如下&#xff1a; // 单链表存储的数据类型 typ…

Dcat Admin文件上传漏洞复现

Dcat Admin框架 Dcat Admin是一个基于laravel-admin二次开发而成的后台系统构建工具&#xff0c;只需极少的代码即可快速构建出一个功能完善的高颜值后台系统。支持页面一键生成CURD代码&#xff0c;内置丰富的后台常用组件&#xff0c;开箱即用&#xff0c;让开发者告别冗杂的…

060基于深度学习的建筑物房屋检测

视频演示和demo仓库地址找060期&#xff1a; 银色子弹zg的个人空间-银色子弹zg个人主页-哔哩哔哩视频 效果图如下: 代码所有文件: 运行01create_txt.py会将data文件下的图片路径及标签保存在txt文本内&#xff0c; 运行02train.py会对图片进行读取并训练模型保存在runs文件…

训练自己的ChatGPT(ChatGLM微调 )

目录 准备 操作 上传数据数据 训练进度 推理 验证 异常处理 总结 参考资料 ChatGLM微调 ptuning 准备 接上文https://blog.csdn.net/dingsai88/article/details/130639365 部署好ChatGLM以后&#xff0c;对它进行微调 操作 如果已经使用过 API 或者web模式的应该已经…

Linux安装elasticsearch、ik分词器、kibana

这里写目录标题 前言下载IK分词器下载Elasticsearch下载Kibana下载JDK安装JDK安装Elasticsearch与IK分词器安装Kibana错误调试参考链接扩展部分 前言 一个PHP程序员接入Elasticsearch并不是公司项目的需求&#xff0c;而是自己平时积累了很多项目信息、代码片段、解决问题的网…

设计模式之【模板方法模式】,模板方法和函数式回调,哪个才是趋势?

文章目录 一、什么是模板方法模式1、主要角色2、应用场景3、优缺点4、注意事项及细节 二、实例1、炒菜案例&#xff08;1&#xff09;模板方法模式的钩子方法 2、重构JDBC案例 三、模板方法模式与Callback回调模式1、回调基本原理2、案例一&#xff1a;回调方式重构JDBC3、案例…

Camtasia Studio2023最新版喀秋莎电脑录制屏幕编辑器

不管是在我们平日的工作当中&#xff0c;还是生活当中&#xff0c;camtasia studio可以方便地进行屏幕操作的录制和配音、视频的剪辑和过场动画、添加说明字幕和水印、制作视频封面和菜单、视频压缩和播放。 你都会因为一些事情&#xff0c;从而需要进行录屏的需求。而Camtasi…

超详细,unity如何制作人物行走的遥杆?

介绍 在游戏中&#xff0c;移动遥杆是一种常见的用户界面元素&#xff0c;它允许玩家通过触摸或鼠标输入来控制游戏对象的移动。移动遥杆通常由一个圆形或方形的背景和一个可以拖动的小球&#xff08;称为拇指杆&#xff09;组成。玩家可以通过拖动拇指杆来控制游戏对象的移动…

某IC交易网 js逆向解析学习【2023/05/16】

文章目录 文章目录 文章目录前言网址目标参数确认加密点cookie解密第一步hex1算法解析rind和rnns完结撒花前言 可以关注我哟,一起学习,主页有更多练习例子 如果哪个练习我没有写清楚,可以留言我会补充 如果有加密的网站可以留言发给我,一起学习共享学习路程 如侵权,联系我…

Vue.js表单输入绑定

对于Vue来说&#xff0c;使用v-bind并不能解决表单域对象双向绑定的需求。所谓双向绑定&#xff0c;就是无论是通过input还是通过Vue对象&#xff0c;都能修改绑定的数据对象的值。Vue提供了v-model进行双向绑定。本章将重点讲解表单域对象的双向绑定方法和技巧。 10.1 实现双…

单片机的介绍

目录 一、介绍 1.单片机简介 2.单片机型号 3.体系 二、硬件基础 1.引言 2.电路基础 电的类比 电流 电压 电路 3.电子元器件 电阻 电容 二极管 三极管 4.常见电气接口 传统音频 视频 电源 RJ45网口 DB9串口 5.开发板/最小系统板 三、STM32介绍 1.简介…

JAVA电商 B2B2C商城系统 多用户商城系统 直播带货 新零售商城 o2o商城 电子商务 拼团商城 分销商城

JAVA电商 B2B2C商城系统 多用户商城系统 直播带货 新零售商城 o2o商城 电子商务 拼团商城 分销商城 1. 鸿鹄Cloud架构清单 2. Commonservice&#xff08;通用服务&#xff09; 通用服务&#xff1a;对spring Cloud组件的使用&封装&#xff0c;是一套完整的针对于分布式微…