搭建一个简单的网络结构(Pytorch实现二分类)

news2024/9/30 7:29:21

搭建一个简单的网络结构(Pytorch实现二分类)

搭建一个神经网络并进行训练的话,大致需要分为三步:

  • 第一步是数据的处理,将数据整理成输入网络结构中合适的格式
  • 第二步是网络的搭建,包括每层网络的结构和前向传播
  • 第三步是网络的训练,包括损失计算,优化器选择,梯度清零,反向传播,梯度优化等

这里我们以简单的二分类来进行一个简单的网络搭建:

一、数据处理

首先来生成含200个样本的数据

#生成200个点的数据集,返回的结果是一个包含两个元素的元组 (X, y),X是点,y是x的分类
X, y =  sklearn.datasets.make_moons(200, noise = 0.2)

绘制出样本的散点图如下图所示:

save_path = "/data/wangweicheng/ModelLearning/SimpleNetWork"
plt.scatter(X[:,0],X[:,1],s=40,c=y,cmap=plt.cm.Spectral)
plt.savefig(f'{save_path}/dataset.png')   #保存生成的图片

在这里插入图片描述

**可以看到我们生成了两类数据,分别用 0 和 1 来表示。**我们接下来将要在这个样本数据上构造一个分类器,采用的是一个很简单的全连接网络,网络结构如下:

在这里插入图片描述

这个网**络包含一个输入层,一个中间层,一个输出层。**中间层包含 3 个神经元,使用的激活函数是 tanh。当然,中间层的神经元越多,分类效果一般越好,但这个 3 层的网络对于我们的样本数据已经足够用了。我们来算一下参数数量:上图中一共有 6 6 = 12 条线,就是 12 个权重,加上 3 2 = 5 个 bias,一共 17 个参数需要训练。

最后我们将样本数据从 numpy 转成 tensor:,后面我们就可以进行网络的搭建了

# 将 NumPy 数组转换为 PyTorch 张量,并指定张量的数据类型
X = torch.from_numpy(X).type(torch.FloatTensor)
y = torch.from_numpy(y).type(torch.LongTensor)

二、网络搭建

# 搭建网络
class BinaryClassifier(nn.Module):
    #初始化,参数分别是初始化信息,特征数,隐藏单元数,输出单元数
    def __init__(self,n_feature,n_hidden,n_output):
        super(BinaryClassifier, self).__init__()
        #输入层到隐藏层的全连接
        self.hidden = torch.nn.Linear(n_feature, n_hidden)
        #隐藏层到输出层的全连接
        self.output = torch.nn.Linear(n_hidden, n_output)

    #前向传播,把各个模块连接起来,就形成了一个网络结构
    def forward(self, x):
        x = self.hidden(x)  
        x = torch.tanh(x)   #激活函数
        x = self.output(x)
        return x
    
      
    def predict(self,x):
        #训练好之后,重新将x输入到网络中,进行测试
        #softmax得到0到1的一个值
        pred = F.softmax(self.forward(x),dim=1)
        ans = []
        #对不同点分别进行判断,如果第一个位置的数值大于第二个位置,就返回0,反之返回1
        for t in pred:
            if t[0]>t[1]:
                ans.append(0)
            else:
                ans.append(1)
        return torch.tensor(ans)

三、网络训练

选择损失函数CrossEntropyLoss,以及梯度优化器 Adam:

#初始化模型
model = BinaryClassifier(2, 3, 2)
#定义损失函数,用于衡量模型预测结果与真实标签之间的差异。
loss_criterion = nn.CrossEntropyLoss()
#定义优化器,优化器(optimizer)用于更新模型的参数,以最小化损失函数并提高模型的性能
# model.parameters() 表示要优化的模型参数,lr=0.01 表示学习率(learning rate)为 0.01,即每次参数更新的步长。
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)

进行迭代,获取当前loss,清除上一次的梯度,进行梯度下降和反向传播,更新参数:

#训练的次数
epochs = 10000
#存储loss
losses = []

for i in range(epochs):
    #得到预测值
    y_pred = model.forward(X)
    #计算当前的损失
    loss = loss_criterion(y_pred, y)
    #添加当前的损失到losses中
    losses.append(loss.item())
    #清除之前的梯度
    optimizer.zero_grad()
    #反向传播更新参数
    loss.backward()
    #梯度优化
    optimizer.step()
    if(i % 500 == 0):
        print('loss: {:.4f}'.format(loss.item()))

