自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

news2025/1/30 11:14:27

代码:

import torch
import numpy as np
import torch.nn as nn

# 定义数据:x_data 是特征,y_data 是标签(目标值)
data = [[-0.5, 7.7],
        [1.8, 98.5],
        [0.9, 57.8],
        [0.4, 39.2],
        [-1.4, -15.7],
        [-1.4, -37.3],
        [-1.8, -49.1],
        [1.5, 75.6],
        [0.4, 34.0],
        [0.8, 62.3]]

# 将数据转为 numpy 数组
data = np.array(data)

# 提取 x_data 和 y_data
x_data = data[:, 0]  # 取第一列作为输入特征
y_data = data[:, 1]  # 取第二列作为目标标签

# 将数据转换为 PyTorch 张量
x_train = torch.tensor(x_data, dtype=torch.float32)  # 输入特征
y_train = torch.tensor(y_data, dtype=torch.float32)  # 目标标签

# 使用 TensorDataset 来创建一个数据集
from torch.utils.data import DataLoader, TensorDataset

dataset = TensorDataset(x_train, y_train)  # 使用训练数据创建数据集
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)  # 将数据集转换为 DataLoader,批大小为 2,且每个 epoch 都会随机打乱数据

# 定义损失函数:均方误差损失 (MSELoss)
criterion = nn.MSELoss()


# 定义线性回归模型
class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        # 使用一个线性层,输入为1维,输出为1维
        self.layers = nn.Linear(1, 1)

    def forward(self, x):
        # 直接返回线性层的输出
        return self.layers(x)


# 初始化模型
model = LinearModel()

# 定义优化器:使用随机梯度下降 (SGD),学习率为 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 设置训练的 epoch 数量
epochs = 500

# 开始训练
for n in range(1, epochs + 1):
    epoch_loss = 0  # 初始化每个 epoch 的总损失

    # 遍历每一个 batch
    for batch_x, batch_y in dataloader:
        # 对每个输入 batch_x 添加一个维度,使其适应线性回归模型
        y_pred = model(batch_x.unsqueeze(1))

        # 计算当前 batch 的损失
        batch_loss = criterion(y_pred.squeeze(1), batch_y)

        # 清除梯度
        optimizer.zero_grad()

        # 反向传播
        batch_loss.backward()

        # 更新模型参数
        optimizer.step()

        # 累加损失
        epoch_loss += batch_loss

    # 计算当前 epoch 的平均损失
    avg_loss = epoch_loss / len(dataloader)

    # 每 10 个 epoch 或第一个 epoch 打印一次损失
    if n % 10 == 0 or n == 1:
        print(f"epoches:{n}, loss:{avg_loss:.4f}")

        # 在每 10 个 epoch 保存一次模型的状态字典
        torch.save(model.state_dict(), 'model.pth')

# 训练完成后,加载保存的模型
model.load_state_dict(torch.load("model.pth"))

# 将模型切换到评估模式
model.eval()

# 测试数据
x_test = torch.tensor([[1.8]], dtype=torch.float32)  # 测试输入 1.8

# 使用模型进行预测
with torch.no_grad():  # 在预测时不计算梯度
    y_pre = model(x_test)

# 打印预测结果
print(y_pre)

结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2284629.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

关于使用PHP时WordPress排错——“这意味着您在wp-config.php文件中指定的用户名和密码信息不正确”的解决办法

本来是看到一位好友的自己建站,所以突发奇想,在本地装个WordPress玩玩吧,就尝试着装了一下,因为之前电脑上就有MySQL,所以在自己使用PHP建立MySQL时报错了。 最开始是我的php启动mysql时有问题,也就是启动过…

【蓝桥杯】43694.正则问题

题目描述 考虑一种简单的正则表达式: 只由 x ( ) | 组成的正则表达式。 小明想求出这个正则表达式能接受的最长字符串的长度。 例如 ((xx|xxx)x|(x|xx))xx 能接受的最长字符串是: xxxxxx,长度是 6。 输入描述 一个由 x()| 组成的正则表达式。…

服务器虚拟化技术详解与实战:架构、部署与优化

📝个人主页🌹:一ge科研小菜鸡-CSDN博客 🌹🌹期待您的关注 🌹🌹 引言 在现代 IT 基础架构中,服务器虚拟化已成为提高资源利用率、降低运维成本、提升系统灵活性的重要手段。通过服务…

jvm--类的生命周期

学习类的生命周期之前,需要了解一下jvm的几个重要的内存区域: (1)方法区:存放已经加载的类信息、常量、静态变量以及方法代码的内存区域 (2)常量池:常量池是方法区的一部分&#x…

TensorFlow实现逻辑回归模型

