使用Tensorboard多超参数随机搜索训练

news2024/11/16 8:25:40

文章目录

  • 1超参数训练代码
  • 2远端电脑启动tensorboard

完整代码位置https://gitee.com/chuge325/base_machinelearning.git
这里还参考了tensorflow的官方文档

但是由于是pytorch训练的差别还是比较大的,经过多次尝试完成了训练
硬件是两张v100

1超参数训练代码

这个代码可以查看每次训练的loss曲线和超参数的对比信息
在这里插入图片描述
在这里插入图片描述

import pandas as pd
import torch
from utils.DataLoader import MNIST_data
from torchvision import transforms
from utils.DataLoader import RandomRotation,RandomShift
from model.AlexNet import Net
# from model.Net import Net
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter
from tensorboard.plugins.hparams import api as hp

from sklearn.model_selection import ParameterSampler
train_df = pd.read_csv('../datasets/digit-recognizer/train.csv')
n_train = len(train_df)
n_pixels = len(train_df.columns) - 1
n_class = len(set(train_df['label']))

# 定义超参数搜索空间
HP_LEARNING_RATE = hp.HParam('learning_rate', hp.RealInterval(1e-4, 1e-2))
HP_BATCH_SIZE = hp.HParam('batch_size', hp.Discrete([64, 128]))
HP_EPOCH = hp.HParam('epoch', hp.Discrete([50,100]))
METRIC_ACCURACY = hp.Metric('accuracy')
# 定义超参数搜索空间
param_dist = {
    'lr': [0.001, 0.003,0.01 ,0.1],
    'batch_size': [64,128],
    'num_epochs': [50,100]
}


def train(epoch):
    model.train()
    

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data), Variable(target)
        
        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda()
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        
        loss.backward()
        optimizer.step()
        
        if (batch_idx + 1)% 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                100. * (batch_idx + 1) / len(train_loader), 
                loss.item()))
    exp_lr_scheduler.step()
    writer.add_scalar('Loss/train', loss/len(train_loader), epoch)
    writer.flush()
def evaluate(data_loader):
    model.eval()
    loss = 0
    correct = 0
    
    for data, target in data_loader:
        data, target = Variable(data, volatile=True), Variable(target)
        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda()
        
        output = model(data)
        
        loss += F.cross_entropy(output, target, size_average=False).item()

        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
        
    loss /= len(data_loader.dataset)
    writer.add_scalar('accuracy/train', correct / len(data_loader.dataset), epoch)
    writer.flush()    
    print('\nAverage loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(
        loss, correct, len(data_loader.dataset),
        100. * correct / len(data_loader.dataset)))
    return correct / len(data_loader.dataset)

hparams_dir='logs/hparam_tuning'
hparams_writer = SummaryWriter(hparams_dir)
# 进行随机超参数搜索
param_list = list(ParameterSampler(param_dist, n_iter=20))
model = Net()
criterion = nn.CrossEntropyLoss()

for params in param_list:
    hparams_dict ={
        HP_LEARNING_RATE.name :params["lr"],
        HP_BATCH_SIZE.name :params['batch_size'],
        HP_EPOCH.name :params['num_epochs']

    }
    
    batch_size = params['batch_size']
    train_dataset = MNIST_data('../datasets/digit-recognizer/train.csv', n_pixels =n_pixels,transform= transforms.Compose(
                                [transforms.ToPILImage(), RandomRotation(degrees=20), RandomShift(3),
                                transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))]))
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                            batch_size=batch_size, shuffle=True)

    logdir = "logs/" + 'epoch{0}_lr{1}_batch_size{2}.pth'.format(params['lr'],params['num_epochs'],params['batch_size'])
    writer = SummaryWriter(logdir)
    optimizer = optim.Adam(model.parameters(), lr=params['lr'])
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

    if torch.cuda.is_available():
        model = model.cuda()
        criterion = criterion.cuda()

    n_epochs = params['num_epochs']

    for epoch in range(n_epochs):
        train(epoch)
        accuracy=evaluate(train_loader)
        if epoch==n_epochs-1:
            hparams_writer.add_hparams(hparams_dict,{'Accuracy':accuracy})
    torch.save(model.state_dict(), 'epoch{0}_lr{1}_batch_size{2}.pth'.format(params['num_epochs'],params['lr'],params['batch_size']))
    writer.close()
