模型优化系列1:分类器CenterLoss使用Pytorch实现MNIST、CIFAR10、CIFAR100分类图示

news2024/11/24 20:28:14

CentLoss实现

前言

参考文章:史上最全MNIST系列(三)——Centerloss在MNIST上的Pytorch实现(可视化)
源码:Gitee或Github都有上传,保留了最优版,对最优版调整了一些参数看效果
Gitee传送门:点击这里跳转Gitee源码库
Github传送门:点击这里跳转Github源码库
--------------------无---------------情---------------分---------------割---------------线--------------------
贴心省流:最优效果见文末的v5版

v1

一开始实现的效果还行,但!是!
我忘记参数了(0.0!!)
于是,开始了漫长的尝试。。。

效果

image.png

v2

网络结构

self.hidden_layer = nn.Sequential(
    ConvLayer(1, 32, 3, 1, 1),
    ConvLayer(32, 32, 3, 1, 1),
    nn.MaxPool2d(2),
    ConvLayer(32, 64, 3, 1, 1),
    ConvLayer(64, 64, 3, 1, 1),
    nn.MaxPool2d(2),
    ConvLayer(64, 128, 3, 1, 1),
    ConvLayer(128, 128, 3, 1, 1),
    nn.MaxPool2d(2)
)

self.fc = nn.Sequential(
    nn.Linear(128 * 3 * 3, 12)
)

批次和学习率

data_loader = data.DataLoader(dataset=train_data, shuffle=True, batch_size=512)
# ...
net_opt = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(net_opt, 30, gamma=0.9)
c_l_opt = torch.optim.SGD(center_loss_fn.parameters(), lr=0.01, momentum=0.9)

效果

image.png

结论

分类效果不明显

v2_2

v2基础上仅修改了网络结构

网络结构

self.hidden_layer = nn.Sequential(
    ConvLayer(1, 32, 3, 1, 1),
    ConvLayer(32, 32, 3, 1, 1),
    nn.MaxPool2d(2),
    ConvLayer(32, 64, 3, 1, 1),
    ConvLayer(64, 64, 3, 1, 1),
    nn.MaxPool2d(2),
    ConvLayer(64, 128, 3, 1, 1),
    ConvLayer(128, 128, 3, 1, 1),
    nn.MaxPool2d(2)
)

self.fc = nn.Sequential(
    nn.Linear(128 * 3 * 3, 2)
)

self.output_layer = nn.Sequential(
    nn.Linear(2, 10)
)

批次和学习率

data_loader = data.DataLoader(dataset=train_data, shuffle=True, batch_size=512)
# ...
net_opt = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(net_opt, 30, gamma=0.9)
c_l_opt = torch.optim.SGD(center_loss_fn.parameters(), lr=0.01, momentum=0.9)

效果

image.png

结论

输出12–>2+10,分类效果明确,但未完全分开

v2_3

v2_2基础上仅修改了学习率

网络结构

self.hidden_layer = nn.Sequential(
    ConvLayer(1, 32, 3, 1, 1),
    ConvLayer(32, 32, 3, 1, 1),
    nn.MaxPool2d(2),
    ConvLayer(32, 64, 3, 1, 1),
    ConvLayer(64, 64, 3, 1, 1),
    nn.MaxPool2d(2),
    ConvLayer(64, 128, 3, 1, 1),
    ConvLayer(128, 128, 3, 1, 1),
    nn.MaxPool2d(2)
)

self.fc = nn.Sequential(
    nn.Linear(128 * 3 * 3, 2)
)

self.output_layer = nn.Sequential(
    nn.Linear(2, 10)
)

批次和学习率

data_loader = data.DataLoader(dataset=train_data, shuffle=True, batch_size=512)
# ...
net_opt = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.8)
scheduler = torch.optim.lr_scheduler.StepLR(net_opt, 30, gamma=0.8)
c_l_opt = torch.optim.SGD(center_loss_fn.parameters(), lr=0.01, momentum=0.8)

效果

image.png

结论

动量0.9–>0.8,分类效果不明显

v2_4

v2_2基础上仅修改了学习率

网络输出:2+10

self.hidden_layer = nn.Sequential(
    ConvLayer(1, 32, 3, 1, 1),
    ConvLayer(32, 32, 3, 1, 1),
    nn.MaxPool2d(2),
    ConvLayer(32, 64, 3, 1, 1),
    ConvLayer(64, 64, 3, 1, 1),
    nn.MaxPool2d(2),
    ConvLayer(64, 128, 3, 1, 1),
    ConvLayer(128, 128, 3, 1, 1),
    nn.MaxPool2d(2)
)