查看 训练准确率:

print(accuracy_score(model.predict(X),y))

结果可视化:

def predict(x):
    x = torch.from_numpy(x).type(torch.FloatTensor)
    ans = model.predict(x)
    return ans.numpy()
    
# 画出边框
def plot_decision_boundary(pred_func,X,y):
    # 找到x,y的最大和最小值,并填充一些边框
    x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
    y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
    h = 0.01
    # 生成一个网格
    xx,yy=np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    # 将两个展平后的一维数组 xx.ravel() 和 yy.ravel() 按列连接起来,生成一个二维数组,然后对这个点进行预测
    Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
    # 将预测结果 Z 重塑为与点网格 xx 相同的形状
    Z = Z.reshape(xx.shape)
    # 画出图像,绘制等高线图的代码
    plt.contourf(xx, yy, Z, cmap=plt.cm.YlGnBu)
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.binary)
    plt.savefig(f'{save_path}/dataout.png')
    
    
plot_decision_boundary(lambda x : predict(x) ,X.numpy(), y.numpy())

输出结果:

在这里插入图片描述

结果还是不错的!

完整代码:GitHub

#引入必要的包
#数据处理的包
import matplotlib.pyplot as plt
from sklearn.cluster import SpectralClustering
import sklearn.datasets
import numpy as np
#搭建网络的包
import torch
import torch.nn as nn
import torch.nn.functional as F
#统计分数
from sklearn.metrics import accuracy_score

#生成200个点的数据集,返回的结果是一个包含两个元素的元组 (X, y),X是点,y是x的分类
X, y =  sklearn.datasets.make_moons(200, noise = 0.2)

save_path = "/data/wangweicheng/ModelLearning/SimpleNetWork"
plt.scatter(X[:,0],X[:,1],s=40,c=y,cmap=plt.cm.Spectral)
plt.savefig(f'{save_path}/dataset.png')     #保存生成的图片

# 将 NumPy 数组转换为 PyTorch 张量,并指定张量的数据类型
X = torch.from_numpy(X).type(torch.FloatTensor)
y = torch.from_numpy(y).type(torch.LongTensor)


# 搭建网络
class BinaryClassifier(nn.Module):
    #初始化,参数分别是初始化信息,特征数,隐藏单元数,输出单元数
    def __init__(self,n_feature,n_hidden,n_output):
        super(BinaryClassifier, self).__init__()
        #输入层到隐藏层的全连接
        self.hidden = torch.nn.Linear(n_feature, n_hidden)
        #隐藏层到输出层的全连接
        self.output = torch.nn.Linear(n_hidden, n_output)

    #前向传播,把各个模块连接起来,就形成了一个网络结构
    def forward(self, x):
        x = self.hidden(x)  
        x = torch.tanh(x)   #激活函数
        x = self.output(x)
        return x
    
      
    def predict(self,x):
        #训练好之后,重新将x输入到网络中,进行测试
        #softmax得到0到1的一个值
        pred = F.softmax(self.forward(x),dim=1)
        ans = []
        #对不同点分别进行判断,如果第一个位置的数值大于第二个位置,就返回0,反之返回1
        for t in pred:
            if t[0]>t[1]:
                ans.append(0)
            else:
                ans.append(1)
        return torch.tensor(ans)


#初始化模型
model = BinaryClassifier(2, 5, 2)
#定义损失函数,用于衡量模型预测结果与真实标签之间的差异。
loss_criterion = nn.CrossEntropyLoss()
#定义优化器,优化器(optimizer)用于更新模型的参数,以最小化损失函数并提高模型的性能
# model.parameters() 表示要优化的模型参数,lr=0.01 表示学习率(learning rate)为 0.01,即每次参数更新的步长。
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)

#训练的次数
epochs = 10000
#存储loss
losses = []

for i in range(epochs):
    #得到预测值
    y_pred = model.forward(X)
    #计算当前的损失
    loss = loss_criterion(y_pred, y)
    #添加当前的损失到losses中
    losses.append(loss.item())
    #清除之前的梯度
    optimizer.zero_grad()
    #反向传播更新参数
    loss.backward()
    #梯度优化
    optimizer.step()
    if(i % 500 == 0):
        print('loss: {:.4f}'.format(loss.item()))


#进行预测
print(accuracy_score(model.predict(X),y))

def predict(x):
    x = torch.from_numpy(x).type(torch.FloatTensor)
    ans = model.predict(x)
    return ans.numpy()
    
