python pytorch 纯算法实现前馈神经网络训练(数据集随机生成)-续

news2024/11/27 14:39:51

python pytorch 纯算法实现前馈神经网络训练(数据集随机生成)-续

上一次的代码博主看了,有两个小问题其实,一个是,SGD优化的时候,那个梯度应该初始化为0,还一个是我并没有用到随机生成batch。
博主修改了代码,并且加入了accuracy图像绘制的代码。代码如下:


#coding=gbk

import torch
from torch.autograd import Variable
from torch.utils import data
import matplotlib.pyplot as plt


dim=5
batch=32
neuron_num=10
def generate_data():
    torch.manual_seed(3)
    X1=torch.randint(0,4,(1000,dim))
    X2=torch.randint(6,10,(1000,dim))

    Y1=torch.randint(0,1,(1000,))
    Y2=torch.randint(1,2,(1000,))
    print(X1)
    print(X2)
    print(Y1)
    print(Y2)
    X_data=torch.cat([X1,X2],0)
    Y_label=torch.cat([Y1,Y2],0)
    print(X_data)
    print(Y_label)
    return X_data,Y_label

def sampling(X_data,Y_label,batch):
    data_size=Y_label.size()
    #print(data_size)
    index_sequense=torch.randperm(data_size[0])
    return index_sequense


def loss_function_crossEntropy(Y_predict,Y_real):
    if Y_real==1:
        return -torch.log(Y_predict)
    else:
         return -torch.log(1-Y_predict)



X_data,Y_label=generate_data()
index_sequense=sampling(X_data,Y_label,batch)



def test():
    l=loss_function_crossEntropy(torch.tensor([0.1]),torch.tensor([1]))
    print(l)



def neuron_net(X,W,b):
    result=torch.matmul(X.type(dtype=torch.float32),W)+b
    result=torch.relu(result).reshape(1,result.size(0))

    #print(result)
    
    #print(result.size())
    return result


def grad(X,W,b,y_predict,y_real,W2,b2):
    g1=y_real/y_predict+(y_real-1)/(1-y_predict)
    result=torch.matmul(X.type(dtype=torch.float32),W)+b
    result=torch.relu(result).reshape(1,result.size(0))
   
    g2=y_predict*(1-y_predict)
    g3=neuron_net(X,W,b)
    g4=W2
    C=torch.matmul(X.type(dtype=torch.float32),W)+b
    a=[]
    for i in C:
        if i<=0:
            a.append(0)
        else:
            a.append(1)
    g5=torch.tensor(a)

    g6=X
    grad_w=g1*g2*g3
    grad_b=g1*g2
    #print("grad_w",grad_w)
    #print(grad_b)
    grad_w2=g1*g2*g4
    grad_w2=grad_w2.reshape(1,10)
   
    grad_w2=grad_w2*g5
  #  print(grad_w2.size())
    grad_w2=grad_w2.reshape(10,1)
  
    g6=g6.reshape(1,5)
    
    grad_b2=grad_w2
    grad_w2=torch.matmul(grad_w2.type(dtype=torch.float32),g6.type(dtype=torch.float32))
   # print(grad_b2.size())

    return grad_w,grad_b,grad_w2,grad_b2

    #print(g1,g2,g3,g4,g5,g6)
    
    #print(grad_w2)
    #print(grad_b2)
  
   

def flat_dense(X,W,b):
    return torch.sigmoid(torch.matmul(X.type(dtype=torch.float32),W)+b)


W=torch.randn(dim,neuron_num)
b=torch.randn(neuron_num)
W2=torch.randn(neuron_num,1)
b2=torch.randn(1)


def net(X,W,b,W2,b2):
    result=neuron_net(X,W,b)

    ans=flat_dense(result,W2,b2)
    return  ans


y_predict=net(X_data[0],W,b,W2,b2)
print(y_predict)

grad_w,grad_b,grad_w2,grad_b2=grad(X_data[0],W,b,y_predict,Y_label[0],W2,b2)

