【人工智能概论】 K折交叉验证

news2025/1/15 21:33:58

【人工智能概论】 K折交叉验证

文章目录

  • 【人工智能概论】 K折交叉验证
  • 一. 简单验证及其缺点
    • 1.1 简单验证简介
    • 1.2 简单验证的缺点
  • 二. K折交叉验证
    • 2.1 K折交叉验证的思路
    • 2.2 小细节
    • 2.3 K折交叉验证的缺点
    • 2.4 K折交叉验证的代码


一. 简单验证及其缺点

1.1 简单验证简介

  • 简单验证: 将原始数据集随机划分成训练集和验证集两部分,例,将数据按照7:3的比例分成两部分,70%的样本用于训练模型;30%的样本用于模型验证,如下图。

1.2 简单验证的缺点

  • 数据都只被用了一次;
  • 验证集上计算出来的评估指标与原始分组有很大关系;
  • 对于时序序列,要保存时序信息,往往不能打乱数据的顺序对数据进行随机截取,这就带来了问题,比如总用春、夏、秋的数据做训练,用冬的数据做测试,这显然是有问题的,是不能容忍的。

二. K折交叉验证

  • 为了解决简单交叉验证的不足,引出K折交叉验证,其既可以解决数据集的数据量不够大的问题,也可以解决参数调优的问题。。

2.1 K折交叉验证的思路

  1. 首先,将全部样本划分成k个大小相等的样本子集;
  2. 依次遍历这k个子集,每次把当前子集作为验证集,其余所有样本作为训练集,进行模型的训练和评估;
  3. 最后把k次评估指标的平均值作为最终的评估指标。在实际实验中,k通常取10,如下图。

在这里插入图片描述

2.2 小细节

  • K折交叉验证中有这样一个细节,下一折的训练不是在上一折的基础上进行的,即每训练新的一折都要重新初始化模型参数。

2.3 K折交叉验证的缺点

  • 因为K折交叉验证执行一次训练的总轮数是每一折的训练轮数(epochs)与总折数(K)的乘积,因此训练的成本会翻倍。

2.4 K折交叉验证的代码

import torch
import random
from torch.utils.data import DataLoader, TensorDataset
from Model.ReconsModel.Recoder import ReconsModel, Loss_function
from Model.ModelConfig import ModelConfig

# 返回第 i+1 折(i取 0 ~ k-1)的训练集(train)与验证集(valid)
def get_Kfold_data(k, i, x):  # k是折数,取第i+1折,x是特征数据
    fold_size = x.size(0) // k  # 计算每一折中的数据数量
    val_start = i * fold_size  # 第 i+1折 数据的测试集初始数据编号
    if i != k - 1:  # 不是最后一折的话,数据的分配策略
        val_end = (i + 1) * fold_size  # 验证集的结束
        valid_data = x[val_start: val_end]
        train_data = torch.cat((x[0: val_start], x[val_end:]), dim=0)
    else:  # 如果是最后一折,数据的分配策略,主要涉及到不能K整除时,多出的数据如何处理
        valid_data = x[val_start:]  # 实际上,多出来的样本,都放在最后一折里了
        train_data = x[0: val_start]

    return train_data, valid_data


# k折交叉验证,某一折的训练
def train(model, train_data, valid_data, batch_size, lr,epochs):
    # 数据准备
    train_loader = DataLoader(TensorDataset(train_data), batch_size, shuffle=True)
    valid_loader = DataLoader(TensorDataset(valid_data), batch_size, shuffle=True)

    # 损失函数,优化函数的准备
    criterion = Loss_function()
    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)

    # 记录每一个epoch的平均损失
    train_loss = []
    valid_loss = []


    for epoch in range(epochs):
        tra_loss = 0
        val_loss = 0
        for i , data in enumerate(train_loader):

            # 假设数据的处理 此时的data是list类型的数据,转化成Tensor,并且把多出来的第0维去掉
            data = torch.stack(data)
            data = data.squeeze(0)


            optimizer.zero_grad()  # 梯度清零
            recon, mu, log_std = model(data, if_train=True)  # if_train不能少

            # 计算损失
            loss = criterion.loss_function(recon, data, mu, log_std)

            # 反向传播
            loss.backward()
            optimizer.step()

            tra_loss = tra_loss + loss.item()
        tra_loss = tra_loss / len(train_data)
        train_loss.append(tra_loss)

        # 计算测试集损失
        with torch.no_grad():
            for i, data in enumerate(valid_loader):

                # 假设数据的处理 此时的data是list类型的数据,转化成Tensor,并且把多出来的第0维去掉
                data = torch.stack(data)
                data = data.squeeze(0)

                optimizer.zero_grad()

                recon, mu, log_std = model(data, if_train=False)

                test_loss = criterion.loss_function(recon, data, mu, log_std).item()

                val_loss = val_loss + test_loss
            val_loss = val_loss / len(valid_data)
            valid_loss.append(val_loss)

        print('第 %d 轮, 训练的平均误差为%.3f, 测试的平均误差为%.3f 。'%(epoch+1, tra_loss, val_loss))
    return train_loss, valid_loss

