计算深度学习的参数

news2024/10/4 18:18:56

构建模型和学习率衰减

model = TextCNN().to(device)
criterion = nn.CrossEntropyLoss().to(device)  #
# optimizer = optim.AdamW(model.parameters(), lr=5e-4)  # weight_decay=1e-4 weight_decay 就是 L2 正则化系数  , betas=(0.9, 0.888)
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)  # weight_decay=1e-4 weight_decay 就是 L2 正则化系数  , betas=(0.9, 0.888)

# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=10, verbose=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, min_lr=1e-5,patience=20, verbose=True)

计算相应指标并画图

import torch
import matplotlib.pyplot as plt
from sklearn.metrics import matthews_corrcoef, f1_score, precision_score, recall_score

best_val_accuracy = 0  # 设置初始最佳验证准确率为0

# 用于存储每个 epoch 的训练和验证结果
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

for epoch in range(300):
    print('Epoch {}/{}'.format(epoch, 300))

    # 训练过程
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    all_train_preds = []
    all_train_targets = []

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        pred = model(x)
        loss = criterion(pred, y)
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
        optimizer.zero_grad()  # 清除梯度

        train_loss += loss.item()
        _, train_predicted = torch.max(pred, 1)
        train_total += y.size(0)
        train_correct += (train_predicted == y).sum().item()

        all_train_preds.extend(train_predicted.cpu().numpy())
        all_train_targets.extend(y.cpu().numpy())

    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = 100 * train_correct / train_total
    train_mcc = matthews_corrcoef(all_train_targets, all_train_preds)
    train_f1 = f1_score(all_train_targets, all_train_preds, average='weighted')
    train_precision = precision_score(all_train_targets, all_train_preds, average='weighted')
    train_recall = recall_score(all_train_targets, all_train_preds, average='weighted')

    print(f'Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%,Train MCC: {train_mcc:.4f}, Train F1: {train_f1:.4f},Train Precision: {train_precision:.4f}, Train Recall: {train_recall:.4f}')
    # print(f'Train MCC: {train_mcc:.4f}, Train F1: {train_f1:.4f}')
    # print(f'Train Precision: {train_precision:.4f}, Train Recall: {train_recall:.4f}')

    # 保存训练集上的损失和准确率
    train_losses.append(avg_train_loss)
    train_accuracies.append(train_accuracy)
    # current_lr = scheduler.optimizer.param_groups[0]['lr']
    # print(f'Current Learning Rate: {current_lr}')
    # scheduler.step(avg_val_loss)

    # 验证过程
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    all_val_preds = []
    all_val_targets = []

    with torch.no_grad():
        for inputs, target in val_loader:
            inputs, target = inputs.to(device), target.to(device)
            output = model(inputs)
            loss = criterion(output, target)

            val_loss += loss.item()
            _, val_predicted = torch.max(output, 1)
            val_total += target.size(0)
            val_correct += (val_predicted == target).sum().item()

            all_val_preds.extend(val_predicted.cpu().numpy())
            all_val_targets.extend(target.cpu().numpy())

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = 100 * val_correct / val_total
    val_mcc = matthews_corrcoef(all_val_targets, all_val_preds)
    val_f1 = f1_score(all_val_targets, all_val_preds, average='weighted')
    val_precision = precision_score(all_val_targets, all_val_preds, average='weighted')
    val_recall = recall_score(all_val_targets, all_val_preds, average='weighted')
    ################################
    current_lr = scheduler.optimizer.param_groups[0]['lr']
    print(f'Current Learning Rate: {current_lr}')
    scheduler.step(avg_val_loss)

    print(f'Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}% ,Validation MCC: {val_mcc:.4f}, Validation F1: {val_f1:.4f} ,Validation Precision: {val_precision:.4f}, Validation Recall: {val_recall:.4f}')
    # print(f'Validation MCC: {val_mcc:.4f}, Validation F1: {val_f1:.4f}')
    # print(f'Validation Precision: {val_precision:.4f}, Validation Recall: {val_recall:.4f}')

    # 保存验证集上的损失和准确率
    val_losses.append(avg_val_loss)
    val_accuracies.append(val_accuracy)

    # 如果需要保存验证集上表现最好的模型,可以添加如下代码
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), 'best_model_{}.pth'.format(epoch))
        print('Best model saved best_model_{}.pth'.format(epoch))

# 训练结束后绘制损失和准确率曲线
epochs = range(1, len(train_losses) + 1)