self.fc = nn.Sequential(
    nn.Linear(128 * 3 * 3, 2)
)

self.output_layer = nn.Sequential(
    nn.Linear(2, 10)
)

批次和学习率

data_loader = data.DataLoader(dataset=train_data, shuffle=True, batch_size=512)
# ...
net_opt = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(net_opt, 30, gamma=0.9)
c_l_opt = torch.optim.SGD(center_loss_fn.parameters(), lr=0.5, momentum=0.9)

效果

image.png

结论

center_loss学习率0.01–>0.5,分类效果不明显

v3

网络结构

self.hidden_layer = nn.Sequential(
    ConvLayer(1, 32, 5, 1, 1),
    ConvLayer(32, 32, 5, 1, 2),
    nn.MaxPool2d(2, 2),
    ConvLayer(32, 64, 5, 1, 1),
    ConvLayer(64, 64, 5, 1, 2),
    nn.MaxPool2d(2, 2),
    ConvLayer(64, 128, 5, 1, 1),
    ConvLayer(128, 128, 5, 1, 2),
    nn.MaxPool2d(2, 2)
)

self.fc = nn.Sequential(
    nn.Linear(128, 12)
)

批次和学习率

data_loader = data.DataLoader(dataset=train_data, shuffle=True, batch_size=1024)
# ...
net_opt = torch.optim.SGD(net.parameters(), lr=0.002, momentum=0.8)
scheduler = torch.optim.lr_scheduler.StepLR(net_opt, 20, gamma=0.8)
c_l_opt = torch.optim.SGD(center_loss_fn.parameters(), lr=0.7)

效果

image.png

v3_2

v3基础上修改学习率

网络结构

class MainNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden_layer = nn.Sequential(
            ConvLayer(1, 32, 5, 1, 1),
            ConvLayer(32, 32, 5, 1, 2),
            nn.MaxPool2d(2, 2),
            ConvLayer(32, 64, 5, 1, 1),
            ConvLayer(64, 64, 5, 1, 2),
            nn.MaxPool2d(2, 2),
            ConvLayer(64, 128, 5, 1, 1),
            ConvLayer(128, 128, 5, 1, 2),
            nn.MaxPool2d(2, 2)
        )

        self.fc = nn.Sequential(
            nn.Linear(128, 12)
        )

    # self.output_layer = nn.Sequential(
    #     nn.Linear(2, 10)
    # )

    def forward(self, _x):
        h_out = self.hidden_layer(_x)
        h_out = h_out.reshape(-1, 128)
        # feature = self.fc(h_out)
        # outs = self.output_layer(feature)
        outs = self.fc(h_out)
        return outs

批次和学习率

data_loader = data.DataLoader(dataset=train_data, shuffle=True, batch_size=1024)
# ...
net_opt = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.8)
scheduler = torch.optim.lr_scheduler.StepLR(net_opt, 30, gamma=0.8)
c_l_opt = torch.optim.SGD(center_loss_fn.parameters(), lr=0.01, momentum=0.8)

效果

image.png

v4

网络结构

class MainNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden_layer = nn.Sequential(
            ConvLayer(1, 32, 5, 1, 1),
            ConvLayer(32, 32, 5, 1, 2),
            nn.MaxPool2d(2, 2),
            ConvLayer(32, 64, 5, 1, 1),
            ConvLayer(64, 64, 5, 1, 2),
            nn.MaxPool2d(2, 2),
            ConvLayer(64, 128, 5, 1, 1),
            ConvLayer(128, 128, 5, 1, 2),
            nn.MaxPool2d(2, 2)
        )

        self.fc = nn.Sequential(
            nn.Linear(128, 2)
        )

        self.output_layer = nn.Sequential(
            nn.Linear(2, 10)
        )

    def forward(self, _x):
        h_out = self.hidden_layer(_x)
        h_out = h_out.reshape(-1, 128)
        feature = self.fc(h_out)
        outs = self.output_layer(feature)
        return torch.cat((feature, outs), dim=1)

批次和学习率

