【分类|回归】深度学习中的分类和回归?离散数据or连续数值?

news2024/11/19 20:21:34

【分类|回归】深度学习中的分类和回归?离散数据or连续数值?

【分类|回归】深度学习中的分类和回归?离散数据or连续数值?


文章目录

  • 【分类|回归】深度学习中的分类和回归?离散数据or连续数值?
  • 前言
  • 1.分类问题
    • 1.1分类问题的定义
    • 1.2深度学习中的分类问题
    • 1.3实际分类问题及代码示例:手写数字识别(MNIST)
    • 1.4分类问题的处理方法
  • 2.回归问题
    • 2.1回归问题的定义
    • 2.2深度学习中的回归问题
    • 2.3实际回归问题及代码示例:房价预测
    • 2.4回归问题的处理方法
  • 总结


前言

在机器学习和深度学习中,分类问题和回归问题是两类基本任务。两者的区别在于输出的类型:分类问题的输出是离散的类别标签,而回归问题的输出是连续的数值

1.分类问题

1.1分类问题的定义

分类问题是指给定输入数据,模型需要将其划分为多个类别中的一个。例如,判断一张图片是猫还是狗,这是一个二分类问题(Binary Classification);而预测一幅手写数字图片的数字是0到9中的哪个,这是多分类问题(Multiclass Classification)。

1.2深度学习中的分类问题

深度学习处理分类问题的常用方法是使用神经网络模型,尤其是卷积神经网络(CNN),如果是图像处理问题。在分类问题中,最后一层通常是Softmax激活函数,将网络输出转换为概率分布,表示属于每个类别的概率。

1.3实际分类问题及代码示例:手写数字识别(MNIST)

MNIST 是一个手写数字数据集,包含 28x28 的灰度图像和对应的类别标签(0-9)。

代码示例:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义一个简单的卷积神经网络
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64*7*7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(torch.relu(self.conv2(x)), 2)
        x = x.view(-1, 64*7*7)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

# 数据预处理和加载
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

# 定义模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
for epoch in range(1, 6):
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}')

代码解释:

  • 1.定义卷积神经网络SimpleCNN 是一个简单的CNN,用于处理图像数据,包含两个卷积层和两个全连接层。最后一层输出的维度为10,表示10个类别(数字0到9)。
  • 2.前向传播(forward):图像输入后经过卷积、池化、展平,再经过全连接层,最后通过 log_softmax 输出每个类别的对数概率。
  • 3.数据预处理:使用 transforms.ToTensor 将图像数据转换为张量,并进行归一化。
  • 4.损失函数:使用交叉熵损失函数 CrossEntropyLoss,适合多分类问题。
  • 5.优化器:Adam优化器用于梯度更新。
  • 6.模型训练:循环训练模型,计算损失并进行梯度更新。

1.4分类问题的处理方法

  • 模型选择:适合分类的模型有逻辑回归(LR)、支持向量机(SVM)、决策树、随机森林等,深度学习中常用的是神经网络。
  • 激活函数:最后一层通常使用 Softmax 激活函数,将模型输出的值转换为概率分布。
  • 损失函数:交叉熵损失函数适合分类任务。

2.回归问题

2.1回归问题的定义

回归问题的目标是预测一个连续的数值。典型的回归问题包括预测房价、股票价格、温度等。例如,给定一个地区的房屋面积、卧室数量等特征,预测房屋的价格。

2.2深度学习中的回归问题

在回归问题中,深度学习模型通常使用全连接神经网络(Fully Connected Neural Networks),输出层的激活函数可以是线性函数(Linear Activation),输出一个连续值。

2.3实际回归问题及代码示例:房价预测

假设我们使用一个简化的数据集,包含房屋面积、卧室数量等特征,来预测房价。

代码示例:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# 创建简单的回归模型
class SimpleRegressionModel(nn.Module):
    def __init__(self):
        super(SimpleRegressionModel, self).__init__()
        self.fc1 = nn.Linear(2, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)  # 回归输出不需要激活函数
        return x

# 模拟房价数据 (面积, 卧室数量)
data = np.array([[1200, 3], [1500, 4], [1700, 3], [2100, 4], [2500, 5]], dtype=np.float32)
prices = np.array([300000, 400000, 425000, 525000, 600000], dtype=np.float32)

# 转换为PyTorch的张量
data_tensor = torch.from_numpy(data)
prices_tensor = torch.from_numpy(prices).view(-1, 1)

# 定义模型、损失函数和优化器
model = SimpleRegressionModel()
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 训练模型
for epoch in range(1, 101):
    optimizer.zero_grad()
    output = model(data_tensor)
    loss = criterion(output, prices_tensor)
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}')

