PyTorch中的学习率预热(warmup)

news2025/1/20 1:10:18

      PyTorch提供了学习率调度器(learning rate schedulers),用于在训练过程中实现各种调整学习率的方法。实现在torch.optim.lr_scheduler.py中,根据epoch数调整学习率。大多数学习率调度器可以称为背对背(back-to-back),也称为链式调度器,结果是每个调度器都一个接一个地应用于前一个调度器获得的学习率。学习率调度器应在优化器更新(optimizer.step())后应用

      warmup是 ResNet 论文中提出的方法:We further explore n = 18 that leads to a 110-layer ResNet. In this case, we find that the initial learning rate of 0.1 is slightly too large to start converging. so we use 0.01 to warm up the training until the training error is below 80% (about 400 iterations), and then go back to 0.1 and continue training. The rest of the learning schedule is as done previously.

      warmup是一种学习率优化方法。使用warmup可以在训练初期使用较小的学习率进行稳定的模型训练,然后逐渐增加学习率以提高收敛速度和模型性能。有助于减缓模型在初始阶段对mini-batch的提前过拟合现象,保持分布的平稳。有助于保持模型深层的稳定性。

      由于刚开始训练时,模型的权重是随机初始化的,loss比较大,此时若选择一个较大的学习率,可能带来模型的不稳定(振荡),选择warmup的方式,可以使得开始训练的几个epoch或者一些step内学习率较小,在预热的小学习率下,模型可以慢慢趋于稳定,等模型相对稳定后再选择预先设置的学习率进行训练,有助于保持模型深层的稳定性,使得模型收敛速度变得更快,模型效果更佳。最终稳定阶段降低学习率更容易找到局部最优,可以增加batch size,这样更稳定。

      如果收敛太快,很快就在训练集上过拟合了,可以降低学习率,如果训练过慢或不收敛,则可以增加学习率。

      一般学习率设置:上升----平稳----下降。

      :以上内容来自于网络整理。

      PyTorch中的学习率预热方法:假设optimizer中设置的学习率为lr

      (1).ConstantLR(optimizer, factor, total_iters):前total_iters次,学习率为lr*factor,以后学习率变为lr。

      (2).LinearLR(optimizer, start_factor, end_factor, total_iters):前total_iters次,学习率从lr*start_factor逐次增加,以后学习率变为lr*end_factor。

      (3).LambdaLR(optimizer, lr_lambda):lr_lambda为lambda函数,如为以下:则学习率为lr*0.95的epoch次方。

lr_lambda = lambda epoch: 0.95 ** epoch

      (4).ExponentialLR(optimizer, gamma):学习率为lr*gamma的epoch次方。

      (5).StepLR(optimizer, step_size, gamma):学习率为lr*gamma的(当前opoch/step_size)次方。

      (6).MultiStepLR(optimizer, milestones, gamma):milestones为列表,如为[5,10,50,200],则epoch<5时,学习率为lr;epoch在[5,10)之间时为lr*gamma的1次方;epoch在[10,50)之间时为lr*gamma的2次方,依次类推。与StepLR相比,它允许学习率在不同的时间点以不同的步长衰减。

      (7).CosineAnnealingLR(optimizer, T_max, eta_min):学习率按照余弦函数进行周期性调整,每个周期结束时重置为初始学习率。T_max为周期内的最大迭代次数,eta_min为最小学习率。

      (8).ReduceLROnPlateau(optimizer, mode, factor, patience, threshold, threshold_mode, cooldown, min_lr, eps):当验证集上的loss停止改进时,自动降低学习率。这种方法不需要预先定义学习率衰减的时间表,而是根据模型的表现动态调整。

      (9).CyclicLR(optimizer, base_lr, max_lr, ...):学习率已恒定频率在给定的两个边界之间循环。

      (10).PolynomialLR(optimizer, total_iters, power):使用多项式函数衰减学习率。当epoch大于total_iters时,后面的学习率都为0。

      测试代码如下所示:

import colorama
import argparse
import time
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import ImageFolder
from torchvision import transforms
import torchvision.models as models

def parse_args():
	parser = argparse.ArgumentParser(description="learning rate warm up")
	parser.add_argument("--epochs", required=True, type=int, help="number of training")
	parser.add_argument("--dataset_path", required=True, type=str, help="source dataset path")
	parser.add_argument("--model_name", required=True, type=str, help="the model generated during training or the model loaded during prediction")
	parser.add_argument("--pretrained_model", type=str, default="", help="pretrained model loaded during training")
	parser.add_argument("--batch_size", type=int, default=2, help="specify the batch size")

	args = parser.parse_args()
	return args

