Multiple Dimension Input 处理多维特征的输入

news2024/11/15 11:19:09

文章目录

    • 6、Multiple Dimension Input 处理多维特征的输入
      • 6.1 Revision
      • 6.2 Diabetes Dataset 糖尿病数据集
      • 6.3 Logistic Regression Model 逻辑斯蒂回归模型
      • 6.4 Mini-Batch(N samples)
      • 6.5 Neural Network 神经网络
      • 6.6 Diabetes Prediction 糖尿病预测
        • 6.6.1 Prepare Dataset
        • 6.6.2 Define Model
        • 6.6.3 Construct Loss and Optimizer
        • 6.6.4 Training Cycle
        • 6.6.5 Activate function
        • 6.6.5 完整代码

6、Multiple Dimension Input 处理多维特征的输入

B站视频教程传送门:PyTorch深度学习实践 - 逻辑斯蒂回归

6.1 Revision

我们先来回顾一下回归分类

差别: 主要在于输出值

回归(Regressiom):y ∈ R

分类(Classification):y ∈ { } 离散的集合

6.2 Diabetes Dataset 糖尿病数据集

如果我们安装过sklearn(python编程安装sklearn),其中就包含糖尿病数据集,可以进入该目录(D:\Software\Anaconda\Lib\site-packages\sklearn\datasets\data)下查看,如下图所示:

6.3 Logistic Regression Model 逻辑斯蒂回归模型

由于这里的 x不再是简简单单的一维,而是 8维,所以应该看成下方两个矩阵相乘:

注: σ = 1 1 + e − z \sigma = \frac {1} {1 + e^{-z}} σ=1+ez1

6.4 Mini-Batch(N samples)

import torch


class Liang(torch.nn.Module):
    def __init__(self):
        super(Liang, self).__init__()
        self.linear = nn.Linear(8, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.sigmoid(self.linear(x))
        return x


model = Liang()

6.5 Neural Network 神经网络

当输入8维,输出2维时:

self.linear = torch.nn.Linear(8, 2)

当输入8维,输出6维时:

self.linear = torch.nn.Linear(8, 6)

可以降维,可以升维,也可以一降(升)一升(降):

6.6 Diabetes Prediction 糖尿病预测

X1~X8:病人相应的指标

Y:一年后病情是否加重(预测)

6.6.1 Prepare Dataset

import numpy as np

xy = np.loadtxt('../data/diabetes.csv.gz', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1])  # 所有行,除了最后一列
y_data = torch.from_numpy(xy[:, [-1]])  # 所有行,最后一列 转为矩阵而不是向量

6.6.2 Define Model

6.6.3 Construct Loss and Optimizer

criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

6.6.4 Training Cycle

for epoch in range(100):
    # Forward
    y_pred = model(x_data) # This program has not use Mini-Batch for training. We shall talk about DataLoader later.
    loss = criterion(y_pred, y_data)
    print(epoch, loss.item())

    # Backward
    optimizer.zero_grad()
    loss.backward()
    
    # Update
    optimizer.step()

6.6.5 Activate function

神经网络中激活函数的可视化:https://dashee87.github.io/deep%20learning/visualising-activation-functions-in-neural-networks/

PyTorch文档:https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity

6.6.5 完整代码

import torch
import numpy as np
import matplotlib.pyplot as plt

xy = np.loadtxt('../data/diabetes.csv.gz', delimiter=',', dtype=np.float32)

x_data = torch.from_numpy(xy[:, :-1])  # 所有行,除了最后一列
y_data = torch.from_numpy(xy[:, [-1]])  # 所有行,最后一列 转为矩阵而不是向量


