python pytorch 纯算法实现前馈神经网络训练(数据集随机生成)

news2025/1/10 13:45:08

python pytorch 纯算法实现前馈神经网络训练(数据集随机生成)

下面这个代码大家可以学习学习,这个代码难度最大的在于反向传播推导, 博主推了很久,整个过程都是纯算法去实现的,除了几个激活函数,可以学习一下下面的代码。

我下面这个代码还是很严谨的,从数据集生成,损失函数,网络结构、梯度求导,优化器等等组件都实现了。


#coding=gbk

import torch
from torch.autograd import Variable
from torch.utils import data
import matplotlib.pyplot as plt


dim=5
batch=32
neuron_num=10
def generate_data():
    torch.manual_seed(3)
    X1=torch.randint(0,5,(1000,dim))
    X2=torch.randint(5,10,(1000,dim))

    Y1=torch.randint(0,1,(1000,))
    Y2=torch.randint(1,2,(1000,))
    print(X1)
    print(X2)
    print(Y1)
    print(Y2)
    X_data=torch.cat([X1,X2],0)
    Y_label=torch.cat([Y1,Y2],0)
    print(X_data)
    print(Y_label)
    return X_data,Y_label

def sampling(X_data,Y_label,batch):
    data_size=Y_label.size()
    #print(data_size)
    index_sequense=torch.randperm(data_size[0])
    return index_sequense


def loss_function_crossEntropy(Y_predict,Y_real):
    if Y_real==1:
        return -torch.log(Y_predict)
    else:
         return -torch.log(1-Y_predict)



X_data,Y_label=generate_data()
index_sequense=sampling(X_data,Y_label,batch)



def test():
    l=loss_function_crossEntropy(torch.tensor([0.1]),torch.tensor([1]))
    print(l)



def neuron_net(X,W,b):
    result=torch.matmul(X.type(dtype=torch.float32),W)+b
    result=torch.relu(result).reshape(1,result.size(0))

    #print(result)
    
    #print(result.size())
    return result


def grad(X,W,b,y_predict,y_real,W2,b2):
    g1=y_real/y_predict+(y_real-1)/(1-y_predict)
    g2=y_predict*(1-y_predict)
    g3=neuron_net(X,W,b)
    g4=W2
    C=torch.matmul(X.type(dtype=torch.float32),W)+b
    a=[]
    for i in C:
        if i<=0:
            a.append(0)
        else:
            a.append(1)
    g5=torch.tensor(a)

    g6=X
    grad_w=g1*g2*g3
    grad_b=g1*g2
    #print("grad_w",grad_w)
    #print(grad_b)
    grad_w2=g1*g2*g4
    grad_w2=grad_w2.reshape(1,10)
   
    grad_w2=grad_w2*g5
  #  print(grad_w2.size())
    grad_w2=grad_w2.reshape(10,1)
  
    g6=g6.reshape(1,5)
    
    grad_b2=grad_w2
    grad_w2=torch.matmul(grad_w2.type(dtype=torch.float32),g6.type(dtype=torch.float32))
   # print(grad_b2.size())

    return grad_w,grad_b,grad_w2,grad_b2

    #print(g1,g2,g3,g4,g5,g6)
    
    #print(grad_w2)
    #print(grad_b2)
  
   

def flat_dense(X,W,b):
    return torch.sigmoid(torch.matmul(X.type(dtype=torch.float32),W)+b)


W=torch.randn(dim,neuron_num)
b=torch.randn(neuron_num)
W2=torch.randn(neuron_num,1)
b2=torch.randn(1)


def net(X,W,b,W2,b2):
    result=neuron_net(X,W,b)

    ans=flat_dense(result,W2,b2)
    return  ans


y_predict=net(X_data[0],W,b,W2,b2)
print(y_predict)

grad_w,grad_b,grad_w2,grad_b2=grad(X_data[0],W,b,y_predict,Y_label[0],W2,b2)

