PyTorch逻辑回归总结

news2025/7/3 23:57:08

目录

  • PyTorch逻辑回归总结
    • 神经网络基础
      • 基本结构
      • 学习路径
    • 线性回归
      • 简单线性回归
      • 多元线性回归
    • 逻辑回归
      • 核心原理
      • 损失函数
    • 梯度下降法
      • 基本思想
      • 关键公式
      • 学习率影响
    • PyTorch实现
      • 数据准备
      • 模型构建
      • 代码优化
    • 核心概念对比

PyTorch逻辑回归总结

神经网络基础

基本结构

  • 输入节点
  • 隐藏节点
  • 输出节点

学习路径

  • 逻辑回归作为神经网络入门基础

线性回归

简单线性回归

  • 模型表达式: y = β 0 + β 1 x + ϵ y = \beta_0 + \beta_1 x + \epsilon y=β0+β1x+ϵ
  • 参数估计方法:最小二乘法
  • 参数求解公式
    • β ^ 1 = ∑ ( x i − x ˉ ) ( y i − y ˉ ) ∑ ( x i − x ˉ ) 2 \hat{\beta}_1 = \frac{\sum (x_i - \bar{x})(y_i - \bar{y})}{\sum (x_i - \bar{x})^2} β^1=(xixˉ)2(xixˉ)(yiyˉ)
    • β ^ 0 = y ˉ − β ^ 1 x ˉ \hat{\beta}_0 = \bar{y} - \hat{\beta}_1 \bar{x} β^0=yˉβ^1xˉ

多元线性回归

  • 模型表达式: y = β 0 + β 1 x 1 + ⋯ + β p x p + ϵ y = \beta_0 + \beta_1 x_1 + \cdots + \beta_p x_p + \epsilon y=β0+β1x1++βpxp+ϵ
  • 矩阵形式求解: β ^ = ( X T X ) − 1 X T y \hat{\beta} = (X^T X)^{-1} X^T y β^=(XTX)1XTy

逻辑回归

核心原理

  • 线性回归结果映射到概率: z = θ T x z = \theta^T x z=θTx
  • Sigmoid函数: σ ( z ) = 1 1 + e − z \sigma(z) = \frac{1}{1 + e^{-z}} σ(z)=1+ez1
    • 输出范围:[0, 1]
    • 代码实现:sigmoid(z)

损失函数

  • 最大似然估计推导
  • 对数损失函数:
    J ( θ ) = − ∑ [ y log ⁡ ( y ^ ) + ( 1 − y ) log ⁡ ( 1 − y ^ ) ] J(\theta) = -\sum \left[ y \log(\hat{y}) + (1 - y) \log(1 - \hat{y}) \right] J(θ)=[ylog(y^)+(1y)log(1y^)]
  • 防止数值溢出:添加极小值 ϵ \epsilon ϵ

梯度下降法

基本思想

  • 类比下山问题
  • 梯度方向:函数下降最快的方向
  • 学习率(η):控制步长的超参数

关键公式

  • 参数更新: θ n + 1 = θ n − η ∂ J ∂ θ \theta_{n+1} = \theta_n - \eta \frac{\partial J}{\partial \theta} θn+1=θnηθJ
  • 偏导数计算:
    • 权重: ∂ J ∂ θ j = 1 m ∑ ( y i − y ^ i ) x i j \frac{\partial J}{\partial \theta_j} = \frac{1}{m} \sum (y_i - \hat{y}_i) x_{ij} θjJ=m1(yiy^i)xij
    • 截距: ∂ J ∂ b = 1 m ∑ ( y i − y ^ i ) \frac{\partial J}{\partial b} = \frac{1}{m} \sum (y_i - \hat{y}_i) bJ=m1(yiy^i)

学习率影响

  • 过小:收敛缓慢
  • 过大:震荡或发散
  • 优化策略:动态衰减、网格搜索

PyTorch实现

数据准备

  • 使用make_classification生成数据
  • 拆分训练集/测试集:train_test_split

