使用 Numpy 自定义数据集,使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数

news2025/2/3 6:07:01

1. 导入必要的库

首先,导入我们需要的库:Numpy、Pytorch 和相关工具包。

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score, recall_score, f1_score
2. 自定义数据集

使用 Numpy 创建一个简单的线性可分数据集,并将其转换为 Pytorch 张量。

# 创建数据集
X = np.random.rand(100, 2)  # 100 个样本,2 个特征
y = (X[:, 0] + X[:, 1] > 1).astype(int)  # 标签,若特征之和大于1则为 1,否则为 0

# 转换为 PyTorch 张量
X_train = torch.tensor(X, dtype=torch.float32)
y_train = torch.tensor(y, dtype=torch.long)
3. 定义逻辑回归模型

在 Pytorch 中定义一个简单的逻辑回归模型。

class LogisticRegressionModel(nn.Module):
    def __init__(self, input_dim):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, 2)  # 二分类问题

    def forward(self, x):
        return self.linear(x)
4. 初始化模型、损失函数和优化器
# 初始化模型
model = LogisticRegressionModel(input_dim=2)

# 损失函数与优化器
criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)
5. 训练模型

训练模型并保存训练好的权重。

epochs = 100
for epoch in range(epochs):
    # 前向传播
    outputs = model(X_train)
    loss = criterion(outputs, y_train)

    # 反向传播
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1) % 20 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

# 保存模型
torch.save(model.state_dict(), 'logistic_regression.pth')
6. 加载模型并进行预测

加载保存的模型并进行预测。

# 加载模型
model = LogisticRegressionModel(input_dim=2)
model.load_state_dict(torch.load('logistic_regression.pth'))
model.eval()  # 设为评估模式

# 预测
with torch.no_grad():
    y_pred = model(X_train)
    _, predicted = torch.max(y_pred, 1)
7. 计算精确度、召回率和 F1 分数

使用 sklearn 中的评估函数计算精确度、召回率和 F1 分数。

accuracy = accuracy_score(y_train, predicted)
recall = recall_score(y_train, predicted)
f1 = f1_score(y_train, predicted)