# 画出边框
def plot_decision_boundary(pred_func,X,y):
    # 找到x,y的最大和最小值,并填充一些边框
    x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
    y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
    h = 0.01
    # 生成一个网格
    xx,yy=np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    # 将两个展平后的一维数组 xx.ravel() 和 yy.ravel() 按列连接起来,生成一个二维数组,然后对这个点进行预测
    Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
    # 将预测结果 Z 重塑为与点网格 xx 相同的形状
    Z = Z.reshape(xx.shape)
    # 画出图像,绘制等高线图的代码
    plt.contourf(xx, yy, Z, cmap=plt.cm.YlOrBr)
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.binary)
    plt.savefig(f'{save_path}/dataout.png')
    
    
plot_decision_boundary(lambda x : predict(x) ,X.numpy(), y.numpy())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1540861.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Neo4j桌面版导入CVS文件

之后会出来一个提示框,而且会跳出相关文件夹: 然后我们将CSV文件放在此目录下: 我们的relation.csv是这样的 参见: NEO4J的基本使用以及桌面版NEO4J Desktop导入CSV文件_neo4j desktop使用-CSDN博客

数学建模体育建模和经济建模国防科大版

目录 6.体育中的数学建模 7.经济学问题中的数学建模 7.1.实物交换模型 7.2.边际效应 7.3.最佳消费选择模型 6.体育中的数学建模 体育科学的研究中,也有大量的数学建模问题,例如:棒球的最佳击球点问题、滑板滑雪赛道的设计、越野自行车比…

基于springboot+vue的旅游推荐系统