def load_dataset(dataset_path, batch_size):
	mean = (0.53087615, 0.23997033, 0.45703197)
	std = (0.29807151489753686, 0.3128615049442739, 0.15151863355831655)

	transform = transforms.Compose([
		transforms.CenterCrop(224),
		transforms.ToTensor(),
		transforms.Normalize(mean=mean, std=std), # RGB
	])

	train_dataset = ImageFolder(root=dataset_path+"/train", transform=transform)
	train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

	val_dataset = ImageFolder(root=dataset_path+"/val", transform=transform)
	val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
	assert len(train_dataset.class_to_idx) == len(val_dataset.class_to_idx), f"the number of categories int the train set must be equal to the number of categories in the validation set: {len(train_dataset.class_to_idx)} : {len(val_dataset.class_to_idx)}"

	return len(train_dataset.class_to_idx), len(train_dataset), len(val_dataset), train_loader, val_loader

def train(model, train_loader, device, optimizer, criterion, train_loss, train_acc):
	model.train() # set to training mode
	for _, (inputs, labels) in enumerate(train_loader):
		inputs = inputs.to(device)
		labels = labels.to(device)

		optimizer.zero_grad() # clean existing gradients
		outputs = model(inputs) # forward pass
		loss = criterion(outputs, labels) # compute loss
		loss.backward() # backpropagate the gradients
		optimizer.step() # update the parameters

		train_loss += loss.item() * inputs.size(0) # compute the total loss
		_, predictions = torch.max(outputs.data, 1) # compute the accuracy
		correct_counts = predictions.eq(labels.data.view_as(predictions))
		acc = torch.mean(correct_counts.type(torch.FloatTensor)) # convert correct_counts to float
		train_acc += acc.item() * inputs.size(0) # compute the total accuracy

	return train_loss, train_acc

def validate(model, val_loader, device, criterion, val_loss, val_acc):
	model.eval() # set to evaluation mode
	with torch.no_grad():
		for _, (inputs, labels) in enumerate(val_loader):
			inputs = inputs.to(device)
			labels = labels.to(device)

			outputs = model(inputs) # forward pass
			loss = criterion(outputs, labels) # compute loss
			val_loss += loss.item() * inputs.size(0) # compute the total loss
			_, predictions = torch.max(outputs.data, 1) # compute validation accuracy
			correct_counts = predictions.eq(labels.data.view_as(predictions))
			acc = torch.mean(correct_counts.type(torch.FloatTensor)) # convert correct_counts to float
			val_acc += acc.item() * inputs.size(0) # compute the total accuracy

	return val_loss, val_acc

def training(epochs, dataset_path, model_name, pretrained_model, batch_size):
	classes_num, train_dataset_num, val_dataset_num, train_loader, val_loader = load_dataset(dataset_path, batch_size)
	model = models.ResNet(block=models.resnet.BasicBlock, layers=[2,2,2,2], num_classes=classes_num) # ResNet18

	if pretrained_model != "":
		model.load_state_dict(torch.load(pretrained_model))

	optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.99), eps=1e-7) # set the optimizer
	scheduler = optim.lr_scheduler.ConstantLR(optimizer, factor=0.2, total_iters=10)
	# scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.2, end_factor=0.8, total_iters=5)
	# assert len(optimizer.param_groups) == 1, f"optimizer.param_groups's length must be equal to 1: {len(optimizer.param_groups)}"
	# lr_lambda = lambda epoch: 0.95 ** epoch
	# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
	# scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
	# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.2)
	# scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5,10,15], gamma=0.2)
	# scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=0.05)
	# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min")
	# scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.05)
	# scheduler = optim.lr_scheduler.PolynomialLR(optimizer, total_iters=5, power=1.)
	print(f"epoch: 0/{epochs}: learning rate: {scheduler.get_last_lr()}")

	criterion = nn.CrossEntropyLoss() # set the loss

	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	model.to(device)

	highest_accuracy = 0.
	minimum_loss = 100.

	for epoch in range(epochs):
		epoch_start = time.time()
		train_loss = 0.0 # loss
		train_acc = 0.0 # accuracy
		val_loss = 0.0
		val_acc = 0.0

		train_loss, train_acc = train(model, train_loader, device, optimizer, criterion, train_loss, train_acc)
		val_loss, val_acc = validate(model, val_loader, device, criterion, val_loss, val_acc)
		# scheduler.step(val_loss) # update lr, ReduceLROnPlateau
		scheduler.step() # update lr

		avg_train_loss = train_loss / train_dataset_num # average training loss
		avg_train_acc = train_acc / train_dataset_num # average training accuracy
		avg_val_loss = val_loss / val_dataset_num # average validation loss
		avg_val_acc = val_acc / val_dataset_num # average validation accuracy
		epoch_end = time.time()
		print(f"epoch:{epoch+1}/{epochs}; learning rate: {scheduler.get_last_lr()}, train loss:{avg_train_loss:.6f}, accuracy:{avg_train_acc:.6f}; validation loss:{avg_val_loss:.6f}, accuracy:{avg_val_acc:.6f}; time:{epoch_end-epoch_start:.2f}s")

		if highest_accuracy < avg_val_acc and minimum_loss > avg_val_loss:
			torch.save(model.state_dict(), model_name)
			highest_accuracy = avg_val_acc
			minimum_loss = avg_val_loss

		if avg_val_loss < 0.00001 and avg_val_acc > 0.9999:
			print(colorama.Fore.YELLOW + "stop training early")
			torch.save(model.state_dict(), model_name)
			break