loss_list=[]
learn_rating=0.01
epoch=10000
def train():
    
    index=0
    global W,W2,b,b2
    for i in range(epoch):
       
        W_g=torch.randn(dim,neuron_num)
        b_g=torch.randn(neuron_num)
        W2_g=torch.randn(neuron_num,1)
        b2_g=torch.randn(1)
        loss=torch.tensor([0.0])
        co=0
        for j in range(32):
            try:
                y_predict=net(X_data[index],W,b,W2,b2)
                grad_w,grad_b,grad_w2,grad_b2=grad(X_data[index],W,b,y_predict,Y_label[index],W2,b2)
             #   print(grad_w2.size(),W_g.size())
                grad_w2=torch.t(grad_w2)
                W_g=W_g+grad_w2
                grad_b2=grad_b2.reshape(10)
                #print("b_g",b_g)
                #print("grad_b2",grad_b2)
                b_g=grad_b2+b_g
             
                W2_g=W2_g+torch.t(grad_w)
                b2_g=b2_g+torch.t(grad_b)
                
             #   print("fdafaf",grad_w,grad_b,grad_w2,grad_b2)
                loss=loss+loss_function_crossEntropy(y_predict,Y_label[index])
               # print( Y_label[index],y_predict[0][0])
                if (Y_label[index]==1) &( y_predict[0][0]>0.5):
                    co=co+1
                if (Y_label[index]==0) &( y_predict[0][0]<=0.5):
                    co=co+1
                index=index+1
            except:
                index=0

        print("loss:",loss[0])
        print("accuracy:",co/32)
        loss_list.append(loss[0])
        W_g=W_g/batch
        b_g=b_g/batch
        W2_g=W2_g/batch
        b2_g=b2_g/batch
        #print(W.size())
        #print(b.size())
        #print(W2.size())
        #print(b2.size())

        W=W+learn_rating*W_g
     #   print("b*********************",b,b_g)
        b=b+learn_rating*(b_g)
        #print(W2_g.size())
        #print(b2_g.size())
        W2=W2+learn_rating*W2_g
        b2=b2+learn_rating*b2_g
        #print(W.size())
        #print(b.size())
        #print(W2.size())
        #print(b2.size())



train()
epoch_list=list(range(epoch))
plt.plot(epoch_list,loss_list,label='SGD')
plt.title("loss")
plt.legend()
plt.show()

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/733542.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【前端知识】React 基础巩固(十八)——组件化开发(二)

React 基础巩固(十八)——组件化开发&#xff08;二&#xff09; 生命周期 生命周期是一个抽象的概念&#xff0c;在生命周期的整个过程中&#xff0c;分成了很多个阶段 比如装载阶段&#xff08;Mount&#xff09;&#xff0c;组件第一次在 DOM 树中被渲染的过程比如更新过程…

【C语言】你知道浮点数是怎么存储的吗?

前言 &#x1f388;大家好&#xff0c;我是何小侠&#x1f388; &#x1f343;大家可以叫我小何或者小侠&#x1f343; &#x1f490;希望能通过写博客加深自己对于学习内容的理解&#x1f490; &#x1f338;也能帮助更多人理解和学习&#x1f338; 积学以储宝&#xff0c;酌…

Debian 11 x64 安装 MySQL 8.0.33

更新 sudo apt update sudo apt install gnupg安装 DEB Package wget -c https://dev.mysql.com/get/mysql-apt-config_0.8.25-1_all.deb sudo dpkg -i mysql-apt-config_0.8.25-1_all.deb具体版本见官方网站&#xff1a;MySQL Community Downloads&#xff0c;这里仅以版本 …

详解什么是新零售和新零售的四种商业模式

前言 自推出新零售概念以来&#xff0c;新零售已成为当前的热门话题。今天我们将进一步了解什么是新零售。 一、什么是新零售? 新零售&#xff0c;英文是New Retailing&#xff0c;即企业以互联网为依托&#xff0c;通过运用大数据、人工智能等先进技术手段&#xff0c;对商…

VMware虚拟机里的Ubuntu通过主机的代理联网

问题描述&#xff1a;主机win10&#xff0c;通过代理联网。主机里装有VMware的虚拟机Ubuntu&#xff0c;想要通过主机的代理进行上网。 步骤&#xff1a; 1 将虚拟机的网络设置为NAT模式。 2 在win10命令行中输入ipconfig&#xff0c;查询ipv4的局域网地址。&#xff08;注&…

使用docker安装Nacos,远程连接nacos报错,please check server x.x.x.x ,port 9848 is available

报错: please check server 127.0.0.1 ,port 9848 is available 原因: 当nacos客户端升级为2.x版本后&#xff0c;新增了gRPC的通信方式&#xff0c;新增了两个端口。这两个端口在nacos原先的端口上(默认8848)&#xff0c;进行一定偏移量自动生成.。 当客户端升级成2.x版本时&…

[工业互联-20]:常见EtherCAT主站方案:TwinCAT的Windows 解决方案

目录 第1章 TwinCAT简介 第2章 软件架构 第3章 应用程序架构 第1章 TwinCAT简介 TwinCAT是由德国Beckhoff公司开发的一套功能强大的自动化软件平台。 它是一个集成的开发环境&#xff0c;用于实现实时控制、PLC编程、运动控制、HMI&#xff08;人机界面&#xff09;设计和…

service 2 暴露服务的 3种 方式

【k8s 系列】k8s 学习十九&#xff0c;service 2 之前我们简单的了解一下 k8s 中 service 的玩法&#xff0c;今天我们来分享一下 service 涉及到的相关细节&#xff0c;我们开始吧 为什么要有 服务 Service&#xff1f; 因为服务可以做到让外部的客户端不用关心服务器的数量…