逻辑回归是一种经典的分类算法,广泛应用于二分类问题。本文将介绍如何使用TensorFlow框架实现逻辑回归模型,并通过动态绘制决策边界和损失曲线来直观地观察模型的训练过程。 数据准备 首先,我们准备两类数据点,分别表示两个不同…

《十七》浏览器基础

浏览器:是安装在电脑里面的一个软件,能够将页面内容渲染出来呈现给用户查看,并让用户与网页进行交互。 常见的主流浏览器: 常见的主流浏览器有:Chrome、Safari、Firefox、Opera、Edge 等。 输入 URL,浏览…

网络安全 | F5-Attack Signatures-Set详解

关注:CodingTechWork 创建和分配攻击签名集 可以通过两种方式创建攻击签名集:使用过滤器或手动选择要包含的签名。  基于过滤器的签名集仅基于在签名过滤器中定义的标准。基于过滤器的签名集的优点在于,可以专注于定义用户感兴趣的攻击签名…

STranslate 中文绿色版即时翻译/ OCR 工具 v1.3.1.120

STranslate 是一款功能强大且用户友好的翻译工具,它支持多种语言的即时翻译,提供丰富的翻译功能和便捷的使用体验。STranslate 特别适合需要频繁进行多语言交流的个人用户、商务人士和翻译工作者。 软件功能 1. 即时翻译: 文本翻译&#xff…

基于微信小程序的助农扶贫系统设计与实现(LW+源码+讲解)

专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…

我谈区域偏心率

偏心率的数学定义 禹晶、肖创柏、廖庆敏《数字图像处理(面向新工科的电工电子信息基础课程系列教材)》P312 区域的拟合椭圆看这里。 Rafael Gonzalez的二阶中心矩的表达不说人话。 我认为半长轴和半短轴不等于特征值,而是特征值的根号。…

关于低代码技术架构的思考

我们经常会看到很多低代码系统的技术架构图,而且经常看不懂。是因为技术架构图没有画好,还是因为技术不够先进,有时候往往都不是。 比如下图: 一个开发者,看到的视角往往都是技术层面,你给用户讲React18、M…

若依路由配置教程

1. 路由配置文件 2. 配置内容介绍 { path: "/tool/gen-edit", component: Layout, //在路由下,引用组件的名称,在页面中包括这个组件的内容(页面框架内容) hidden: true, //此页面的内容,在左边的菜单中不用显示。 …

基于ESP8266的多功能环境监测与反馈系统开发指南

项目概述 本系统集成了物联网开发板、高精度时钟模块、环境传感器和可视化显示模块,构建了一个智能环境监测与反馈装置。通过ESP8266 NodeMCU作为核心控制器,结合DS3231实时时钟、DHT11温湿度传感器、光敏电阻和OLED显示屏,实现了环境参数的…

HTML5 Web Worker 的使用与实践

引言 在现代 Web 开发中,用户体验是至关重要的。如果页面在执行复杂计算或处理大量数据时变得卡顿或无响应,用户很可能会流失。HTML5 引入了 Web Worker,它允许我们在后台运行 JavaScript 代码,从而避免阻塞主线程,保…

flutter_学习记录_00_环境搭建

1.参考文档 Mac端Flutter的环境配置看这一篇就够了 flutter的中文官方文档 2. 本人环境搭建的背景 本人的电脑的是Mac的,iOS开发,所以iOS开发环境本身是可用的;外加Mac电脑本身就会配置Java的环境。所以,后面剩下的就是&#x…

自助设备系统设置——对接POS支付

输入管理员密码 一、录入POS网关信息 填写网关信息后保存,重新启动软件

Calibre(阅读转换)-官方开源中文版[完整的电子图书馆系统,包括图书馆管理,格式转换,新闻,材料转换为电子书]

Calibre(阅读&转换)-官方开源中文版 链接:https://pan.xunlei.com/s/VOHbKYUwd3ASVXTi2Ok1vkK3A1?pwd92ny#

【unity游戏开发之InputSystem——06】PlayerInputManager组件实现本地多屏的游戏(基于unity6开发介绍)

文章目录 PlayerInputManager 简介1、PlayerInputManager 的作用2、主要功能一、PlayerInputManager组件参数1、Notification Behavior 通知行为2、Join Behavior:玩家加入的行为3、Player Prefab 玩家预制件4、Joining Enabled By Default 默认启用加入5、Limit Number Of Pl…

算法刷题Day29:BM67 不同路径的数目(一)

题目链接 描述 解题思路: 二维dp数组初始化。 dp[i][0] 1, dp[0][j] 1 。因为到达第一行第一列的每个格子只能有一条路。状态转移 dp[i][j] dp[i-1][j] dp[i][j-1] 代码: class Solution: def uniquePaths(self , m: int, n: int) -> int: #…