class Liang(torch.nn.Module):
    def __init__(self):
        super(Liang, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)

        self.sigmoid = torch.nn.Sigmoid()  # Sigmoid
        self.tanh = torch.nn.Tanh()

    def forward(self, x):
        x = self.tanh(self.linear1(x))
        x = self.tanh(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x


model = Liang()

criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

epoch_list = []
loss_list = []

for epoch in range(100):
    # Forward
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    print(epoch, loss.item())

    epoch_list.append(epoch)
    loss_list.append(loss.item())

    # Backward
    optimizer.zero_grad()
    loss.backward()

    # Update
    optimizer.step()

plt.plot(epoch_list, loss_list)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Tanh')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/176688.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Exadata存储服务器(又称Exadata存储单元)

存储单元可以说是让Exadata如此大规模普及并且使用效果优异的核心要素。 I/O性能问题始终是Exadata存储或者存储服务器尽力去解决的问题。 Exadata存储服务器概述 Exadata数据库一体机通常预装了3类硬件: 数据库计算节点服务器存储服务器极速的InfiniBand存储交…

离散数学与组合数学-03函数

文章目录离散数学与组合数学-03函数3.1 函数基本概念3.1.1 函数的定义3.1.2 函数举例3.1.3 函数的数量3.1.4 关系与函数的差别3.2函数的类型3.2.1 函数类型3.2.2 函数类型的必要条件3.2.3 函数类型的数学化描述3.2.4 函数类型的证明3.3 函数的运算3.3.1 函数的复合运算3.3.2 函…

统计学习方法 学习笔记(1)统计学习方法及监督学习理论

统计学习方法及监督学习理论1.1.统计学习1.1.1.统计学习的特点1.1.2.统计学习的对象1.1.3.统计学习的目的1.1.4.统计学习的方法1.1.5.统计学习的研究1.1.6.统计学习的重要性1.2.统计学习的分类1.2.1.基本分类1.2.1.1.监督学习1.2.1.2.无监督学习1.2.1.3.强化学习1.2.1.4.半监督…

【HBase入门】2. 集群搭建

安装 上传解压HBase安装包 tar -xvzf hbase-2.1.0.tar.gz -C ../server/ 修改HBase配置文件 hbase-env.sh cd /export/server/hbase-2.1.0/conf vim hbase-env.sh # 第28行 export JAVA_HOME/export/server/jdk1.8.0_241/ export HBASE_MANAGES_ZKfalsehbase-site.xml vim…

【算法】洗牌算法

目录1.概述2.代码实现2.1.暴力法2.2.Fisher-Yates 洗牌算法3.应用本文参考: LeetCode 384. 打乱数组 1.概述 (1)洗牌算法可以理解为:设计算法来打乱一个没有重复元素的数组 nums,并且打乱后,数组的所有排列…

使用C++实现学委作业管理系统

开发环境学委作业管理系统在 Microsoft Visual Studio 2013 编译器开发的 MFC 项目,计算机使用的系统是 window10。1.2 基本原理与技术要求熟悉文件读写、mfc 基本知识、c 类运用、链表使用、排序算法、Microsoft Visual Studio 2013 编译器的使用。1.3 需求说明学委…

【数据结构】二叉搜索树的实现

目录 一、二叉搜索树的概念 二、二叉搜索树的中序遍历用于排序去重 三、二叉搜索树的查找 1、查找的非递归写法 2、查找的递归写法 四、二叉搜索树的插入 1、插入的非递归写法 2、插入的递归写法 五、二叉搜索树的删除 1、删除的非递归写法 2、删除的递归写法 六、…

autojs模仿QQ长按弹窗菜单(二)

牙叔教程 简单易懂 上一节讲了列表和长按事件 autojs模仿QQ长按弹窗菜单 今天讲弹窗菜单 由粗到细, 自顶向下的写代码 我们现在要修改的文件是showMenuWindow.js function showMenuWindow(view) {let popMenuWindow ui.inflateXml(view.getContext(),<column><bu…

基于双层优化的微电网系统规划设计方法(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

机制设计原理与应用(一)机制设计基础

什么是机制设计&#xff1f; 微观经济学和CS /EE的交叉学科。它采用了一种工程方法来设计激励机制&#xff0c;以实现战略环境中不完全信息的预期目标。机制设计具有广泛的应用,特别是在资源管理方面。 文章目录1 机制设计的基础1.1 简介1.2 机制设计与博弈及优化的关系1.3 机…

手撕Pytorch源码#4.Dataset类 part4

写在前面手撕Pytorch源码系列目的&#xff1a;通过手撕源码复习了解高级python语法熟悉对pytorch框架的掌握在每一类完成源码分析后&#xff0c;会与常规深度学习训练脚本进行对照本系列预计先手撕python层源码&#xff0c;再进一步手撕c源码版本信息python&#xff1a;3.6.13p…

大数据之HBase集群搭建

文章目录前言一、上传并解压HBase安装包二、修改HBase配置文件&#xff08;一&#xff09;hbase-env.sh&#xff08;二&#xff09;hbase-site.xml三、配置环境变量四、复制jar包到lib文件夹五、修改regionservers文件六、分发安装包和配置文件七、启动Hbase八、验证HBase是否启…

尚硅谷前端ES6-ES11

ECMAScript 是由 Ecma 国际通过 ECMA-262 标准化得脚本程序设计语言。 1.let变量声明以及变量声明特性 <body><script>//let的声明let a , b10;//特性1&#xff1a;变量不能重复声明&#xff0c;避免命名污染// let star "罗翔"// let star "张…

Java | 浅谈多态中的向上转型与向下转型

文章目录&#x1f333;向上转型&#x1f4d5;概念明细&#x1f4aa;使用场景1&#xff1a;直接赋值&#x1f4aa;使用场景2&#xff1a;方法传参&#x1f4aa;使用场景3&#xff1a;方法返回&#x1f4aa;向上转型的优缺点&#x1f333;向下转型&#x1f529;向下转型解决【调用…

程序员拯救了一次地球

流浪地球2&#xff1a;程序员拯救了一次地球 顺便给我们讲了一个道理&#xff1a; 人类会谋划未来&#xff0c; 但关键的一步是靠勇气迈出去的 趣讲大白话&#xff1a;算得好不如胆量好 *********** 电影工业的皇冠是特效 国产电影的特效进步不小 时时刻刻&#xff0c;分分秒秒…

用户画像计算更新

3.1 用户画像计算更新 目标 目标 知道用户画像建立的流程应用 无 3.1.1 为什么要进行用户画像 要做精准推送同样可以使用多种推荐算法&#xff0c;例如&#xff1a;基于用户协同推荐、基于内容协同的推荐等其他的推荐方式&#xff0c;但是以上方式多是基于相似进行推荐。而构…

ROS移动机器人——ROS基础知识与编程

此文章基于冰达机器人进行笔记整理&#xff0c;使用的环境为其配套环境&#xff0c;可结合之前的ROS&#xff0c;赵虚左老师的文章结合进行观看&#xff0c;后期也会进行整合 1. ROS安装 &#xff08;1&#xff09;配置ubuntu的软件和更新&#xff0c;允许安装不经认证的软件…

JS手动触发PWA安装窗口

✅作者简介&#xff1a;人工智能专业本科在读&#xff0c;喜欢计算机与编程&#xff0c;写博客记录自己的学习历程。 &#x1f34e;个人主页&#xff1a;小嗷犬的博客 &#x1f34a;个人信条&#xff1a;为天地立心&#xff0c;为生民立命&#xff0c;为往圣继绝学&#xff0c;…

仿写Dubbo-初识Dubbo

概念 Dubbo 在Dubbo官网介绍到&#xff0c;Apache Dubbo 是一款 RPC 服务开发框架&#xff0c;用于解决微服务架构下的服务治理与通信问题。 RPC RPC&#xff08;Remote Procedure Call&#xff09;远程过程调用协议&#xff0c;一种通过网络从远程计算机上请求服务&#xff0c…

【Android】手机安装Termux运行nodejs学习Javascript编程入门

Termux 是运行在Android手机上的一个 Linux 终端模拟器&#xff0c;干什么都要输入命令执行&#xff0c;不像 Windows 操作系统桌面用鼠标点点点&#xff0c;这里主要介绍用它来学习Javascript编程入门&#xff0c;当然&#xff0c;这和小时候学过的C语言编程课入门一样的&…