神经网络模型预训练

news2025/1/23 17:32:40

根据神经网络各个层的计算逻辑用程序实现相关的计算,主要是:前向传播计算、反向传播计算、损失计算、精确度计算等,并提供保存超参数到文件中。

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
from DeepLearn_Base.common.functions import *
from DeepLearn_Base.common.gradient import numerical_gradient
import pickle

# 三层神经网络处理类(两层隐藏层+1层输出层)
class ThreeLayerNet:

    # input_size:输入层神经元数量,灰度图像的三维表示: 1 * 28 * 28 = 784
    # output_size: 输出层神经元数量,10,表示10个数字
    # hidden_size:第一层隐藏层神经元数量,50
    # second_hidden_size:第二层隐藏层神经元数量,100
    # weight_init_std:权重初始化
    def __init__(self, input_size, hidden_size, output_size, second_hidden_size, weight_init_std=0.01):
        # 初始化权重
        self.params = {}
        self.params['W1'] = weight_init_std * np.random.randn(input_size, hidden_size)
        self.params['b1'] = np.zeros(hidden_size)
        self.params['W2'] = weight_init_std * np.random.randn(hidden_size, second_hidden_size)
        self.params['b2'] = np.zeros(second_hidden_size)
        self.params['W3'] = weight_init_std * np.random.randn(second_hidden_size, output_size)
        self.params['b3'] = np.zeros(output_size)

    # 执行预测
    def predict(self, x):
        W1, W2, W3 = self.params['W1'], self.params['W2'], self.params['W3']
        b1, b2, b3 = self.params['b1'], self.params['b2'], self.params['b3']
    
        # 隐藏层第一层
        a1 = np.dot(x, W1) + b1
        z1 = sigmoid(a1)

        # 隐藏层第二层
        a2 = np.dot(z1, W2) + b2
        z2 = sigmoid(a2)
        
        # 输出层
        a3 = np.dot(z2, W3) + b3
        y = softmax(a3)
        
        return y
        
    # x:输入数据, t:监督数据
    def loss(self, x, t):
        y = self.predict(x)
        
        return cross_entropy_error(y, t)
    
    # 精确度计算
    def accuracy(self, x, t):
        y = self.predict(x)
        y = np.argmax(y, axis=1)
        t = np.argmax(t, axis=1)
        
        accuracy = np.sum(y == t) / float(x.shape[0])
        return accuracy
        
    # 梯度计算
    def gradient(self, x, t):
        W1, W2, W3 = self.params['W1'], self.params['W2'], self.params['W3']
        b1, b2, b3 = self.params['b1'], self.params['b2'], self.params['b3']
        grads = {}
        
        batch_num = x.shape[0]
        
        # forward
        # 隐藏层第一层
        a1 = np.dot(x, W1) + b1
        z1 = sigmoid(a1)

        # 隐藏层第二层
        a2 = np.dot(z1, W2) + b2
        z2 = sigmoid(a2)
        
        # 输出层
        a3 = np.dot(z2, W3) + b3
        y = softmax(a3)
        
        # backward
        # 两层隐藏层计算梯度
        # 输出层梯度: Loss与输出的导数,分类场景下,等于预测值-真实值
        # 权重梯度: 隐藏层输出的转置 * 损失函数梯度
        dy = (y - t) / batch_num
        grads['W3'] = np.dot(z2.T, dy)
        grads['b3'] = np.sum(dy, axis=0)
        
        # 反向传播到隐藏层
        # 隐藏层梯度:Loss与输出的导数 * 输出层权重的转置
        da2 = np.dot(dy, W3.T)
        dz2 = sigmoid_grad(a2) * da2
        grads['W2'] = np.dot(z1.T, dz2)
        grads['b2'] = np.sum(dz2, axis=0)

        da1 = np.dot(da2, W2.T)
        dz1 = sigmoid_grad(a1) * da1
        grads['W1'] = np.dot(x.T, dz1)
        grads['b1'] = np.sum(dz1, axis=0)

        return grads
    
    # 保存参数到文件
    def save_params(self, file_name="params.pkl"):
        params = {}
        for key, val in self.params.items():
            params[key] = val
        with open(file_name, 'wb') as f:
            pickle.dump(params, f)

预训练实现

读取MNIST训练数据集,总共有60000个。每次从60000个训练数据中随机取出100个数据 (图像数据和正确解标签数据)。然后,对这个包含100笔数据的批数据求梯度,使用随机梯度下降法(SGD)更新参数。这里,梯度法的更新次数(循环的次数)为10000。每更新一次,都对训练数据计算损失函数的值,并把该值添加到数组中。

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
import matplotlib.pyplot as plt
from DeepLearn_Base.dataset.mnist import load_mnist
from three_layer_net import ThreeLayerNet

