前馈神经网络正则化例子

news2024/12/27 9:46:38

直接看代码:

import torch  
import numpy as np  
import random  
from IPython import display  
from matplotlib import pyplot as plt  
import torchvision  
import torchvision.transforms as transforms   

mnist_train = torchvision.datasets.MNIST(root='/MNIST', train=True, download=True, transform=transforms.ToTensor())  
mnist_test = torchvision.datasets.MNIST(root='./MNIST', train=False,download=True, transform=transforms.ToTensor())  

batch_size = 256 

train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True,num_workers=0)  
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False,num_workers=0)  

num_inputs,num_hiddens,num_outputs =784, 256,10

def init_param():
    W1 = torch.tensor(np.random.normal(0, 0.01, (num_hiddens,num_inputs)), dtype=torch.float32)  
    b1 = torch.zeros(1, dtype=torch.float32)  
    W2 = torch.tensor(np.random.normal(0, 0.01, (num_outputs,num_hiddens)), dtype=torch.float32)  
    b2 = torch.zeros(1, dtype=torch.float32)  
    params =[W1,b1,W2,b2]
    for param in params:
        param.requires_grad_(requires_grad=True)  
    return W1,b1,W2,b2

def relu(x):
    x = torch.max(input=x,other=torch.tensor(0.0))  
    return x

def net(X):  
    X = X.view((-1,num_inputs))  
    H = relu(torch.matmul(X,W1.t())+b1)  
    #myrelu =((matmal x,w1)+b1),return  matmal(myrelu,w2 )+ b2
    return relu(torch.matmul(H,W2.t())+b2 )
    return torch.matmul(H,W2.t())+b2


def SGD(paras,lr):  
    for param in params:  
        param.data -= lr * param.grad  
        
def l2_penalty(w):
    return (w**2).sum()/2


def train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr=None,optimizer=None,mylambda=0):  
    
    train_ls, test_ls = [], []
    
    for epoch in range(num_epochs):
        
        ls, count = 0, 0
        
        for X,y in train_iter :
            X = X.reshape(-1,num_inputs)
            l=loss(net(X),y)+ mylambda*l2_penalty(W1) + mylambda*l2_penalty(W2)
            
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            
            ls += l.item()
            count += y.shape[0]
            
        train_ls.append(ls)
        
        ls, count = 0, 0
        
        for X,y in test_iter:
            X = X.reshape(-1,num_inputs)
            l=loss(net(X),y) + mylambda*l2_penalty(W1) + mylambda*l2_penalty(W2)
            ls += l.item()
            count += y.shape[0]
            
        test_ls.append(ls)
        
        if(epoch)%2==0:
            print('epoch: %d, train loss: %f, test loss: %f'%(epoch+1,train_ls[-1],test_ls[-1]))
            
    return train_ls,test_ls



lr = 0.01

num_epochs = 20

Lamda = [0,0.1,0.2,0.3,0.4,0.5]

Train_ls, Test_ls = [], []

for lamda in Lamda:
    print("current lambda is %f"%lamda)
    W1,b1,W2,b2 = init_param()
    loss = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD([W1,b1,W2,b2],lr = 0.001)
    train_ls, test_ls = train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr,optimizer,lamda)   
    Train_ls.append(train_ls)
    Test_ls.append(test_ls)
    
x = np.linspace(0,len(Train_ls[1]),len(Train_ls[1]))

plt.figure(figsize=(10,8))

for i in range(0,len(Lamda)):
    plt.plot(x,Train_ls[i],label= f'L2_Regularization:{Lamda [i]}',linewidth=1.5)
    
    plt.xlabel('different epoch')
    
    plt.ylabel('loss')
    
plt.legend(loc=2, bbox_to_anchor=(1.1,1.0),borderAxesPad = 0.)

plt.title('train loss with L2_penalty')

plt.show()

运行结果:

在这里插入图片描述