data_loader = data.DataLoader(dataset=train_data, shuffle=True, batch_size=1024)
# ...
net_opt = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.8)
scheduler = torch.optim.lr_scheduler.StepLR(net_opt, 30, gamma=0.8)
c_l_opt = torch.optim.SGD(center_loss_fn.parameters(), lr=0.1, momentum=0.8)

效果

image.png

v4_2

v4基础上修改了学习率

网络结构

class MainNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden_layer = nn.Sequential(
            ConvLayer(1, 32, 5, 1, 1),
            ConvLayer(32, 32, 5, 1, 2),
            nn.MaxPool2d(2, 2),
            ConvLayer(32, 64, 5, 1, 1),
            ConvLayer(64, 64, 5, 1, 2),
            nn.MaxPool2d(2, 2),
            ConvLayer(64, 128, 5, 1, 1),
            ConvLayer(128, 128, 5, 1, 2),
            nn.MaxPool2d(2, 2)
        )

        self.fc = nn.Sequential(
            nn.Linear(128, 2)
        )

        self.output_layer = nn.Sequential(
            nn.Linear(2, 10)
        )

    def forward(self, _x):
        h_out = self.hidden_layer(_x)
        h_out = h_out.reshape(-1, 128)
        feature = self.fc(h_out)
        outs = self.output_layer(feature)
        return torch.cat((feature, outs), dim=1)

批次和学习率

data_loader = data.DataLoader(dataset=train_data, shuffle=True, batch_size=1024)
# ...
net_opt = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.8)
scheduler = torch.optim.lr_scheduler.StepLR(net_opt, 30, gamma=0.8)
c_l_opt = torch.optim.SGD(center_loss_fn.parameters(), lr=0.01, momentum=0.8)

效果

image.png

结论

centerloss优化器学习率0.1–>0.01,分类效果明显

更换批次

data_loader = data.DataLoader(dataset=train_data, shuffle=True, batch_size=256)
# ...
net_opt = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.8)
scheduler = torch.optim.lr_scheduler.StepLR(net_opt, 30, gamma=0.8)
c_l_opt = torch.optim.SGD(center_loss_fn.parameters(), lr=0.01, momentum=0.8)

效果

image.png

结论

批次1024–>256,计算速度加快,分类速度减慢

v4_3

v4基础上修改了学习率

网络结构

class MainNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden_layer = nn.Sequential(
            ConvLayer(1, 32, 5, 1, 1),
            ConvLayer(32, 32, 5, 1, 2),
            nn.MaxPool2d(2, 2),
            ConvLayer(32, 64, 5, 1, 1),
            ConvLayer(64, 64, 5, 1, 2),
            nn.MaxPool2d(2, 2),
            ConvLayer(64, 128, 5, 1, 1),
            ConvLayer(128, 128, 5, 1, 2),
            nn.MaxPool2d(2, 2)
        )

        self.fc = nn.Sequential(
            nn.Linear(128, 2)
        )

        self.output_layer = nn.Sequential(
            nn.Linear(2, 10)
        )

    def forward(self, _x):
        h_out = self.hidden_layer(_x)
        h_out = h_out.reshape(-1, 128)
        feature = self.fc(h_out)
        outs = self.output_layer(feature)
        return torch.cat((feature, outs), dim=1)

批次和学习率

data_loader = data.DataLoader(dataset=train_data, shuffle=True, batch_size=1024)
# ...
net_opt = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.8)
scheduler = torch.optim.lr_scheduler.StepLR(net_opt, 30, gamma=0.8)
c_l_opt = torch.optim.SGD(center_loss_fn.parameters(), lr=0.5, momentum=0.8)

效果

image.png

结论

centerloss优化器学习率0.1–>0.5,分类效果明显

v5

加深网络,修改centerloss计算(取消平方根),修改学习率

centerloss计算

# return lamda / 2 * torch.mean(torch.div(torch.sqrt(torch.sum(torch.pow(_x - center_exp, 2), dim=1)), count_exp))
return lamda / 2 * torch.mean(torch.div(torch.sum(torch.pow((_x - center_exp), 2), dim=1), count_exp))

网络结构

