文章目录
- 前言
- 一、多输出预测(回归)
-
- 1 坐标数据生成
- 2 网络搭建训练预测
- 二、多标签分类
-
- 1 多标签数据生成
- 2 网络搭建训练
- 总结
前言
前面我们搭建的无论是分类还是回归都只能预测一个标签,这显然效果很局限。下面我们想做到下面这两种效果:
- 多输出预测(回归):例如训练网络拟合北东天坐标转机体坐标的关系,输入是三坐标,输出也是三坐标
- 多标签分类:例如,输入图像数据,训练网络判断图片里面有猫,有狗,还是只有其中一种这样
【注】:在介绍pytorch的内置损失函数博客中已经介绍了pytorch的损失函数是支持这个功能的。
一、多输出预测(回归)
1 坐标数据生成
# 本示例演示如何使用 PyTorch 实现多标签回归模型。
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
# 构建数据集
# 假设您有一些经纬高度和对应的地心地固坐标的数据
# 这里只是一个示例,您需要根据实际情况准备您自己的数据集
X = np.random.rand(100, 3) # 100个样本,每个样本有3个特征(经度、纬度、高度)
y = np.random.rand(100, 3) # 每个样本有3个目标值(地心地固坐标)
print('y:\n',y)
2 网络搭建训练预测
# 转换数据为 PyTorch 的 Tensor 类型
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)
# 定义模型
class MultiLabelRegressionModel(nn.Module):
def __init__(self, input_size, output_size):
super(MultiLabelRegressionModel, self).__init__()
self.fc = nn.Linear(input_size, output_size)
def forward(self, x):
out = self.fc(x)
return out
# 初始化模型
input_size = 3 # 输入特征的数量
output_size =