# k折交叉验证
def k_test(config, datas): # k是总折数,
    valid_loss_sum = 0

    for i in range(config.k):

        model = ReconsModel(config) # 细节,每一折,并不是在上一折训练好的模型基础上继续训练,而是重新训练

        print('-'*25,'第',i+1,'折','-'*25)

        train_data , valid_data = get_Kfold_data(config.k, i, datas) # 获取某一折的训练数据、测试数据

        train_loss, valid_loss = train(model, train_data, valid_data, config.batch_size, config.lr, config.epochs)

        # 求某一折的平均损失
        train_loss_ave = sum(train_loss)/len(train_loss)
        valid_loss_ave = sum(valid_loss)/len(valid_loss)
        print('-*-*-*- 第 %d 折, 平均训练损失%.3f,平均检验损失%.3f -*-*-*-'%(i+1, train_loss_ave,valid_loss_ave))
        valid_loss_sum = valid_loss_sum + valid_loss_ave

    valid_loss_k_ave = valid_loss_sum / config.k  # 基于K折交叉验证的验证损失
    print('*' * 60, )
    print('基于K折交叉验证的验证损失为%.4f'%valid_loss_k_ave)




if __name__ == "__main__":
    # 创建数据集,或者说数据集只要是这样的形式即可
    X = torch.rand(5000, 16, 38)  # 5000条数据,,每条有16个时间步,每步38个特征,时序数据

    # 随机打乱
    index = [i for i in range(len(X))]
    random.shuffle(index)
    X = X[index]  # 要是有标签的话,index要对得上

    config = ModelConfig()
    config.load('./Model/config.json')

    k_test(config, X)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/501449.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

join 语句使用

目录 前言 创建数据 知识点补充 Join算法Index Nested-Loop 小结: Join算法Block Nested-Loop join_buffer放不下驱动表情况 小结: 小表是什么? 总结: 参考内容 前言 在实际开发中,我们一般会有两类问题&a…

腾讯云2核4G服务器5M带宽轻量CPU性能、流量和系统盘测试

腾讯云轻量应用服务器2核4G5M配置,自带5M公网带宽,5M带宽下载速度峰值可达640KB/秒,系统盘为60GB SSD盘,每月500GB流量包,折合每天16GB流量。腾讯云百科来详细说下腾讯云轻量应用服务器2核4G5M配置、CPU型号处理器主频…

威联通nas服务器中勒索病毒被encrypted勒索病毒攻击怎么办有哪些预防措施

威联通是一家专业提供网络存储设备和应用方案的公司,旗下NAS服务器因为实用、多功能而深受用户喜欢,但是NAS服务器在使用过程中也面临许多安全问题,例如被encrypted勒索病毒攻击。下面将为大家介绍encrypted勒索病毒在威联通NAS服务器上的危害…

黑马---Redis入门到实战【实战篇】

