TensorFlow实现Softmax回归

news2024/9/25 21:11:05

原理

模型

相比线性回归,Softmax只多一个分类的操作,即预测结果由连续值变为离散值,为了实现这样的结果,我们可以使最后一层具有多个神经元,而输入不变,其结构如图所示:

为了实现分类,我们使用一个Softmax操作,Softmax函数能够将未规范化的预测变换为非负数并且总和为1,同时让模型保持可导的性质。 为了完成这一目标,我们首先对每个未规范化的预测求幂,这样可以确保输出非负。 为了确保最终输出的概率值总和为1,我们再让每个求幂后的结果除以它们的总和。

\hat{y}_j=\frac{exp(o_j)}{\Sigma_k exp(o_k)}

那么对于y的结果,可以采用如下的方式表示:

\hat{y}=Softmax(Wx+b)

由于softmax操作只改变大小的值,不改变大小次序,因此对输出使用Softmax操作后,仍然有

{argmax}_j \hat{y}_j={argmax}_j \hat{o}_j

损失函数

在分类问题中一般使用交叉熵损失函数,这样可以更好的使模型辨别正确的label,而不是每一个label都使用同样的权重判断损失。

结果的可视化

通过构建Animator图像化类和Accumulator累加类完成数据的可视化实现。

Animator类

class Animator:
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):
        if legend is None:
            legend = []
        d2l.use_svg_display()
        self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [self.axes, ]
        self.config_axes = lambda: d2l.set_axes(
            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts

Accumulator类 

class Accumulator:
    def __init__(self, n):
        self.data = [0.0] * n
    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]
    def reset(self):
        self.data = [0.0] * len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

读取数据集

为实现Softmax回归,我们首先引入相关的库并读取数据集。这里使用mnist数据集进行测试。

import tensorflow as tf

batch_size = 256
def load_data_fashion_mnist(batch_size, resize=None):
    mnist_train, mnist_test = tf.keras.datasets.fashion_mnist.load_data()
    process = lambda X, y: (tf.expand_dims(X, axis=3) / 255,tf.cast(y, dtype='int32'))
    resize_fn = lambda X, y: (
        tf.image.resize_with_pad(X, resize, resize) if resize else X, y)
    return (tf.data.Dataset.from_tensor_slices(process(*mnist_train)).batch(batch_size).shuffle(len(mnist_train[0])).map(resize_fn),tf.data.Dataset.from_tensor_slices(process(*mnist_test)).batch(batch_size).map(resize_fn)) 
train_iter, test_iter = load_data_fashion_mnist(batch_size)

初始化模型参数

首先用Sequential构建一个模型容器,然后添加一个Flatten层将28x28的输入展平,然后添加一个全连接层获得输出。

net = tf.keras.models.Sequential()
net.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
weight_initializer = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.01)
net.add(tf.keras.layers.Dense(10, kernel_initializer=weight_initializer))

模型训练

首先定义一个损失函数,这里使用稀疏类别交叉熵损失函数,适应标签是整数而不是独热编码的情况,然后定义训练模型,采用小批量随机梯度下降(SGD)算法进行训练。

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
trainer = tf.keras.optimizers.SGD(learning_rate=.1)
num_epochs = 10

接下来定义模型的训练具体方式,对每一轮采用随机梯度下降的后向计算方式,进行具体的训练。其中train_epoch_ch3是在一轮中进行训练,train_ch3是整体的训练过程。

def train_epoch_ch3(net, train_iter, loss, updater):
    metric = Accumulator(3)
    for X, y in train_iter:
        with tf.GradientTape() as tape:
            y_hat = net(X)
            if isinstance(loss, tf.keras.losses.Loss):
                l = loss(y, y_hat)
            else:
                l = loss(y_hat, y)
        if isinstance(updater, tf.keras.optimizers.Optimizer):
            params = net.trainable_variables
            grads = tape.gradient(l, params)
            updater.apply_gradients(zip(grads, params))
        else:
            updater(X.shape[0], tape.gradient(l, updater.params))
        l_sum = l * float(tf.size(y)) if isinstance(
            loss, tf.keras.losses.Loss) else tf.reduce_sum(l)
        metric.add(l_sum, accuracy(y_hat, y), tf.size(y))
    return metric[0] / metric[2], metric[1] / metric[2]

