PyTorch深度学习模型训练流程:(二、回归)

news2024/11/13 14:47:31

回归的流程与分类基本一致,只需要把评估指标改动一下就行。回归输出的是损失曲线、R^2曲线、训练集预测值与真实值折线图、测试集预测值散点图与真实值折线图。输出效果如下:

 注意:预测值与真实值图像处理为按真实值排序,图中呈现的升序与数据集趋势无关。

代码如下:

from functools import partial
import numpy as np
import pandas as pd
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, r2_score

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Dataset
from visdom import Visdom

from typing import Union, Optional
from sklearn.base import TransformerMixin
from torch.optim.optimizer import Optimizer


def regress(
        data: tuple[Union[np.ndarray, Dataset], Union[np.ndarray, Dataset]],
        model: nn.Module,
        optimizer: Optimizer,
        criterion: nn.Module,
        scaler: Optional[TransformerMixin] = None,
        batch_size: int = 64,
        epochs: int = 10,
        device: Optional[torch.device] = None
) -> nn.Module:
    """
    回归任务的训练函数。
    :param data: 形如(X,y)的np.ndarray类型,及形如(train_data,test_data)的torch.utils.data.Dataset类型
    :param model: 回归模型
    :param optimizer: 优化器
    :param criterion: 损失函数
    :param scaler: 数据标准化器
    :param batch_size: 批大小
    :param epochs: 训练轮数
    :param device: 训练设备
    :return: 训练好的回归模型
    """
    if isinstance(data[0], np.ndarray):
        X, y = data
        # 分离训练集和测试集,指定随机种子以便复现
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        # 数据标准化
        if scaler is not None:
            X_train = scaler.fit_transform(X_train)
            X_test = scaler.transform(X_test)
        # 转换为tensor
        X_train = torch.from_numpy(X_train.astype(np.float32))
        X_test = torch.from_numpy(X_test.astype(np.float32))
        y_train = torch.from_numpy(y_train.astype(np.float32))
        y_test = torch.from_numpy(y_test.astype(np.float32))
        # 将X和y封装成TensorDataset
        train_dataset = TensorDataset(X_train, y_train)
        test_dataset = TensorDataset(X_test, y_test)

    elif isinstance(data[0], Dataset):
        train_dataset, test_dataset = data
    else:
        raise ValueError('Unsupported data type')

    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
    )
    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
    )

    model.to(device)
    vis = Visdom()
    # 训练模型
    for epoch in range(epochs):
        for step, (batch_x_train, batch_y_train) in enumerate(train_loader):
            batch_x_train = batch_x_train.to(device)
            batch_y_train = batch_y_train.to(device)
            # 前向传播
            output = model(batch_x_train)
            loss = criterion(output, batch_y_train)
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            niter = epoch * len(train_loader) + step + 1  # 计算迭代次数
            if niter % 100 == 0:
                # 评估模型
                model.eval()
                with torch.no_grad():
                    eval_dict = {
                        'test_loss': [],
                        'test_r2': [],
                        'y_test': [],
                        'y_pred': [],
                    }
                    for batch_x_test, batch_y_test in test_loader:
                        batch_x_test = batch_x_test.to(device)
                        batch_y_test = batch_y_test.to(device)
                        test_output = model(batch_x_test)
                        test_predicted_tuple = (batch_y_test.numpy(), test_output.numpy())
                        # 计算并记录损失、R^2、真实值、预测值
                        eval_dict['test_loss'].append(criterion(test_output, batch_y_test))
                        eval_dict['test_r2'].append(r2_score(*test_predicted_tuple))
                        eval_dict['y_test'].append(batch_y_test)
                        eval_dict['y_pred'].append(test_output)

                    # 画出损失曲线
                    vis.line(
                        X=torch.ones((1, 2)) * (niter // 100),
                        Y=torch.stack((loss, torch.mean(torch.tensor(eval_dict['test_loss'])))).unsqueeze(0),
                        win='loss',
                        update='append',
                        opts=dict(title='Loss', legend=['train_loss', 'test_loss']),
                    )
                    # 画出R^2曲线
                    train_r2 = r2_score(batch_y_train.numpy(), output.numpy())
                    vis.line(
                        X=torch.ones((1, 2)) * (niter // 100),
                        Y=torch.tensor((train_r2, np.mean(eval_dict['test_r2']))).unsqueeze(0),
                        win='R^2',
                        update='append',
                        opts=dict(title='R^2', legend=['train_R^2', 'test_R^2'], ytickmin=0, ytickmax=1),
                    )
                    # 画出训练集预测值和真实值折线图
                    sorted_train_idx = torch.argsort(batch_y_train)  # 按真实值排序
                    vis.line(
                        X=torch.arange(batch_size).repeat(2, 1).t(),
                        Y=torch.stack((batch_y_train[sorted_train_idx], output[sorted_train_idx]), dim=1),
                        win='batch_train_line',
                        opts=dict(title='Predicted vs. Actual (Train Set)', legend=['Actual', 'Predicted']),
                    )
                    # 画出测试集预测值散点图和真实值折线图
                    x = list(range(len(y_test)))
                    y_test = torch.cat(eval_dict['y_test'])
                    y_pred = torch.cat(eval_dict['y_pred'])
                    sorted_test_idx = torch.argsort(y_test)
                    vis._send({
                        'data': [
                            {'x': x, 'y': y_test[sorted_test_idx].tolist(), 'type': 'custom', 'mode': 'lines', 'name': 'Actual'},
                            {'x': x, 'y': y_pred[sorted_test_idx].tolist(), 'type': 'custom', 'mode': 'markers', 'name': 'Predicted', 'marker': {'size': 3}}
                        ],
                        'win': 'test_line',
                        'layout': {'title': 'Predicted vs. Actual (Test Set)'},
                    })
    return model

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2079740.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习概述与应用:深度学习、人工智能与经典学习方法

引言 机器学习(Machine Learning)是人工智能(AI)领域中最为核心的分支之一,其主要目的是通过数据学习和构建模型,帮助计算机系统自动完成特定任务。随着深度学习(Deep Learning)的崛起,机器学习技术在各行各业中的应用变得越来越广泛。在本文中,我们将详细介绍机器学…

Datawhale X 李宏毅苹果书 AI夏令营 Task1笔记

课程内容 学习笔记 (一)术语解释 一 . 机器学习(Machine Learning,ML) 机器学习,在本书的解释中是让机器具备找一个函数的能力。个人理解是基于所拥有的数据构建起概率统计模型来对数据进行预测与分析。…

python可视化-散点图

散点图可以了解数据之间的各种相关性,如正比、反比、无相关、线性、指数级、 U形等,而且也可以通过数据点的密度(辅助拟合趋势线)来确定相关性的强度。另外,也可以探索出异常值(在远超出一般聚集区域的数据…

【Java】Record的使用 (简洁教程)

Java系列文章目录 补充内容 Windows通过SSH连接Linux 第一章 Linux基本命令的学习与Linux历史 文章目录 Java系列文章目录一、前言二、学习内容:三、问题描述四、解决方案:4.1 为什么引入Record4.2 Record与Class区别4.3 使用场景 五、总结:…

使用uni-app开发微信小程序

一、前提环境 1.1 :uniapp开发文档:https://uniapp.dcloud.net.cn/quickstart-cli.html 细节都在这一页,这里不过多解释 二、开发工具下载 2.1 微信开发者工具 下载链接:https://developers.weixin.qq.com/miniprogram/dev/dev…

Java:Calendar类

文章目录 Calendar类常用方法代码 黑马学习笔记 Calendar类 calendar是可变对象,一旦修改后其对象本身表示的时间将发生变化 原始对象会跟着修改,造成原始对象的丢失 常用方法 代码 package Time;import java.util.Calendar; import java.util.Date;/…

【RabbitMQ高级特性】消息可靠性原理

1. 消息确认机制 1.1 介绍 我们可以看到RabbitMQ的消息流转图: 当消息从Broker投递给消费者的时候会存在以下两种情况: consumer消费消息成功consumer消费消息异常 如果说RabbitMQ在每次将消息投递给消费者的时候就将消息从Broker中删除&#xff0c…

用 like concat 不用 like,为了防止sql注入;#{}和${}的区别和用法;#{}预防SQL注入的原理

一、like concat 和 like mybatis中为了防止sql注入&#xff0c;使用like语句时并不是直接使用&#xff0c;而是使用concat函数<if test"goodName ! null and goodName ! "> and good_name like concat(%, #{goodName}, %)</if> concat()函数1、功能&a…

Webbench1.5安装使用Ubuntu

1、安装依赖包 sudo apt-get update sudo apt-get install libtirpc-dev2、安装Webbench1.5 参考https://github.com/baiguo/webbench-1.5 # 可能需要root权限&#xff0c;我是切换到root用户才安装成功 wget http://home.tiscali.cz/~cz210552/distfiles/webbench-1.5.tar.…

APP.vue引入子组件进行页面展示

一.将vue项目启动服务器原始页面进行清空 打开APP.vue文件&#xff0c;将<template>标签里的内容和<style>标签里的内容 ctrl/ 选中进行注释&#xff0c;以及引入的Helloworld.vue文件内容代码进行注释 并且 ctrls 保存 服务器页面从原始页面 变为空白 二.在comp…

树莓派4B安装golang最新版(20210520)

前置条件&#xff1a; 树莓派4B 安装官方系统 Linux raspberrypi 5.10.17-v7l #1414 更换最新版的原因&#xff1a; 截至 2021.5.20 &#xff0c;Raspberry Pi OS 最新版系统中&#xff0c;默认安装golang1.11&#xff0c;但是使用 go get golang.org/x/crypto/ssh 时&#xff…

推荐系统实战(七)-多任务多场景(上)多任务

多任务Multi-Task&#xff0c;有时也被称为多目标Multi-Objective建模。比如说电商场景下&#xff0c;希望曝光的物料被多多点击&#xff0c;还希望商品被下单购买&#xff0c;因此同时建模三个目标&#xff1a;曝光到点击CTR&#xff0c;点击到购买转换率CVR&#xff0c;曝光到…

记一次对某佛教系统的漏洞挖掘

前言 简单记录一次漏洞挖掘&#xff0c;一个系统居然爆了这么多类型的洞&#xff0c;于是想记录哈。(比较基础&#xff0c;我是菜狗&#xff0c;大佬轻喷) 业务介绍 是一个某佛教的系统 有一些佛教的学习资源、一些佛教相关的实物商品可购买&#xff0c;有个人中心&#xff…

PyCharm中python语法要求——消去提示波浪线

PyCharm中python语法要求——消去提示波浪线 关闭代码规范检查 在Setting里边搜索pep&#xff0c;取消勾选pep8 coding style violation 问题产生 解决问题 按照下图操作&#xff0c;也可直接CtrlAlts弹出设置页面 在 Settings 中 &#xff1a; Editor > Color Sheame >…

设计模式26-解析器模式

设计模式26-解析器模式 动机定义与结构定义结构 C代码推导代码说明 优缺点应用总结 动机 在软件构建过程中&#xff0c;如果某一特定领域的问题比较复杂&#xff0c;类似结构会不断重复的出现。如果使用普通的编程方式来实现&#xff0c;将面临非常频繁的变化。 在这种情况下&…

二叉树算法算法【二叉树的创建、插入、删除、查找】

一、原理 1.1、二叉排序树的插入 1.2、二叉树的删除 &#xff08;1&#xff09;删除度为0的节点&#xff0c;就是最后的叶子节点&#xff0c;直接删除就可以了. &#xff08;2&#xff09;删除度为1的节点&#xff0c;就是爷爷节点接收孙子节点。 &#xff08;3&#xff09;删…

什么软件可以约束员工摸鱼行为?「5款软件助力企业管控员工上班摸鱼!」

你的企业是否也在面临这些问题&#xff1a; 1.工作效率下降&#xff1a;频繁的分心会打断工作连贯性&#xff0c;降低任务完成的质量和速度。 2.团队协作受损&#xff1a;个别员工的低效可能导致整个团队进度滞后&#xff0c;影响项目按时交付。 3.资源浪费&#xff1a;非工…

Git —— 1、Windows下安装配置git

Git简介 Git 是一个免费的开源分布式版本控制系统&#xff0c;旨在处理从小型到 快速高效的超大型项目。 Git 易于学习&#xff0c;占用空间小&#xff0c;性能快如闪电。 它超越了 Subversion、CVS、Perforce 和 ClearCase 等 SCM 工具 具有 cheap local branching、 方便的暂…

【分布式架构幂等性总结】

文章目录 幂等性什么场景需要幂等设计&#xff1f;产生幂等性的原因解决重复操作&#xff0c;实现幂等性 幂等性 接口幂等性就是用户对于同一操作发起的一次请求或者多次请求的结果是一致的&#xff0c;不会因为多次点击而产生了副作用。比如&#xff1a;公交车刷卡&#xff0…

.NET8 Web 利用BAT命令 一键部署 IIS - CI-CD基础

1. Windows Server 前置准备 1.1 IIS安装好 1.2 .NET8 Sdk 运行时 安装 官方下载地址&#xff1a;https://dotnet.microsoft.com/zh-cn/download/dotnet/8.0 1.3 创建一个.NET8 WebMvc项目 生成发布包 微软MVC这个项目模板直接创建&#xff0c;发布 2. 利用 BAT 来一键部署…