if __name__ == "__main__":
	# python test_learning_rate_warmup.py --epochs 1000 --dataset_path datasets/melon_new_classify --pretrained_model pretrained.pth --model_name best.pth
	colorama.init(autoreset=True)
	args = parse_args()

	training(args.epochs, args.dataset_path, args.model_name, args.pretrained_model, args.batch_size)

	print(colorama.Fore.GREEN + "====== execution completed ======")

      执行结果如下所示:

      GitHub:https://github.com/fengbingchun/NN_Test

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2149099.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux入门2

文章目录 一、Linux基本命令1.1 文件的创建和查看命令1.2 文件的复制移动删除等命令1.3 查找命令1.4 文件的筛选和管道的使用1.5 echo、tail和重定向符 二、via编辑器三、权限控制3.1 root用户&#xff08;超级管理员&#xff09;3.2 用户和用户组3.3 权限信息3.4 chmod命令 一…

Streamlit:使用 Python 快速开发 Web 应用

一、简单介绍 Streamlit 是一个开源 Python 库&#xff0c;官网地址&#xff1a; https://streamlit.io/http://StreamlitStreamlit 是一个开源的 Python 框架&#xff0c;旨在为数据科学家和 后端工程师们提供只需几行代码即可创建动态数据应用的功能。 让没有任何前端基础…

C#软键盘设计字母数字按键处理相关事件函数

应用场景&#xff1a;便携式设备和检测设备等小型设备经常使用触摸屏来代替键盘鼠标的使用&#xff0c;因此在查询和输入界面的文本或者数字输入控件中使用软件盘来代替真正键盘的输入。 软键盘界面&#xff1a;软键盘界面实质上就是一个普通的窗体上面摆放了很多图片按钮&…

使用SpringCloud构建可伸缩的微服务架构

Spring Cloud是一个用于构建分布式系统的开源框架。它基于Spring Boot构建&#xff0c;并提供了一系列的工具和组件&#xff0c;用于简化开发分布式系统的难度。Spring Cloud可以帮助开发人员快速构建可伸缩的微服务架构。 要使用Spring Cloud构建可伸缩的微服务架构&#xff0…

实时监控局域网计算机桌面怎么设置!五个可实现的方法分享,绝了!

员工在工作时间里究竟在做什么&#xff1f;他们的网络活动是否合规&#xff1f;如何确保敏感信息不被泄露&#xff1f; 在企业管理层面&#xff0c;实时监控局域网计算机桌面已成为提升工作效率、确保数据安全与规范员工行为的重要手段。 技术的不断进步&#xff0c;多种解决…

【C++进阶】map和set的使用

【C进阶】map和set的使用 &#x1f955;个人主页&#xff1a;开敲&#x1f349; &#x1f525;所属专栏&#xff1a;C&#x1f96d; &#x1f33c;文章目录&#x1f33c; 1. 序列式容器和关联式容器 2. set系列的使用 2.1 set 和 multiset 2.2 set 类的介绍 2.3 set 的构造和…

【Linux篇】常用命令及操作技巧(基础篇)

&#x1f30f;个人博客主页&#xff1a;意疏-CSDN博客 希望文章能够给到初学的你一些启发&#xff5e; 如果觉得文章对你有帮助的话&#xff0c;点赞 关注 收藏支持一下笔者吧&#xff5e; 阅读指南&#xff1a; 开篇说明帮助命令常见的七个linux操作终端实用的技巧跟文件目录…

C++11之统一的列表初始化

一.{}初始化 在c98中&#xff0c;标准允许使用{}对数组或结构体元素进行统一的列表初始值设定&#xff1a; struct mess {int _x;string _str; }; int main() {//注意&#xff0c;使用new的一定是指针int* arr new int[4] {1, 2, 3, 4};//数组初始化int arr[] { 1,3,5,6 };…

基于Spring Boot和Vue的私人牙科诊所系统的设计与实现(毕业论文)

目 录 1 前言 1 1.1 研究目的与意义 1 1.2 国内外研究概况 1 1.3 主要研究内容 2 1.4 论文结构 3 2 系统分析 3 2.1 可行性分析 3 2.1.1 技术可行性分析 3 2.1.2 经济可行性分析 3 2.1.3 操作可行性分析 4 2.1.4 法律可行性分析 4 2.2 需求分析 4 2.2.1 管理员需求分析 4 2.2.2…