plt.figure(figsize=(12, 5))

# 绘制损失曲线
plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, 'b', label='Train Loss')
plt.plot(epochs, val_losses, 'r', label='Validation Loss')
plt.title('Train and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# 绘制准确率曲线
plt.subplot(1, 2, 2)
plt.plot(epochs, train_accuracies, 'b', label='Train Accuracy')
plt.plot(epochs, val_accuracies, 'r', label='Validation Accuracy')
plt.title('Train and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.tight_layout()
plt.show()

#
# from google.colab import drive
# drive.mount('/content/drive')
#
# # 保存模型到 Google Drive 中
# model_save_path = '/content/drive/MyDrive/best_model_{}.pth'.format(epoch)
# torch.save(BiGRU.state_dict(), model_save_path)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2188469.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Chromium 中前端js XMLHttpRequest接口c++代码实现

在JavaScript中发出HTTP请求的主要方式包括&#xff1a;XMLHttpRequest对象、Fetch API、Axios库和各种其他的HTTP客户端库。 本人主要分析下XMLHttpRequest接口在c中对应实现 一、上前端代码 <!DOCTYPE html> <html lang"en"> <head> <meta…

Go基础学习11-测试工具gomock和monkey的使用

文章目录 基础回顾MockMock是什么安装gomockMock使用1. 创建user.go源文件2. 使用mockgen生成对应的Mock文件3. 使用mockgen命令生成后在对应包mock下可以查看生成的mock文件4. 编写测试代码5. 运行代码并查看输出 GomonkeyGomonkey优势安装使用对函数进行monkey对结构体中方法…

Marp精华总结(二)进阶篇

概述 这是Marp精华总结的第二篇&#xff0c;主要补充第一篇未提到的一些内容。 系列目录 Marp精华总结&#xff08;一&#xff09;基础篇Marp精华总结&#xff08;二&#xff09;进阶篇Marp精华总结&#xff08;三&#xff09;高级篇 自适应标题 通过在标题行中插入<!-…

历经十年/头发都快掉光/秘钥生成器终极版/机器码/到期功能限制/运行时间限制/日期防篡改/跨平台

一、项目介绍 1.0 前言说明 标题一点都不夸张&#xff0c;从第一版的秘钥生成器到今天这个版本&#xff0c;确实经历了十年的时间&#xff0c;最初的版本做的非常简陋&#xff0c;就是搞了个异或加密&#xff0c;控制运行时间&#xff0c;后面又增加设备数量的控制&#xff0…

JavaFX加载fxml文件几种方法

环境&#xff1a;idea&#xff0c;maven创建JavaFX工程 工程目录如下&#xff1a; MusicPlayer.java package cn.com;import java.io.IOException;import javafx.application.Application; import javafx.fxml.FXMLLoader; import javafx.geometry.Insets; import javafx.geo…

目标检测 Deformable DETR(2021)详细解读

文章目录 前言整体网络架构可变形注意力模块backbone生成多尺度特征多尺度位置编码prediction heads两个变体 前言 为解决DETR attention的计算量大导致收敛速度慢、小目标检测效果差的问题&#xff1a;提出了Deformable Attention&#xff0c;其注意力模块只关注一个query周围…

ML 系列: (10)— ML 中的不同类型的学习

一、说明 我们之前将机器学习方法分为三类&#xff1a;监督学习、无监督学习和强化学习。机器学习方法可以分为不同的类型&#xff0c;我们将在下面讨论最重要的类型。 二、懒惰学习与急切学习 预先学习的工作原理是使用训练数据构建模型&#xff0c;然后使用此模型评估测试数据…

STM32F103C8----3-3 蜂鸣器(跟着江科大学STM32)

一&#xff0c;电路图 &#xff08;接线图&#xff09; 面包板的的使用请参考&#xff1a;《面包板的使用_面包板的详细使用方法-CSDN博客》 二&#xff0c;目的/效果 3-3 蜂鸣器 三&#xff0c;创建Keil项目 详细参考&#xff1a;《STM32F103C8----2-1 Keil5搭建STM32项目模…

MySQL 中的 EXPLAIN 命令详解

在 MySQL 数据库中&#xff0c;EXPLAIN命令是一个非常强大的工具&#xff0c;它可以提供关于 SQL 查询执行计划的关键信息。理解这些信息对于优化查询性能至关重要。本文将详细介绍 MySQL 中的EXPLAIN命令提供的关键信息。 一、什么是 EXPLAIN 命令 EXPLAIN命令用于获取 MySQ…

Java多态(向上转型、动态绑定)+结合题目理解原理

第一次尝试使用markdowm写博客哈 文章目录 1.多态的引入2.重写和重载3.避免在构造方法里面去调用重写4.向上转型和向下转型5.让你真正明白什么是多态6.通过一些习题进行理解 1.多态的引入 首先说一下&#xff0c;这个想要使用多态需要我们满足的条件&#xff0c;然后具体的进行…

进程概念(冯诺依曼体系结构、操作系统、进程)-- 详解

目录 一、冯诺依曼体系结构1、概念2、硬件层面的数据流3、关于冯诺依曼的知识点强调4、CPU 工作原理5、补充&#xff08;CPU 和寄存器、高速缓存以及主存之间的关系&#xff09; 二、操作系统&#xff08;Operating System&#xff09;1、概念2、定位3、设计 OS 的目的4、如何理…

Linux高级编程_28_进程

文章目录 进程并行与并发单道与多道程序进程控制块(PCB)了解PCB存储位置进程号&#xff1a;进程号&#xff1a;&#xff08;PID&#xff09;进程组号&#xff1a;&#xff08;PGID&#xff09;父进程号&#xff1a;&#xff08;PPID&#xff09; fork函数 多进程创建进程状态进…

基于vue框架的大学生勤工俭学咨询服务系统的设计与实现60uw9(程序+源码+数据库+调试部署+开发环境)系统界面在最后面。

系统程序文件列表 项目功能&#xff1a;大学生,企业,招聘信息,在线咨询,咨询回复,职位应聘 开题报告内容 基于Vue框架的大学生勤工俭学咨询服务系统的设计与实现 开题报告 一、研究背景 随着高等教育的普及与就业市场的竞争加剧&#xff0c;大学生勤工俭学已成为一种普遍现…

<<机器学习实战>>1-9节笔记

2.前言与导学 从关注算法的分类与特性到关注算法适合解决哪类问题 很多经典算法不再有效&#xff0c;但特征工程、集成学习越来越有效&#xff0c;和深度学习分别适合于不同领域 3、基本概念 如果预测目标是离散的&#xff0c;则是分类问题&#xff0c;否则回归 机器学习相比…

【AIGC】ChatGPT开发者必备:如何获取 OpenAI 的 API Key

博客主页&#xff1a; [小ᶻZ࿆] 本文专栏: AIGC | ChatGPT 文章目录 &#x1f4af;前言&#x1f4af;API Key的重要性&#x1f4af;获取API Key的基本步骤&#x1f4af;定价策略和使用建议&#x1f4af;小结 &#x1f4af;前言 在现代应用开发中&#xff0c;获取OpenAI的…

TCP Analysis Flags 之 TCP ZeroWindowProbe

前言 默认情况下&#xff0c;Wireshark 的 TCP 解析器会跟踪每个 TCP 会话的状态&#xff0c;并在检测到问题或潜在问题时提供额外的信息。在第一次打开捕获文件时&#xff0c;会对每个 TCP 数据包进行一次分析&#xff0c;数据包按照它们在数据包列表中出现的顺序进行处理。可…

什么是沉默成本?超详细+通俗易懂版

沉默成本是一个在会计学、金融学以及经济学中常用的概念&#xff0c;但更常见的表述是沉没成本&#xff08;Sunk Cost&#xff09;。沉没成本指的是已经发生且无法收回的成本&#xff0c;这些成本与当前的决策无关&#xff0c;但往往会影响人们的决策过程。以下是对沉没成本的详…

【MySQL】Ubuntu环境下MySQL的安装与卸载

目录 1.MYSQL的安装 2.MYSQL的卸载 1.MYSQL的安装 首先我们要看看我们环境里面有没有已经安装好的MySQL 我们发现是默认是没有的。 我们还可以通过下面这个命令来确认有没有mysql的安装包 首先我们得知道我们当前的系统版本是什么 lsb_release -a 我们在找apt源的时候&a…

vulnhub-unknowndevice64 2靶机

vulnhub&#xff1a;https://www.vulnhub.com/entry/unknowndevice64-2,297/ 导入靶机&#xff0c;放在kali同网段&#xff0c;扫描 靶机在192.168.81.9&#xff0c;扫描端口 啥啊这都是&#xff0c;详细扫描一下 5555是adb&#xff0c;6465是ssh&#xff0c;12345看样子应该是…