Lecture7 处理多维特征的输入(Multiple Dimension Input)

news2025/1/10 23:49:00

以实际代码出发,逐行讲解。

完整代码:

import numpy as np
import torch
import matplotlib.pyplot as plt

# load data
xy = np.loadtxt('C:\\Users\\14185\\Desktop\\diabetes.csv', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1])
y_data = torch.from_numpy(xy[:, [-1]])

# define model
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x

model = Model()

# define loss function and optimizer
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

# define list to store losses
losses = []

# train model
for epoch in range(10000):
    # Forward
    y_pred = model(x_data) # 注意这里没有使用小批量数据集,在后面的课程会讲解如何加载数据集
    loss = criterion(y_pred, y_data)
    losses.append(loss.item())
    # Backward
    optimizer.zero_grad()
    loss.backward()
    # Update
    optimizer.step()

# plot losses
plt.plot(range(len(losses)), losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()
标题

首先准备数据集

标题

将 CSV 文件中的数据加载到 PyTorch 张量中,以便进行深度学习模型的训练和评估。

xy = np.loadtxt('diabetes.csv', delimiter=',', dtype=np.float32) # 一般GPU只支持32位浮点数
x_data = torch.from_numpy(xy[:,:-1])
y_data = torch.from_numpy(xy[:, [-1]])
  1. x_data = torch.from_numpy(xy[:, :-1]) 这一行代码将 NumPy 数组 xy 的所有行和除最后一列之外的所有列转换为 PyTorch 张量 x_data。这个张量包含了输入特征的数据。

  2. y_data = torch.from_numpy(xy[:, [-1]]) 这一行代码将 NumPy 数组 xy 的所有行和最后一列转换为 PyTorch 张量 y_data。这个张量包含了输出标签的数据。

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x

model = Model()

  1. self.linear1 = torch.nn.Linear(8, 6) : 这一行代码定义了一个全连接层(也称为线性层),该层有 8 个输入节点和 6 个输出节点。这个层将数据从输入层传递到隐藏层。

  2. self.linear2 = torch.nn.Linear(6, 4) : 这一行代码定义了第二个全连接层,该层有 6 个输入节点和 4 个输出节点。这个层将数据从第一个隐藏层传递到第二个隐藏层。

  3. self.linear3 = torch.nn.Linear(4, 1) : 这一行代码定义了第三个全连接层,该层有 4 个输入节点和 1 个输出节点。这个层将数据从第二个隐藏层传递到输出层。

  4. self.sigmoid = torch.nn.Sigmoid() : 这一行代码定义了一个 Sigmoid 激活函数,将在模型的前向传播中使用。

  5. def forward(self, x): : 这一行代码定义了模型的前向传播函数。这个函数将输入数据 x 作为参数,将数据从输入层传递到输出层。

  6. x = self.sigmoid(self.linear1(x)) : 这一行代码将输入数据 x 传递到第一个全连接层,并将其输出通过 Sigmoid 激活函数进行处理。

  7. x = self.sigmoid(self.linear2(x)) : 这一行代码将第一个全连接层的输出传递到第二个全连接层,并将其输出通过 Sigmoid 激活函数进行处理。

  8. x = self.sigmoid(self.linear3(x)) : 这一行代码将第二个全连接层的输出传递到输出层,并将其输出通过 Sigmoid 激活函数进行处理。

标题

为什么对每层都应用一个sigmoid?

criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
  1. criterion = torch.nn.BCELoss(size_average=True): 创建二元交叉熵损失函数。该损失函数通常用于二元分类问题,计算模型预测值与真实标签之间的差异,可以用于指导模型参数的更新。 size_average=True 意味着对每个样本的损失值取平均值,并将其用于计算模型的总体损失值。

  2. optimizer = torch.optim.SGD(model.parameters(), lr=0.1): 创建随机梯度下降(SGD)优化器。该优化器通常用于训练神经网络,根据损失函数计算的梯度更新网络中的参数。model.parameters() 是将所有神经网络中可训练的参数(权重和偏置)作为输入传递给优化器,以便更新它们。lr=0.1 表示优化器的学习率为 0.1,它控制了每次更新参数的步长,也影响了模型的训练速度和性能。

逻辑斯蒂模型的变化

预测目标使用的 含多维特征的数据集——糖尿病数据集(sklearn中 其实也有类似的糖尿病数据集):每个样本/记录(sample/record)有8个维度的信息(feature),并以此进行二分类。Y表示一年后 糖尿病 病情是否加重。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/484344.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

226. 翻转二叉树【58】

难度等级:容易 上一篇算法: 543. 二叉树的直径【71】 力扣此题地址: 226. 翻转二叉树 - 力扣(Leetcode) 1.题目:226. 翻转二叉树 给你一棵二叉树的根节点 root ,翻转这棵二叉树,并返…

DAY 52 LVS+Keepalived群集

Keepalived工具介绍 普通集群容易出现的问题 企业应用中,单台服务器承担应用存在单点故障的危险。 单点故障一旦发生, 企业服务将发生中断,造成极大的危害。 Keepalived工具 Keepalived 是一个基于VRRP协议来实现的LVS服务高可用方案&…

v2c - 从Verilog 转换到 C语言的工具

文章目录 一、如何安装1.下载二进制文件2.基准测试 二、如何使用v2c的应用描述工具流程使用 v2c 转换器的工作示例 三、注意事项情形一:拼接:{4{x}}情形1-1 y&{x,x,x,x}情形1-2 y&{x,x,…

【C++】string 类的实现

目录 构造函数赋值重载关于浅拷贝 迭代器容量相关reserveresize 修改push_backappendinserterase关于npos 流运算符重载流插入流提取 构造函数 无参数构造和传参构造 通过对参数设置缺省值为空串""同时满足无参构造和传参构造成员 _size 和 _capacity 均是针对有效…

自动驾驶—连续系统LQR最优控制的黎卡提方程推导

1. Why use the Riccati equation? 最优控制算法LQR是Linear Quadratic Regulator的缩写,Q、R就是需要设计的半正定矩阵和正定矩阵。考虑根据实车的情况去标定此参数,从理论和工程层面去理解,如果增大Q、减小R,则此时控制系统响应速度比较快速(比较剧烈),直观反映方向…

5月1日 9H45min|5.2 8H20min+30min|时间轴复盘

8:00 起床 8:00-8:30 洗漱吃饭 8:30-10:40 temporary pools阅读真题精读 (真的很慢了 不知道什么原因 感觉也没有彻底完全弄懂)【2h+10min】 10:40-11:10 午餐+酸奶(423+174KJ) 11:20-12:30 三篇阅读【1h+10min】 13:10-14:50 健身 14:50-15:45诵默写list…

Ae:画笔工具

画笔工具 Brush Tool 快捷键:Ctrl B 画笔工具 Brush Tool仅能工作在图层 Layer面板上。 双击纯色图层、像素图层等可打开图层面板。 在 Ae 中的每次画笔绘制都将新建一条路径,然后通过对路径的描边来显示绘制结果,故又称为“绘画描边”或“…

函数-实现交换两个变量的内容

用函数实现交换两个变量的内容&#xff0c;对于该问题我们该如何实现呢&#xff1f;在这里我就用整型变量来说明。 题目&#xff1a;写一个函数可以交换两个整形变量的内容。 我们先来看看如下代码&#xff1a; #include <stdio.h> void swap(int x, int y) {int tem…

Android进阶之光:Dagger2原理简要分析

Dagger2注入框架原理简要分析 使用Dagger2需要的依赖: implementation com.google.dagger:dagger-android:2.46 implementation com.google.dagger:dagger-android-support:2.46 annotationProcessor com.google.dagger:dagger-android-processor:2.46 annotationProcessor c…

第二十七章 碰撞体Collision(下)

本章节我们继续研究碰撞体&#xff0c;并且探索一下碰撞体与刚体之间的联系。我们回到之前的工程&#xff0c;然后给我们的紫色球体Sphere1也添加一个刚体组件。如下所示 此时&#xff0c;两个球体都具备了碰撞体和刚体组件。接下来&#xff0c;我们Play运行查看效果 我们发现&…

从零开始带你开发橙光游戏AVG框架(仿 葬花 )

来源 从零开始带你开发橙光游戏AVG框架【55课数 收费】 从零开始带你开发橙光游戏AVG框架 unity教程【16课数 免费】 。。。。。。 挺大的&#xff0c;因为很多音频&#xff0c;.git就有 2.6G AVG_20230413_2020.2.23f1c1 介绍 QuickSheet使用 bug 包报错 可能是我换了un…

LeetCode138. 复制带随机指针的链表

138. 复制带随机指针的链表 描述示例解题思路以及代码解法1解法2 描述 给你一个长度为 n 的链表&#xff0c;每个节点包含一个额外增加的随机指针 random &#xff0c;该指针可以指向链表中的任何节点或空节点。 构造这个链表的 深拷贝。 深拷贝应该正好由 n 个 全新 节点组成…

电脑文件加密软件哪个最好用:试试文件加密软件排行榜第一的EaseUS LockMyFile吧 | 军事级加密你值得拥有!!!

EaseUS LockMyFile是一款出色且安全可靠的军事级电脑文件加密管理软件&#xff0c;也叫易我文件加密软件&#xff0c;拥有文件隐藏、文件加锁、文件保护、读写监控、安全删除等诸多实用功能&#xff0c;能帮助大家锁定和隐藏闪存驱动器、外部USB 驱动器、内部硬盘驱动器以及局域…

51单片机(六)矩阵键盘和矩阵键盘密码锁

❤️ 专栏简介&#xff1a;本专栏记录了从零学习单片机的过程&#xff0c;其中包括51单片机和STM32单片机两部分&#xff1b;建议先学习51单片机&#xff0c;其是STM32等高级单片机的基础&#xff1b;这样再学习STM32时才能融会贯通。 ☀️ 专栏适用人群 &#xff1a;适用于想要…

几种常见时间复杂度实例分析

多项式量级 常量阶 O(1) 对数阶 O(logn) 线性阶 O(n) 线性对数阶 O(nlogn) 平方阶O(n2 ),立方阶O(n3 )...k次方阶O(nk) 非多项式量级&#xff08;NP&#xff08;Non-Deterministic Polynomial&#xff0c;非确定多项式&#xff09;问题&#xff09; 指数阶O(2n) 阶乘阶…

离线数据同步Sqoop与DataX

文章目录 一、Sqoop安装与使用1、简介2、Sqoop安装3、Sqoop实例3.1 Mysql导入Hadoop3.2 Hadoop导出到Mysql 二、DataX概述与入门1、DataX概述1.1 简介1.2 框架设计1.3 运行原理 2、DataX与 Sqoop 的对比3、快速入门 三、DataX常用入门案例1、从stream 流读取数据并打印到控制台…

前端web3入门脚本六:套利夹子机器人,羊毛党必备

一、前言 DEX上有很多零风险套利的机会&#xff0c;包括三角套利&#xff0c;夹子机器人… 今天主要介绍一下架子机器人的思路和简易实现。 二、实现思路 套利原理&#xff1a; 夹子机器人的核心&#xff1a;在韭菜买入前以更低价格买入&#xff0c;并再韭菜买入后卖出&#…

Curator中的分布式锁解读

目录 基本介绍 基本配置 可重入锁InterProcessMutex 不可重入锁InterProcessSemaphoreMutex 可重入读写锁InterProcessReadWriteLock 联锁InterProcessMultiLock 信号量InterProcessSemaphoreV2 栅栏barrier 倒计数器 基本介绍 Curator是netflix公司开源的一套zookeeper…

C语言力扣简单题-无重复字符的最长子串

(创作不易&#xff0c;感谢有你&#xff0c;你的支持&#xff0c;就是我前行的最大动力&#xff0c;如果看完对你有帮助&#xff0c;请留下您的足迹&#xff09; 目录 无重复字符的最长子串 题目&#xff1a; 代码思路&#xff1a; 代码表示&#xff1a; 无重复字符的最长子…

【C++】lambda表达式

文章目录 lambda表达式lambda概念lambda表达式的格式关于捕获列表常见问题: 使用lambda表达式交换两个数lambda表达式底层原理 lambda表达式 lambda概念 lambda表达式本质是一个匿名函数(因为它没有名字),恰当使用lambda表达式可以让代码变得简洁.并且可以提高代码的可读性 例…