3.1 数据表的基本查询

我们学习的怎么管理逻辑空间&#xff0c;怎么创建数据表&#xff0c;怎么定义字段&#xff0c;怎么创建索引&#xff0c;这些都是DDL语句。从本次课开始&#xff0c;我们来学习DML语句&#xff0c;也就是该如何增删改查操作数据。我们学习DML语句的前提是数据表要有足够多的数据…

Moving Elevator System Fully functional

这是一个功能齐全的电梯系统,配有电梯箱车、电梯井、电缆和每层的门框 电梯完全被操纵,有动画门、电缆线、滑轮系统。 还有几个C#脚本文件控制电梯、门和灯。 此套餐还包括相关声音,如电梯移动、门打开/关闭、楼层铃叮。 电梯车厢有工作门和按钮,车顶还有一个逃生舱口。 每…

低投入、高效率 基于PHP+MySQL组合开发的求职招聘小程序源码系统 带完整的安装代码包以及搭建部署教程

系统概述 这款求职招聘小程序源码系统是专门为求职招聘领域打造的综合性平台。它利用 PHP 强大的编程语言特性和 MySQL 稳定的数据存储功能&#xff0c;实现了一个功能齐全、性能优越的求职招聘系统。 整个系统架构设计合理&#xff0c;具备良好的扩展性和兼容性。无论是小型…

从《中国数据库前世今生》看中国数据库技术的发展与挑战

从《中国数据库前世今生》看中国数据库技术的发展与挑战 引言 在当今数字化浪潮中&#xff0c;数据库技术已成为支撑全球经济运行的核心基础设施。作为程序员&#xff0c;我一直对数据库技术的发展充满好奇。《中国数据库前世今生》纪录片深入探索了中国数据库技术的演变历程…

【Python报错已解决】libpng warning: iccp: known incorrect sRGB profile

&#x1f3ac; 鸽芷咕&#xff1a;个人主页 &#x1f525; 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想&#xff0c;就是为了理想的生活! 专栏介绍 在软件开发和日常使用中&#xff0c;BUG是不可避免的。本专栏致力于为广大开发者和技术爱好者提供一个关于BUG解决的经…

怎么把图片压缩变小?把图片压缩变小的八种压缩方法介绍

怎么把图片压缩变小&#xff1f;在当今这个信息高度共享的时代&#xff0c;图片不仅仅是简单的视觉元素&#xff0c;它们承载着我们的记忆、故事和创意。无论是旅行的风景、家庭的聚会&#xff0c;还是工作中的项目展示&#xff0c;图片都在我们的生活中扮演着不可或缺的角色。…

帕金森患者必看!这5种水果成“抗抖”小能手,吃出健康好生活!

在这个快节奏的时代&#xff0c;健康成为了我们最宝贵的财富之一。而对于帕金森病患者而言&#xff0c;如何在日常生活中通过合理的饮食来缓解症状、提升生活质量&#xff0c;成为了许多家庭关注的焦点。今天&#xff0c;就让我们一起探索那些被誉为“抗抖”小能手的水果&#…

Pandas -----------------------基础知识(一)

目录 Series对象 属性和方法 布尔值列表获取Series对象中部分数据 运算 DateFrame对象 常用属性 常见方法 运算 总结 Series对象 是DataFrame的列对象或者行对象 生成Series对象生成索引使用元组创建Series对象使用字典创建Series对象 通过Pandas创建对象 自定义索引 …

RealityCapture1.4设置成中文

RealityCapture 1.4 设置成中文的教程 RealityCapture 1.4 是一款强大的三维建模软件&#xff0c;它能够从图像或激光扫描中创建实景三维模型和正射影像等。以下是一个详细的教程&#xff0c;指导您如何将 RealityCapture 1.4 的界面设置为中文。 1.找到设置按钮 在WORKFLOW…

【一起学NLP】Chapter1-基本语法与神经网络的推理

备注&#xff1a;本专栏为个人的NLP学习笔记&#xff0c;欢迎大家共同讨论交流学习。代码同步&#xff1a;https://github.com/codesknight/Learning-NLP-Together 参考书籍&#xff1a;《深度学习进阶&#xff1a;自然语言处理》——斋藤康毅 目录 基础知识点复习测试环境使用…

OceanMesh2D | 基于精确距离的沿海海洋/浅水流动模型二维自动网格生成MATLAB工具箱推荐

Precise distance-based two-dimensional automated mesh generation toolbox intended for coastal ocean/shallow water flow models OceanMesh2D | 基于精确距离的沿海海洋/浅水流动模型二维自动网格生成MATLAB工具箱推荐 1. 简介2. 特点3. 代码基本要求:4. 基本流程 1. 简…