Softmax回归(多类分类模型)

news2025/4/9 15:36:35

目录

  • 1.对真实值类别编码:
  • 2.预测值:
  • 3.目标函数要求:
  • 4.使用Softmax模型将输出置信度Oi计算转换为输出匹配概率y^i:
  • 5.使用交叉熵作为损失函数:
  • 6.代码实现:

1.对真实值类别编码:

在这里插入图片描述

  • y为真实值,有且仅有一个位置值为1,该位置即为该元素真实类别

2.预测值:

在这里插入图片描述

  • Oi为该元素与类别i匹配的置信度

3.目标函数要求:

在这里插入图片描述

  • 对于正确类y的置信度Oy要远远大于其他非正确类的置信度Oi,才能使识别到的正确类与错误类具有更明显的差距

4.使用Softmax模型将输出置信度Oi计算转换为输出匹配概率y^i:

在这里插入图片描述

  • y^为n维向量,每个元素非负且和为1
  • y^i为元素与类别i匹配的概率

5.使用交叉熵作为损失函数:

在这里插入图片描述

  • L为真实概率y与预测概率y^的差距
  • 分类问题不关心非正确类的预测值,只关心正确类的预测值有多大

6.代码实现:

import sys
import os
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import transforms
from torch.utils import data
from d2l import torch as d2l


os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

## 读取小批量数据
batch_size = 256
trans = transforms.ToTensor()
#train_iter, test_iter = common.load_fashion_mnist(batch_size) #无法翻墙的,可以参考这种方法取下载数据集
mnist_train  = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True) # 需要网络翻墙,这里数据集会自动下载到项目跟目录的/data目录下
mnist_test  = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True) # 需要网络翻墙,这里数据集会自动下载到项目跟目录的/data目录下
print(len(mnist_train))  # train_iter的长度是235;说明数据被分成了234组大小为256的数据加上最后一组大小不足256的数据
print('11111111')


