5-pytorch-torch.nn.Sequential()快速搭建神经网络

news2024/11/25 10:59:34

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 前言
  • torch.nn.Sequential()快速搭建网络法
    • 1 生成数据
    • 2 快速搭建网络
    • 3 训练、输出结果
  • 总结


前言

本文内容还是基于4-pytorch前馈网络简单(分类)问题搭建这篇的相同例子,只是为了介绍另一种更加快速搭建网络的方法,看个人喜好用哪一种。
【注】:建议先看完上面链接的博客4,在来看本篇。
这里的这种搭建方法是使用**torch.nn.Sequential()**快速搭建,不用我们在继承重写net类了。

torch.nn.Sequential()快速搭建网络法

1 生成数据

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

n_data = torch.ones(100,2)
x0 = torch.normal(2*n_data,1)
y0 = torch.zeros(100,1)
x1 = torch.normal(-2*n_data,1)
y1 = torch.ones(100,1)

x = torch.cat((x0,x1),0)
# 在分类问题中标签必须用一维tensor,回归中则没有这个要求
y = torch.cat((y0,y1),0).reshape(-1)
# 在分类问题中标签还需要用torch.LongTensor类型
# 将张量 y 的类型转换为 long,这是因为在 PyTorch 中,分类问题的标签通常是整数类型(long),以便与模型输出的类别概率进行比较,从而计算损失。
y = y.long()


fig = plt.figure()
plt.scatter(x.data.numpy()[:,0],x.data.numpy()[:,1],c=y.data.numpy())
# 给画出来的每一个点标上标签,有点难看,注了吧
# 循环遍历每个数据点,根据其对应的标签添加标签文本
for i in range(len(x)):
    plt.text(x[i][0], x[i][1], str(int(y[i].item())), fontsize=8)
plt.show()

输出:
在这里插入图片描述

2 快速搭建网络

## 搭建网络method1
# class Net(torch.nn.Module):
#     def __init__(self,n_features,n_hidden,n_output):
#         # 继承原来结构体的全部init属性及方法
#         super(Net,self).__init__()
#         # 线性层就是全连接层
#         self.hidden = torch.nn.Linear(n_features,n_hidden)
#         self.predict = torch.nn.Linear(n_hidden,n_output)
#         
#     def forward(self,x):
#         # 重写继承类的向前传播方法,就是在这个里面选择激活函数的
#         x = F.relu(self.hidden(x))
#         # 分类中输出层也可以不用激活函数,我们最后在对输出结果进行softmax处理
#         x = self.predict(x)
#         return x
#         
# net = Net(2,10,2)
# # 输出层定义2个输出,对输出在进行softmax处理,取出概率最大的元素的下标就是我们分类的类别;与回归有所不同
# # 有点类似机器学习里面的独热编码
# print(net)


## 快速搭建法,和前面注释掉的效果是一样的。
net = torch.nn.Sequential(
    torch.nn.Linear(2,10),
    torch.nn.ReLU(), # 这里激活函数大写了要
    torch.nn.Linear(10,2)
)
print(net)

输出:
在这里插入图片描述

3 训练、输出结果

optimizer = torch.optim.SGD(net.parameters(),lr=0.02)
# 分类用交叉熵损失函数
loss_func = torch.nn.CrossEntropyLoss()

# 开启matplotlib的交换模式
plt.ion()
for t in range(100):
    # 这一步其实是调用了类里面的 __call__魔术方法,又学到一个魔术方法
    out = net(x)
    loss = loss_func(out,y)
    # 梯度清零
    optimizer.zero_grad()
    # 误差反向传播,求梯度
    loss.backward()
    # 进行优化器优化
    optimizer.step()
    if t%5 == 0:
        plt.cla()
        prediction = torch.max(F.softmax(out,1),1)[1]
        pred_y = prediction.data.numpy().reshape(-1)
        target_y = y.data.numpy().reshape(-1)
        plt.scatter(x.data.numpy()[:,0],x.data.numpy()[:,1],c=pred_y)
        accuracy = sum(pred_y==target_y)/200
        plt.text(1.2,-4,'accuracy=%.2f' % accuracy, fontdict={'size':20,'color':'red'})
        plt.pause(0.1)
