深度学习模型训练的全流程

news2024/11/23 22:30:09

目标是使用Pytorch来完成CNN的训练和验证过程,CNN网络结构。需要完成的逻辑结构如下:

构造训练集和验证集;

每轮进行训练和验证,并根据最优验证集精度保存模型。

# 将自定义的Dataset封装成一个Batch Size大小的Tensor,用于后面的训练。
# 训练集封装 批量处理数据
train_loader = torch.utils.data.DataLoader(
train_dataset,    # 数据加载
batch_size=10,     # 批处理大小设置
shuffle=True,     # 是否进项洗牌操作
num_workers=10,   # 是否进行多进程加载数据设置
)
# 验证集封装
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=10,
shuffle=False,
num_workers=10,
)
model = SVHN_Model1()
criterion = nn.CrossEntropyLoss (size_average=False) # 计算交叉熵(交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters())  # Adam优化算法
best_loss = 1000.0
for epoch in range(20):
	train(train_loader, model, criterion, optimizer, epoch)
	val_loss = validate(val_loader, model, criterion)
	# 保存验证集精度
	if val_loss < best_loss:
		best_loss = val_loss
	torch.save(model.state_dict(), './model.pt')  
	print('Epoch: ', epoch)

train()、validate()与predict()代码如下:

# 训练函数
def train(train_loader, model, criterion, optimizer, epoch):
    # 切换模型为训练模式
    model.train()
    for i, (input, target) in enumerate(train_loader):
        c0, c1, c2, c3, c4, c5 = model(data[0])
        loss = criterion(c0, data[1][:, 0]) + \
                criterion(c1, data[1][:, 1]) + \
                criterion(c2, data[1][:, 2]) + \
                criterion(c3, data[1][:, 3]) + \
                criterion(c4, data[1][:, 4]) + \
                criterion(c5, data[1][:, 5])
        loss /= 6
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# 验证函数
def validate(val_loader, model, criterion):
    # 切换模型为预测模型
    model.eval()
    val_loss = []
    # 不记录模型梯度信息
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            c0, c1, c2, c3, c4, c5 = model(data[0])
            loss = criterion(c0, data[1][:, 0]) + \
                criterion(c1, data[1][:, 1]) + \
                    criterion(c2, data[1][:, 2]) + \
                    criterion(c3, data[1][:, 3]) + \
                    criterion(c4, data[1][:, 4]) + \
                    criterion(c5, data[1][:, 5])
            loss /= 6
            val_loss.append(loss.item())
    return np.mean(val_loss)
# 预测函数
def predict(test_loader, model, tta=10):
    model.eval()
    test_pred_tta = None

    # TTA 次数
    for _ in range(tta):
        test_pred = []

        with torch.no_grad():
            for i, (input, target) in enumerate(test_loader):
                # if use_cuda:
                # input = input.cuda()

                c0, c1, c2, c3, c4, c5 = model(input)
                output = np.concatenate([
                    c0.data.numpy(),
                    c1.data.numpy(),
                    c2.data.numpy(),
                    c3.data.numpy(),
                    c4.data.numpy()], axis=1)
                test_pred.append(output)

        test_pred = np.vstack(test_pred)
        if test_pred_tta is None:
            test_pred_tta = test_pred
        else:
            test_pred_tta += test_pred
    return test_pred_tta

模型保存与加载:

# 保存模型为文件model.pt
torch.save(model_object.state_dict(), 'model.pt')

# 读取文件model.pt载入模型
model.load_state_dict(torch.load(' model.pt'))

模型调参

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/703309.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

threejs后期处理

个人博客地址: https://cxx001.gitee.io 1. 如何使用Threejs的后期处理 后期处理就是在场景渲染完后&#xff0c;最后对场景显示效果调整的手段。 使用后期处理步骤&#xff1a; &#xff08;1&#xff09;创建THREE.EffectComposer对象。(效果组合器) &#xff08;2&#x…

指定某个时间,计算和当前时间间隔几天几时几分

