深度学习笔记:神经网络(3)

news2024/11/17 13:38:35

关于本章之前内容可以参考以下之前文章:

1 https://blog.csdn.net/Raine_Yang/article/details/128473486?spm=1001.2014.3001.5501

2 https://blog.csdn.net/Raine_Yang/article/details/128584916?spm=1001.2014.3001.5501

神经网络的输出层设计

机器学习问题一般分为分类问题和回归问题。分类文件判断数据属于哪一类别,输出层一般使用softmax函数。回归问题即根据输入预测一个连续的数值,输出层一般使用恒等函数

设有n个输出层,其中第k个神经元的输出值yk
softmax函数表达式:
在这里插入图片描述
在实际应用中,由于e ^ x 的值往往较大,容易导致变量溢出。利用以下简单变换可以很好避免这一问题:
在这里插入图片描述
这里我们将分数上下同乘以常数C,把ln©代入到指数函数内,最后把ln©即为C’。我们将C’取输入数据最大值即可很好避免溢出

程序实现如下:

def softmax(a):
    c = np.max(a)
    exp_a = np.exp(a - c)
    sum_exp_a = np.sum(exp_a)
    y = exp_a / sum_exp_a
    
    return y

2 softmax函数特征

softman函数输出是0到1之间实数,并且输出值总和恒为1。这一性质使得我们将softmax函数输出解释为分类问题各项的概率

network = init_network()
x = np.array([1.6, 5.5])
y = forward(network, x)
print(y)

# output: [0.36750414 0.63249586]

将x代入我们上一篇文章中搭建的神经网络里得到以下结果,输出可以被解释为x[0]的概率是0.36,x[1]的概率是0.63

输出层的神经元个数又问题而定,一般来说对于分类问题神经元个数即为类别的个数,而每个输出值即为对应类别的概率

3 手写数字识别

要实现首先数字识别,我们先引入常用的手写数字库mnist用于训练。运行下面程序前需要先将mnist库下载到该程序文件父文件夹下

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
from dataset.mnist import load_mnist
from PIL import Image


def img_show(img):
    pil_img = Image.fromarray(np.uint8(img))
    pil_img.show()

(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)

img = x_train[0]
label = t_train[0]
print(label)  # 5

print(img.shape)  # (784,)
img = img.reshape(28, 28)  # 把图像的形状变为原来的尺寸
print(img.shape)  # (28, 28)

img_show(img)

1


(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)

load_mnist方法以(训练图像,训练标签),(测试图像,测试标签)的形式返回mnist数据。其参数flatten为是否将二维图像展开为一维数组,normalize为是否将图像正规化为0.0-1.0的值,属于对图像的一种预处理

2

def img_show(img):
    pil_img = Image.fromarray(np.uint8(img))
    pil_img.show()

Image.fromaeeay方法将Numpy数组图像转换为PIL数据类型

注:如报错不存在PIL可以下载库Pillow,下载时切换清华源加速
打开cmd,输入

pip install Pillow -i https://pypi.tuna,tsinghua.edu.cn/simple

神经网络的推理处理

该神经网络输入层有784个神经元(来源于展开图像28 * 28),输出层有10个神经元(代表结果0-10),其中隐藏层我们使用两层,第一层50个神经元,第二次100个神经元

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
import pickle
from dataset.mnist import load_mnist
from common.functions import sigmoid, softmax


def get_data():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test