hparams_writer.close()

2远端电脑启动tensorboard

tensorboard --logdir logs

如果您的 TensorBoard 日志存储在远程服务器上,但您无法通过本地计算机上的浏览器访问它,可能是由于防火墙或网络设置的限制导致的。以下是一些可能的解决方案:

使用 SSH 隧道:如果您无法直接访问远程服务器上的 TensorBoard 日志,请考虑使用 SSH 隧道来建立本地和远程服务器之间的安全连接。在终端中使用以下命令建立 SSH 隧道:

   ssh -L 6006:localhost:6006 username@remote_server_ip

其中 username 是您在远程服务器上的用户名,remote_server_ip 是远程服务器的 IP 地址。然后,在本地计算机上打开浏览器并访问 http://localhost:6006 即可访问 TensorBoard 日志。

如果您使用的是 Windows 操作系统,可以使用 PuTTY 或其他 SSH 客户端来建立 SSH 隧道。
实例

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/431367.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android Studio升级Gradle Plugin升级导致项目运行失败问题

背景&错误 升级Android Studio 旧项目无法运行,奇奇怪怪什么错误都有 例如: java.lang.IllegalAccessError: class org.gradle.api.internal.tasks.compile.processing.AggregatingProcessingStrategy (in unnamed module 0x390ea9fb) cannot acce…

传智健康-day2

一.需求分析(预约管理功能开发) 预约管理功能,包括检查项管理、检查组管理、体检套餐管理、预约设置等、预约管理属于系统的基础功能,主要就是管理一些体检的基础数据。 检查组是检查项的集合 二.基础环境搭建 1导入预约管理模块数据表 需要用到的…

Ubuntu安装MySQL及常用操作

一、安装MySQL 使用以下命令即可进行mysql安装,注意安装前先更新一下软件源以获得最新版本: sudo apt-get update #更新软件源 sudo apt-get install mysql-server #安装mysql 上述命令会安装以下包: apparmor mysql-client-5.7 mysql-c…

不定期更新:我对 ChatGPT 进行多方位了解后的报告,超级全面,建议想了解的朋友看看

优质介绍视频: GPT4前端【AI编程新纪元】 【渐构】万字科普GPT4为何会颠覆现有工作流;为何你要关注微软Copilot、文心一言等大模型 此文章不定期更新(一周应该会更新一次) 最近一次更新:2023.4.16 12:00 ChatGPT 是什…

零基础搭建私人影音媒体平台【远程访问Jellyfin播放器】

文章目录1. 前言2. Jellyfin服务网站搭建2.1. Jellyfin下载和安装2.2. Jellyfin网页测试3.本地网页发布3.1 cpolar的安装和注册3.2 Cpolar云端设置3.3 Cpolar本地设置4.公网访问测试5. 结语1. 前言 随着移动智能设备的普及,各种各样的使用需求也被开发出来&#xf…

关于加强供水企业营销管理的几点思考

供水营销部门是供水企业最重要的职能部门之一,其工作职能直接与供水企业的经济利益和社会效益息息相关,具体来说,主要涉及到五个方面的指标内容:水费回收率、 水量漏损率(产销差率)、水表完好率、水价调整及…

《年会抽奖》:无人获奖的概率

目录 一、题目 二、思路 1、错排问题 2、n 的阶乘 3、输出格式要求 三、代码 一、题目 题目:年会抽奖 题目链接:年会抽奖 今年公司年会的奖品特别给力,但获奖的规矩却很奇葩: 1. 首先,所有人员都将…

SpringBoot起步依赖和自动配置

文章目录 1、起步依赖2、自动配置 1、起步依赖 概念 起步依赖本质上是一个Maven项目对象模型(Project Object Model,POM),定义了对其他库的传递依赖,这些东西加在一起支持某一功能。 简单的说,起步依赖就…

这才是后端API该有的样子