模型构建

  1. 参数初始化

    • 权重:w = torch.randn(1, 10, requires_grad=True)
    • 偏置:b = torch.randn(1, requires_grad=True)
  2. 前向传播

    • 线性运算:z = torch.mm(x, w.T) + b
    • Sigmoid激活:y_hat = torch.sigmoid(z)
  3. 损失计算

    • 二元交叉熵:loss = F.binary_cross_entropy(y_hat, y_true)
  4. 反向传播

    • 自动求导:loss.backward()
    • 梯度清零:w.grad.zero_()
  5. 参数更新

    • w -= lr * w.grad
    • b -= lr * b.grad

代码优化

  • 对比NumPy与PyTorch实现
  • 利用自动求导简化梯度计算

核心概念对比

  • 概率 vs 似然
    • 概率:已知参数预测结果
    • 似然:已知结果估计参数
  • 超参数 vs 权重参数
    • 超参数:手动设置(如学习率)
    • 权重参数:模型自动学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2338311.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

EDID结构

EDID DDC通讯中传输显示设备数据 VGA , DVI 的EDID由128字节组成,hdmi的EDID增加扩展块128字节。扩展快的内容主要是和音频属性相关的,DVI和vga没有音频,hdmi自带音频,扩展快数据规范按照cea-861x标准。 Edid为了让pc或其他的图像…

文件包含(详解)

文件包含漏洞是一种常见的Web安全漏洞,其核心在于应用程序未对用户控制的文件路径或文件名进行严格过滤,导致攻击者能够包含并执行任意文件(包括本地或远程恶意文件)。 1. 文件包含原理 动态文件包含机制 开发者使用动态包含函数…

《SpringBoot中@Scheduled和Quartz的区别是什么?分布式定时任务框架选型实战》​

🌟 ​大家好,我是摘星!​ 🌟 今天为大家带来的是Scheduled和Quartz对比分析: 新手常见困惑: 刚学SpringBoot时,我发现用Scheduled写定时任务特别简单。但当我看到同事在项目里用Quartz时&…

安装fvm可以让电脑同时管理多个版本的flutter、flutter常用命令、vscode连接模拟器

打开 PowerShellfvm安装 dart pub global activate fvm安装完成后,如果显示FVM无法识别,那么需要去添加环境变量path添加这个:C:\Users\Administrator\AppData\Local\Pub\Cache\bin 常用命令 fvm releases 查看用户可以装的flutter版本fvm l…

Kafka系列之:计算kafka集群topic占的存储大小

Kafka系列之:计算kafka集群topic占的存储大小 topic存储数据格式统计topic存储大小定时统计topic存储大小topic存储数据格式 单位是字节大小 size_bytes{directory="/data/datum/kafka/optics-all" } 782336计算topic存储大小脚本逻辑是: 计算指定目录或文件的大小…

智谱AI大模型免费开放:开启AI创作新时代

文章摘要:近日,国内领先的人工智能公司智谱AI宣布旗下多款大模型服务免费开放,这一举措标志着大模型技术正式迈入普惠阶段。本文将详细介绍智谱AI此次开放的GLM-4 等大模型,涵盖其主要功能、技术特点、使用步骤以及应用场景&#…

T1结构像+RS-fMRI影像处理过程记录(数据下载+Matlab工具箱+数据处理)

最近需要仿真研究T1结构像RS-fMRI影像融合处理输出目标坐标的路线可行性。就此机会记录下来。 为了完成验证目标处理,首先需要有数据,然后需要准备对应的处理平台和工具箱,进行一系列。那么开始记录~ 前言: 为了基于种子点的功能连…

【前端基础】--- HTML

个人主页  :  9ilk    专栏  :  前端基础 文章目录 🏠 初识HTML🏠 HTML结构认识HTML标签HTML文件基本结构标签层次结构快速生成代码框架 🏠 HTML常见标签注释标签标题标签 h1-h6段落标签 p换行标签 br格式化标签图片标签 img超链接标签…

黑马V11版 最新Java高级软件工程师课程-JavaEE精英进阶课

课程大小:60.2G 课程下载:https://download.csdn.net/download/m0_66047725/90615581 更多资源下载:关注我 阶段一 中台战略与组件化开发专题课程 阶段二 【物流行业】品达物流TMS 阶段三 智牛股 阶段四 千亿级电商秒杀解决方案专题 …