def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):
    animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],legend=['train loss', 'train acc', 'test acc'])
    for epoch in range(num_epochs):
        train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
        test_acc = evaluate_accuracy(net, test_iter)
        animator.add(epoch + 1, train_metrics + (test_acc,))
    train_loss, train_acc = train_metrics
    assert train_loss < 0.5, train_loss
    assert train_acc <= 1 and train_acc > 0.7, train_acc
    assert test_acc <= 1 and test_acc > 0.7, test_acc

最后调用函数直接进行训练,需要注意的是,不必调用train_epoch_ch3函数,他在训练过程中是自动调用的。

train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

训练结果

在刚刚的训练过程中我们使用了animator和accumulator来可视化训练结果,因此训练结果较为直观,如图所示:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2073582.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AI编程简介

文章目录 AI 编程的特点常见编程工具copilot的工作原理AI编程常用技巧 AI 编程的特点 AI 编程是指利用人工智能技术来辅助开发过程的一种编程方式。包括但不限于&#xff1a;代码生成、优化、调试、审查&#xff0c;文档生成、测试自动化。 编程能力是大模型各项能力的天花板&…

可用于便携音箱的18V同步升压变换器TPS61288

在音频市场中,便携式音箱因其自带电源、方便携带深受消费者喜爱。便携式音箱通常配备2节可充电锂离子电池,当输出功率要求高于10W时,电池电压不足以为音频功放提供足够的功率,一般需要升压电路将电池电压升压至12V~18V以满足功率要求。 TPS61288是德州仪器公司推出的一款最…

力扣2025.分割数组的最多方案数