疑问和心得:

  1. 画图的实现和细节还是有些模糊。
  2. 正则化系数一般是一个可以根据算法有一定变动的常数。
  3. 前馈神经网络中,二分类最后使用logistic函数返回,多分类一般返回softmax值,若是一般的回归任务,一般是直接relu返回。
  4. 前馈神经网络的实现,从物理层上应该是全连接的,但是网上的代码一般都是两层单个神经元,这个容易产生误解。个人感觉,还是要使用nn封装的函数比较正宗。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/893561.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【云原生】kuberneter中Helm入门到实践

引言 helm是k8s的包管理工具,使用helm,可以使用更为简化和系统化的方式对k8s应用进行部署、升级。 helm是CNCF已毕业的项目,社区也是相当活跃的,在 https://artifacthub.io/ 上,能找到很多现成的helm chart&#xff…

轻松学会网络编程

目录 一、UDP 和 TCP 特点对比 1、有连接和无连接 2、可靠传输和不可靠传输 3、面向字节流和面向数据报 4、全双工和半双工 二、UDP 的 socket.api 1、DatagramSocket 2、DatagramPacket 回显服务器的实现 (1)服务器代码 (2&#…

GaussDB数据库SQL系列-子查询

目录 一、前言 二、GaussDB SQL子查询表达式 1、EXISTS/NOT EXISTS 2、IN/NOT IN 3、ANY/SOME 4、ALL 三、GaussDB SQL子查询实验示例 1、创建实验表 2、EXISTS/NOT EXISTS示例 3、IN/NOT IN 示例 4、ANY/SOME 示例 5、ALL示例 四、注意事项及建议 五、小结 一、…

投资不识筹码峰,炒遍A股也枉然? | 如何用python计算筹码分布数据

你听说过股市上著名的丁蟹效应吗? 你知道丁蟹报仇点到为止,丁蟹报恩家破人亡吗? 你又是否曾在微信群中见过这些表情包? 01 大时代 不知道大家有没有看过《大时代》这部剧,看过的欢迎点我头像交流讨论。 剧中逆天强运…

Java:JVM虚拟机的三种模式

在JVM中有三种模式: 混合模式:解释器热点代码编译 编译模式:启动快,执行慢 解释模式:启动慢,执行快 使用 在我们的JVM虚拟机中一般默认的是混合模式 如上所示,我们可以看到后面有mixed&#xf…

【mysql异常】Specified key was too long; max key length is 1000 bytes

最近在创建数据库的时候,报错内容如下所示: Caused by: java.sql.SQLSyntaxErrorException: Specified key was too long; max key length is 1000 bytesat com.mysql.cj.jdbc.exceptions.SQLError.createSQLException(SQLError.java:120) ~[mysql-conn…

Vue中实现自动匹配搜索框内容 关键字高亮文字显示

实现效果如下: 1.首先需要给输入框进行双向绑定 2.拿到搜索的结果去渲染页面 将返回的结果和搜索的关键字进行比对 如果相同的 就变红 上代码 html部分 //输入框<div class"search"><div class"shuru"><input type"请输入要查询的…

软件测试报告有哪些测试内容?

软件测试报告可以包含以下测试内容&#xff1a; 1、功能测试&#xff1a;测试软件的基本功能是否实现&#xff0c;是否符合要求。 2、性能测试&#xff1a;测试软件的响应速度、并发能力、稳定性等性能指标。 3、界面测试&#xff1a;测试软件的用户界面是否友好、易于使用。 …

开车打电话买什么样的蓝牙好,分享几款通话性能最好的蓝牙耳机

随着时间的推移&#xff0c;如今的年轻人越来越倾向于使用骨传导耳机&#xff0c;因为他们都知道&#xff0c;骨传导耳机最大的优点就是带着很舒服的感觉&#xff0c;它不仅比普通的入耳式耳机更容易戴上&#xff0c;而且还比普通的入耳式耳机更安全&#xff0c;能有效地减少中…

try-with-resource语法使用

try-with-resources 是 Java 7 引入的一种语法结构&#xff0c;用于更方便地管理需要关闭的资源&#xff08;如 I/O 流、数据库连接等&#xff09;。它可以在代码块结束后自动关闭资源&#xff0c;无需显式调用 close() 方法&#xff0c;从而避免资源泄漏。 基本结构 try (Res…

opencv-python使用鼠标点击图片显示该点坐标和像素值IPM逆透视变换车道线

OpenCV的鼠标操作 实现获取像素点的功能主要基于OpenCV的内置函数cv2.setMouseCallback()&#xff0c;即鼠标事件回调 setMouseCallback(winname, onMouse,userdata0) winname: 接收鼠标事件的窗口名称 onMouse: 处理鼠标事件的回调函数指针 userdata: 传给回调函数的用户数据…

交流充电桩控制主板的优点

你是否曾经担心过充电桩可能会对你的电动车电池造成危害?让我们来探讨一下交流充电桩主板的优点&#xff0c;让你安心充电。 首先&#xff0c;交流充电桩主板采用了高安全性的电源设计&#xff0c;能够有效地保护电池免受电流、电压过高的危害&#xff0c;确保电池的安全使用。…

解决执行 spark.sql 时版本不兼容的一种方式

场景描述 hive 数据表的导入导出功能部分代码如下所示&#xff0c;使用 assemble 将 Java 程序和 spark 相关依赖一起打成 jar 包&#xff0c;最后 spark-submit 提交 jar 到集群执行。 public class SparkHiveApplication {public static void main(String[] args){long sta…

Dubbo—核心优势

一、快速易用 无论你是计划采用微服务架构开发一套全新的业务系统&#xff0c;还是准备将已有业务从单体架构迁移到微服务架构&#xff0c;Dubbo 框架都可以帮助到你。Dubbo 让微服务开发变得非常容易&#xff0c;它允许你选择多种编程语言、使用任意通信协议&#xff0c;并且…

什么是低价治理服务

当商品的销售价低于品牌要求的建议价时&#xff0c;就会被认为是低价销售&#xff0c;销售的主体是店铺&#xff0c;那店铺的运营方就成了低价的主导者&#xff0c;低价行为大部分品牌都会跟进&#xff0c;低价店铺的信息品牌也会去收集&#xff0c;因为只有掌握了低价链接、低…

什么是 脏写,脏读,幻读,不可重复读?怎样能解决这四种问题?

我们通过如下语句先创建一个 student 学生表。我就以对学生表的操作来解释什么是脏写&#xff0c;脏读&#xff0c;幻读&#xff0c;不可重复读 创建完成之后随便插入一条数据 1. 脏写&#xff1f; 对于两个事务 SessionA&#xff0c;SessionB&#xff0c;如果SessionA修改了另…

无公网IP,公网SSH远程访问家中的树莓派教程

文章目录 前言 如何通过 SSH 连接到树莓派步骤1. 在 Raspberry Pi 上启用 SSH步骤2. 查找树莓派的 IP 地址步骤3. SSH 到你的树莓派步骤 4. 在任何地点访问家中的树莓派4.1 安装 Cpolar内网穿透4.2 cpolar进行token认证4.3 配置cpolar服务开机自启动4.4 查看映射到公网的隧道地…

Timeplate Definition

timeplate定义描述单个tester cycle&#xff0c;并指定所有event edges被放置在cycle的位置。 必须在引用之前定义所有的timeplates。一个procedure必须有至少一个timeplate定义&#xff0c;所有的时钟必须在timeplate定义中进行定义&#xff0c;timeplate的定义有以下格式&am…

C++ STL关联式容器(详解)

STL关联式容器 C STL关联式容器是什么&#xff1f; 在《C STL容器》一节中讲到&#xff0c;C 容器大致分为 2 类&#xff0c;即序列式容器和关联式容器。其中&#xff0c;序列式容器&#xff08;包括 array、vector、list、deque 和 forward_list&#xff09;已经在前面章节中…

【校招VIP】前端JS语言考点之Vue考察

考点介绍&#xff1a; Vue是一套用于构建用户界面的渐进式框架。与其它大型框架不同的是&#xff0c;Vue 被设计为可以自底向上逐层应用。Vue 的核心库只关注视图层&#xff0c;不仅易于上手&#xff0c;还便于与第三方库或既有项目整合。另一方面&#xff0c;当与现代化的工具…