【深度学习】1.手动LogisticRegression模型的训练和预测

news2024/11/19 13:38:00

通过这个示例,可以了解逻辑回归模型的基本原理和训练过程,同时可以通过修改和优化代码来进一步探索机器学习模型的训练和调优方法。

步骤:

  1. 生成了一个模拟的二分类数据集:通过随机生成包含两个特征的数据data_x,并基于一定规则生成对应的二分类标签数据data_y
  2. 创建了一个手动实现的逻辑回归模型LogisticRegressionManually,其中包括:
    • 初始化函数__init__:初始化模型的权重参数w和偏置参数b
    • 前向传播函数forward:计算给定输入数据的预测值。
    • 损失函数loss_func:定义了交叉熵损失函数,用于评估模型的预测性能。
    • 训练函数train:在每个epoch中,遍历数据集的每个样本,计算预测值、损失值、梯度,并利用梯度下降法更新模型参数。
  3. 实例化LogisticRegressionManually类,然后调用train方法对模型进行训练。
  4. 在训练过程中,打印每个epoch的损失值。

演示:

# 生成模拟的二分类数据集,其中X数据是随机生成的,Y数据根据一定规则生成。

import torch
import torch.nn.functional as F

n_items = 1000
n_features = 2
learning_rate = 0.001
epochs = 100

# 置了随机种子,以确保每次运行代码时生成的随机数相同,从而使结果具有可重现性。
torch.manual_seed(123) 
# 生成了一个大小为(1000, 2)的张量data_x,其中包含1000个样本,每个样本具有2个特征。这里使用torch.randn生成标准正态分布的随机数作为数据,并将数据类型转换为float。
data_x = torch.randn(size=(n_items, n_features)).float()
# 成了标签数据data_y,通过对第一个特征乘以0.5和第二个特征乘以1.5的差值进行判断,如果差值大于0就将标签设为1,否则为0。这样生成了一个二分类标签数据集,同样将数据类型转换为float。
data_y = torch.where(torch.subtract(data_x[:, 0]*0.5, data_x[:, 1]*1.5) > 0, 1., 0.).float()

# print(data_x)
# print(data_y)

# 在每个epoch中,遍历数据集的每个样本,计算预测值、损失值、梯度,利用梯度下降法更新模型参数。通过这种方式训练模型可以逐渐优化模型参数,以达到更好的预测效果。

class LogisticRegressionManually(object):
    # 初始化函数__init__
    def __init__(self):
        # w是一个大小为(n_features, 1)的张量,用于存储权重参数,并且设置了requires_grad=True表示需要计算梯度;
        self.w = torch.randn(size=(n_features, 1), requires_grad=True)
        # b是一个大小为(1, 1)的张量,用于存储偏置参数,并且设置了requires_grad=True
        self.b = torch.zeros(size=(1, 1), requires_grad=True)
    # 前向传播函数forward
    def forward(self, x):
        # 过矩阵乘法计算预测值y_hat:将参数w转置后与输入数据x相乘,并加上偏置b后通过F.sigmoid函数进行激活,最终返回激活后的预测值。
        y_hat = F.sigmoid(torch.matmul(self.w.transpose(0, 1), x) + self.b)
        return y_hat
    # 损失函数loss_func
    @staticmethod
    def loss_func(y_hat, y):
        # 定义了交叉熵损失函数。通过计算实际标签y和预测值y_hat之间的交叉熵损失来评估模型的预测性能。
        return -(torch.log(y_hat)*y + (1-y)*torch.log(1-y_hat))
    # 训练函数train
    def train(self):
        # 在每个epoch中,遍历数据集中的每个样本
        for epoch in range(epochs):
            for step in range(n_items):
                # 利用模型的前向传播函数forward计算当前样本的预测值y_hat。
                y_hat = self.forward(data_x[step])
                # 获取当前样本的真实标签y
                y = data_y[step]
                # 调用损失函数loss_func计算预测值与真实标签之间的损失。
                loss = self.loss_func(y_hat, y)
                # 利用反向传播计算损失对模型参数的梯度
                loss.backward()
                # 进入torch.no_grad()上下文管理器,保证在该范围内的操作不会被记录用于自动微分。
                with torch.no_grad():
                    # 更新权重参数w和偏置参数b,通过梯度下降法更新参数,learning_rate是学习率。
                    self.w.data -= learning_rate * self.w.grad
                    self.b.data -= learning_rate * self.b.grad
                # 清零梯度,以便进行下一次参数更新时重新计算梯度。
                self.w.grad.data.zero_()
                self.b.grad.data.zero_()
            print("Epoch: %03d, Loss: %.3f" % (epoch, loss.item()))

