深度学习——优化器Optimizer

news2025/1/18 9:06:28

代码以及详细注释:

import torch
import torch.utils.data as Data
import torch.nn.functional as F
import matplotlib.pyplot as plt

# torch.manual_seed(1)    # reproducible
"""
    超参数
"""
# 学习率
LR = 0.01
# 批大小
BATCH_SIZE = 32
# 轮次
EPOCH = 12

# 造数据
x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)
y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size()))

# # plot dataset
# plt.scatter(x.numpy(), y.numpy())
# plt.show()


# put dateset into torch dataset
torch_dataset = Data.TensorDataset(x, y)
# 数据加载器
loader = Data.DataLoader(dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,)


# default network
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(1, 20)   # hidden layer
        self.predict = torch.nn.Linear(20, 1)   # output layer

    def forward(self, x):
        x = F.relu(self.hidden(x))      # activation function for hidden layer
        x = self.predict(x)             # linear output
        return x

if __name__ == '__main__':
    # 相同的网络结构
    net_SGD         = Net()
    net_Momentum    = Net()
    net_RMSprop     = Net()
    net_Adam        = Net()
    # 将上面的网络集成到这里
    nets = [net_SGD, net_Momentum, net_RMSprop, net_Adam]

    # 不同的优化器
    opt_SGD         = torch.optim.SGD(net_SGD.parameters(), lr=LR)
    opt_Momentum    = torch.optim.SGD(net_Momentum.parameters(), lr=LR, momentum=0.8)
    opt_RMSprop     = torch.optim.RMSprop(net_RMSprop.parameters(), lr=LR, alpha=0.9)
    opt_Adam        = torch.optim.Adam(net_Adam.parameters(), lr=LR, betas=(0.9, 0.99))
    # 将上面的优化器集成到这里
    optimizers = [opt_SGD, opt_Momentum, opt_RMSprop, opt_Adam]

    # 损失函数
    loss_func = torch.nn.MSELoss()
    losses_his = [[], [], [], []]   # record loss

    # 训练轮次
    for epoch in range(EPOCH):
        print('Epoch: ', epoch)
        # 分批训练
        for step, (b_x, b_y) in enumerate(loader):          # for each training step
            for net, opt, l_his in zip(nets, optimizers, losses_his):
                output = net(b_x)              # get output for every net
                loss = loss_func(output, b_y)  # compute loss for every net
                opt.zero_grad()                # clear gradients for next train
                loss.backward()                # backpropagation, compute gradients
                opt.step()                     # apply gradients
                l_his.append(loss.data)        # loss recoder

    # 绘图
    labels = ['SGD', 'Momentum', 'RMSprop', 'Adam']
    for i, l_his in enumerate(losses_his):
        plt.plot(l_his, label=labels[i])
    plt.legend(loc='best')
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.ylim((0, 0.2))
    plt.show()

运行结果:

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/738774.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

API测试之Postman使用完全指南

前言 Postman是一个可扩展的API开发和测试协同平台工具,可以快速集成到CI/CD管道中。旨在简化测试和开发中的API工作流。 Postman 工具有 Chrome 扩展和独立客户端,推荐安装独立客户端。 Postman 有个 workspace 的概念,workspace 分 pers…

16、Python读取气象数据的正确姿势

文章目录 一、气象数据格式(常用)二、单个文件读取1. 常规格式2. CSV格式3. NetCDF格式4. GRIB格式 一、气象数据格式(常用) 常规格式(Plain Text):气象数据可以使用纯文本格式进行存储&#xf…

漏洞复现 || 某友文件上传

免责声明 技术文章仅供参考,任何个人和组织使用网络应当遵守宪法法律,遵守公共秩序,尊重社会公德,不得利用网络从事危害国家安全、荣誉和利益,未经授权请勿利用文章中的技术资料对任何计算机系统进行入侵操作。利用此…

HarmonyOS学习路之开发篇—流转(跨端迁移 一)

跨端迁移开发 场景介绍 开发者在应用FA中通过调用流转任务管理服务、分布式任务调度的接口,实现跨端迁移。 1. 设备A上的应用FA向流转任务管理服务注册一个流转回调: Alt1-系统推荐流转:系统感知周边有可用设备后,主动为用户提…

网络版本的计算器

文章目录 1. TCP协议通讯流程2. 应用层2.1 再谈 "协议" 3. 网络版计算器3.1 服务器提供服务3.1.1 提取有效载荷3.1.2 服务器的反序列化3.1.3 计算服务3.1.4 服务器的序列化3.1.5 添加序列化后的长度 3.2 客户端发送请求3.2.1 填充客户端请求3.2.2 客户端进行序列化3.…

为什么我挖不倒sql注入啊!

为什么我挖不倒sql注入啊! 背景一句话讲原理小白速挖注入 背景 不知道是不是初学安全的小伙伴都和我一样,刚开始学的时候,诶挺简单啊!我咋这么聪明一学就会,靶场轻轻松松过关,到了实战根本挖不出来&#x…

【C++】float / double 与 0 值比较

