《Pytorch深度学习和图神经网络(卷 1)》学习笔记——第三章

news2024/11/23 14:53:24

学习基于如下书籍,仅供自己学习,用来记录回顾,非教程。

<PyTorch深度学习和图神经网络(卷 1)——基础知识>一书配套代码:
https://github.com/aianaconda/pytorch-GNN-1st
百度网盘链接:https://pan.baidu.com/s/1WjggntFWod0CQh6_y77l4w
提取码:jqtq
压缩包密码:dszn

该章为实例,不做具体分析,只叙述大致流程。

先利用sklearn生成如下数据集。我将随机种子设置为404
在这里插入图片描述
然后开始搭建网络,该网络模型共三层,输入、隐藏、输出。第一层为两个输入,即为点的横纵坐标,第二层为任意设置的神经元个数用来拟合,第三层为2个输出,即为两种类别。
在这里插入图片描述
在这里插入图片描述
第一次训练便出现了欠拟合的问题,进入了局部最优解,由于隐藏层只有3个神经元,将其改为5便得到了很好的拟合。
在这里插入图片描述

在这里插入图片描述
但是其第200到600轮损失很高,明显陷入局部最优,让我们看看200轮是什么效果。
在这里插入图片描述
果然模型欠拟合,增加更多的隐藏层的神经元试试。下图为10个的loss图像,在200轮后,新增加的神经元发挥了作用,能够继续拟合。得到了不错的效果。
在这里插入图片描述
在这里插入图片描述

此图为随机种子为0时,隐藏层神经元个数为100时的有些过拟合的效果图。
在这里插入图片描述

具体代码分析

利用sklearn.datasets中的make_moons函数生成两组半月数据,并加入0.2的噪声。
X为有两个元素的列表,表示横纵坐标。
Y为该坐标对应的标签类别。

np.random.seed(404)           #设置随机数种子
X, Y = sklearn.datasets.make_moons(200,noise=0.4) #生成2组半圆形数据

然后分别获取类别0和1对应的索引,这样X[arg,0],X[arg,1]则为Y对应索引的X的横纵坐标。

arg = np.squeeze(np.argwhere(Y==0),axis = 1)     #获取第1组数据索引
arg2 = np.squeeze(np.argwhere(Y==1),axis = 1)    #获取第2组数据索引

用matplotlib.pyplot绘制出来

plt.title("moons data")
plt.scatter(X[arg,0], X[arg,1], s=100,c='b',marker='+',label='data1')
plt.scatter(X[arg2,0], X[arg2,1],s=40, c='r',marker='o',label='data2')
plt.legend() #显示类别的label
plt.show()

定义网络模型。共三层,输入、隐藏、输出层。都为线性层。forward完后输出x作为预测值,输入进predict和getloss

class LogicNet(nn.Module):
    def __init__(self,inputdim,hiddendim,outputdim):#初始化网络结构
        super(LogicNet,self).__init__()
        self.Linear1 = nn.Linear(inputdim,hiddendim) #定义全连接层
        self.Linear2 = nn.Linear(hiddendim,outputdim)#定义全连接层
        self.criterion = nn.CrossEntropyLoss() #定义交叉熵函数

    def forward(self,x): #搭建用两层全连接组成的网络模型
        x = self.Linear1(x)#将输入数据传入第1层
        x = torch.tanh(x)#对第一层的结果进行非线性变换
        x = self.Linear2(x)#再将数据传入第2层
#        print("LogicNet")
        return x

    def predict(self,x):#实现LogicNet类的预测接口
        #调用自身网络模型,并对结果进行softmax处理,分别得出预测数据属于每一类的概率
        pred = torch.softmax(self.forward(x),dim=1)
        return torch.argmax(pred,dim=1)  #返回每组预测概率中最大的索引

    def getloss(self,x,y): #实现LogicNet类的损失值计算接口
        y_pred = self.forward(x)
        loss = self.criterion(y_pred,y)#计算损失值得交叉熵
        return loss

定义网络模型

model = LogicNet(inputdim=2,hiddendim=10,outputdim=2)#初始化模型
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)#定义优化器

模型训练

xt = torch.from_numpy(X).type(torch.FloatTensor)#将Numpy数据转化为张量
yt = torch.from_numpy(Y).type(torch.LongTensor)
epochs = 1000 #定义迭代次数
losses = [] #定义列表,用于接收每一步的损失值
for i in range(epochs):
    loss = model.getloss(xt,yt)
    losses.append(loss.item())
    optimizer.zero_grad()#清空之前的梯度
    loss.backward()#反向传播损失值
    optimizer.step()#更新参数
plot_losses(losses)

可视化训练结果,定义如下函数

def moving_average(a, w=10):#定义函数计算移动平均损失值
    if len(a) < w:
        return a[:]
    return [val if idx < w else sum(a[(idx-w):idx])/w for idx, val in enumerate(a)]

