PyTorch学习6:多维特征输入

news2024/11/29 4:33:20

文章目录

  • 前言
  • 一、模型说明
  • 二、示例
    • 1.求解步骤
    • 2.示例代码
  • 总结


前言

介绍了如何处理多维特征的输入问题

一、模型说明

多维问题分类模型
在这里插入图片描述

二、示例

1.求解步骤

1.载入数据集:数据集用路径D:\anaconda\Lib\site-packages\sklearn\datasets\data下的diabetes.csv,输入有8个维度
2.创建模型:维度8-6-4-2-1
3.选择损失函数和优化器
3.进行训练

2.示例代码

代码如下(示例):

import numpy as np
import torch
import matplotlib.pyplot as plt

# prepare dataset
xy = np.loadtxt('diabetes.csv', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1])  # 第一个‘:’是指读取所有行,第二个‘:’是指从第一列开始,最后一列不要
print("input data.shape", x_data.shape)
y_data = torch.from_numpy(xy[:, [-1]])  # [-1] 最后得到的是个矩阵


# print(x_data.shape)
# design model using class


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 2)
        self.linear4 = torch.nn.Linear(2, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))  # y hat
        x = self.sigmoid(self.linear4(x))  # y hat
        return x


model = Model()

# construct loss and optimizer
criterion = torch.nn.BCELoss(size_average = True)

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

epoch_list = []
loss_list = []
# training cycle forward, backward, update
for epoch in range(1000):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    # print(epoch, loss.item())
    print(epoch, loss.item())
    epoch_list.append(epoch)
    loss_list.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 100 == 99:
        y_pred_label = torch.where(y_pred >= 0.5, torch.tensor([1.0]), torch.tensor([0.0]))

        acc = torch.eq(y_pred_label, y_data).sum().item() / y_data.size(0)
        print("loss = ", loss.item(), "acc = ", acc)