一般系统大致架构如下: 有些小伙伴会说,这个架构太简单太low了吧,什么网关、缓存、消息中间件都没有。 需要说明的是,因为我们主题是API接口(tbAPI,pinduoduo API接口调用)所以聚焦这一点上就行…

Java FileChannel文件的读写实例

一、概述: 文件通道FileChannel是用于读取,写入,文件的通道。FileChannel只能被InputStream、OutputStream、RandomAccessFile创建。使用fileChannel.transferTo()可以极大的提高文件的复制效率,他们读和写直接建立了通道&#x…

【Leetcode刷题】链表的中间结点和合并两个有序链表

生命如同寓言,其价值不在与长短,而在与内容。 ——塞涅卡 目录 一.链表的中间结点 1.快慢指针 二.合并两个有序链表 1.尾插法 一.链表的中间结点 给你单链表的头结点 head ,请你找出并返回链表的中间结…

Java——对象克隆(复制)

假如想复制一个简单变量。很简单: int apples 5; int pears apples; 不仅int类型,其它七种原始数据类型(boolean,char,byte,short,float,double.long)同样适用于该类情况。 但是如果你复制的是一个对象,情况就复杂了。 假设说我是一个b…

windows安装scoop

scoop介绍 Scoop是一款适用于Windows平台的命令行软件(包)管理工具。简单来说,就是可以通过命令行工具(PowerShell、CMD等)实现软件(包)的安装管理等需求,通过简单的一行代码实现软…

博客上几种新职业的工作指南

© 2019 Conmajia 我不是在嘲讽谁,真的😅 看了不少博客,发现了一些共同点。我觉得可以把这些博主分类一下,形成几种新的职业。 1. 超文本抄书匠Hypertext Copier Job description 拥有悠久历史的手打大师,大段抄录…

2023-04-16 算法面试中常见的栈和队列问题

栈和队列 1 栈的基础应用:20.括号匹配 class Solution {public boolean isValid(String s) {Stack<Character> stack new Stack<>();for (int i 0; i < s.length(); i) {char c s.charAt(i);if (c ( || c [ || c {) {stack.push(c);} else {// 还有字符…

【Linux】进程间通信 -- 命名管道

前言 在管道的通信中&#xff0c;除了匿名管道&#xff0c;还有一个命名管道。 匿名管道只支持具有“亲戚关系”的进程间通信&#xff0c;而命名管道就可以支持不同的&#xff0c;任意的进程通信。 那就下来就开始我们今天的学习。 文章目录 前言一. 命名管道二. 命名管道的应用…

快速了解数据仓库建模

快速了解数据仓库建模 1、什么是OLTP和OLAP&#xff1f;2、为什么不在业务系统做数据分析呢&#xff1f;3、什么是数据库建模&#xff1f;4、关系建模。5、维度建模。5.1 事实表5.2 维度表 6、 数据仓库分层。6.1、 数仓分层结构6.2、 为什么需要对数据仓库分层&#xff1f; 1、…

Linux网络服务之DHCP篇

目录一、了解DHCP服务1.1DHCP定义1.2DHCP好处1.3DHCP的分配方式二、DHCP工作过程三. 使用DHCP动态配置主机地址一、了解DHCP服务 1.1DHCP定义 DHCP&#xff08;动态主机配置协议&#xff09;是一个局域网的网络协议。指的是由服务器控制一段IP地址范围&#xff0c;客户机登录…

基于SDM450 兼容st7701s不同id屏幕

authordaisy.skye的博客_CSDN博客-嵌入式,Qt,Linux领域博主 sdm450 P326 在高通的 SDM450 中&#xff0c;有两种屏幕初始化代码&#xff0c;分别称为 "lk" 和 "kernel" 代码&#xff0c; "lk" 代码是用于在内核中初始化屏幕的代码。它通常在内核…

【Linux】进程间通信 -- 匿名管道的应用

前言 上篇博客初步学习了匿名管道的周边知识和使用&#xff0c;本篇文章将基于这些知识&#xff0c;实现一下进程间通信 话不多说&#xff0c;马上开始今天的内容 文章目录 前言一. 大体框架二. 分配任务三. 创建控制模块四. 开始通信五. 关闭程序六. 完整代码结束语 一. 大体框…