def plot_losses(losses):
    avgloss= moving_average(losses) #获得损失值的移动平均值
    plt.figure(1)
    plt.subplot(211)
    plt.plot(range(len(avgloss)), avgloss, 'b--')
    plt.xlabel('step number')
    plt.ylabel('Training loss')
    plt.title('step number vs. Training loss')
    plt.show()

使用及评估模型
计算准确率accuracy

from sklearn.metrics import accuracy_score
print(accuracy_score(model.predict(xt),yt))

可视化模型
数据为二维数组,可以在直角坐标系中进行可视化。
定义函数plot_decision_boundary()

def predict(model,x):   #封装支持Numpy的预测接口
    x = torch.from_numpy(x).type(torch.FloatTensor)
    ans = model.predict(x)
    return ans.numpy()

def plot_decision_boundary(pred_func,X,Y):#在直角坐标系中可视化模型能力
    #计算取值范围
    x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
    y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
    h = 0.01
    #在坐标系中采用数据,生成网格矩阵,用于输入模型
    xx,yy=np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    #将数据输入并进行预测
    Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    #将预测的结果可视化
    plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
    plt.title("Linear predict")
    arg = np.squeeze(np.argwhere(Y==0),axis = 1)
    arg2 = np.squeeze(np.argwhere(Y==1),axis = 1)
    plt.scatter(X[arg,0], X[arg,1], s=100,c='b',marker='+')
    plt.scatter(X[arg2,0], X[arg2,1],s=40, c='r',marker='o')
    plt.show()
    
plot_decision_boundary(lambda x : predict(model,x) ,xt.numpy(), yt.numpy())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/684065.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vite优化

1.利用 rollup-plugin-analyzer 插件进行进行代码体积分析&#xff0c;从而优化你的代码。 根据项目体积分析&#xff0c;进行接下来的优化&#xff1a; &#xff08;一&#xff09;使用unplugin-vue-components插件按需加载antd vue 组件&#xff1a; 使用步骤 1、安装插件…

6.18 、Java初级:锁

1 同步锁 1.1 前言 经过前面多线程编程的学习,我们遇到了线程安全的相关问题,比如多线程售票情景下的超卖/重卖现象. 上节笔记点这里-进程与线程笔记 我们如何判断程序有没有可能出现线程安全问题,主要有以下三个条件: 在多线程程序中 有共享数据 多条语句操作共享数据 多…

GPT-4 的创造力全方位持平或碾压人类 | 一项最新研究发现

文章目录 一、前言二、主要内容三、总结 &#x1f349; CSDN 叶庭云&#xff1a;https://yetingyun.blog.csdn.net/ 一、前言 最近&#xff0c;一项有关 GPT-4 的创造力思维测试火了。来自蒙大拿大学和 UM Western 大学的研究团队发现&#xff0c;GPT-4 在 Torrance 创造性思维…

Sharding-JDBC之RangeShardingAlgorithm(范围分片算法)

目录 一、简介二、maven依赖三、数据库3.1、创建数据库3.2、创建表 四、配置&#xff08;二选一&#xff09;4.1、properties配置4.2、yml配置 五、范围分片算法六、实现6.1、实体层6.2、持久层6.3、服务层6.4、测试类6.4.1、保存订单数据6.4.2、根据时间范围查询订单 一、简介…

还在等待本地渲染?云渲染才是真的省时省心又省钱!

可能很多设计师会觉得本地渲染效果图或动画更灵活&#xff0c;而且没有什么额外的附加费用&#xff0c;但其实不然&#xff01;当你面对多个大型或紧急的项目时&#xff0c;本地渲染就“慌”了。 接下来我将全面对比“本地渲染”和“云渲染”&#xff0c;相信还在等待本地渲染…

黑客常用cmd命令(window版)

1、ping命令 ping命令是一个常用的网络工具&#xff0c;用来测试和诊断网络连接状况。通过发送ICMP&#xff08;Internet控制消息协议&#xff09;数据包到目标主机&#xff0c;并接收回复的数据包&#xff0c;可以测量目标主机的可达性、平均响应时间等指标。 在Windows操作…

前后端实现:行为验证码---文字点选

最近接到一个新的需求&#xff0c;由于客户是内网&#xff0c;你能使用腾讯的验证码了&#xff0c;需要改为前后端实现。 具体的代码已经提交git 项目效果图&#xff1a; 使用的技术栈&#xff1a;vitevue3ts git地址&#xff1a;https://github.com/susanliy/point_captcha…

TCP/IP协议是什么?

78. TCP/IP协议是什么&#xff1f; TCP/IP协议是一组用于互联网通信的网络协议&#xff0c;它定义了数据在网络中的传输方式和规则。作为前端工程师&#xff0c;了解TCP/IP协议对于理解网络通信原理和调试网络问题非常重要。本篇文章将介绍TCP/IP协议的概念、主要组成部分和工…