博主主页:猫头鹰源码 博主简介:Java领域优质创作者、CSDN博客专家、阿里云专家博主、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战,欢迎高校老师\讲师\同行交流合作 ​主要内容:毕业设计(Javaweb项目|小程序|Pyt…

全网最强JavaWeb笔记 | 万字长文爆肝JavaWeb开发——Web开发介绍

万字长文爆肝黑马程序员2023最新版JavaWeb教程。这套教程打破常规,不再局限于过时的老套JavaWeb技术,而是与时俱进,运用的都是企业中流行的前沿技术。笔者认真跟着这个教程,再一次认真学习一遍JavaWeb教程,温故而知新&…

利用免费 GPU 部署体验大型语言模型推理框架 vLLM

vLLM简介 vLLM 是一个快速且易于使用的 LLM(大型语言模型)推理和服务库。 vLLM 之所以快速,是因为: 最先进的服务吞吐量 通过 PagedAttention 高效管理注意力键和值内存 连续批处理传入请求 使用 CUDA/HIP 图快速模型执行 量…

输入N个整数,输出这个整数两两组合且不重复的所有二元组,要求从小到大输出并且用括号的形式。

输入描述: 第一行输入一个整数N&#xff0c;N<30。 第二行输入N个整数。 输出描述: 按题意输出。 输入样例#: 3 1 2 3 输出样例#: (1,2) (1,3) (2,1) (2,3) (3,1) (3,2) #include <stdio.h>void quicksort(int s[],int min,int max); //快速排序int partitio…

那些王道书里的题目-----计算机网络篇

注&#xff1a;仅记录个人认为有启发的题目 p155 34.下列四个地址块中&#xff0c;与地址块 172.16.166.192/26 不重叠&#xff0c;且与172.16.166.192/26聚合后的地址块不会引入多余地址的是&#xff08;&#xff09; A.172.16.166.192/27 B.172.16.166.128/26 …

53 initrd/initramfs 相关

前言 呵呵 这里主要是 探究一下 根文件系统 相关的东西 以及 附加了一些 系统启动的相关信息 计算机启动 硬件重置寄存器 设置初始化数据 计算机访问 0xffff0, 执行 bios 的代码, bios 选择启动设备, 然后执行 启动设备 boolloader 的代码 bootloader 将 boot.img 加载…

玩具蛇(蓝桥杯)

文章目录 玩具蛇题目描述答案&#xff1a;552dfs 玩具蛇 题目描述 本题为填空题&#xff0c;只需要算出结果后&#xff0c;在代码中使用输出语句将所填结果输出即可。 小蓝有一条玩具蛇&#xff0c;一共有 16 节&#xff0c;上面标着数字 1 至 16。每一节都是一个正方形的形…

GCC制作静态库详解

目录 前言 一.静态动态库区别 二.静态库制作 2.1 库文件命名 三.静态库文件制作 3.1 静态库制作 3.1.1 先获得.o文件 3.1.2 生成静态库文件 3.1.3 删除不必要文件 3.1.4 使用静态库 3.1.5 使用运行运行 前言 带大家快速入门&#xff0c;学会制作静态库。本文详细介绍在Linux系统…

“玩转文本魔法师:Python编程轻松变格式“

Hey小伙伴们&#xff0c;今天我们要一起打造一个文本转换器&#xff0c;就像神奇的魔法棒&#xff0c;能把普通的文字变成各种奇妙的格式&#xff01;想象一下&#xff0c;你的输入是&#xff1a;“Hello, World!”&#xff0c;输出可以是Markdown、HTML或者粗体、斜体的文字&a…

大语言模型(Large Language Model,LLM)简介

1. 什么是大语言模型 它是一种基于深度学习的人工智能模型&#xff0c;它从大量来自书籍、文章、网页和图像等来源的数据中学习&#xff0c;以发现语言模式和规则&#xff0c;如处理和生成自然语言文本。通常&#xff0c;大语言模型含数百亿&#xff08;或更多&#xff09;参数…

外包干了4年,技术退步明显.......

先说一下自己的情况&#xff0c;大专生&#xff0c;19年通过校招进入杭州某软件公司&#xff0c;干了接近4年的功能测试&#xff0c;今年年初&#xff0c;感觉自己不能够在这样下去了&#xff0c;长时间呆在一个舒适的环境会让一个人堕落! 而我已经在一个企业干了四年的功能测…

GJB5000软件配置管理计划模板

1 范围 1.1 标识 本条应描述本文档所适用的系统和软件的完整标识&#xff0c;适用时&#xff0c;包括其标识号、名称、缩略名、版本号和发布号。 1.2 系统概述 本条应概述本文档所适用的系统和软件的用途。它还应描述软件的一般特性&#xff1b;概述软件开发、运行和维护…

【 Mysql8.0 忘记登录密码 可以试试 】

** Mysql8.0 忘记登录密码 可以试试 ** 2024-3-21 段子手168 1、首先停止 mysql 服务 &#xff0c;WIN R 打开运行&#xff0c;输入 services.msc 回车打开服务&#xff0c;找到 mysql 服务&#xff0c;停止。 然后 WIN R 打开运行&#xff0c;输入 CMD 打开控制台终端输…

深度学习绘制热力图heatmap、使模型具有可解释性

思路 获取想要解释的那一层的特征图&#xff0c;然后根据特征图梯度计算出权重值&#xff0c;加在原图上面。 Demo 加上类激活(cam) 可以看到&#xff0c;cam将模型认为有利于分类的特征标注了出来。 下面以ResNet50为例: Trick: 使用 for i in model._modules.items():可以…

springboot做自定义校验注解

目录 自定义校验注解的实现 注意&#xff1a; 首先&#xff0c;我们需要自定义一个校验注解&#xff1a; 注解含义&#xff1a; Target({ElementType.FIELD}) Retention(RetentionPolicy.RUNTIME) Constraint(validatedBy PhoneValidator.class) 校验注解逻辑实现类&a…

数据结构:图的最短路径

目录 一、最短路径的基本概念 二、无权图单源最短路径 三、Dijkstra算法&#xff08;正权图单源&#xff09; 3.1、算法的基本步骤 3.2、算法的实现 3.3、习题思考 3.3.1、网络延迟时间 四、A*算法&#xff08;正权图单源单目标点&#xff09; 4.1、算法的基本概念 4…

阿里必问:Spring源码背后的10大设计奥秘!

如有疑问或者更多的技术分享,欢迎关注我的微信公众号“知其然亦知其所以然”! 各位小米粉丝们,大家好!今天小米要和大家分享的是一个备受关注的话题——“阿里巴巴面试题:Spring源码中的设计模式?”设计模式是软件工程领域中的经典话题,也是技术面试中的常见考点之一。而…

UE5学习日记——Rope Swing 人物与绳索摆动知识准备

rope swing荡绳 比我想的要复杂&#xff0c;目前还没查到简单的做法。本文为查资料的记录&#xff0c;积累后再做一个自己满意的荡绳蓝图。 一、某国外网友的解释 原文 https://forums.unrealengine.com/t/implementing-rope-swing/83098/15 Project Flake - Physics Rope De…