一、短信登录 基于session实现短信登录的流程 实现发送短信验证码功能 发送验证码功能: Overridepublic Result sendCode(String phone, HttpSession session) {//1.校验手机号if(RegexUtils.isPhoneInvalid(phone)){//2.如果不符合,返回错误信息return…

Java 基础进阶篇(十三)—— 异常处理机制

文章目录 一、异常概述、体系二、异常的分类三、异常的默认处理流程四、异常的处理机制4.1 编译时异常的处理机制4.1.1 方式一:抛出异常4.1.2 方式二:捕获异常4.1.3 方式三:前两者结合 4.2 运行时异常的处理机制 五、自定义异常5.1 自定义编译…

程序员面试金典10.*

文章目录 10.1合并排序的数组10.02变位词组10.03搜索旋转数组10.05稀疏数组搜索10.09排序矩阵查找10.10 数字流的秩10.11 峰与谷 10.1合并排序的数组 这个就从后往前加入到新数组里就行。如果B的下标是-1则结束,A的下标是-1则一直加B的元素。 class Solution { pub…

挑战14天学完Python---初识python基本图形绘制

往期文章 目录 往期文章前言1."Python蟒蛇绘制"实例2.Python标准库 之turtle库3. 面向对象编程风格3.1 import更多玩法3.1.1使用from和import保留字共同完成3.1.2 使用import和as保留字共同完成 4.turtle的原(wan)理 (fa)4.1 turtle绘图窗体布局---turtul.setup()4.2…

京东小程序折叠屏适配探索 | 京东云技术团队

前言 随着近年来手机行业的飞速发展,手机从功能机进入到智能机,手机屏幕占比也随着技术和系统的进步越来越大,特别是Android 10推出以后,折叠屏逐渐成为Android手机发展的趋势。 图 1 Android手机屏幕发展趋势 京东小程序近年来…

Python程序员辞职后,如何踏出自由职业的第一步,聊聊我自己的看法

大家好,我是兴哥。有个广州的朋友说他辞职了,想要自由职业该怎么开始第一步呢?我问他你之前的收入月薪是多少,他说2万出头。我不得不说,对于写项目的自由职业程序员,2万是一个极高的门槛。但既然他已经辞职…

第三十章 React的路由基本使用

关于React路由,我们在学习之前先了解一下其他知识点:SPA应用、路由的理解、react中如何使用路由。 SPA应用的理解 我们知道React脚手架给我们构建的是一个单页应用程序(SPA),在页面加载时,只会加载一个HT…

2.Redis入门概述

1.Redis是什么 Remote Dictionary Server(远程字典服务)是完全开源的,使用ANSIC语言编写遵守BSD协议, 是一个高性能的Key-Value数据库, 提供了丰富的数据结构,例如String、Hash、List、Set、SortedSet等等。 数据是存在内存中的&a…

学会这几个Word技巧,让你办公省时又省力(二)

Word是我们经常用到的办公软件,下面分享的几个小技巧,可以提高你的办公效率,一起看看吧。 1. 改变Word文档的背景颜色 有时候我们打开的Word文档是有颜色的,如果你想恢复白色背景,或者改成其他颜色,只…

《Linux 内核设计与实现》08. 下半部和推后执行的工作

文章目录 下半部软中断软中断的实现使用软中断 tasklettasklet 的实现使用 tasklet 工作队列工作队列的实现使用工作队列 下半部 中断处理程序的局限性: 中断处理程序以异步方式执行,并且可能打断其它代码,因此为了避免被打断的代码停止时间…

PR控制以及使用PR控制用于单相离/并网逆变器

文章目录 前言基本知识实际使用单相离网逆变器单相并网逆变器 PR控制器离散化基本知识 DSP实现总结 前言 最近想学习一下并网逆变器,需要用到PR控制,全网找遍了许多学习资料,终于掌握的差不多了,在此做个记录,以及个人…

【每日一题】23年4月

文章目录 C 技术点多边三角形剖分的最低得分(dp思路,选不选问题)移动石子到连续(思路)1027. 最长等差数列(动态规划)1105. 填充书架(动态规划)1031 两个非重叠子数组的最大和1163.按字典序排在最…

【Java 】从源码全面解析Java 线程池

文章目录 一、引言二、使用三、源码1、初始化1.1 拒绝策略1.1.1 AbortPolicy1.1.2 CallerRunsPolicy1.1.3 DiscardOldestPolicy1.1.4 DiscardPolicy1.1.5 自定义拒绝策略1.2 其余变量 2、线程池的execute方法3、线程池的addWorker方法3.1 校验3.2 添加线程 4、线程池的 worker …

PostgreSQL 基础知识:psql 提示和技巧

对于积极使用和连接到 PostgreSQL 数据库的任何开发人员或 DBA 来说,能够访问psql命令行工具是必不可少的。在我们的第一篇文章中,我们讨论了 psql的简要历史,并演示了如何在您选择的平台上安装它并连接到 PostgreSQL 数据库。 在本文中&…

使用腾讯云快速完成网站备案的详细过程

最近总是被备案弄得血压飙升,明明是一件很简单的事情,不知道大家为什么搞得那么复杂,首先了解下为什么要备案,根据国务院令第292号《互联网信息服务管理办法》和 《非经营性互联网信息服务备案管理办法》规定,国家对经…

【TCP四次挥手】

文章目录 TCP 四次挥手过程是怎样的?为什么挥手需要四次?第一次挥手丢失了,会发生什么?第二次挥手丢失了,会发生什么?第三次挥手丢失了,会发生什么?第四次挥手丢失了,会发…

Lecture 13(Extra Material):Q-Learning

目录 Introduction of Q-Learning Tips of Q-Learning Double DQN Dueling DQN Prioritized Reply Multi-step Noisy Net Distributional Q-function Rainbow Q-Learning for Continuous Actions Introduction of Q-Learning Critic: The output values of a critic…