《PyTorch深度学习实践》第七讲 处理多维特征的输入

news2025/1/10 21:02:55

b站刘二大人《PyTorch深度学习实践》课程第七讲处理多维特征的输入笔记与代码:https://www.bilibili.com/video/BV1Y7411d7Ys?p=7&vd_source=b17f113d28933824d753a0915d5e3a90


Diabetes Dataset

  • 每一行是一个记录
  • 每一列是一个特征,每个样本有8个特征
image-20230701135124288

每个样本不再是一个特征,即模型输入变成了多维,那么就要将所有特征乘以相应的权重然后再累加起来

image-20230701135715087

Mini-Batch(N samples)

  • 向量化形式

    • 可以用上并行计算
    image-20230701140237529
image-20230701140318302
  • Linear的第一个参数是输入特征数,第二个是输出特征数
image-20230701140602305
  • 通过引入激活函数 σ \sigma σ,给线性变换增加非线性因子,这样就可以去拟合非线性变换
image-20230701141213192 image-20230701141243980
  • 学习能力越强,有可能会把输入样本中噪声的规律也学到,而我们要学习的是数据本身真实数据的规律,因此关键的是模型的泛化能力

例子:Artificial Neural Network -> Diabetes Prediction

image-20230701141917363
  • Prepare dataset:

    import numpy as np
    import torch
    
    xy = np.loadtxt('dataset/diabetes.csv.gz', delimiter=',', dtype=np.float32)
    # :表示所有行;:-1表示第一列开始,最后一列(-1)不要,最后一列是分类(这是输出)
    x_data = torch.from_numpy(xy[:, :-1])  
    # : 表示所有行;[-1]表示只要最后一列,加中括号意味着取出后是一个矩阵,不加则是向量
    y_data = torch.from_numpy(xy[:, [-1]])  
    

    数据集放到和源代码同一个存储目录下即可,代码目录是Liuer_lecturer,数据集放在Liuer_lecturer/dataset

    image-20230701142807927 image-20230701143747036
    • delimiter是分隔符;dtype是指定数据类型

      • 用float32是因为常用的GPU(1080,2080等)中都只支持32位浮点数,因此在神经网络计算中通常使用32位浮点数
    • 可以用print查看数据,如下是print(x_data)

      image-20230701143047352
  • Design model using class:

    image-20230701144621272
    class Model(torch.nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.linear1 = torch.nn.Linear(8, 6)
            self.linear2 = torch.nn.Linear(6, 4)
            self.linear3 = torch.nn.Linear(4, 1)
            # 激活函数
            # 之前的是调用torch.nn.Functional的sigmoid函数
            # 现在调用的是nn下的模块,把它当成一个层(运算模块)构建计算图
            self.sigmoid = torch.nn.Sigmoid()
    
        def forward(self, x):
            x = self.sigmoid(self.linear1(x))
            x = self.sigmoid(self.linear2(x))
            x = self.sigmoid(self.linear3(x))
            return x
    
    model = Model()
    
  • Construct loss and optimizer:

    image-20230701144806211
    # criterion = torch.nn.MSELoss(size_average=True) pytorch更新后被弃用了
    criterion = torch.nn.BCELoss(reduction='mean')
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    
  • Training cycle:

    image-20230701145102335
    • 此处没做mini-batch,而是将全部数据放进去,mini-batch后续才讲
    # 训练过程
    for epoch in range(100):
        # 前馈
        y_pred = model(x_data)              # 计算y_hat
        loss = criterion(y_pred, y_data)    # 计算损失
        print(epoch, loss.item())
    
        # 反馈
        optimizer.zero_grad()   # 在反向传播开始将上一轮的梯度归零
        loss.backward()         # 反向传播(计算梯度)
    
        # 更新
        optimizer.step()        # 更新权重w和偏置b
    

完整代码:

import torch
import numpy as np
import matplotlib.pyplot as plt

# 建立数据集
xy = np.loadtxt('dataset/diabetes.csv.gz', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1])   # :表示所有行;:-1表示第一列开始,最后一列(-1)不要,最后一列是分类(这是输出)
y_data = torch.from_numpy(xy[:, [-1]])  # : 表示所有行;[-1]表示只要最后一列,加中括号意味着取出后是一个矩阵,不加则是向量