## 展示部分数据
def get_fashion_mnist_labels(labels):  # @save
    """返回Fashion-MNIST数据集的文本标签。"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]


def show_fashion_mnist(images, labels):
    d2l.use_svg_display()
    # 这里的_表示我们忽略(不使用)的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()


train_data, train_targets = next(iter(data.DataLoader(mnist_train, batch_size=18)))
#展示部分训练数据
show_fashion_mnist(train_data[0:10], train_targets[0:10])

# 初始化模型参数
num_inputs = 784
num_outputs = 10

W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)


# 定义模型
def softmax(X):
    X_exp = X.exp()
    partition = X_exp.sum(dim=1, keepdim=True)
    return X_exp / partition  # 这里应用了广播机制


def net(X):
    return softmax(torch.matmul(X.reshape(-1, num_inputs), W) + b)


# 定义损失函数
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 2])
y_hat.gather(1, y.view(-1, 1))


def cross_entropy(y_hat, y):
    return - torch.log(y_hat.gather(1, y.view(-1, 1)))


# 计算分类准确率
def accuracy(y_hat, y):
    return (y_hat.argmax(dim=1) == y).float().mean().item()


# 计算这个训练集的准确率
def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for X, y in data_iter:
        acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
        n += y.shape[0]
    return acc_sum / n


num_epochs, lr = 10, 0.1


# 本函数已保存在d2lzh包中方便以后使用
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,
              params=None, lr=None, optimizer=None):
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for X, y in train_iter:
            y_hat = net(X)
            l = loss(y_hat, y).sum()

            # 梯度清零
            if params is not None and params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()

            l.backward()
            # 执行优化方法
            if optimizer is not None:
                optimizer.step()
            else:
                d2l.sgd(params, lr, batch_size)

            train_l_sum += l.item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
              % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))


# 训练模型
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [W, b], lr)

# 预测模型
for X, y in test_iter:
    break
true_labels = get_fashion_mnist_labels(y.numpy())
pred_labels = get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]
show_fashion_mnist(X[0:9], titles[0:9])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1392121.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java顺序表(2)

🐵本篇文章将对ArrayList类进行讲解 一、ArrayList类介绍 上篇文章我们对顺序表的增删查改等方法进行了模拟实现,实际上Java提供了ArrayList类,而在这个类中就包含了顺序表的一系列方法,这样在用顺序表解决问题时就不用每次都去实…

smartgit选择30天试用后需要输入可执行文件

突然有一天smartgit提示到期了,我按照以往那样删除license和preferences文件后,选择30天试用,弹出了需要选择git可执行文件。 我尝试选择了我的git.exe,发现根本不行,提示让我执行下git --version 执行过后提示我的.gi…

【Shell编程练习】编写 shell 脚本,打印 9*9 乘法表

系列文章目录 输出Hello World 通过位置变量创建 Linux 系统账户及密码 监控内存和磁盘容量,小于给定值时报警 猜大小 输入三个数并进行升序排序 编写脚本测试 192.168.4.0/24 整个网段中哪些主机处于开机状态,哪些主机处于关机状态 系列文章目录编写 shell 脚本,打…

Vue3响应式系统(二)

Vue3响应式系统(一)https://blog.csdn.net/qq_55806761/article/details/135587077 六、嵌套的effect与effect栈。 什么场景会用到effect嵌套呢?听我娓娓道来。 就用Vue.js来说吧,Vue.js的渲染函数就是在effect中执行的: /*Foo组件*/ const…

动态路由协议

一、动态路由协议 动态路由协议,用在多个 Router 之间定期的、自动的、互相交换 Routes(路由信息,包含了网段信息、可达性信息、路径信息等),动态生成 Routing Table Entries,并最终达到全网的路由收敛&am…

Java项目:123SSM高校运动会信息管理系统

博主主页:Java旅途 简介:分享计算机知识、学习路线、系统源码及教程 文末获取源码 一、项目介绍 高校运动会信息管理系统基于SpringSpringMVCMybatis开发,主要用来管理高校运动会信息,系统分为管理员何运动员两种角色。系统主要功…

AI在广告中的应用——预测性定位和调整

营销人员的工作就是在恰当的时间将适合的产品呈现在消费者面前,从而增加他们购买的可能性。随着时间的推移,营销人员能够深入挖掘越来越精准的客户细分市场,他们不仅具备了实现上述目标的能力,而且这种能力还在呈指数级提升。在AI…

如何将github copilot当gpt4用

现在写代码已经离不开ai辅助了我用的是github copilot,一方面是因为它和vscode结合得比较好,另一方面就是copilot chat了。可以在不切换工具的情况下,问它问题,在copilot chat还在内测阶段的时候我就申请使用了(现在已…

【现代密码学】笔记9-10.3-- 公钥(非对称加密)、混合加密理论《introduction to modern cryphtography》

【现代密码学】笔记9-10.3-- 公钥(非对称加密)、混合加密理论《introduction to modern cryphtography》 写在最前面8.1 公钥加密理论随机预言机模型(Random Oracle Model,ROM) 写在最前面 主要在 哈工大密码学课程 张…

Python多线程爬虫——数据分析项目实现详解

前言 「作者主页」:雪碧有白泡泡 「个人网站」:雪碧的个人网站 ChatGPT体验地址 文章目录 前言爬虫获取cookie网站爬取与启动CSDN爬虫爬虫启动将爬取内容存到文件中 多线程爬虫选择要爬取的用户 线程池 爬虫 爬虫是指一种自动化程序,能够模…

Java基础知识整理,驼峰规则、流程控制、自增自减

写在开头 本文接着上一篇文章续写哈。Java基础知识整理,注释、关键字、运算符 在这一篇文章中我们总结了包括注释、关键字、运算符的Java基础知识点,今天继续来聊一聊命名规则(驼峰)、流程控制、自增自减。 一、命名规则 上一…

pygame学习(三)——支持多种类型的事件

大家好!我是码银🥰 欢迎关注🥰: CSDN:码银 公众号:码银学编程 实时事件循环 为了保证程序的持续刷新、保持打开的状态,我们会创建一个无限循环,通常使用的是while语句,w…

第二百六十九回

文章目录 概念介绍设置方法示例代码内容总结 我们在上一章回中介绍了Card Widget相关的内容,本章回中将介绍国际化设置.闲话休提,让我们一起Talk Flutter吧。 概念介绍 我们在这里说的国际化设置是指在App设置相关操作,这样可以让不同国家的…

关于VScode的这个ssh的配置的经验

1.首先,我是因为重装了ubantu系统,不得不重新配置ssh 2.第一步,在本机的终端安装ssh插件: (1) (2)restart开启这个ssh端口 3.然后,就在vscode里面,安装哪个…

【Rust学习】安装Rust环境

本笔记为了记录学习Rust过程,内容如有错误请大佬指教 使用IDE:vs code 参考教程:菜鸟教程链接: 菜鸟教程链接: Rust学习 Rust入门安装Rust编译环境Rust 编译工具 构建Rust 工程目录 Rust入门 安装Rust编译环境 因为我已经安装过VSCode了&am…

【ArcGIS遇上Python】ArcGIS Python批量筛选多个shp中指定字段值的图斑(以土地利用数据为例)

文章目录 一、案例分析二、提取效果二、代码运行效果三、Python代码四、数据及代码下载一、案例分析 以土地利用数据为例,提取多个shp数据中的旱地。 二、提取效果 原始土地利用数据: 属性表: 提取的旱地:(以图层名称+地类名称命名)

mysql — 生产环境发布DDL之避坑操作onlineDDL

一、Mysql onLineDDL特性 1、Mysql 5.6 DDL MySQL 的 DDL(Data Definition Language) 包括增减字段、增减索引等操作。MySQL Online DDL 功能从 5.6 版本开始正式引入,发展到现在的 8.0 版本,那么在 MySQL 5.6 之前,MySQL 的 DDL 操作会按照…

Vue高级(二)

3.搭建vuex环境 创建文件:src/store/index.js //引入Vue核心库import Vue from vue//引入Vueximport Vuex from vuex//应用Vuex插件Vue.use(Vuex)//准备actions对象——响应组件中用户的动作const actions {}//准备mutations对象——修改state中的数据const mutat…

近4w字吐血整理!只要你认真看完【C++编程核心知识】分分钟吊打面试官(包含:内存、函数、引用、类与对象、文件操作)

🌈个人主页:godspeed_lucip 🔥 系列专栏:C从基础到进阶 🏆🏆关注博主,随时获取更多关于C的优质内容!🏆🏆 C核心编程🌏1 内存分区模型&#x1f384…

mac查看maven版本报错:The JAVA_HOME environment variable is not defined correctly

终端输入mvn -version报错: The JAVA_HOME environment variable is not defined correctly, this environment variable is needed to run this program. Java环境变量的问题,打开bash_profile查看 open ~/.bash_profile export JAVA_8_HOME/Library/Java/JavaVirtualMachine…