self.hidden_layer = nn.Sequential(
    # ConvLayer(1, 32, 5, 1, 2),
    ConvLayer(3, 32, 5, 1, 2),
    ConvLayer(32, 64, 5, 1, 2),
    nn.MaxPool2d(2, 2),
    ConvLayer(64, 128, 5, 1, 2),
    ConvLayer(128, 256, 5, 1, 2),
    nn.MaxPool2d(2, 2),
    ConvLayer(256, 512, 5, 1, 2),
    ConvLayer(512, 512, 5, 1, 2),
    nn.MaxPool2d(2, 2),
    ConvLayer(512, 256, 5, 1, 2),
    ConvLayer(256, 128, 5, 1, 2),
    ConvLayer(128, 64, 5, 1, 2),
    nn.MaxPool2d(2, 2)
)

self.fc = nn.Sequential(
    # nn.Linear(64, 2)
    nn.Linear(64 * 2 * 2, 2)
)

self.output_layer = nn.Sequential(
    nn.Linear(2, 10)
)

批次和学习率

data_loader = data.DataLoader(dataset=train_data, shuffle=True, batch_size=256)
# ...
net_opt = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(net_opt, 20, gamma=0.8)
c_l_opt = torch.optim.SGD(center_loss_fn.parameters(), lr=0.5)

效果

MNIST
image.png
output.gif
image.png
CIFAR10
image.png
output_cifar10.gif
image.png
CIFAR100
image.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1561787.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【MySQL】DCL-数据控制语言-【管理用户&权限控制】 (语法语句&案例演示&可cv案例代码)

前言 大家好吖,欢迎来到 YY 滴MySQL系列 ,热烈欢迎! 本章主要内容面向接触过C Linux的老铁 主要内容含: 欢迎订阅 YY滴C专栏!更多干货持续更新!以下是传送门! YY的《C》专栏YY的《C11》专栏YY的…

13.5k star, 免费开源 Markdown 编辑器

13.5k star, 免费开源 Markdown 编辑器 分类 开源分享 项目名: Editor.md -- Markdown 编辑器 Github 开源地址: https://github.com/pandao/editor.md 在线测试地址: Editor.md - 开源在线 Markdown 编辑器 完整实例: HTML Preview(mark…

网络安全 | 什么是威胁情报?

关注WX:CodingTechWork 威胁情报 威胁情报-介绍 威胁情报也称为“网络威胁情报”(CTI),是详细描述针对组织的网络安全威胁的数据。威胁情报可帮助安全团队更加积极主动地采取由数据驱动的有效措施,在网络攻击发生之前就将其消弭于无形。威…

Java基础入门--面向对象课后题(2)

文章目录 1 Employee2 SalariedEmployee3 HourlyEmployee4 SalesEmployee5 BasePlusSalesEmployee6 测试类 Example177 完整代码 某公司的雇员分为5类,每类员工都有相应的封装类,这5个类的信息如下所示。 (1) Employee:这是所有员工总的父类。…

用Typora+picgo+cloudflare+Telegraph-image的免费,无需服务器,无限空间的图床搭建(避坑指南)

用TyporapicgocloudflareTelegraph-image的免费,无需服务器,无限空间的图床搭建(避坑指南) 前提:有github何cloudflare (没有的话注册也很快) 首先,是一个别人写的详细的配置流程,傻瓜式教程&am…

华为配置防止ARP中间人攻击实验

配置防止ARP中间人攻击实验 组网图形 图1 配置防止ARP中间人攻击组网图 动态ARP检测简介配置注意事项组网需求配置思路操作步骤配置文件 动态ARP检测简介 ARP(Address Resolution Protocol)安全是针对ARP攻击的一种安全特性,它通过一系列…

网络原理 - HTTP / HTTPS(1)——http请求

目录 一、认识HTTP协议 理解 应用层协议 二、fiddler的安装以及介绍 1、fiddler的安装 2、fiddler的介绍 三、HTTP 报文格式 1、http的请求 2、http的响应 五、认识URL 六、关于URL encode 一、认识HTTP协议 HTTP 全称为:“超文本传输协议”,是…

如何将平板或手机作为电脑的外接显示器?

先上官网链接:ExtensoDesk 家里有一台华为平板,自从买回来以后除了看视频外,基本没什么作用,于是想着将其作为我电脑的第二个屏幕,提高我学习办公的效率,废物再次利用。最近了解到华为和小米生态有多屏协同…

iPhone设备中使用第三方工具查看应用程序崩溃日志的教程

​ 目录 如何在iPhone设备中查看崩溃日志 摘要 引言 导致iPhone设备崩溃的主要原因是什么? 使用克魔助手查看iPhone设备中的崩溃日志 奔溃日志分析 总结 摘要 本文介绍了如何在iPhone设备中查看崩溃日志,以便调查崩溃的原因。我们将展示三种不同的…