# 读入数据
# x_train.sharp 60000 * 784
# t_train.sharp 60000 * 10
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)

network = ThreeLayerNet(input_size=784, hidden_size=50, second_hidden_size=100, output_size=10)

iters_num = 10000  # 适当设定循环的次数
# 训练集大小 60000
train_size = x_train.shape[0]
batch_size = 100
learning_rate = 0.1

train_loss_list = []
train_acc_list = []
test_acc_list = []

# 每批次迭代数量:600
iter_per_epoch = max(train_size / batch_size, 1)

for i in range(iters_num):
    # 从训练集中选取100个为一批次进行训练
    batch_mask = np.random.choice(train_size, batch_size)
    x_batch = x_train[batch_mask]
    t_batch = t_train[batch_mask]
    
    # 更新超参数梯度
    grad = network.gradient(x_batch, t_batch)
    
    # 更新超参数W,b
	# 基于SGD算法更新梯度,上面是随机选择的批数据处理,因此更新时,也是随即更新梯度
    for key in ('W1', 'b1', 'W2', 'b2', 'W3', 'b3'):
        network.params[key] -= learning_rate * grad[key]
    
    loss = network.loss(x_batch, t_batch)
    train_loss_list.append(loss)
    
    if i % iter_per_epoch == 0:
        train_acc = network.accuracy(x_train, t_train)
        test_acc = network.accuracy(x_test, t_test)
        train_acc_list.append(train_acc)
        test_acc_list.append(test_acc)
        print("train acc, test acc | " + str(train_acc) + ", " + str(test_acc))

# 绘制图形
markers = {'train': 'o', 'test': 's'}
x = np.arange(len(train_acc_list))
plt.plot(x, train_acc_list, label='train acc')
plt.plot(x, test_acc_list, label='test acc', linestyle='--')
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()

# 输出到文件保存参宿后
network.save_params("E:\\workcode\\code\\DeepLearn_Base\\ch04\\myparams.pkl")

用图像来表示这个损失函数的值的推移,如图所示;并保存最终的超参数到pkl文件

应用自训练超参数

将之前用于预测图像文字中使用的超参数文件替换为自己预训练生成的pkl参数文件,并执行代码,打印出精确度。
这是基于默认的超参数进行推理后的精确度:

替换超参数文件,进行图像识别推理

 

精确度上涨了0.01,因此选择合适的梯度更新超参数,是保证推理精确度好坏的关键。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1279516.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2023年12月4日支付宝蚂蚁庄园小课堂小鸡宝宝考考你今日正确答案是什么?

问题:你知道电杆上安装的“小风车”有什么用途吗? 答案:防止鸟类筑巢 解析:小风车一般做成橙色,因为橙色是一种可令野鸟产生恐慌感的颜色;小风车在转动时,会发出令野鸟害怕的噪声;…

【iOS】数据持久化(三)之SQLite3及其使用

目录 数据库简介什么是SQLite?在Xcode引入SQLite APISQL语句的种类存储字段类型 SQLite的使用创建数据库创建表和删表数据表操作增(插入数据INSERT)删(删除数据DELETE)改(更新数据UPDATE)查&…

softmax实现

import matplotlib.pyplot as plt import torch from IPython import display from d2l import torch as d2lbatch_size 256 train_iter,test_iter d2l.load_data_fashion_mnist(batch_size) test_iter.num_workers 0 train_iter.num_workers 0 num_inputs 784 # 将图片…

【msg_msg】corCTF2021-msgmsg 套题

前言 该套题共两题,一道简单模式 fire_of_salvation,一道困难模式 wall_of_perdition,都是关于 msg_msg 的利用的。这题跟之前的 TPCTF2023 core 的很像(应该是 TPCTF2023 core 跟他很像,bushi)。 其中 f…

ISP算法简述-BLC

Black Level Calibration, 黑电平矫正 现象 1)在纯黑条件下拍张图,你会发现像素值不为0 2)或者你发现图像整体偏色 这些问题可能是黑电平导致的。 原因 存在黑电平的原因有2个: 1)sensor的电路本身存在暗电流。暗电流主要产生在光电信号转换过程中&#…

人工智能 - 人脸识别:发展历史、技术全解与实战

目录 一、人脸识别技术的发展历程早期探索:20世纪60至80年代技术价值点: 自动化与算法化:20世纪90年代技术价值点: 深度学习的革命:21世纪初至今技术价值点: 二、几何特征方法详解与实战几何特征方法的原理…

【C语言】深入理解指针(1)

