使用PyTorch实现逻辑回归:从训练到模型保存与性能评估

news2025/2/2 19:11:09

1. 引入必要的库

首先,需要引入必要的库。PyTorch用于构建和训练模型,pandas和numpy用于数据处理,scikit-learn用于计算性能指标。

import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, recall_score, f1_score

2. 加载自定义数据集

假设有一个CSV文件custom_dataset.csv,其中包含特征(自变量)和标签(因变量)。我们使用pandas来加载数据,并进行预处理。

# 加载自定义数据集
data = pd.read_csv('custom_dataset.csv')

# 假设数据集中有多列特征和一个二分类标签
X = data.iloc[:, :-1].values.astype(np.float32)  # 特征
y = data.iloc[:, -1].values.astype(np.float32)   # 标签

# 将标签转换为0和1
y = np.where(y == 'positive', 1, 0)

3. 构建逻辑回归模型

使用PyTorch来构建逻辑回归模型。

# 构建逻辑回归模型
class LogisticRegression(nn.Module):
    def __init__(self, num_features):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(num_features, 1)
    
    def forward(self, x):
        return torch.sigmoid(self.linear(x))

# 初始化模型
num_features = X.shape[1]
model = LogisticRegression(num_features)

4. 定义损失函数和优化器

我们使用二元交叉熵损失函数和随机梯度下降(SGD)优化器。

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

5. 训练模型

使用自定义数据集训练模型。

# 将数据转换为PyTorch的张量
X_tensor = torch.tensor(X)
y_tensor = torch.tensor(y.reshape(-1, 1))

# 训练模型
num_epochs = 100
batch_size = 32
for epoch in range(num_epochs):
    for i in range(0, len(X), batch_size):
        X_batch = X_tensor[i:i+batch_size]
        y_batch = y_tensor[i:i+batch_size]
        
        # 前向传播
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')

6. 保存模型

训练完成后,我们可以使用PyTorch的state_dict方法保存模型。

# 保存模型
torch.save(model.state_dict(), 'logistic_regression_model.pth')

7. 加载模型并进行预测

在需要时,我们可以使用PyTorch的load方法加载模型,并进行预测。

# 加载模型
model = LogisticRegression(num_features)
model.load_state_dict(torch.load('logistic_regression_model.pth'))
model.eval()

# 进行预测
with torch.no_grad():
    X_test = torch.tensor(X[:5])
    predictions = model(X_test)
    predicted_labels = (predictions > 0.5).float().numpy().flatten()

print("Predicted Labels:", predicted_labels)

8. 性能评估

计算预测结果的精确度、召回率和F1分数。

# 假设前5个样本为测试集,真实标签如下
y_true = y[:5]

# 计算性能指标
accuracy = accuracy_score(y_true, predicted_labels)
recall = recall_score(y_true, predicted_labels)
f1 = f1_score(y_true, predicted_labels)