# 用于绘图
epoch_list = []
loss_list =[]


# 定义模型
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x


model = Model()

# criterion = torch.nn.MSELoss(size_average=True) pytorch更新后被弃用了
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 训练过程
for epoch in range(1000):
    # 前馈
    y_pred = model(x_data)              # 计算y_hat
    loss = criterion(y_pred, y_data)    # 计算损失
    print(epoch, loss.item())
    epoch_list.append(epoch)
    loss_list.append(loss.item())

    # 反馈
    optimizer.zero_grad()   # 在反向传播开始将上一轮的梯度归零
    loss.backward()         # 反向传播(计算梯度)

    # 更新
    optimizer.step()        # 更新权重w和偏置b

# 绘制loss曲线
plt.plot(epoch_list, loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()
image-20230701145749724
image-20230701145408508 image-20230701145551850

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/707239.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

为什么我们家里的IP都是192.168开头的?

为什么我们家里的IP都是192.168开头的? 本文为掘金社区首发签约文章,14天内禁止转载,14天后未获授权禁止转载,侵权必究! 是的,还是我小白,什么技术博主,老情感博主了。 来讲个故事。…

网络安全合规-数据安全分类分级

数据安全是指保护数据免受未经授权的访问、使用、泄露、破坏或篡改的措施。数据安全包括物理安全、网络安全、应用程序安全、数据备份和恢复等方面。 数据分级分类是指根据数据的重要性和敏感程度,将数据划分为不同的级别,并根据不同级别的数据制定不同…

enote笔记法之附录1——“语法词”(即“关联词”)(ver0.23)

enote笔记法之附录1——“语法词”(即“关联词”)(ver0.23) 最上面的是截屏的完整版,分割线下面的是纯文字版本: 作者姓名(本人的真实姓名):胡佳吉 居住地&#xff1…

前言-----

因要参加电赛,接触到STC89C52RC(A51)单片机 STC89C52RC引脚功能 1电源: ①VCC - 芯片电源,接5V; ②VSS - 接地端; 2.时钟: XTAL1、XTAL2 - 晶体振荡电路反相输入端和输出端。 3.控制线: 控制线共…

Java 17官方编程手册都针对哪些方面做了更新?

Java 17,官方编程手册, 《International Developer》杂志称为“全世界醉著名的编程书籍创作者之一”的Herbert Schildt倾情解读 《Java官方编程手册》从1996年首次出版以来,已经经历了数次改版,每次改版都反映 了Java不断演化的进…

分享解析,2+1链动模式为何能在市场上经久不衰

​小编介绍:10年专注商业模式设计及软件开发,擅长企业生态商业模式,商业零售会员增长裂变模式策划、商业闭环模式设计及方案落地;扶持10余个电商平台做到营收过千万,数百个平台达到百万会员,欢迎咨询。 随…

服务网格:Istio 架构

什么是服务网格 服务网格(Service Mesh)这个术语通常用于描述构成这些应用程序的微服务网络以及应用之间的交互。随着规模和复杂性的增长,服务网格越来越难以理解和管理。 它的需求包括服务发现、负载均衡、故障恢复、指标收集和监控以及通常更加复杂的运维需求&am…

数据结构--双端队列

数据结构–双端队列 双端队列(Double-ended Queue,简称Deque)是一种具有队列和栈特性的数据结构,可以在队列的两端进行插入和删除操作。双端队列允许从前端和后端同时进行插入和删除操作,因此可以称为“两端都可以进出…

「STC8A8K64D4开发板」第2-6讲:串口通信

第2-6讲:串口通信 学习目的掌握USB转串口电路的原理和设计。学习STC8A8K64D4的串口通信,包括串口初始化、波特率计算、串口发送和接收。编写串口收发程序,尤其是串口接收的软件缓存处理。编写串口发送命令控制LED指示灯亮灭的程序。 硬件电路…

【电商API接口系列】店铺所有商品数据的采集

API接口允许不同应用程序之间共享数据,在系统之间传输、读取和更新数据。例如,一个电商网站可以通过API接口获取支付系统的支付状态。API接口允许开发人员使用他人开发的功能来扩展自己的应用程序。通过调用第三方API接口,开发人员无需重新实…

二进制部署Kubernetes

二进制部署Kubernetes v1.20 k8s集群master01:192.168.142.10 kube-apiserver kube-controller-manager kube-scheduler etcd k8s集群master02:192.168.142.20 k8s集群node01:192.168.142.30 kubelet kube-proxy docker k8s集群node…

基于Java汽车售票网站设计实现(源码+lw+部署文档+讲解等)

博主介绍:✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专…

VUE_网页自定义右键菜单组件

可以在uni-app或vue脚手架项目使用 引入组件会接管页面右键事件&#xff0c;所有options为空数组时&#xff0c;在页面右键将没有反应 rightMenu.vue <template><view><view v-if"show" class"contextMenu" :style"lay_style"…

Kafka:Kafka资料整理

一、官网 二、博主文章 1、kafka是什么 • Worktile社区 三、源码解读

一文了解云计算

目录 &#x1f34e;云服务 &#x1f34e;云计算类型 &#x1f352;公有云 &#x1f352;私有云 &#x1f352;混合云 &#x1f34e;云计算服务模式 &#x1f352;IaaS基础设施即服务 &#x1f352;PaaS平台即服务 &#x1f352;SaaS软件即服务 &#x1f352;三者之间区别 &…

4.springboot原理篇

原理篇 spring与springboot区别 spring是承载容器 springboot做的主要工作&#xff1a; ①简化配置&#xff08;省去了spring中配置xml&#xff0c;引入application.yml文件&#xff09; ②为我们提供了 spring-boot-starter-web 依赖&#xff0c;这个依赖包含了Tomcat和sprin…

二进制搭建Kubernetes集群(二)——部署Worker Node 组件

四.部署node节点 4.1 所有node节点部署 docker引擎 #所有 node 节点部署docker引擎#安装依赖包yum install -y yum-utils device-mapper-persistent-data lvm2#设置阿里云镜像源yum-config-manager --add-repo https://mirrors.aliyun.com/docker-ce/linux/centos/docker- ce.…

Nuget更新全局包、缓存和临时文件夹路径位置

Nuget更新缓存 1、查看默认的Nuget路径2、更改全局包路径2.1 通过环境变量来进行修改2.2通过Nuget.Config配置文件来进行修改 3、更改http-cache路径4、更改temp文件路径5、更改plugins-cache文件路径 NuGet是一个流行的软件包管理器&#xff0c;可以帮助.NET开发人员轻松地添加…

【Python】 【Pandas 】【read_csv()】Pandas库的read_csv()方法的使用,处理:None,NULL

近期&#xff0c;使用read_csv的时候&#xff0c;遇到一个问题&#xff0c;就是本地读取的csv文件中的数据有None和NaN 两种&#xff0c;如&#xff1a; 直接使用 pd.read_csv(rF:\我爱Python\预测\历史样本.csv,encodingutf-8)发现读取的数据是将None 和 NULL 直接处理成 NaN…

SpingData-JDBC(看这篇文章就够了,新手入门指引)

JdbcTemplate 的基本使用 写在前面&#xff1a; 当DDL操作时&#xff0c;一般是用execute方法&#xff0c;这也是一种规范吧&#xff0c;这个也可以运行DML但是通常来说我DML操作是需要返回值的&#xff0c;一般就是返回影响的行数。然后这篇文章主要介绍增删改查&#xff0c…