前言 C语言是一种直接操作内存的编程语言,我们可以直接访问和操作计算机内存中的地址空间。 而C语言中存在的指针类型,指针指向的就是内存中的地址。我们可以通过指针来访问和修改内存中存储的数据。 因此,深入理解指针,并且理解内…

基于SSH的员工管理系统(一)——包结构

基于SSH的员工管理系统(一)——包结构 包结构 1、整体包结构 2、action包 3、domain实体包 4、service层 5、dao层 6、util工具包 7、页面层

【Oracle】数据库登陆错误:ORA-28000:the account is locked解决方法

问题描述 在连接Oracle数据库的时候出现了ORA-28000:the account is locked报错,登录账号被锁定,出现这种情况就需要将被锁定用户解锁。 解决方法 解锁方法就是通过用system账号登录数据库,然后修改被锁定账户状态,具体如下图所示…

03 数仓平台 Kafka

kafka概述 定义 Kafka 是一个开源的分布式事件流平台(Event Streaming Plantform),主要用于大数据实时领域。本质上是一个分布式的基于发布/订阅模式的消息队列(Message Queue)。 消息队列 在大数据场景中主要采用…

右值引用和移动语句(C++11)

左值引用和右值引用 回顾引用 我们之前就了解到了左值引用,首先我们要了解引用在编译器底层其实就是指针。具体来说,当声明引用时,编译器会在底层生成一个指针来表示引用,但在代码编写和使用时,我们可以像使用变量类…

鸿蒙绘制折线图基金走势图

鉴于鸿蒙下一代剥离aosp,对于小公司而言,要么用h5重构,要么等大厂完善工具、等华为出转换工具后跟进,用鸿蒙重新开发一套代码对于一般公司而言成本会大幅增加。但对于广大开发者来说,暂且不论未来鸿蒙发展如何&#xf…

中国消费电子行业发展趋势及消费者需求洞察|徐礼昭

一、引言 近年来,随着科技的飞速发展,消费电子行业面临着前所未有的挑战与机遇。本文将从行业发展趋势、消费者需求洞察以及企业数字化转型的方向和动作三个方面,对消费电子行业进行深入剖析。 二、消费电子行业发展趋势 5G技术的普及和应…

WEB安全之Python

WEB安全之python python-pyc反编译 python类似java一样,存在编译过程,先将源码文件*.py编译成 *.pyc文件,然后通过python解释器执行 生成pyc文件 创建一个py文件随便输入几句代码(1.py) 通过python交互终端 >>>import py_compil…

【C++练级之路】【Lv.1】C++,启动!(命名空间,缺省参数,函数重载,引用,内联函数,auto,范围for,nullptr)

目录 引言入门须知一、命名空间1.1 作用域限定符1.2 命名空间的意义1.3 命名空间的定义1.4 命名空间的使用 二、C输入&输出2.1 cout输出2.2 cin输入2.3 std命名空间的使用惯例 三、缺省参数3.1 缺省参数概念3.2 缺省参数分类 四、函数重载4.1 函数重载概念4.2 函数重载分类…

JavaSE自定义验证码图片生成器

设计项目的时候打算在原有的功能上补充验证码功能,在实现了邮箱验证码之后想着顺便把一个简单的图片验证码生成器也实现一下,用作分享。 注意,实际开发中验证码往往采用各种组件,通过导入依赖来在后端开发时使用相关功能&#xf…

泊车功能专题介绍 ———— 汽车全景影像监测系统性能要求及试验方法(国标未公布)

文章目录 术语和定义一般要求功能要求故障指示 性能要求响应时间图像时延单视图视野范围平面拼接视图视野平面拼接效果总体要求行列畸变拼接错位及拼接无效区域 试验方法环境条件仪器和设备车辆条件系统响应时间试验图像时延试验单视图视野范围试验平面拼接视图视野试验平面拼接…

【大学英语视听说上】Mid-term Test 2

Section A 【短篇新闻1】 You probably think college students are experts at sleeping, but parties, preparations for tests, personal problems and general stress can rack a students sleep habits, which can be bad for the body and the mind. Texas Tech Univer…

51爱心流水灯32灯炫酷代码

源代码摘自远眺883的文章,大佬是30个灯的,感兴趣的铁汁们可以去看看哦~(已取得原作者的许可):基于STC89C51单片机设计的心形流水灯软件代码部分_单片机流水灯代码_远眺883的博客-CSDN博客 由于博主是个小菜鸡&#xff…

【Python从入门到进阶】43.验证码识别工具结合requests的使用

接上篇《42、使用requests的Cookie登录古诗文网站》 上一篇我们介绍了如何利用requests的Cookie登录古诗文网。本篇我们来学习如何使用验证码识别工具进行登录验证的自动识别。 一、图片验证码识别过程及手段 上一篇我们通过requests的session方法,带着原网页登录…