def init_network():
    with open("sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network


def predict(network, x):
    W1, W2, W3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3
    y = softmax(a3)

    return y


x, t = get_data()
network = init_network()
accuracy_cnt = 0
for i in range(len(x)):
    y = predict(network, x[i])
    p= np.argmax(y) # 获取概率最高的元素的索引
    if p == t[i]:
        accuracy_cnt += 1

print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

1

def init_network():
    with open("sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network

这里我们直接使用已经训练好的权重和偏置。使用pickle方法加载sample_weight.pkl文件里面的权重值

pickle方法可以将程序运行中对象保存为文件。加载pickle方法可以复原程序运行中的对象

predict函数实现一个两次隐藏层的神经网络

2

x, t = get_data()
network = init_network()
accuracy_cnt = 0
for i in range(len(x)):
    y = predict(network, x[i])
    p= np.argmax(y) # 获取概率最高的元素的索引
    if p == t[i]:
        accuracy_cnt += 1

print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

predict得到网络输出结果,为一个长度为10的一维数组,其中每个元素索引为预测的数字,而元素值为该元素判定概率。np.argmax()找到最大值(即出现概率最大的数)的索引

批处理

在我们上一个例子中,我们一次处理一张展开的图像,相当于处理一个长为784 (28 X 28)的一维数组,数组在神经网络中形状变化如下
在这里插入图片描述
如果我们一次输入100张图片,即相当于输入一个100 X 784的数组,这使得输出也将为一个100 X 10的数组
在这里插入图片描述
这种打包式输入数据被称为批(batch)。批处理可以节省数据传输时间,使得效率高于分开处理

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
import pickle
from dataset.mnist import load_mnist
from common.functions import sigmoid, softmax


def get_data():
    (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
    return x_test, t_test


def init_network():
    with open("sample_weight.pkl", 'rb') as f:
        network = pickle.load(f)
    return network


def predict(network, x):
    w1, w2, w3 = network['W1'], network['W2'], network['W3']
    b1, b2, b3 = network['b1'], network['b2'], network['b3']

    a1 = np.dot(x, w1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, w2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, w3) + b3
    y = softmax(a3)

    return y


x, t = get_data()
network = init_network()

batch_size = 100 # 批数量
accuracy_cnt = 0

for i in range(0, len(x), batch_size):
    x_batch = x[i:i+batch_size]
    y_batch = predict(network, x_batch)
    p = np.argmax(y_batch, axis=1)
    accuracy_cnt += np.sum(p == t[i:i+batch_size])

print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

该程序中每次将100个数据x输入神经网络并得到结果y。np.argmax参数axis = 1指取每一维的最大值

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/160393.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2022年12月青少年软件编程(Python) 等级考试试卷(三级)

2022. 12 青少年软件编程(Python) 等级考试试卷(三级) 一、 单选题(共 25 题, 共 50 分) 1.列表 L1 中全是整数, 小明想将其中所有奇数都增加 1, 偶数不变, 于是编写了如下图 所示的…

高级树结构之平衡二叉树(ALV树)

文章目录平衡二叉树简介失衡类型&处理办法RR型失衡【左旋调整】调整演示代码实现LL型失衡【右旋调整】调整演示代码实现RL型失衡【先右旋后左旋调整】调整演示代码实现LR型失衡【先右旋后左旋调整】调整演示代码实现插入操作综合代码演示平衡二叉树简介 在数据有序的情况下…

Codeforces Round #842 (Div. 2)

Codeforces Round #843 (Div. 2) 传送门 不想搞的很累&#xff0c;对自己不做要求&#xff0c;有兴趣就做。 A. Greatest Convex #include <bits/stdc.h>using namespace std; const int maxn 1e6 10; vector<int> cnt[maxn]; map<int, int> mp;int mai…

nvm安装 疑难问题解决

nvm介绍 NVM是node.js的版本管理器&#xff0c;设计为每个用户安装&#xff0c;并在每个shell中调用。nvm可以在任何兼容posix的shell (sh、dash、ksh、zsh、bash)上运行&#xff0c;特别是在这些平台上:unix、macOS和windows WSL。 nvm安装 &#xff01;&#xff01;重要&a…

强化学习在智能补货场景的应用

本文作者&#xff1a;应如是&#xff0c;观远算法团队工程师&#xff0c;毕业于伦敦帝国理工学院计算机系&#xff0c;主要研究方向为强化学习、时间序列算法及其落地应用。深耕零售消费品场景&#xff0c;解决供应链运筹优化问题。为客户提供基于机器学习的AI解决方案。1. 背景…

2023.Q1 go语言记录

1. Go 语言数组声明和初始化var variable_name [SIZE] variable_type&#xff0c;eg&#xff1a;var balance [10] float32var balance [5]float32{1000.0, 2.0, 3.4, 7.0, 50.0}balance : [5]float32{1000.0, 2.0, 3.4, 7.0, 50.0}长度不确定的初始化var balance [...]float…

少儿Python每日一题(21):八皇后问题

原题解答 本次的题目如下所示: 会下国际象棋的人都很清楚:皇后可以在横、竖、斜线上不限步数地吃掉其他棋子。如何将8个皇后放在棋盘上(有8 8个方格),使它们谁也不能被吃掉!这就是著名的八皇后问题。 对于某个满足要求的8皇后的摆放方法,定义一个皇后串a与之对应,即,…

【Ansible】Ansible Playbook 的任务控制

Ansible Playbook 的任务控制 文章目录Ansible Playbook 的任务控制一、Ansible 任务控制基本介绍二、条件判断1.解决第一个问题2.nginx 语法校验三、循环控制四、Tags 属性五、Handlers 属性一、Ansible 任务控制基本介绍 任务控制类似于编程语言中的 if …、for …等逻辑控制…

MSF社会工程学

● Metasploit发现两个远程代码执行漏洞 ○ 问题都出在WEB组件方面 ○ MSF不受影响 ● 安全面前任何软件都是平等的 ○ 没有绝对安全的软件 为什么要说社会工程学 ● Metasploit可以很好的配合到社会工程学攻击的各个阶段 ● Setoolkit工具包大量依赖Metasploit ● 基于浏览器…

[SUCTF 2019]EasySQL

目录 预先知识 信息收集 思路 源码分析 非预期解 预期解 补充 预先知识 环境 use mysql; create table if not exists my_table( id int PRIMARY key auto_increment, name VARCHAR(20), age int); insert into my_table values(NULL,xiao,19); insert into my_table v…

BUFF80双模蓝牙5.2热插拔PCB

键盘使用说明索引&#xff08;均为出厂默认值&#xff09;软件支持一些常见问题解答&#xff08;FAQ&#xff09;首次使用测试步骤蓝牙配对规则&#xff08;重要&#xff09;蓝牙和USB切换键盘默认层默认触发层0的FN键配置的功能默认功能层1配置的功能默认的快捷键蓝牙参数蓝牙…

Jetpack Compose中的Canvas

Jetpack Compose中的Canvas API 使用起来感觉比传统View中的要简单一些&#xff0c;因为它不需要画笔Paint和画布分开来&#xff0c;大多数直接就是一个函数搞定&#xff0c;当然也有一些限制。 Compose 直接提供了一个叫 Canvas 的 Composable 组件&#xff0c;可以在任何 Co…

containerd环境下build镜像

containerd环境下build镜像安装nerdctl使用nerdctl打包docker镜像下载安装 buildkit编写systemd unit文件&#xff1a;启用buildkit.service并设置开机自动运行修改Dockerfile构建镜像containerd配置代理containerd配置代理ansible剧本安装nerdctl https://blog.csdn.net/omai…

Python采集最热影评 + 制作词云图

人生苦短&#xff0c;我用Python 电影评论&#xff0c;简称影评 是对一部电影的导演、演员、镜头、摄影、剧情、线索、环境、色彩、光线、视听语言、道具作用、转场、剪辑等进行分析和评论。 电影评论的目的在于分析、鉴定和评价蕴含在银幕中的审美价值、认识价值、社会意义、…

Springboot中,异步线程的执行顺序的控制

1、顺序的定义异步任务存在如下几种顺序&#xff1a;顺序的开启子任务&#xff08;运行顺序和结束顺序不确定&#xff09;。顺序的完成&#xff08;就是A先启动&#xff0c;先执行完&#xff0c;再执行B任务&#xff0c;往往A、B之间存在某种依赖关系&#xff09;。还有就是优先…

思科设备-配置静态路由

⬜⬜⬜ &#x1f430;&#x1f7e7;&#x1f7e8;&#x1f7e9;&#x1f7e6;&#x1f7ea; (*^▽^*)欢迎光临 &#x1f7e7;&#x1f7e8;&#x1f7e9;&#x1f7e6;&#x1f7ea;&#x1f430;⬜⬜⬜ ✏️write in front✏️ &#x1f4dd;个人主页&#xff1a;陈丹宇jmu &a…

【云原生进阶之容器】第四章Operator原理4.2节--CRD

1 CRD概述 1.1 CRD简介 CRD全称是CustomResourceDefinition,在Kubernetes 中一切都可视为资源,Kubernetes 1.7 之后增加了对 CRD 自定义资源二次开发能力来扩展 Kubernetes API,通过 CRD 我们可以向 Kubernetes API 中增加新资源类型,而不需要修改 Kubernetes 源码来创建自…

Django搭建个人博客Blog-Day03

对user模块进行开发设计数据表Django默认就提供了和用户相关的功能&#xff0c;但是这个Django默认提供的功能有个不好的点: 不太适合我们的项目&#xff0c;例如里面的字段不够等&#xff0c;所以我们要对它进行改造一下&#xff0c;方便项目开发。拓展用户模型进入虚拟环境安…

java1算法

排序–comparable接口 java提供了一个接口Comparable用来定义类的排序规则 eg: 1、定义一个学生类Student&#xff0c;具有年龄age和姓名username连个属性&#xff0c;并通过Comparable接口提供比较规则&#xff1b; 2、定义测试类Test,在测试类中定义测试方法Comparable getM…

还不会用YakitBp?来,我教你

前言 &#x1f340;作者简介&#xff1a;被吉师散养、喜欢前端、学过后端、练过CTF、玩过DOS、不喜欢java的不知名学生。 &#x1f341;个人主页&#xff1a;被吉师散养的职业混子 &#x1f342;相应专栏&#xff1a;CTF专栏 Yakit介绍 好兄弟&#xff0c;你听说过yakit吗&…