lrm = LogisticRegressionManually()
lrm.train()

结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1699467.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android Compose 九:interactionSource 的使用

先上官方文档 InteractionSource InteractionSource represents a stream of Interactions corresponding to events emitted by a component. These Interactions can be used to change how components appear in different states, such as when a component is pressed or…

WordPress安装memcached提升网站速度

本教程使用环境为宝塔 第一步、服务器端安装memcached扩展 在网站使用的php上安装memcached扩展 第二步:在 WordPress 网站后台中,安装插件「Memcached Is Your Friend」 安装完成后启用该插件,在左侧工具-中点击Memcached 查看是否提示“U…

《拯救大学生课设不挂科第四期之蓝桥杯是什么?我是否要参加蓝桥杯?选择何种语言?如何科学备赛?方法思维教程》【官方笔记】

背景: 有些同学在大一或者大二可能会被老师建议参加蓝桥杯,本视频和文章主要是以一个过来人的身份来给与大家一些思路。 比如蓝桥杯是什么?我是否要参加蓝桥杯?参加蓝桥杯该选择何种语言?如何科学备赛?等…

webpack5生产模式

生产模式 生产模式准备 开发模式和生产模式有不同的 配置文件 2修改webpack.prod.js文件修改webpack.dev.js文件 修改webpack.dev.js文件 1》修改输出路径为undefined 2》将绝对路径进行修改,进行回退 此时文件的执行命令为 修改webpack.prod.js文件 1》修改绝…

跨平台之用VisualStudio开发APK嵌入OpenCV(三)

本篇将包含以下内容: 1.使用 Visual Studio 2019 开发一个 Android 的 App 2.导入前篇 C 编译好的 so 动态库 3.一些入门必须的其它设置 作为入门,我们直接使用真机进行调试,一方面运行速度远高于模拟器,另一方面模拟器使用的…

2024年【危险化学品经营单位安全管理人员】考试及危险化学品经营单位安全管理人员考试资料

题库来源:安全生产模拟考试一点通公众号小程序 危险化学品经营单位安全管理人员考试考前必练!安全生产模拟考试一点通每个月更新危险化学品经营单位安全管理人员考试资料题目及答案!多做几遍,其实通过危险化学品经营单位安全管理…

Zoho Campaigns邮件营销怎么发邮件?

Zoho Campaigns,作为业界领先的邮件营销平台,以其强大的功能、用户友好的界面以及深度的分析能力,为企业提供了一站式的邮件营销解决方案,助力企业高效地触达目标受众,构建并巩固庞大的客户基础。云衔科技为企业提供Zo…

电量计量芯片HLW8110的前端电路设计与误差分析校正.pdf 下载

电量计量芯片HLW8110的前端电路设计与误差分析校正.pdf 下载地址: 链接:https://pan.baidu.com/s/1vlCtC3LGFMzYpSUUDY-tEg 提取码:8110

用Prometheus全面监控MySQL服务:一篇文章搞定

简介 在现代应用中,MySQL数据库的性能和稳定性对业务至关重要。有效的监控可以帮助预防问题并优化性能。Prometheus作为一款强大的开源监控系统,结合Grafana的可视化能力,可以提供全面的MySQL监控方案。 设置Prometheus 安装Prometheus 使…