# 关闭matplotlib的交换模式
plt.ioff()
plt.show()

输出:
在这里插入图片描述

# 输出out经softmax处理过后才变成概率
out2probability = F.softmax(out,1)
#print(out2probability.round(decimals=2))
# 取出概率向量里面概率最大的下标就是最终的分类结果
prediction = torch.max(F.softmax(out,1),1)[1]
print(prediction)在这里插入代码片

输出:
在这里插入图片描述

总结

选择那种方法搭建,看个人喜好,效果完全一样。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1600431.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SQL刷题---2021年11月每天新用户的次日留存率

解题思路: 1.首先算出每个新用户注册的日期,将其命名为表a select uid,min(date(in_time)) dt from tb_user_log group by uid2.计算出每个用户登录的天数,将其命名为表b select uid,date(in_time) dt from tb_user_log union select uid,date(out_time) dt fro…

【Windows10】Anaconda3安装+pytorch+tensorflow+pycharm

文章目录 一、下载anaconda0.双击下载的文件1. 选择All users2. 安装路径3. 勾选环境变量和安装python4.安装完成5.添加环境变量6.测试是否安装成功 二、安装pytorch(先看四!先检查一下自己电脑是不是只能安装GPU版的1.查看conda图形化界面2.在安装pytor…

PHP-extract变量覆盖

[题目信息]: 题目名称题目难度PHP-extract变量覆盖1 [题目考点]: 变量覆盖指的是用我们自定义的参数值替换程序原有的变量值,一般变量覆盖漏洞需要结合程序的其它功能来实现完整的攻击。 经常导致变量覆盖漏洞场景有:$$&#x…

【Git】安装 Git

文章目录 1. CentOS 下安装2. Ubuntu 下安装 Git 是开放源代码的代码托管工具,最早是在 Linux 下开发的。开始也只能应用于 Linux 平台,后面慢慢的被移植到 Windows 下。现在,Git 可以在 Linux、Unix、Mac 和 Windows 这几大平台上正常运行了…

RabbitMQ的简单

前言 RabbitMQ是一套开源(MPL)的消息队列服务软件,是由 LShift 提供的一个 Advanced Message Queuing Protocol (AMQP) 的开源实现,由以高性能、健壮以及可伸缩性出名的 Erlang 写成。 目录 介绍 RabbitMQ系统结构 RabbitMQ成员…

Flutter 插件站新升级: 加入优秀 GitHub 开源项目

Flutter 插件站新升级: 加入优秀 GitHub 开源项目 视频 https://youtu.be/qa49W6FaDGs https://www.bilibili.com/video/BV1L1421o7fV/ 前言 原文 https://ducafecat.com/blog/flutter-awesome-github-repo-download 这几天晚上抽空把 Flutter 插件站升级,现在支…

超越GPT-4V,苹果多模态大模型上新,神经网络形态加速MLLM(一)

4月8日,苹果发布了其最新的多模态大语言模型(MLLM )——Ferret-UI,能够更有效地理解和与屏幕信息进行交互,在所有基本UI任务上都超过了GPT-4V! 苹果开发的多模态模型Ferret-UI增强了对屏幕的理解和交互&am…

智谱AI通用大模型:官方开放API开发基础

目录 一、模型介绍 1.1主要模型 1.2 计费单价 二、前置条件 2.1 申请API Key 三、基于SDK开发 3.1 Maven引入SDK 3.2 代码实现 3.3 运行代码 一、模型介绍 GLM-4是智谱AI发布的新一代基座大模型,整体性能相比GLM3提升60%,支持128K上下文&#x…

实用图像视频修复工具:完善细节、提高分辨率 | 开源日报 No.225

xinntao/Real-ESRGAN Stars: 25.6k License: BSD-3-Clause Real-ESRGAN 是一个旨在开发实用的图像/视频恢复算法的项目。 该项目主要功能、关键特性和核心优势包括: 提供动漫视频小模型和动漫插图模型支持在线 Colab 演示和便携式 Windows/Linux/MacOS 可执行文件…

Flex弹性盒子布局案例(认识弹性布局)

一、导航菜单 此示例创建了一个水平导航菜单&#xff0c;其中链接在 Flex 容器中等距分布。 HTML结构&#xff1a; <nav class"nav-menu"><a href"#">Home</a><a href"#">About</a><a href"#">…

【云计算】混合云分类

《混合云》系列&#xff0c;共包含以下 3 篇文章&#xff1a; 【云计算】混合云概述【云计算】混合云分类【云计算】混合云组成、应用场景、风险挑战 &#x1f60a; 如果您觉得这篇文章有用 ✔️ 的话&#xff0c;请给博主一个一键三连 &#x1f680;&#x1f680;&#x1f68…

【Java】spring+springmvc+hibernate最完整整合配置教程

附Maven的依赖 <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"><mo…

【一竞技CS2】VP战队官宣签下electroNic取代mir

1、近日VP战队官宣签下electroNic&#xff0c;以取代阵容中的mir。 electroNic自己也表示&#xff1a;“VP是一支顶级队伍。阵容核心曾赢得Major冠军&#xff0c;所有队员都处于巅峰状态并且时刻准备着去争夺冠军。我们有着一样的雄心壮志。 此外我还对和Jame很感兴趣&#xf…

AUTOCAD输出或打印PDF文件时,如何将图形居中且布满图纸?

AUTOCAD输出或打印PDF文件时,如何将图形居中且布满图纸? 如下图所示,我们打开一份DWG格式的图纸文件,然后点击上方的“打印“图标, 如下图所示, 打印机/绘图仪这里选择“DWG To PDF“; 图纸尺寸:这里以普通的A4纸为例进行说明; 打印比例选择“布满图纸“; 打印偏移…

【Redis 神秘大陆】005 常见性能优化方式

五、Redis 性能优化 5.1 系统层面的优化 https://github.com/sohutv/cachecloud/blob/main/redis-ecs/script/cachecloud-init.sh initConfig() {# 支持虚拟内存分配sysctl vm.overcommit_memory1# 最大排队连接数设置为 511&#xff0c;一般默认是 128echo 511 >/proc/sy…

免费SSL证书和付费SSL证书的区别和申请

免费SSL证书和付费SSL证书的区别点还是比较多的。对来说免费证书适用的环境会单一一些&#xff0c;一般使用免费证书的环境都是个人门户网站或者是小微企业的门户官网&#xff08;无隐私信息&#xff09;。受免费证书安全等级以及安全性的限制影响&#xff0c;如果是为了自身网…

RISC-V微架构验证

对于RISC-V处理器因其灵活性和可扩展性而受到广泛关注&#xff0c;但如果没有高效验证策略&#xff0c;错误的设计实现可能会影响RISC-V的继续推广。 在RISC-V出现之前&#xff0c;对于大多数半导体公司来说&#xff0c;处理器验证几乎成为一门屠龙之技。专业知识被浓缩到少数几…

C语言基础入门案例(3)

目录 第一题&#xff1a;一维数组的最大值和最小值求解 第二题&#xff1a;求一维数组中的第二大的数 第三题&#xff1a;计算5个整数的平均值 第四题&#xff1a;查找整数在数组中的索引位置 第五题&#xff1a;统计字符串中数字字符的个数 第一题&#xff1a;一维数组的…

vue的就地更新与v-for的key属性

vue的就地更新 Vue中的就地更新到底是怎么回事&#xff0c;为什么会存在就地更新的现象&#xff1f; 注意下面的例子&#xff0c;使用v-for指令时&#xff0c;没有绑定key值&#xff0c;才有就地更新的现象&#xff0c;因为Vue默认按照就地更新的策略来更新v-for渲染的元素列表…

【echarts】使用 ECharts 绘制3D饼图

使用 ECharts 绘制3D饼图 在数据可视化中&#xff0c;饼图是表达数据占比信息的常见方式。ECharts 作为一个强大的数据可视化库&#xff0c;除了标准的二维饼图&#xff0c;也支持更加生动的三维饼图绘制。本文将指导你如何使用 ECharts 来创建一个3D饼图&#xff0c;提升你的…