【C】float / double 与 0 值比较 文章目录 【C】float / double 与 0 值比较1. 概述不同1.1 - float 与 double 实际存储1.2 - C 语言与 C 中不同 2. 比较方法2.1 - C 风格比较2.2 - 使用 limits 函数 3. 参考链接 References 1. 概述不同 当然使用普通的比较没有问题&#xf…

项目管理中,WBS与项目计划有什么区别?

为了成功完成项目并控制成本,我们有必要采取科学的项目管理方法。实现这一目标的工具是项目计划和工作分解结构(WBS)。 WBS 与项目计划是项目管理中必不可少的工具,但两者有不同的用途。WBS精确描述了项目工作和可交付成果&#…

前端vue入门(纯代码)26_多级路由

如果耐不住寂寞,你就看不到繁华。 【24.Vue Router--多级路由】 [可以去官网看看Vue Router文档](嵌套路由 | Vue Router (vuejs.org)) 在实际开发中,我们不单单会使用到一层路由,有时候会涉及到两层或两层以上的路由,多级路由…

带清除按钮的输入框

// index.html <!DOCTYPE html> <html> <head><meta charset"utf-8"><meta name"viewport" content"widthdevice-width, initial-scale1, maximum-scale1"><title>测试 - layui</title><link rel&…

Gof23设计模式之桥接模式

1.概述 桥接模式&#xff08;Bridge Pattern&#xff09;是一种结构型设计模式&#xff0c;它将抽象部分与实现部分分离&#xff0c;使它们可以独立地变化。它的核心思想就是将一个大类或一系列紧密关联的类拆分成两个独立的抽象和实现部分&#xff0c;以便能够更加灵活地扩展…

html相关面试题

html相关面试题 1.html和css中的图片加载与渲染规则是什么样的&#xff1f;2.title与h1的区别、b与strong的区别、i与em的区别&#xff1f;title 和 h1 的区别b 和 strong 的区别i 和 em 的区别最后 3.script 标签为什么建议放在 body 标签的底部&#xff08;defer、async&…

Duplicate keys detected: ‘0‘. This may cause an update error.

问题 vue报错 Duplicate keys detected: ‘0‘. This may cause an update error. 原因 <div v-for“(item,id) in items” :key"id”>{{item.name}} </div><div v-for“(item,id) in items” :key"id”>{{item.address}} </div>:key重…

G1垃圾收集器

一、内存结构 G1将堆内存划分成2048个相同大小的内存Region&#xff0c;一般Region大小等于堆内存大小除以2048&#xff0c;比如堆内存有4个G&#xff0c;每个Region大小为2M&#xff08;-XX:G1HeapRegionSize参数可以设置Region大小&#xff0c;一般不推荐修改&#xff09; G…

C. Strong Password

Problem - C - Codeforces 思路&#xff1a;根据题意我们能够知道就是对于每一位都要再区间范围内&#xff0c;并且不是s的子序列&#xff0c;我们先看第一位&#xff0c;第一位有l[1]-r[1]这几种选择&#xff0c;假如说某一种选择在s中没有那么我们就选择以这个开头的作为答案…

一文讲透进销存管理,和4款值得推荐的进销存管理软件!

进销存管理已经成为当下很多企业和商户必须面对的问题&#xff0c;想要在激烈的市场竞争中取胜&#xff0c;告别混乱管理&#xff0c;必须要有完善合理的进销存管理方法。 那么&#xff0c;进销存管理具体指的是什么&#xff0c;如何做好进销存管理&#xff0c;以及市面上有哪些…

L1-033 出生年(c语言)

作者 陈越 单位 浙江大学 以上是新浪微博中一奇葩贴&#xff1a;“我出生于1988年&#xff0c;直到25岁才遇到4个数字都不相同的年份。”也就是说&#xff0c;直到2013年才达到“4个数字都不相同”的要求。本题请你根据要求&#xff0c;自动填充“我出生于y年&#xff0c;直到…

【风险管理】认知风险管理

NLP技术的商业应用 介绍 机器学习 (ML) 应用程序已经无处不在。每天都有关于自动驾驶汽车人工智能、在线客户支持、虚拟个人助理等的新闻。然而&#xff0c;如何将现有的商业实践与所有这些惊人的创新联系起来可能并不明显。一个经常被忽视的领域是应用自然语言处理 (NLP) 和深…

极智AI | cv::cuda::GpuMat数据排布的误区

欢迎关注我的公众号 [极智视界]&#xff0c;获取我的更多经验分享 大家好&#xff0c;我是极智视界&#xff0c;本文来谈谈 cv::cuda::GpuMat 数据排布的误区。 邀您加入我的知识星球「极智视界」&#xff0c;星球内有超多好玩的项目实战源码下载&#xff0c;链接&#xff1a;…

Tomcat NIO 实现

1. tomcat网络整体架构 来自 https://www.cnblogs.com/cuzzz/p/17499364.html 上图是tomcat整个网络请求模型 Acceptor线程作为监听线程,会通过通过 accept 方法 获取连接&#xff0c;该线程没有使用selector进行多路复用&#xff0c;使用了阻塞式的accept有请求连接后&#x…