print(f"Accuracy: {accuracy:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
8. 总结

这篇博客展示了如何使用 Numpy 自定义数据集,利用 Pytorch 框架实现逻辑回归模型,并进行训练。训练后的模型被保存,并在加载后进行预测,最后计算了精确度、召回率和 F1 分数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2291117.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JVM方法区

一、栈、堆、方法区的交互关系 二、方法区的理解: 尽管所有的方法区在逻辑上属于堆的一部分,但是一些简单的实现可能不会去进行垃圾收集或者进行压缩,方法区可以看作是一块独立于Java堆的内存空间。 方法区(Method Area)与Java堆一样,是各个…

【Python】第七弹---Python基础进阶:深入字典操作与文件处理技巧

✨个人主页: 熬夜学编程的小林 💗系列专栏: 【C语言详解】 【数据结构详解】【C详解】【Linux系统编程】【MySQL】【Python】 目录 1、字典 1.1、字典是什么 1.2、创建字典 1.3、查找 key 1.4、新增/修改元素 1.5、删除元素 1.6、遍历…

在实际开发中,如何正确使用 INT(1) 和 INT(10)

在实际开发中,如何正确使用 INT(1) 和 INT(10) 前言 在数据库设计和开发过程中,数据类型的选择至关重要。 最近,我在工作中遇到了一个关于MySQL中INT类型的误解问题,这让我意识到很多开发者对INT类型的理解存在误区。 本文将深…

像接口契约文档 这种工件,在需求 分析 设计 工作流里面 属于哪一个工作流

οゞ浪漫心情ゞο(20***328) 2016/2/18 10:26:47 请教一下,像接口契约文档 这种工件,在需求 分析 设计 工作流里面 属于哪一个工作流? 潘加宇(35***47) 17:17:28 你这相当于问用例图、序列图属于哪个工作流,看内容。 如果你的&quo…

GAMES101学习笔记(六):Geometry 几何(基本表示方法、曲线与曲面、网格处理)

文章目录 几何的表示方法隐式几何 Implicit Geometry代数曲面(Algebraic surface)构造实体几何CSG(Constructive Solid Geometry)距离函数(Distance Function)水平集方法(Level Set Methods)分型几何(Fractal) 显式几何 Explicit Geometry点云(Point Cloud)多边形网格(Polygon …

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.24 随机宇宙:生成现实世界数据的艺术

1.24 随机宇宙:生成现实世界数据的艺术 目录 #mermaid-svg-vN1An9qZ6t4JUcGa {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-vN1An9qZ6t4JUcGa .error-icon{fill:#552222;}#mermaid-svg-vN1An9qZ6t4JUc…

爬虫基础(三)Session和Cookie讲解

目录 一、前备知识点 (1)静态网页 (2)动态网页 (3)无状态HTTP 二、Session和Cookie 三、Session 四、Cookie (1)维持过程 (2)结构 正式开始说 Sessi…

HTMLCSS :下雪了

这段代码创建了一个动态的雪花飘落加载动画,通过 CSS 技术实现了雪花的下落和消失效果,为页面添加了视觉吸引力和动态感。 大家复制代码时,可能会因格式转换出现错乱,导致样式失效。建议先少量复制代码进行测试,若未能…

【Windows Server实战】生产环境云和NPS快速搭建

前置条件 本文假定你已达成以下前提条件: 有域控DC。有证书服务器(AD CS)。已使用Microsoft Intune或者GPO为客户机申请证书。服务器上至少有两张网卡(如果用虚拟机做的测试环境,可以用一张HostOnly网卡做测试&#…

RHCSA——搭建FTP文件共享服务器

一、实验目的 1、掌握vsftpd服务器的配置方法 2、熟悉FTP客户端工具的使用 3、掌握常见的FTP服务器的故障排除 二、实验项目背景 某企业像架构一台FTP服务器,为企业局域网中的计算机提供文件传送的任务,为财务部门、销售部门和OA系统提供异地数据备…

IM 即时通讯系统-50-[特殊字符]cim(cross IM) 适用于开发者的分布式即时通讯系统

IM 开源系列 IM 即时通讯系统-41-开源 野火IM 专注于即时通讯实时音视频技术,提供优质可控的IMRTC能力 IM 即时通讯系统-42-基于netty实现的IM服务端,提供客户端jar包,可集成自己的登录系统 IM 即时通讯系统-43-简单的仿QQ聊天安卓APP IM 即时通讯系统-44-仿QQ即…

Python在线编辑器

from flask import Flask, render_template, request, jsonify import sys from io import StringIO import contextlib import subprocess import importlib import threading import time import ast import reapp Flask(__name__)RESTRICTED_PACKAGES {tkinter: 抱歉&…

ZZNUOJ(C/C++)基础练习1041——1050(详解版)

1041 : 数列求和2 题目描述 输入一个整数n&#xff0c;输出数列1-1/31/5-……前n项的和。 输入 输入只有一个整数n。 输出 结果保留2为小数,单独占一行。 样例输入 3 样例输出 0.87注意sum 1相当于sumsum1 注意sum * 1相当于sumsum*1 C语言版 #include<stdio.h> // 包含…

浅析DDOS攻击及防御策略

DDoS&#xff08;分布式拒绝服务&#xff09;攻击是一种通过大量计算机或网络僵尸主机对目标服务器发起大量无效或高流量请求&#xff0c;耗尽其资源&#xff0c;从而导致服务中断的网络攻击方式。这种攻击方式利用了分布式系统的特性&#xff0c;使攻击规模更大、影响范围更广…

深度学习 Pytorch 神经网络的学习

本节将从梯度下降法向外拓展&#xff0c;介绍更常用的优化算法&#xff0c;实现神经网络的学习和迭代。在本节课结束将完整实现一个神经网络训练的全流程。 对于像神经网络这样的复杂模型&#xff0c;可能会有数百个 w w w的存在&#xff0c;同时如果我们使用的是像交叉熵这样…

【回溯】目标和 字母大小全排列

文章目录 494. 目标和解题思路&#xff1a;回溯784. 字母大小写全排列解题思路&#xff1a;回溯 494. 目标和 494. 目标和 给你一个非负整数数组 nums 和一个整数 target 。 向数组中的每个整数前添加 或 - &#xff0c;然后串联起所有整数&#xff0c;可以构造一个 表达式…

Linux系统上安装与配置 MySQL( CentOS 7 )

目录 1. 下载并安装 MySQL 官方 Yum Repository 2. 启动 MySQL 并查看运行状态 3. 找到 root 用户的初始密码 4. 修改 root 用户密码 5. 设置允许远程登录 6. 在云服务器配置 MySQL 端口 7. 关闭防火墙 8. 解决密码错误的问题 前言 在 Linux 服务器上安装并配置 MySQL …

记录一次,PyQT的报错,多线程Udp失效,使用工具如netstat来检查端口使用情况。

1.问题 报错Exception in thread Thread-1: Traceback (most recent call last): File "threading.py", line 932, in _bootstrap_inner File "threading.py", line 870, in run File "main.py", line 456, in udp_recv IndexError: list…

群晖NAS安卓Calibre 个人图书馆

docker 下载镜像johngong/calibre-web&#xff0c;安装之 我是本地的/docker/xxx/metadata目录 映射到 /usr/local/calibre-web/app/cps/metadata_provider CALIBREDB_OTHER_OPTION 删除 CALIBRE_SERVER_USER calibre_server_user 缺省用户名口令 admin admin123 另外有个N…

android主题设置为..DarkActionBar.Bridge时自定义DatePicker选中日期颜色

安卓自定义DatePicker选中日期颜色 背景&#xff1a;解决方案&#xff1a;方案一&#xff1a;方案二&#xff1a;实践效果&#xff1a; 背景&#xff1a; 最近在尝试用原生安卓实现仿element-ui表单校验功能&#xff0c;其中的的选择日期涉及到安卓DatePicker组件的使用&#…