loss_list=[]
accuracy_list=[]
learn_rating=0.01
epoch=2000
def train():
    
    index=0
    global W,W2,b,b2
    for i in range(epoch):
       
        W_g=torch.zeros(dim,neuron_num)
        b_g=torch.zeros(neuron_num)
        W2_g=torch.zeros(neuron_num,1)
        b2_g=torch.zeros(1)
        loss=torch.tensor([0.0])
        co=0
        for j in range(32):
            try:
                y_predict=net(X_data[index_sequense[index]],W,b,W2,b2)
                grad_w,grad_b,grad_w2,grad_b2=grad(X_data[index_sequense[index]],W,b,y_predict,Y_label[index_sequense[index]],W2,b2)
             #   print(grad_w2.size(),W_g.size())
                grad_w2=torch.t(grad_w2)
                W_g=W_g+grad_w2
                grad_b2=grad_b2.reshape(10)
                #print("b_g",b_g)
                #print("grad_b2",grad_b2)
                b_g=grad_b2+b_g
             
                W2_g=W2_g+torch.t(grad_w)
                b2_g=b2_g+torch.t(grad_b)
                
             #   print("fdafaf",grad_w,grad_b,grad_w2,grad_b2)
                loss=loss+loss_function_crossEntropy(y_predict,Y_label[index_sequense[index]])
               # print( Y_label[index],y_predict[0][0])
                if (Y_label[index_sequense[index]]==1) &( y_predict[0][0]>0.5):
                    co=co+1
                if (Y_label[index_sequense[index]]==0) &( y_predict[0][0]<=0.5):
                    co=co+1
                index=index+1
            except:
                index=0

        print("loss:",loss[0])
        print("accuracy:",co/32)
        loss_list.append(loss[0])
        accuracy_list.append(co/32)
        W_g=W_g/batch
        b_g=b_g/batch
        W2_g=W2_g/batch
        b2_g=b2_g/batch
        #print(W.size())
        #print(b.size())
        #print(W2.size())
        #print(b2.size())

        W=W+learn_rating*W_g
     #   print("b*********************",b,b_g)
        b=b+learn_rating*(b_g)
        #print(W2_g.size())
        #print(b2_g.size())
        W2=W2+learn_rating*W2_g
        b2=b2+learn_rating*b2_g
        #print(W.size())
        #print(b.size())
        #print(W2.size())
        #print(b2.size())



train()
epoch_list=list(range(epoch))
plt.plot(epoch_list,loss_list,label='SGD')
plt.title("loss")
plt.legend()

plt.show()

epoch_list=list(range(epoch))
plt.plot(epoch_list,accuracy_list,label='SGD')
plt.title("loss")
plt.legend()
plt.show()

可一下跑出的结果:
在这里插入图片描述
在这里插入图片描述
可以看到这样看下来,效果就很不错了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/742288.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Flowable边界事件-定时边界事件

定时边界事件 定时边界事件一、定义1. 图形标记2. 完整的流程图3. XML标记 二、测试用例2.1 定时边界事件xml文件2.2 定时边界事件测试用例 总结 定时边界事件 一、定义 时间达到设定的时间之后触发事件 由于定时边界事件和开始定时事件几乎差不多&#xff0c;四种情况我就不一…

14、双亲委托模型

双亲委托模型 先直接来看一幅图 双亲委派模型的好处&#xff1a; 主要是为了安全性&#xff0c;避免用户自己编写的类动态替换Java的一些核心类&#xff0c;比如 String。 同时也避免了类的重复加载&#xff0c;因为JVM中区分不同类&#xff0c;不仅仅是根据类名&#xff0c…

React 新版官方文档 (二) useState 用法详解

背景 本文默认读者对 useState 有最为基本的了解&#xff0c;比如知道他的写法应当是怎样的&#xff0c;下面着重介绍部分重要的、在开发过程中会踩的坑和一些特性&#xff0c;最后动手实现一个最基本的 useState 代码 useState ⭐️ 注意事项: 状态只在下次更新时异步变化&…

Shiro教程(一):入门概述与基本使用

Shiro 第一章&#xff1a;入门概述 1.1 Shiro是什么 Apache.Shiro是一个功能强大且易于使用的Java安全&#xff08;权限&#xff09;框架。Shiro可以完成&#xff1a;认证、授权、加密、会话管理、与Web集成、缓存等。借助Shiro可以快速轻松地保护任何应用程序——从最小的移…

用于3D渲染和平面设计应该怎么选择显卡?

首先了解快速解决3D渲染本地配置不足&#xff0c;节省硬件成本的方法&#xff1a; 本地电脑资源不足&#xff0c;在不增加额外的硬件成本投入的情况下&#xff0c;想要快速提升渲染速度&#xff0c;可使用渲云云渲染&#xff0c;且可批量渲染&#xff0c;批量出结果&#xff0…

centos7密码忘记恢复方法

首先启动系统看到如下界面&#xff1a; 然后按"e"键&#xff0c;看到下面的界面 然后使用"↓"按键移动光标&#xff0c; 移动到linux16 将上图中红色箭头指向的ro替换成下图中画红线的内容&#xff1a; ro替换成rw init/sysroot/bin/sh。 然后按CTRLX进入…