print(f'Accuracy: {accuracy:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2290908.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C语言】main函数解析

文章目录 一、前言二、main函数解析三、代码示例四、应用场景 一、前言 在学习编程的过程中,我们很早就接触到了main函数。在Linux系统中,当你运行一个可执行文件(例如 ./a.out)时,如果需要传入参数,就需要…

深度学习练手小例子——cifar10数据集分类问题

CIFAR-10 是一个经典的计算机视觉数据集,广泛用于图像分类任务。它包含 10 个类别的 60,000 张彩色图像,每张图像的大小是 32x32 像素。数据集被分为 50,000 张训练图像和 10,000 张测试图像。每个类别包含 6,000 张图像,具体类别包括&#x…

【Git】初识Git Git基本操作详解

文章目录 学习目标Ⅰ. 初始 Git💥注意事项 Ⅱ. Git 安装Linux-centos安装Git Ⅲ. Git基本操作一、创建git本地仓库 -- git init二、配置 Git -- git config三、认识工作区、暂存区、版本库① 工作区② 暂存区③ 版本库④ 三者的关系 四、添加、提交更改、查看提交日…

【JavaEE进阶】应用分层

目录 🎋序言 🍃什么是应用分层 🎍为什么需要应用分层 🍀如何分层(三层架构) 🎄MVC和三层架构的区别和联系 🌳什么是高内聚低耦合 🎋序言 通过上⾯的练习,我们学习了SpringMVC简单功能的开…

【数据结构篇】时间复杂度

一.数据结构前言 1.1 数据结构的概念 数据结构(Data Structure)是计算机存储、组织数据的⽅式,指相互之间存在⼀种或多种特定关系的数 据元素的集合。没有⼀种单⼀的数据结构对所有⽤途都有⽤,所以我们要学各式各样的数据结构, 如&#xff1a…

【数据结构】_链表经典算法OJ(力扣/牛客第二弹)

目录 1. 题目1:返回倒数第k个节点 1.1 题目链接及描述 1.2 解题思路 1.3 程序 2. 题目2:链表的回文结构 2.1 题目链接及描述 2.2 解题思路 2.3 程序 1. 题目1:返回倒数第k个节点 1.1 题目链接及描述 题目链接: 面试题 …

深度学习之“缺失数据处理”

缺失值检测 缺失数据就是我们没有的数据。如果数据集是由向量表示的特征组成,那么缺失值可能表现为某些样本的一个或多个特征因为某些原因而没有测量的值。通常情况下,缺失值由特殊的编码方式。如果正常值都是正数,那么缺失值可能被标记为-1…

MYSQL--一条SQL执行的流程,分析MYSQL的架构

文章目录 第一步建立连接第二部解析 SQL第三步执行 sql预处理优化阶段执行阶段索引下推 执行一条select 语句中间会发生什么? 这个是对 mysql 架构的深入理解。 select * from product where id 1;对于mysql的架构分层: mysql 架构分成了 Server 层和存储引擎层&a…

C++解决输入空格字符串的三种方法

一.gets和fgets char * gets ( char * str ); char * fgets ( char * str, int num, FILE * stream ); 1. gets 是从第⼀个字符开始读取,⼀直读取到 \n 停⽌,但是不会读取 \n ,也就是读取到的内容 中没有包含 \n ,但是会在读取到的内…

多模态论文笔记——NaViT

大家好,这里是好评笔记,公主号:Goodnote,专栏文章私信限时Free。本文详细解读多模态论文NaViT(Native Resolution ViT),将来自不同图像的多个patches打包成一个单一序列——称为Patch n’ Pack—…

云中漫步:精工细作铸就免费公益刷步平台

云中漫步,历经三年深度研发与优化,平台以高稳定性、零成本及公益属性为核心特色,依托前沿技术手段与多重安全防护机制,确保用户步数数据的精准修改与隐私安全。我们致力于提供无缝流畅的用户体验,让每一次步数更新都轻…

neo4j入门

文章目录 neo4j版本说明部署安装Mac部署docker部署 neo4j web工具使用数据结构图数据库VS关系数据库 neo4j neo4j官网Neo4j是用ava实现的开源NoSQL图数据库。Neo4作为图数据库中的代表产品,已经在众多的行业项目中进行了应用,如:网络管理&am…

【ts + java】古玩系统开发总结

src别名的配置 开发中文件和文件的关系会比较复杂,我们需要给src文件夹一个别名吧 vite.config.js import { defineConfig } from vite import vue from vitejs/plugin-vue import path from path// https://vitejs.dev/config/ export default defineConfig({pl…

【Docker】快速部署 Nacos 注册中心

【Docker】快速部署 Nacos 注册中心 引言 Nacos 注册中心是一个用于服务发现和配置管理的开源项目。提供了动态服务发现、服务健康检查、动态配置管理和服务管理等功能,帮助开发者更轻松地构建微服务架构。 仓库地址 https://github.com/alibaba/nacos 步骤 拉取…

SpringCloud篇 微服务架构

1. 工程架构介绍 1.1 两种工程架构模型的特征 1.1.1 单体架构 上面这张图展示了单体架构(Monolithic Architecture)的基本组成和工作原理。单体架构是一种传统的软件架构模式,其中所有的功能都被打包在一个单一的、紧密耦合的应用程序中。 …

tf.Keras (tf-1.15)使用记录4-model.fit方法及其callbacks参数

model.fit() 方法是 TensorFlow Keras 中用于训练模型的核心方法。 其中里面的callbacks参数是实现模型保存、监控、以及和tensorboard联动的重要API 1 model.fit() 方法的参数及使用 必需参数 x: 训练数据的输入。可以是 NumPy 数组、TensorFlow tf.data.Dataset、Python 生…

Easy系列PLC尺寸测量功能块ST代码(激光微距仪应用)

激光微距仪可以测量短距离内的产品尺寸,产品规格书的测量 精度可以到0.001mm。具体需要看不同的型号。 1、激光微距仪 2、尺寸测量应用 下面我们以测量高度为例子,设计一个高度测量功能块,同时给出测量数据和合格不合格指标。 3、高度测量功能块 4、复位完成信号 5、功能…

996引擎 -地图-添加安全区

996引擎 -地图-添加安全区 文件位置配置 cfg_startpoint.xls特效效果1345参考资料文件位置 文件位置服务端D:\996M2-lua\MirServer-lua\Mir200客户端D:\996M2-lua\996M2_debug\dev配置 cfg_startpoint.xls 服务端\Mir200\Envir\DATA\cfg_startpoint.xls 填歪了也有可能只画一…

[Collection与数据结构] B树与B+树

🌸个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 🏵️热门专栏: 🧊 Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 🍕 Collection与…

redex快速体验

第一步: 2.回调函数在每次state发生变化时候自动执行