【二叉树part09】| 669.修剪二叉搜索树、108.将有序数组转换为二叉搜索树、538.把二叉搜索树转换为累加树

目录 &#x1f388;LeetCode669. 修剪二叉搜索树 &#x1f388;LeetCode108.将有序数组转换为二叉搜索树 &#x1f388;LeetCode538.把二叉搜索树转换为累加树 &#x1f388;LeetCode669. 修剪二叉搜索树 链接&#xff1a;669.修剪二叉搜索树 给你二叉搜索树的根节点 root…

使用Go 语言的三个原因

几个星期前&#xff0c;我一个朋友问我&#xff1a;“为什么要关心 Go 语言”&#xff1f; 因为他们知道我热衷于 Go 语言&#xff0c;但他们想知道为什么我认为其他人也应该关心。有三个原因&#xff1a;安全性、生产力和并发性。有些语言可以涵盖一个也有可能是两个方面&…

代码逐行解析!冠军选手解读锂电池生产温度预测赛事方案

Datawhale干货 作者&#xff1a;鱼佬、骆秀韬&#xff0c;Datawhale成员 本实践是数据挖掘类型的比赛&#xff0c;聚焦于工业场景。实践任务本质上为回归任务&#xff0c;其中会涉及到时序预测相关的知识。 本实践可帮助大家&#xff1a; 快速掌握数据挖掘任务基本流程&#x…

【开源-文章迁移利器】MarkDown本地图片转云端存储脚本-支持目录递归查找转换

从一些笔记软件导出markdown文档后&#xff0c;图片都是本地图片&#xff0c;文档数量过多&#xff0c;用typora一一打开上传图片过于繁琐&#xff0c;特开发一个一键迁移文章图片的脚本&#xff0c;方便markdown文档的迁移。 文章目录 大致需求开源地址设计思路脚本介绍快速使…

蓝桥杯专题-真题版含答案-【大衍数列】【圆周率】【分糖果】【等额本金】

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例点击跳转>软考全系列点击跳转>蓝桥系列 &#x1f449;关于作者 专注于Android/Unity和各种游…

Java csv文件上传下载中的相关转换

目录 一. 需求二. List<Entity>转List<List<String>>2.1 实体类2.2 转换 三. 上传csv文件转List<Map>3.1 csv文件3.2 前台3.3 实体类3.4 转换3.5 效果 一. 需求 &#x1f914;项目中遇到了两个需求 1.查询数据库&#xff0c;得到List<Entity>这…

快速搭建一个美观且易用的 Django 管理后台 —— django-xadmin

Django-xadmin&#xff08;也称为Xadmin&#xff09;是一个第三方的 Django 应用程序&#xff0c;它提供了一系列工具和模板来快速开发基于 Django 的后台管理界面。使用 Django-xadmin 可以用很少的代码就创建出一个强大的、具备实时查看数据、增、删、改等基本操作的 Django …

leetcode-704.二分查找

leetcode-704.二分查找 文章目录 leetcode-704.二分查找一.题目描述二.第1次代码提交(非二分查找)三.第2次代码提交(非二分查找&#xff0c;std::find和std::distance)四.第3次代码提交(二分查找)五.关于C中int型的奇数除以2 一.题目描述 二.第1次代码提交(非二分查找) class …

Linux 学习记录47(QT篇待完成)

Linux 学习记录47(QT篇) 本文目录 Linux 学习记录47(QT篇)一、将资源文件加载到项目1. 将资源文件放到项目下2. 添加到项目 二、信号与槽机制1. 信号与槽机制概念2. 信号与槽概念 三、四、思维导图练习1. main_page.cpp2. main.cpp3. main_page.h4. login.cpp5. login.h 一、将…

Delphi 11必备指南:使用Git集成Python4Delphi的完整步骤

在Delphi中使用Python有很多好处&#xff0c;可以扩展Delphi的功能并利用Python强大的科学计算和数据分析库。但是&#xff0c;为了将Python集成到Delphi中&#xff0c;我们需要安装Python for Delphi (P4D)组件套件。在这篇博客中&#xff0c;我将介绍如何使用Git安装P4D组件套…

N-122基于springboot,vue网上订餐系统

开发工具&#xff1a;IDEA 服务器&#xff1a;Tomcat9.0&#xff0c; jdk1.8 项目构建&#xff1a;maven 数据库&#xff1a;mysql5.7 前端技术 &#xff1a;VueElementUI 服务端技术&#xff1a;springbootmybatisredis 本系统分用户前台和管理后台两部分&#xff0c;…

python_day3_list

数据容器 &#xff1a; list&#xff08;列表&#xff09; tuple&#xff08;元组&#xff09; str&#xff08;字符串&#xff09; set&#xff08;集合&#xff09; dict&#xff08;字典&#xff09; 列表 list name_list [java, c, python] print(name_list) print(type…