《程序喵》项目跨域问题解决思路

跨域问题&#xff1a;由于浏览器的 同源策略 限制&#xff0c;当一个请求url的协议、域名、端口号三者之间有任意一个与当前的url不同即为跨域。 同源策略是一种约定&#xff0c;它是浏览器中最核心也最基本的安全功能。同源策略会阻止一个域的 Javascript 脚本和另一个域的内…

举例说明梯度下降算法与最小二乘法的区别

梯度下降算法和最小二乘法都是用于求解线性回归问题中参数的优化方法。我们可以通过一个简单的例子来说明它们之间的区别。 假设我们有以下数据点&#xff1a;(1, 2)&#xff0c;(2, 3)&#xff0c;(3, 4)&#xff0c;(4, 5)&#xff0c;我们希望找到一条最佳拟合线 y wx b&a…

Android 中Looper机制详解

版本基于&#xff1a;Android R 0. 前言 在《Android 基于Handler 剖析消息机制》一文中&#xff0c;以 Handler 类为起点详细分析了异步通信&#xff0c;分析了Java 端 Handler 与Looper、MessageQueue、Message 之前的通信关系。 框架如下&#xff1a; 在Java 端的 Looper …

2. IO 流原理及流的分类

2.1 Java IO 原理 • Java 程序中&#xff0c;对于数据的输入/输出操作以“流(stream)” 的方式进行&#xff0c;可以看做是一种数据的流动。 • I/O 流中的 I/O 是 Input/Output 的缩写&#xff0c; I/O 技术是非常实用的技术&#xff0c;用于处理设备之间的数据传输。如读/写…

浅谈 Cache

1. Cache的历史 在科研领域&#xff0c;C. J. Conti等人于1968年在描述360/85和360/91系统性能差异时最早引入了高速缓存&#xff08;cache&#xff09;一词。Alan Jay Smith于1982年的一篇论文中引入了空间局部性和时间局部性的概念。 Mark Hill在1987年发明了3C&#xff08…

OpenCV项目开发实战--实现平均脸功能教程附(C++/Python)源码实现

文末附基于Python和C++两种方式实现的测试代码下载链接 图 1:计算生成的平均脸 介绍 在本教程中,我们将学习如何使用 OpenCV (C++ / Python) 创建平均面孔。 大多数人会同意图 1 中的女人很漂亮。你能猜出她的种族吗?为什么她的皮肤完美无瑕?好吧,她不是真的。她也不是完…

如何识别手写笔记?这些方法助你快速制作电子版笔记

小张是一名大学生&#xff0c;每天需要上多门课程&#xff0c;整理笔记就成了他不得不常常面对的事情&#xff0c;但是&#xff0c;手写笔记管理起来也有些麻烦&#xff0c;有时候还容易丢失。所以在朋友的推荐下&#xff0c;他使用了一款识别软件并将手写笔记转化为可编辑的电…

推荐好用的AI工具集

AI技术未来已来&#xff0c;我们要拥抱变化 &#xff0c;笔记试用好用AI工具&#xff0c;也在代码中试用chatGPT 一、工具集 解决任何问题&#xff1a;ChatGPT 写文案&#xff1a;Jasper Al 、Copysmith 生成真人视频&#xff1a;Synthesia、 CogView2 AI AI 解决法律问题…

如何对加密字段进行模糊查询

当我们在日常开发中设置数据表时&#xff0c;经常需要定义一些敏感字段&#xff0c;如&#xff1a;身份证号、手机号、住址、账号密码等信息&#xff0c;对待这些敏感信息&#xff0c;我们必须进行加密存储&#xff0c;以保证数据存储安全。但是&#xff0c;当用户查询个人信息…

DEV-C++安装OpenGL详细教程

Dev C配置OpenGL环境——计算机图形学 一、首先自行下载dev-c 二、以下过程请认真阅读~ 确保你的C:\Windows\System32与C:\Windows\SysWOW64中有上述链接中的.dll文件(即&#xff1a;glut.dll,glut32.dll)确保你的~\Dev-CPP\MinGW64\x86_64-w64-mingw32\lib中有上述链接中的…

Mybatis源码分析_解析大流程梳理_解析配置文件 (3)

学习mybatis&#xff0c;绕不开一个核心类 Configuration。这个类相当于一个小型数据库&#xff0c;把mybatis里面所有的配置信息基本全部给存储起来了。 package org.apache.ibatis.session;import java.util.Arrays; import java.util.Collection; import java.util.HashMap;…

常见的网络抓包工具推荐

因为发现好多人想抓包&#xff0c;但是不知道有哪些工具&#xff0c;今天我给大家推荐几款抓包工具&#xff0c;希望对大家有所帮助。 网络抓包工具的用途 网络抓包工具的主要功能是将网络执行的过程&#xff0c;详细的记录下来。如果你是一个程序员&#xff0c;肯定对网络抓…