【ABAP】数据类型(五)「结构体概要」

💂作者简介: THUNDER王,一名热爱财税和SAP ABAP编程以及热爱分享的博主。目前于江西师范大学本科在读,同时任汉硕云(广东)科技有限公司ABAP开发顾问。在学习工作中,我通常使用偏后端的开发语言ABAP,SQL进行任务的完成,对SAP企业管理系统,SAP ABAP开发和数据库具有较…

Spring Cloud | No URLs will be polled as dynamic configuration sources.

添加config.properties文件就行了&#xff0c;内容为空的都可以 加上该文件再次运行

基于STM32设计的简易手机

一、项目介绍 基于STM32设计的简易手机可以作为智能手表的模型进行开发&#xff0c;方便老人和儿童佩戴。项目主要是为了解决老年人或儿童使用智能手表时可能遇到的困难&#xff0c;例如操作困难、功能复杂等问题。 在这个项目中&#xff0c;采用了STM32F103RCT6主控芯片和SI…

Effective Java(第三版)目录

本书的目标是帮助读者更加有效地使用Java编程语言及其基本类库java.lang、java.util和java.io&#xff0c;以及子包java.util.concurrent和java.util.function等。本书也会时不时地讨论到其他的类库。 本书一共包含90个条目&#xff0c;每个条目讨论一条规则。这些规则…

驱动 day8 作业

1.在内核模块中启用定时器&#xff0c;定时1s,让led1 一秒亮、一秒灭 2.基于gpio子系统完成LED灯驱动的注册&#xff0c;应用程序测试 1.mychrdev_timer.c #include <linux/init.h> #include <linux/module.h> #include <linux/fs.h> #include <linux/io…

Qt DAY5 Qt制作简易网络聊天室

服务器 widget.h文件 #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QTcpServer> #include <QTcpSocket> #include <QVector>//向量&#xff0c;函数类模板 #include <QMessageBox>namespace Ui { class Widget; }class Wid…

STM32+FreeRTOS 使用SystemView监控系统

前言 本文以STM32F407ZET6 FreeRTOS V9.0作为演示&#xff0c;其它的Cortex M芯片同样可以参考此文&#xff0c;其他内核和RTOS理论上也支持&#xff0c;本文暂时不做研究。 所以开始阅读本文前&#xff0c;需要一块能运行FreeRTOS的Cortex M芯片&#xff0c;如果没有移植好…

发一下接口自动化测试框架(python3+requests+excel)

Git&#xff1a; https://github.com/lilinyu861/Interface-Test 环境配置&#xff1a; 开发工具&#xff1a;pycharm2018Excel 开发框架&#xff1a;python3requestsexcel 接口自动化测试框架介绍&#xff1a; 此接口测试框架&#xff0c;首先由用户设计原始的测试用例并为…

webpack笔记二

文章目录 背景拆分环境清除上次构建产物插件&#xff1a;clean-webpack-plugin合并配置文件插件&#xff1a;webpack-merge实时更新和预览效果&#xff1a;webpack-dev-server babel配置参考 背景 webpack笔记一 在前面的学习&#xff0c;完成了webpack的基本配置&#xff0c…

C++教程——const修饰指针、结构体、文件操作

const修饰指针 常量指针 指针常量 const既修饰指针&#xff0c;又修饰常量 指针与数组 结构体 通过指针访问结构体变量中的数据 结构体中const使用场景 文件操作 写文件 读文件 读取数据的方式 二进制读写文件 写文件 读文件

master、origin master和origin/master

天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物。 每个人都有惰性&#xff0c;但不断学习是好好生活的根本&#xff0c;共勉&#xff01; 文章均为学习整理笔记&#xff0c;分享记录为主&#xff0c;如有错误请指正&#xff0c;共同学习进步。…

线程任务分支合并框架

1、原理 2、实用类 &#xff08;1&#xff09; ForkJoinPool 分支合并池 类比> 线程池 &#xff08;2&#xff09; ForkJoinTask ForkJoinTask 类比> FutureTask &#xff08;3&#xff09; RecursiveTask 递归任务&#xff1a;继承后可以实现递归(自己调自己)调用…

从小白到大神之路之学习运维第57天--------shell脚本实例应用3.0--以及————结合“三贱客”之“grep”的相关用法

第三阶段基础 时 间&#xff1a;2023年7月11日 参加人&#xff1a;全班人员 内 容&#xff1a; shell实例 目录 一、循环的基本使用 while随机循环 二、case控制服务的基本应用 1、case的语法格式 2、使用case写脚本&#xff0c;以以下实验为主 例1&#xff1a;控…