C#插件与可扩展性

外接程序为主机应用程序提供了扩展功能或服务。.net framework提供了一个编程模型,开发人员可以使用该模型来开发加载项并在其主机应用程序中激活它们。该模型通过在主机和外接程序之间构建通信管道来实现此目的。该模型是使用: System.AddIn, System.AddIn.Hosting, System.…

CVPR‘25 | 高文字渲染精度的商品图文海报生成

本文分享阿里妈妈智能创作与AI应用团队在图文广告创意方向上提出的商品图文海报生成模型,通过构建字符级视觉表征作为控制信号,可以实现精准的图上中文逐像素生成。基于该项工作总结的论文已被 CVPR 2025录用,并在阿里妈妈业务场景落地&#…

Golang|抽奖相关

文章目录 抽奖核心算法生成抽奖大转盘抽奖接口实现 抽奖核心算法 我们可以根据 单商品库存量/总商品库存量 得到每个商品被抽中的概率,可以想象这样一条 0-1 的数轴,数轴上的每一段相当于一种商品,概率之和为1。 抽奖时,我们会生…

idea maven 命令后控制台乱码

首先在idea中查看maven的编码方式 执行mvn -v命令 查看编码语言是GBK C:\Users\13488>mvn -v Apache Maven 3.6.3 (cecedd343002696d0abb50b32b541b8a6ba2883f) Maven home: D:\maven\apache-maven-3.6.3\bin\.. Java version: 1.8.0_202, vendor: Oracle Corporation, runt…

C# dll 打包进exe

Framework4.x推荐使用 Costura.Fody 1. 安装 NuGet 包 Install-Package Costura.Fody工程自动生成packages文件夹,300M左右。生成FodyWeavers.xml、FodyWeavers.xsd文件。 2. 自动嵌入 编译后,所有依赖的 DLL 会被自动嵌入到 EXE 中。 运行时自动解压…

【数据融合实战手册·实战篇】二维赋能三维的5种高阶玩法:手把手教你用Mapmost打造智慧城市标杆案例

在当今数字化时代,二三维数据融合技术的重要性不言而喻。二三维数据融合通过整合二维数据的结构化优势与三维数据的直观性,打破了传统数据在表达和分析上的局限,为各行业提供了更全面、精准的数据分析手段。从智慧城市建设到工业智能制造&…

ValueError: model.embed_tokens.weight doesn‘t have any device set

ValueError: model.embed_tokens.weight doesn’t have any device set model.embed_tokens.weight 通常在深度学习框架(如 PyTorch)中使用,一般是在处理自然语言处理(NLP)任务时,用于指代模型中词嵌入层(Embedding layer)的权重参数。下面详细解释: 词嵌入层的作用 …

解决:QTcpSocket: No such file or directory

项目场景: 使用QTcpSocket进行网络编程: 调用connectToHost连接服务器,调用waitForConnected判断是否连接成功,连接信号readyRead槽函数,异步读取数据,调用waitForReadyRead,阻塞读取数据。 问题描述 找不…

支付宝商家转账到账户余额,支持多商户管理

大家好,我是小悟 转账到支付宝账户是一种通过 API 完成单笔转账的功能,支付宝商家可以向其他支付宝账户进行单笔转账。 商家只需输入另一个正确的支付宝账号,即可将资金从本企业支付宝账户转账至另一个支付宝账户。 该产品适用行业较广&am…

3.Chromium指纹浏览器开发教程之chromium119版本源码拉取

获取Chromium最新版源码 Git是一个分布式版本控制系统,用于管理代码的版本和协作开发,它是目前最流行和广泛使用的版本控制系统之一。在Chromium项目中,通常使用gclient来获取Chromium的源代码,并使用Git来对代码进行版本控制和管…

【计算机视觉】OpenCV项目实战- Artificial-Eyeliner 人脸眼线检测

Artificial-Eyeliner 人脸眼线检测 项目介绍运行方式运行步骤常见问题及解决方法1. dlib 安装失败其他注意事项 2. 缺少 make / gcc3. **依赖库安装问题**:4. *人脸关键点检测失败:5. 眼线效果不理想:6. 实时处理延迟:7. 保存文件…