深度学习面试问题总结(21)| 模型优化

本文给大家带来的百面算法工程师是深度学习模型优化面试总结,文章内总结了常见的提问问题,旨在为广大学子模拟出更贴合实际的面试问答场景。在这篇文章中,我们还将介绍一些常见的深度学习面试问题,并提供参考的回答及其理论基础&a…

ic基础|时钟篇05:芯片中buffer到底是干嘛的?一文带你了解buffer的作用

大家好,我是数字小熊饼干,一个练习时长两年半的ic打工人。我在两年前通过自学跨行社招加入了IC行业。现在我打算将这两年的工作经验和当初面试时最常问的一些问题进行总结,并通过汇总成文章的形式进行输出,相信无论你是在职的还是…

leecode 637 二叉树的层平均值

leetcode 二叉树相关-层序遍历专题 二叉树的层序遍历一般来说,我们是利用队列来实现的,先把根节点入队,然后在出队后将其对应的子节点入队,然后往复此种操作。相比于二叉树的遍历递归,层序遍历比较简单,有…

2024年5月26日 (周日) 叶子游戏新闻

资深开发者:3A游戏当前处于一种尴尬的中间地带游戏行业整体,尤其是3A游戏正处于艰难时期。尽管2023年3A游戏佳作频出,广受好评,但居高不下的游戏开发成本(传闻《漫威蜘蛛侠2》的制作成本高达3亿美元)正严重…

WEB攻防【1】——ASP应用/HTTP.SYS/短文件/文件解析/Access注入/数据库泄漏

ASP:常见漏洞:本文所写这些 ASPX:未授权访问、报错爆路径、反编译 PHP:弱类型对比、mdb绕过、正则绕过(CTF考得多) JAVA:反序列化漏洞 Python:SSTI、字符串、序列化 Javascript&…

微服务下认证授权框架的探讨

前言 市面上关于认证授权的框架已经比较丰富了,大都是关于单体应用的认证授权,在分布式架构下,使用比较多的方案是--<应用网关>,网关里集中认证,将认证通过的请求再转发给代理的服务,这种中心化的方式并不适用于微服务,这里讨论另一种方案--<认证中心>,利用jwt去中…

elementui中 表格使用树形数据且固定一列时展开子集移入时背景色不全问题(父级和子级所展示的字段是不一样的时候)

原来的效果 修改后实现效果 解决- 需要修改elementui的依赖包中lib/element-ui.common.js中的源码 将js中此处代码改完下面的代码 watch: {// dont trigger getter of currentRow in getCellClass. see https://jsfiddle.net/oe2b4hqt/// update DOM manually. see https:/…

【单片机】STM32F070F6P6 开发指南(一)STM32建立HAL工程

文章目录 一、基础入门二、工程初步建立三、HSE 和 LSE 时钟源设置四、时钟系统&#xff08;时钟树&#xff09;配置五、GPIO 功能引脚配置六、配置 Debug 选项七、生成工程源码八、生成工程源码九、用户程序下载 一、基础入门 f0 pack下载&#xff1a; https://www.keil.arm…

关于XtremIO 全闪存储维护的一些坑(建议)

XtremIO 是EMC过去主推的一款全闪存储系统&#xff0c;号称性能小怪兽&#xff0c;对付那些对于性能要求极高的业务场景是比较合适的&#xff0c;先后推出了1代和2代产品&#xff0c;目前这个产品好像未来的演进到了PowerStor或者PowerMax全闪&#xff0c;应该不独立发展这个产…

Leetcode260

260. 只出现一次的数字 III - 力扣&#xff08;LeetCode&#xff09; class Solution {public int[] singleNumber(int[] nums) {//通过异或操作,使得最终结果为两个只出现一次的元素的异或值int filterResult 0;for(int num:nums){filterResult^num;}//计算首个1(从右侧开始)…