dateDiff(startTime,endTime) {let t1 new Date(startTime).getTime()*1000; //开始时间 2023-06-29 10:00:00let t2 new Date(endTime).getTime()*1000; //结束时间 1688090400000000 2023-06-30 10:00:00 1688092230000000 2023-06-30 10:30:30let dateTime 1000 *…

小程序反编译

第一步&#xff1a;下载软件 根据把博客下载好三个软件 夜神模拟器 RE文件管理器 Node.js 第二步&#xff1a;打开模拟器中的 “微信” 第三步&#xff1a;点击要下载的小程序 并 记录当时的时间 方便一会查找pkg文件 第四步&#xff1a;打开文件资源管理器 第五步&#xff1a…

PyTorch的ONNX结合MNIST手写数字数据集的应用(.pth和.onnx的转换与onnx运行时)

在PyTorch以前的模型都是.pth格式&#xff0c;后面Meta跟微软一起做了一个.onnx的通用格式。这里对这两种格式文件&#xff0c;分别做一个介绍&#xff0c;依然使用MNIST数据集来做示例 1、CUDA下的pth文件 那pth文件里面是什么结构呢&#xff1f;其实在以前的文章就有介绍过…

0基础学习VR全景平台篇 第50篇:高级功能-自定义右键

本期为大家带来蛙色VR平台&#xff0c;高级功能—自定义右键功能操作。 功能位置示意 一、本功能将用在哪里&#xff1f; 自定义右键功能&#xff0c;观看者可通过电脑端右键和手机端长按屏幕&#xff0c;出现作者配置的自定义内容&#xff0c;使VR全景玩法变得多样化。 二、…

欧科云链2023年报:毛利达1.55亿港元,数字资产业务成最大增长点

据香港商报报道&#xff0c;2023年6月28日&#xff0c;欧科云链控股有限公司&#xff08;以下简称“欧科云链”&#xff09;及其附属公司&#xff08;股份代号&#xff1a;1499.HK&#xff0c;以下简称“集团”&#xff09;发布了截至2023年3月31日的年度报告。报告期内&#x…

工业读码器的选择和使用注意事项有哪些?

工业读码器是一种能够读取条形码、二维码等信息的设备&#xff0c;广泛应用于物流、生产制造、零售等行业。如何选择和使用工业读码器呢?下面是一些注意事项。 选择工业读码器 要根据应用场景选择合适的读码器类型&#xff0c;如手持式、固定式、手动旋转式等。 要考虑读取码的…

【C++】详解多态

目录 一、多态的概念二、多态的定义及实现1、多态的构成条件2、虚函数3、虚函数的重写1、虚函数重写的两个例外 4、C11 override 和 final5、重载、覆盖(重写)、隐藏(重定义)的对比 三、抽象类1、概念2、接口继承和实现继承 四、多态的原理1、虚函数表2、多态的原理3、动态绑定…

Mysql架构篇--Mysql(M-S) 主从同步

文章目录 前言&#xff1a;一、主从同步是什么&#xff1f;二、主从同步实现&#xff1a;1.准备工作&#xff1a;2.开启主从复制&#xff1a;2.1 mysql 服务端配置文件修改&#xff1a;2.2 mysql master 节点用户创建&#xff1a;2.3 mysql slave 节点开启数据复制&#xff1a;…

突破传统设计灵感,虚拟展厅设计方案

导语&#xff1a; 随着科技的不断发展&#xff0c;虚拟展厅设计方案正成为现代设计行业的新宠。这种创新的设计形式不仅突破了传统设计的局限&#xff0c;还为传统设计公司带来了诸多优势和特点&#xff0c;从而提高了设计产量和创意灵感。 在这篇软文中&#xff0c;我们将深入…

雅迪、爱玛谁是“新宠”?

电动两轮车下半场&#xff0c;谁是“新王”&#xff1f; 6月15日&#xff0c;爱玛科技有限公司&#xff08;下称“爱玛”&#xff0c;603529.SH)迎来了上市两周年。 作为电动两轮车的头部玩家&#xff0c;雅迪控股有限公司&#xff08;下称“雅迪”&#xff0c;01585.HK&…

HJ101 输入整型数组和排序标识,对其元素

描述 输入整型数组和排序标识&#xff0c;对其元素按照升序或降序进行排序 数据范围&#xff1a; 1≤n≤1000 1≤n≤1000 &#xff0c;元素大小满足 0≤val≤100000 0≤val≤100000 输入描述&#xff1a; 第一行输入数组元素个数 第二行输入待排序的数组&#xff0c;每个…

python实现九九乘法表

九九乘法表 i 1 while i < 9:j 1while j < i:print(f{j}*{i}{i * j}, end\t)j 1print()i 1结果&#xff1a;

window10 查看本机TCP协议进程

1. netstat 是一个常见的网络工具&#xff0c;用于显示网络连接状态、路由表、接口统计信息等网络相关的信息&#xff0c;可以帮助诊断和解决网络问题。 其中&#xff0c;各参数的含义为&#xff1a; -a&#xff1a;显示所有的网络连接和监听端口。 -e&#xff1a;显示以太网…

CDH yarn Fair 队列最大资源使用限制,任务无法提交

一、问题背景描述 1.任务提交异常日志 2023-06-29 15:48:20,877 INFO org.apache.flink.yarn.YarnClusterDescriptor [] - Deployment took more than 60 seconds. Please check if the requested resources are available in the YARN cluster 2023-06-29 15:48:21,129 IN…

1-什么是NumPy?【视频版】

目录 问题解答观看视频 问题 解答 NumPy&#xff0c;全称Numerical Python&#xff0c;是一个开源的Python科学计算库。它为Python提供了大量的数学库&#xff0c;包括&#xff1a; 强大的N维数组对象成熟的广播功能集成C/C和Fortran代码的工具有用的线性代数、傅里叶变换和随…

第一个spring程序

我们今天写第一个spring程序 我们采用maven形式创建工程。 我们首先在pom.xml中加入引用。 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSch…

(6)蜂鸣器(又称音调报警)

文章目录 6.1 使用有源蜂鸣器而不是无源蜂鸣器 6.2 安装蜂鸣器 6.3 使蜂鸣器安静 蜂鸣器&#xff08;或音调报警器&#xff09;可用于以声音指示飞行器的状态变化。根据电路板的能力&#xff0c;它可以是一个有源设备&#xff08;只需要施加电压来产生一个单一频率的音调&am…

给定一组数据样本,计算:【样本的平均值】, 【样本的标准差】, 【样本的变异系数】,【样本的标准误差】

一、指标含义 样本的平均值&#xff1a;指样本中所有数据的总和除以样本大小&#xff0c;是样本的中心趋势的度量。平均值常用于描述数据的集中程度&#xff0c;具有良好的代表性和易于计算的优点。 样本的标准差&#xff1a;指样本中每个数据与平均值的偏差的平方和的平均值的…

openssl版本升级与降级

openssl版本升级与降级 flyfish 环境 Ubuntu 22.04 1.1.1升级3.1.1 查看openssl版本 openssl versionOpenSSL 1.1.1t 7 Feb 2023https://www.openssl.org/source/ 编译和安装 ./config --prefix/usr/local/openssl311 make -j8 make install进入/usr/local/openssl311/l…