【MySQL】数据库函数-案例演示【字符串/数值/日期/流程控制函数】(代码演示&可cv代码)

前言 大家好吖,欢迎来到 YY 滴MySQL系列 ,热烈欢迎! 本章主要内容面向接触过C Linux的老铁 主要内容含: 欢迎订阅 YY滴C专栏!更多干货持续更新!以下是传送门! YY的《C》专栏YY的《C11》专栏YY的…

WebSorcket 集成 Spring Boot

WebSorcket 集成 Spring Boot 配置 Configuration public class WebSocketConfiguraion {Beanpublic ServerEndpointExporter serverEndpointExporter (){ServerEndpointExporter exporter new ServerEndpointExporter();return exporter;} }服务类 import lombok.extern.sl…

Prometheus+grafana环境搭建mysql(docker+二进制两种方式安装)(三)

由于所有组件写一篇幅过长,所以每个组件分一篇方便查看,前两篇 Prometheusgrafana环境搭建方法及流程两种方式(docker和源码包)(一)-CSDN博客 Prometheusgrafana环境搭建rabbitmq(docker二进制两种方式安装)(二)-CSDN博客 1.监控mysql 1.1官方地址:…

iPhone设备中如何导出和分享应用程序崩溃日志的实用方法

​ 目录 如何在iPhone设备中查看崩溃日志 摘要 引言 导致iPhone设备崩溃的主要原因是什么? 使用克魔助手查看iPhone设备中的崩溃日志 奔溃日志分析 总结 摘要 本文介绍了如何在iPhone设备中查看崩溃日志,以便调查崩溃的原因。我们将展示三种不同的…

医院智慧手术麻醉系统管理源码 C# .net有演示

医院智慧手术麻醉系统管理源码 C# .net有演示 手术麻醉管理系统(DORIS)是应用于医院手术室、麻醉科室的计算机软件系统。该系统针对整个围术期,对病人进行全程跟踪与信息管理,自动集成病人HIS、LIS、RIS、PACS信息,采集监护等设备数据&#x…

【前端】FreeMarker学习笔记

文章目录 1. 介绍2.FreeMarker环境搭建(maven版本)3. 语法3.1 freemarker的数据类型3.1.1 布尔类型3.1.2 日期类型 FreeMarker视频教程 1. 介绍 中文官网 英文官网 FreeMarker 是一款 模板引擎: 即一种基于模板和要改变的数据, 并用来生成输出文本(HTML…

第三天开始写了

现在的情况 写俩个接口信息 1. 一个修改 2. 一个 删除 发现了一个问题 只有这些参数无法完成修改的 因为这些关联到一个商品表和一个用户表,我们应该查询他们id信息,修改其中的内容,单独根据字符串查看效果可能不好 这里我们提交应该是用…

深入探究Shrio反序列化漏洞

Shrio反序列化漏洞 什么是shrio反序列化漏洞环境搭建漏洞判断rememberMe解密流程代码分析第一层解密第二层解密2.1层解密2.2层解密 exp 什么是shrio反序列化漏洞 Shiro是Apache的一个强大且易用的Java安全框架,用于执行身份验证、授权、密码和会话管理。使用 Shiro 易于理解的…

neo4j使用详解(七、cypher数学函数语法——最全参考)

Neo4j系列导航: neo4j及简单实践 cypher语法基础 cypher插入语法 cypher插入语法 cypher查询语法 cypher通用语法 cypher函数语法 5.数学函数 5.1.数值函数 数学函数仅对数字表达式进行运算,如果对任何其他值使用,将返回错误 abs()&#xf…

CentOs7.9中修改Mysql8.0.28默认的3306端口防止被端口扫描入侵

若你的服务器被入侵,可以从这些地方找到证据: 若有上述信息,300%是被入侵了,重装服务器系统以后再重装Mysql数据库,除了设置一个复杂的密码以外,还需要修改默认的Mysql访问端口,逃避常规端口扫描…

超图打开不同格式的dem文件

dem,数字高程模型; dem文件的后缀是什么? 有*.dem格式的,也有Raster,ASCII和Tiff类型的。Raster类型的是一个raster文件夹里面有很多不同格式的文件共同组成了DEM文件的内容。ASCII类型的是个txt文件。Tiff类型的也是一个文件夹…