plt.plot(epoch_list, loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()

得到如下结果:
在这里插入图片描述
在这里插入图片描述

总结

PyTorch学习6:多维特征输入

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1804546.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C++ STL】模拟实现 string

标题:【C :: STL】手撕 STL _string 水墨不写bug (图片来源于网络) C标准模板库(STL)中的string是一个可变长的字符序列,它提供了一系列操作字符串的方法和功能。 本篇文章,我们将模拟实现STL的…

Polar Web【中等】写shell

Polar Web【中等】写shell Contents Polar Web【中等】写shell思路&探索EXP运行&总结 思路&探索 初看题目,预测需要对站点写入木马,具体操作需要在过程中逐步实现。 打开站点(见下图),出现 file_put_contents 函数,其…

pdf文件如何防篡改内容

PDF文件防篡改内容的方法有多种,以下是一些常见且有效的方法,它们可以帮助确保PDF文件的完整性和真实性: 加密PDF文档: 原理:通过设置密码来保护PDF文档,防止未经授权的访问和修改。注意事项:密…

如何对stm32查看IO功能。

有些同学对于别人的开发板的资源,或者IO口,或者串口等资源不知道怎么分配。 方法1、看硬石、野火、正点原子的开发板,看下他们的例子,那个资源用什么。自己多看几个原理图,多看几个视频,做一下笔记。以后依…

通过无障碍控制 Compose 界面滚动的实战和原理剖析

前言 针对 Compose UI 工具包,开发者不仅需要掌握如何使用新的 UI 组件达到 design 需求,更需要了解和实现与 UI 的交互逻辑。 比如 touch 事件、Accessibility 事件等等。 Compose 中对 touch 事件的处理和原理,笔者已经在《通过调用栈快…

Point-LIO:鲁棒高带宽激光惯性里程计

1. 动机 现有系统都是基于帧的,类似于VSLAM系统,频率固定(例如10Hz),但是实际上LiDAR是在不同时刻进行顺序采样,然后积累到一帧上,这不可避免地会引入运动畸变,从而影响建图和里程计精度。此外…

NASA数据集——SARAL 近实时增值业务地球物理数据记录海面高度异常

SARAL Near-Real-Time Value-added Operational Geophysical Data Record Sea Surface Height Anomaly SARAL 近实时增值业务地球物理数据记录海面高度异常 简介 2020 年 3 月 18 日至今 ALTIKA_SARAL_L2_OST_XOGDR 这些数据是近实时(NRT)&#xff…

【稳定检索/投稿优惠】2024年材料科学与能源工程国际会议(MSEE 2024)

2024 International Conference on Materials Science and Energy Engineering 2024年材料科学与能源工程国际会议 【会议信息】 会议简称:MSEE 2024大会地点:中国苏州会议官网:www.iacmsee.com会议邮箱:mseesub-paper.com审稿结…

WPF音乐播放器 零基础4个小时左右

前言:winfrom转wpf用久的熟手说得最多的是,转回去做winfrom难。。当时不明白。。做一个就知道了。 WPF音乐播放器 入口主程序 FontFamily"Microsoft YaHei" FontSize"12" FontWeight"ExtraLight" 居中显示WindowStartupLocation&quo…

【越界写null字节】ACTF2023 easy-netlink

前言 最近在矩阵杯遇到了一道 generic netlink 相关的内核题,然后就简单学习了一下 generic netlink 相关概念,然后又找了一到与 generic netlink 相关的题目。简单来说 generic netlink 相关的题目仅仅是将用户态与内核态的交互方式从传统的 ioctl 变成…

以sqlilabs靶场为例,讲解SQL注入攻击原理【42-53关】

【Less-42】 使用 or 11 -- aaa 密码,登陆成功。 找到注入点:密码输入框。 解题步骤: # 获取数据库名 and updatexml(1,concat(0x7e,(select database()),0x7e),1) -- aaa# 获取数据表名 and updatexml(1,concat(0x7e,(select group_conca…

CSS函数: translate、translate3d的使用

translate()和translate3d()函数可以实现元素在指定轴的平移的功能。函数使用在CSS转换属性transform的属性值。实现转换的函数类型有: translate():2D平面实现X轴、Y轴的平移translate3d():3D空间实现位置的平移translateX():实…

Spring Boot整合Jasypt 库实现配置文件和数据库字段敏感数据的加解密

😄 19年之后由于某些原因断更了三年,23年重新扬帆起航,推出更多优质博文,希望大家多多支持~ 🌷 古之立大事者,不惟有超世之才,亦必有坚忍不拔之志 🎐 个人CSND主页——Mi…

idea如何根据路径快速在项目中快速打卡该页面

在idea项目中使用快捷键shift根据路径快速找到该文件并打卡 双击shift(连续按两下shift) -粘贴文件路径-鼠标左键点击选中跳转的路径 自动进入该路径页面 例如:我的实例路径为src/views/user/govType.vue 输入src/views/user/govType或加vue后缀src/views/user/go…

ChatGLM2-6b的本地部署

** 大模型玩了一段时间了,一直没有记录,借假期记录下来 ** ChatGlm2介绍: chatglm2是清华大学发布的中英文双语对话模型,具备强大的问答和对话功能,拥有长达32K的上下文,可以输出比较长的文本。6b的训练参…

Python:处理矩阵之NumPy库(上)

目录 1.前言 2.Python中打开文件操作 3.初步认识NumPy库 4.使用NumPy库 5.NumPy库中的维度 6.array函数 7.arange函数 8.linspace函数 9.logspace函数 10.zeros函数 11.eye函数 前言 NumPy库是一个开源的Python科学计算库,它提供了高性能的多维数组对象、派生对…

linux centos redis-6.2.6一键安装及配置密码

linux centos redis-6.2.6一键安装及配置密码 redis基本原理一、操作阶段,开始安装 redis基本原理 redis作为非关系型nosql数据库,一般公司会作为缓存层,存储唯一会话id,以及请求削峰作用 一、数据结构 Redis支持多种数据结构&a…

操作系统期末复习整理知识点

操作系统的概念:①控制和管理整个计算机系统的硬件和软件资源,并合理地组织调度计算机的工作和资源的分配;②提供给用户和其他软件方便的接口和环境;③是计算机中最基本的系统软件 功能和目标: ①操作系统作为系统资源…

【JAVASE】详讲JAVA语法

这篇你将收获到以下知识: (1)方法重载 (2)方法签名 一:方法重载 什么是方法重载? 在一个类中,出现了多个方法的名称相同,但是它们的形参列表是不同的,那…

【Linux系统编程】进程地址空间

目录 前言 进程虚拟地址空间的引入 进程地址空间的概念 进一步理解进程地址空间 为什么需要进程地址空间? 系统层面理解malloc/new内存申请 前言 首先,在我们学习C语言的时候一定会见过如下这张图。(没见过也没关系,接下来…