代码解释:

  • 1.定义回归模型SimpleRegressionModel 是一个简单的三层全连接网络,输入维度为2(房屋面积和卧室数量),输出维度为1(房价)。
  • 2.数据准备:模拟了一个小的房屋数据集,包含面积和卧室数量作为特征,房价作为标签。将数据转换为PyTorch的张量。
  • 3.损失函数:使用均方误差损失(MSELoss),这是回归问题的常用损失函数,衡量预测值与真实值的差异。
  • 4.优化器:Adam优化器,用于更新模型的权重。
  • 5.模型训练:每个epoch中,计算预测结果与实际值的损失,通过反向传播更新权重。

2.4回归问题的处理方法

  • 模型选择:适合回归的模型包括线性回归、决策树回归、支持向量回归(SVR)、深度学习中的全连接神经网络等。
  • 激活函数:回归任务的输出层通常不使用激活函数,直接输出一个连续值。
  • 损失函数:常用均方误差损失函数(MSE)来优化回归模型。

总结

在这里插入图片描述

  • 分类问题的输出是离散的类别标签,处理方法包括Softmax激活函数和交叉熵损失函数,常用于图像分类、文本分类等任务。
  • 回归问题的输出是连续值,处理方法包括线性激活函数和均方误差损失函数,适用于房价预测、股票价格预测等连续值预测问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2166220.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

帆软通过JavaScript注入sql,实现数据动态查询

将sql语句设置为参数 新建数据库查询 设置数据库查询的sql语句 添加控件 JavaScript实现sql注入 添加事件 编写JavaScript代码 //获取评价人id var pjrid this.options.form.getWidgetByName("id").getValue();//显示评价人id alert("评价人:&…

单片机串口AT指令操作SIM800、900拨打电话

文章目录 一、前言1.1 功能简介1.2 拨打电话功能的应用场景1.3 SIM900A与SIM800C模块介绍1.4 原理图 三、模块调试3.1 工具软件下载3.2 准备好模块3.3 串口调试助手的设置3.4 初始化配置3.5 拨打电话的测试流程 四、代码实现4.1 底层的命令发送接口4.2 底层数据接收接口4.3 检测…

Cisco Packet Tracer的安装加汉化

这个工具学计算机网络的同学会用到 1.下载安装 网盘链接:https://pan.baidu.com/s/1CmnxAD9MkCtE7pc8Tjw0IA 提取码:frkb 点击第一个进行安装,按步骤来即可。 2.汉化 (1)复制chinese.ptl文件 (2&…

四元组问题

目录 问题描述 输入格式 输出格式 样例输入 样例输出 说明 评测数据规模 运行限制 原题链接 代码思路 问题描述 从小学开始,小明就是一个非常喜欢数学的孩子。他喜欢用数学的方式解决各种问题。在他的高中时期,他遇到了一个非常有趣的问题&…

【Unity服务】如何使用Unity Version Control

Unity上的线上服务有很多,我们接触到的第一个一般就是Version Control,用于对项目资源的版本管理。 本文介绍如何为项目添加Version Control,并如何使用,以及如何将项目与Version Control断开链接。 其实如果仅仅是对项目资源进…

华盈伯乐 | Bio-Plex多重细胞因子检测技术培训及研讨会现场回顾

精彩华盈现场回顾 迎着朝霞,与会的伙伴们一早踏上了旅程,参与华盈生物与伯乐生命科学联合举办的Bio-Plex多重细胞因子检测技术培训及研讨会活动。随着活动的序幕缓缓拉开,我们迎来了一段充实而富有成效的学习之旅。 精彩开幕 华盈生物的副总…

SpringMVC4-SpringMVC获取请求参数

test_param.html&#xff1a; <!DOCTYPE html> <html lang"en" xmlns:th"http://www.thymeleaf.org"> <head><meta charset"UTF-8"><title>测试请求参数</title> </head> <body> <h1>测…

解决Pymysql has no attribute ‘escape_string‘ 并且无法引入该模块

打印出的pymysql版本是1.4.6 需要import这个module&#xff0c;并且根据pymysql的版本import的方式还不同 import pymysqlif pymysql.__version__ >1.0.0:from pymysql.converters import escape_string else:escape_string lambda x: pymysql.escape_string(x)然而&am…

如何借助Java批量操作Excel文件?

最新技术资源&#xff08;建议收藏&#xff09; https://www.grapecity.com.cn/resources/ 前言 | 问题背景 在操作Excel的场景中&#xff0c;通常会有一些针对Excel的批量操作&#xff0c;批量的意思一般有两种&#xff1a; 对批量的Excel文件进行操作。如导入多个Excel文件…