力扣2025.分割数组的最多方案数 哈希表 前缀和 用两个哈希表分别存元素(之后会遍历)左侧和右侧的前缀和 typedef long long LL;class Solution {public:int waysToPartition(vector<int>& nums, int k) {int n nums.size(),ans 0;vector<LL> sum(n);unor…

【Redis】Redis编程技巧

Redis编程技巧 一、StringVeiw是什么&#xff1f;二、OptionalString是什么&#xff1f;三、怎么看keys *1、vector配合back_inserter2、set配合inserter 四、chrono_literals技巧 一、StringVeiw是什么&#xff1f; 是一种轻量级的字符串视图类型&#xff0c;通常提供的是一种…

Mora:多智能体框架实现通用视频生成

人工智能咨询培训老师叶梓 转载标明出处 尽管已有一些模型能够生成视频&#xff0c;但大多数模型在生成超过10秒的长视频方面存在局限。Sora模型的出现标志着视频生成能力的一个新时代&#xff0c;它不仅能够根据文本提示生成长达一分钟的详细视频&#xff0c;而且在编辑、连接…

026集—CAD中多段线批量增加折点(相交点)——vba代码实现

当需要批量在多段线中加入顶点&#xff08;与多段线相交的点&#xff09;时&#xff0c;如下图所示&#xff1a;若干条线相交&#xff1a; 我们想在相交处增加折点&#xff0c;可通过vba插件一键完成。 &#xff08;使用方法命令行输入&#xff1a;vbaman,加载插件&#xff0c…

一本读懂数据库发展史的书

数据库及其存储技术&#xff0c;一直以来都是基础软件的主力。数据库系统的操作接口标准&#xff0c;也是应用型软件的重要接口&#xff0c;关系重大。 作为最“有感”的系统软件&#xff0c;数据库的历史悠久、品类繁多、创新活跃。 对数据库历史发展的介绍&#xff0c;有利…

JavaEE(1):web后端开发环境搭建和创建一个Servlet项目

web后端(javaEE)程序需要运行在服务器的&#xff0c;这样前端才可以访问得到 web后端开发&#xff1a; 服务器&#xff1f; 解释1&#xff1a;服务器就是一款软件&#xff0c;可以向其发送请求&#xff0c;服务器会作出一个响应。可以在服务器中部署文件&#xff0c;让他人访问…

MySQL必会知识精华2(了解基础篇)

我们的目标是&#xff1a;按照这一套资料学习下来&#xff0c;大家可以完成数据库增删改查的实际操作。轻松应对面试或者笔试题中MySQL相关题目 上篇文章我们先做一下MySQL学习的准备工作&#xff0c;如安装MySQL 服务&#xff0c;配置MySQL&#xff0c;连接MySQL。本篇文章注重…

大模型学习笔记 - LLM 之RAG

RAG RAG RAG SuveryRAG 简介RAG 范式的演变 1. 初级 RAG2. 高级 RAG3. 模块化的 RAG 介绍 RAG框架简述 检索技术文本生成增强技术简介 RAG 与 微调的区别RAG 模型评估解析RAG 研究的挑战与前景构建 RAG 系统的工具 在学习RAG中, 发现这个网站的内容特别好&#xff0c;也比较…

决策树算法:ID3与C4.5的对比分析

决策树是一种非常直观且易于理解的机器学习方法&#xff0c;被广泛应用于分类和回归任务中。在这篇文章中&#xff0c;我们将探讨两种经典的决策树算法&#xff1a;ID3与C4.5&#xff0c;并分析它们之间的区别。 一 算法概述 我们每天都做着各种形形色色的决策——周末怎么嗨…

普元EOS-微前端的base基座介绍

1 前言 微前端开发的时候要使用base基座。 这个base基座到底是什么&#xff1f; base基座能提供哪些功能&#xff1f; 本文将进行简单的介绍。 2 高开前端引用base基座 在高开页面引入base基座的语法如下&#xff1a; <script>import { BaseVue, AjaxUtil } from …

五、Centos7-安装Jenkins

目录 一、基础环境准备 1.安装JDK 2.安装Tomcat 二、安装Jenkins 1.配置Jenkins插件镜像源 2.问题&#xff1a;进入manager jenkins页面报错 3.配置Git 4.配置jdk 三、重新安装Jenkins 四、另一种Centos安装jenkins的方式--最终可用版 克隆了一个base的虚拟机&#x…

深度学习入门:循环神经网络------RNN概述,词嵌入层,循环网络层及案例实践!(万字详解!)

目录 &#x1f354; RNN 概述 1.1 循环神经网络 1.2 自然语言处理 &#x1f354; 词嵌入层 2.1 词嵌入层的使用 2.2 关于词嵌入层的思考 2.3 小节 &#x1f354; 循环网络层 3.1 RNN 网络原理 3.1.1 RNN计算过程 3.1.2 如何计算神经元内部 3.2 PyTorch RNN 层的使用…

机器学习(前六关大总结)生动讲解+代码实例

老粉都知道&#xff08;还不点关注&#xff09;我这机器学习已经有几天没更了&#xff0c;主要是最近忙碌比赛&#xff0c;所以时间紧张 那么我为大家总结一下&#xff0c;之前的机器学习知识点&#xff0c;让大家更好了解机器学习领域。 在此阅读前&#xff0c;感谢大家的关…

HTMl标签;知识回忆;笔记分享;

HTML标签是用于定义和组织网页内容的基础构建块。每个标签都有特定的作用。 一&#xff0c;标准结构标签&#xff1a; HTML文档标准结构&#xff1a; <html><head></head><body>this is my second html... </body> </html> 【1】htm…

代码随想录 | day 15 | 二叉树part03

完全二叉树的节点个数 方法一&#xff1a;可以用递归法遍历一遍左子树和右子树的个数之和再加1等于全部节点个数 class Solution { public:int getcount(TreeNode* cur){if(curNULL) return 0;int leftcount getcount(cur->left);int rightcount getcount(cur->right…

Python3.11二进制AI项目程序打包为苹果Mac App(DMG)-应用程序pyinstaller制作流程(AppleSilicon)

众所周知&#xff0c;苹果MacOs系统虽然贵为Unix内核系统&#xff0c;但由于系统不支持N卡&#xff0c;所以如果想在本地跑AI项目&#xff0c;还需要对相关的AI模块进行定制化操作&#xff0c;本次我们演示一下如何将基于Python3.11的AI项目程序打包为MacOS可以直接运行的DMG安…

90. UE5 RPG 实现技能的装配

在上一篇里&#xff0c;我们实现了在技能面板&#xff0c;点击技能能够显示出技能的相关描述以及下一级的技能的对应描述。 在这一篇里&#xff0c;我们实现一下技能的装配。 在之前&#xff0c;我们实现了点击按钮时&#xff0c;在技能面板控制器里存储了当前选中的技能的相关…

企业高性能web服务器(nginx)

目录 Web服务器基础介绍 正常情况下的单次web服务器访问流程 Apache 经典的 Web服务端 Apache prefork 模型 Apache work模型 Apache event模型 服务端的I/O流程 服务器的I/O 磁盘I/O 网络I/O 网络I/O处理过程 I/O模型 I/O模型相关概念 同步/异步 阻塞/非阻塞 网…