鸿蒙OpenHarmony【小型系统基础内核(虚实映射)】子系统开发

虚实映射 基本概念 虚实映射是指系统通过内存管理单元&#xff08;MMU&#xff0c;Memory Management Unit&#xff09;将进程空间的虚拟地址与实际的物理地址做映射&#xff0c;并指定相应的访问权限、缓存属性等。程序执行时&#xff0c;CPU访问的是虚拟内存&#xff0c;通…

实现微信小程序中点击单词显示在input的交互功能指南

✅作者简介&#xff1a;2022年博客新星 第八。热爱国学的Java后端开发者&#xff0c;修心和技术同步精进。 &#x1f34e;个人主页&#xff1a;Java Fans的博客 &#x1f34a;个人信条&#xff1a;不迁怒&#xff0c;不贰过。小知识&#xff0c;大智慧。 &#x1f49e;当前专栏…

卷积神经网络-学习率

文章目录 一、学习率的定义二、学习率的作用三、学习率的调整方法1.有序调整(1).有序调整StepLR(等间隔调整学习率&#xff09;(2).有序调整MultiStepLR(多间隔调整学习率)(3).有序调整ExponentialLR (指数衰减调整学习率)(4).有序调整CosineAnnealing (余弦退火函数调整学习率…

TypeScript 设计模式之【单例模式】

文章目录 **单例模式**: 独一无二的特工我们为什么需要这样的特工?单例模式的秘密&#xff1a;如何培养这样的特工?特工的利与害代码实现单例模式的主要优点单例模式的主要缺点单例模式的适用场景总结 单例模式: 独一无二的特工 单例模式就像是一个秘密组织里的特殊特工。这…

Java介绍及JDK 21详细安装教程

文章目录 1. 文章简介2. Java和JDK的介绍与关系2.1 Java2.2 JDK 3. Java版本的发展历程4. Java 21安装步骤 1. 文章简介 本文介绍如何Java、JDK、Java的发展及如何快速安装JDK 21。内容详细充实&#xff0c;旨在帮助您快速了解并使用Java。 2. Java和JDK的介绍与关系 2.1 Jav…

828华为云征文|华为云Flexus云服务器X实例——部署EduSoho网校系统、二次开发对接华为云视频点播实现CDN加速播放

EduSoho 是一款功能强大的网校系统&#xff0c;能够帮助教育机构快速搭建在线学习平台。本文将详细介绍如何在华为云服务器上安装和部署 EduSoho 网校系统&#xff0c;以及二次开发对接华为云视频点播VOD来实现CDN加速播放。 edusoho本地存储的视频播放存在诸多弊端。一方面&a…

「C++系列」命名空间

【人工智能教程】&#xff0c;前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。 点击跳转到网站&#xff1a;【人工智能教程】 文章目录 一、命名空间1. 定义命名空间2. 使用命名空间中的成员3. 命名空间的…

SAP 利润分配-未分配利润的年初余额和年末余额不一致的问题

SAP OB53 本年利润科目的年初余额和年末余额不一致的问题 关于OB53科目的问题 OB53维护的本年利润科目 现象&#xff1a;为何去年年末的本年利润金额和今年年初的本年利润金额不一致。 解释原因&#xff1a; 本年利润科目的这种现象归根结底是“表结法”产生的&#xff0c;换…

QT----Creater14.0,qt5.15无法启动调试,Launching GDB Debugger报红

问题描述 使用QT Creater 14.0 和qt5.15,无法启动调试也没有报错,加载debugger报红 相关文件都有 解决方案 尝试重装QT,更换版本5.15.2,下载到文件夹,shift鼠标右键打开powershell输入 .\qt-online-installer-windows-x64-4.8.0.exe --mirror http://mirrors.ustc.edu.cn…

VMware 虚拟机配置固定 IP

1. VMware 配置 参考&#xff1a;https://blog.csdn.net/jsryin/article/details/123304582 参考&#xff1a;https://zhuanlan.zhihu.com/p/455097916 1.1. 点击编辑 -> 虚拟网络编辑器 1.2. Net 设置 选择VMnet8 进行配置 查看当前虚拟机的网关是192.168.17.2&#x…

HAproxy-7层负载均衡集群根据不同服务请求分配服务器

搭建HAproxy----7层负载均衡集群的补充 https://blog.csdn.net/qq_73990369/article/details/142500451?spm1001.2014.3001.5501 一、再准备两台虚拟机进行测试 192.168.229.15/24 ----php1 192.168.229.16/24 ----php2 1、PHP1 & php2(192.168.229.15/24 ,192…