[代码案例] pytorch快速上手写机器学习

news2024/9/24 11:33:21

任务背景

给定未来一段时间的温度,使用神经网络预测输出是天气炎热,温暖,凉爽,偏冷,寒冷
输入是未来 20天内的气温数据,输出标签是 0,1,2,3,4

代码

"""
    @Author : 琛歌很无聊
    @Description: 使用 Pytorch的简单demo:给定一段时间的温度,输出是天气炎热,温暖,凉爽,偏冷,寒冷
                  输入是未来 20天内的气温数据,输出标签是 0,1,2,3,4
"""
import numpy as np
import torch
import torch.utils.data as Data
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt


def getlabel(X):  # 获取样本对应的标签
    Y = np.zeros(len(X))
    for i in range(len(X)):
        if np.mean(X[i]) > 30:
            Y[i] = 0  # 平均温度大于30时,标签设置为炎热(0)
        elif np.mean(X[i]) > 20:
            Y[i] = 1
        elif np.mean(X[i]) > 10:
            Y[i] = 2
        elif np.mean(X[i]) > 0:
            Y[i] = 3
        else:
            Y[i] = 4

    return Y


class MyDataset(Data.Dataset):  # 继承父类
    def __init__(self, train_x, train_y):
        self.train_x = train_x
        self.train_y = train_y

    def __getitem__(self, item):
        return self.train_x[item], self.train_y[item]

    def __len__(self):
        return len(self.train_y)


class MyNeuralNet(nn.Module):  # 继承父类
    def __init__(self, input_dim, output_dim):
        super(MyNeuralNet, self).__init__()
        self.activation = nn.ReLU()  # 激活函数
        # 定义网络
        self.n1 = nn.Linear(input_dim, 16)
        self.n2 = nn.Linear(16, 8)
        self.n3 = nn.Linear(8, output_dim)

    def forward(self, x):
        '''
        :param x: 输入矩阵
        :return: 网络输出值
        '''
        y = self.n1(x)
        y = self.activation(y)
        y = self.n2(y)
        y = self.n3(y)
        return y


'''
1.获取训练集和测试集
'''
train_X = np.random.uniform(33, 42, (100, 20))  # 随机生成温度值域[33,42),大小为100*20的数组,100是样本数,20是维度数
train_X = np.append(train_X, np.random.uniform(19, 35, (100, 20)), axis=0)  # np.append是数组拼接
train_X = np.append(train_X, np.random.uniform(5, 24, (100, 20)), axis=0)  # np.append是数组拼接
train_X = np.append(train_X, np.random.uniform(-2, 11, (100, 20)), axis=0)  # np.append是数组拼接
train_X = np.append(train_X, np.random.uniform(-15, 3, (100, 20)), axis=0)  # np.append是数组拼接
train_Y = getlabel(train_X)
# print(train_Y)
print("训练集输入的大小为:", np.shape(train_X))

test_X = np.random.uniform(30, 40, (1, 20))
test_Y=getlabel(test_X)

'''
2.加载数据集和加载模型
'''
train_X = torch.Tensor(train_X)  # 将array转成Tensor,float类型
train_Y = torch.LongTensor(train_Y)  # LongTensor中数字是整数
test_X = torch.Tensor(test_X)  # 将array转成Tensor,float类型

train_dataset = MyDataset(train_x=train_X, train_y=train_Y)
train_loader = Data.DataLoader(train_dataset, batch_size=16, shuffle=True)  # batch_size控制一批样本的数量,shuffle控制是否打乱样本取样

model = MyNeuralNet(input_dim=20, output_dim=5)  # 定义模型网络
loss_fun = nn.CrossEntropyLoss()  # 定义损失函数
optimizer = optim.AdamW(model.parameters(), lr=1e-3)  # lr是学习率

'''
3.在模型上训练
'''
max_epoch = 100
loss_list = []
for epoch in range(max_epoch):
    for i, (x, y) in enumerate(train_loader):
        predict_y = model(x)  # 1.做向前计算
        loss = loss_fun(predict_y, y)  # 2.计算损失函数
        optimizer.zero_grad()  # 3.清除网络状态
        loss.backward()  # 4.loss反向传播
        optimizer.step()  # 5.更新网络参数

        # 输出Loss,并存储
        print(loss)
        loss_list.append(loss.item())

'''
4.绘制损失函数曲线图
'''
print(loss_list)
plt.plot(loss_list)
plt.title("loss")
plt.show()

'''
5.使用训练好的模型进行泛化测试
'''
print("测试集是\n", test_X)

predict_Y = model(test_X)
predict_Y = torch.softmax(predict_Y, dim=1)  # 将输出值映射到[0,1]概率分布区间
predict_Y = predict_Y[0].detach().numpy()  # Tensor转数组array
predict_Y = np.round(predict_Y * 100, 3)  # 乘以100为百分比,保留3位小数
print("预测标签概率分布")
print(predict_Y)
print("预测标签值")
print(np.argmax(predict_Y, 0))  # 取最大值的下标为预测值
print("实际标签值")
print(test_Y)

结果展示

在这里插入图片描述

测试集是
 tensor([[38.7667, 37.0874, 30.7980, 34.7945, 37.9251, 39.0055, 37.2950, 38.5757,
         30.0081, 36.3308, 38.7851, 39.2595, 35.2415, 38.8545, 35.7583, 32.1818,
         35.3024, 32.5636, 32.8940, 34.4463]])
预测标签概率分布
[99.759  0.241  0.     0.     0.   ]
预测标签值
0
实际标签值
[0.]

代码讲解地址

https://www.bilibili.com/video/BV1KX4y177x8/?spm_id_from=333.999.0.0&vd_source=de24eb60706a87145c55dda9edb79815

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/823067.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

wordpress 学习贴

安装问题 我的使用环境为docker环境,php、nginx、mysql分别处于3个容器中, 提示异常,打开debug模式,会发现 No such file or directory Warning: mysqli_real_connect(): (HY000/2002): No such file or directory 这个其实问题其…

ConcurrentHashMap底层具体实现以及实现原理

问题描述 ConcurrentHashMap 底层具体实现以及实现原理 分析维度: 1. ConcurrentHashMap的整体架构 2. ConcurrentHashMap的基本功能 3. ConcurrentHashMap在性能方面的优化 解决方案: ConcurrentHashMap 的整体架构 如图所示,这个是 Concu…

int[]数组转Integer[]、List、Map「结合leetcode:第414题 第三大的数、第169题 多数元素 介绍」

文章目录 1、int[ ] 转 Integer[ ]:2、两道leetcode题遇到的场景:2.1、int[ ] 转 List<Integer> :2.2、int[ ] 转 Map: 1、int[ ] 转 Integer[ ]: public static void main(String[] args) {int[] nums {1, 2, 3}; Integer[] array Arrays.stream(nums).boxed().to…

Java反射(一)

目录 1.了解反射 2.Class类的三种实例化方法 3.反射机制与对象实例化 4.反射与单例设计模式 5.通过反射获取类结构的信息 1.了解反射 什么是反射&#xff0c;反射有什么作用 1.在Java中&#xff0c;反射是一种机制&#xff0c;允许程序在运行时动态地获取、使用和修改类的…

ASL芯片CS5261 替代瑞昱RTD2171替代AG9310芯片 Type-C转HDMI音频单转 CS5261搭配VL171母座正反插原理图性价比方案

在2021年末尾&#xff0c;瑞昱RTD2171已经停产&#xff0c;ASL集睿致远的单转Type-C转HDMI方案芯片&#xff0c;ASL集睿致远 CS5261却可以完 全替代兼容RTD2171和AG9310, CS5261芯片还可以实现对Type-C接口信号转换的同时实现投屏的慢充功能。另外如果使用芯片CS5261VL171支持T…

Nodejs 第四章(Npm install 原理)

在执行npm install 的时候发生了什么&#xff1f; 首先安装的依赖都会存放在根目录的node_modules,默认采用扁平化的方式安装&#xff0c;并且排序规则.bin第一个然后系列&#xff0c;再然后按照首字母排序abcd等&#xff0c;并且使用的算法是广度优先遍历&#xff0c;在遍历依…

【Python】Web学习笔记_flask(3)——上传文件

用GET、POST请求上传图片并呈现出来 首先还是创建文件上传的模板 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>上传图片</title> </head> <body> <form action""…

XGBoost的参数

目录 1. 迭代过程 1.1 迭代次数/学习率/初始&#x1d43b;最大迭代值 1.1.1 参数num_boost_round & 参数eta 1.1.2 参数base_score 1.1.3 参数max_delta_step 1.2 xgboost的目标函数 1.2.1 gamma对模型的影响 1.2.2 lambda对模型的影响 2. XGBoost的弱评估器 2.…

android app控制ros机器人四(调整界面布局)

半吊子改安卓&#xff0c;记录页面布局调整&#xff1a; 在ros-mobile基础上顶端增加一行&#xff0c;用于显示app名称和logo图像&#xff1b;修改标签页。 添加文字简单&#xff0c;但是替换图标长知识了&#xff0c;开始只是简单的把mipmap各个文件夹下的图片进行替换&…

Web-7-深入理解Cookie与Session:实现用户跟踪和数据存储

深入理解Cookie与Session&#xff1a;实现用户跟踪和数据存储 今日目标 1.掌握客户端会话跟踪技术Cookie 2.掌握服务端会话跟踪技术Sesssion 1.会话跟踪技术介绍 会话&#xff1a;用户打开浏览器&#xff0c;访问web服务器的资源&#xff0c;会话建立&#xff0c;直到有一方断…

Spring boot开发实用篇

一、热部署 1.启动热部署 1.导入坐标 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-devtools</artifactId> </dependency> 2.使用构建项目操作启动热部署 3.关于热部署 重启&#xff1a;自定义开发…

【软件测试学习】—软件测试的基本认识(一)

【软件测试学习】—软件测试的基本认识&#xff08;一&#xff09; 文章目录 【软件测试学习】—软件测试的基本认识&#xff08;一&#xff09;一、什么是软件测试二、软件测试的目的三、测试的原则四、测试的标准五、测试的基本要求六、bug的由来七、测试的流程八、开发模式九…

消息中间件应用场景介绍

提高系统性能首先考虑的是数据库的优化&#xff0c;但是数据库因为历史原因&#xff0c;横向扩展是一件非常复杂的工程&#xff0c;所有我们一般会尽量把流量都挡在数据库之前。 不管是无限的横向扩展服务器&#xff0c;还是纵向阻隔到达数据库的流量&#xff0c;都是这个思路。…

Web后端基本设计思想

JavaWeb应用的后端一般基于MVC和三层架构思想实现。 MVC是一种设计模式&#xff0c;用于开发用户界面和交互式应用程序。M即Model&#xff0c;业务模型&#xff0c;负责处理应用程序的业务逻辑和数据&#xff1b;V即View&#xff0c;视图&#xff0c;负责给用户展示界面和数据&…

【Web】web

dns与域名 网络是基于tcp/ip协议进行通信和连接的 应用层——传输层——网络层——数据链路层——物理层 每一定的台主机都有一个唯一且固定的地址标识——IP地址 IP地址的做用&#xff1a;1.区分用户和计算机&#xff1b;2.进行通信 IP地址由32位二进制数组成&#xff0c;…

<C++> 引用

1.引用的概念 引用&#xff08;Reference&#xff09;是一种别名&#xff0c;用于给变量或对象起另一个名称。引用可以理解为已经存在的变量或对象的别名&#xff0c;通过引用可以访问到原始变量或对象的内容。引用在声明时使用 & 符号来定义。 示例&#xff1a; #inclu…

小程序如何从分类中移除商品

​有时候商家可能需要在商品分类中删除某些商品&#xff0c;无论是因为商品已下架、库存不足还是其他原因。在这篇文章中&#xff0c;我们将介绍如何从分类中移除商品。 方式一&#xff1a;分类管理中删除商品。 进入小程序管理后台&#xff0c;找到分类管理&#xff0c;在分…

记录一次通过iostat命令定位系统数据库CPU飙升的案例

一、背景 我们有个移动考勤的系统&#xff0c;运维监控系统显示&#xff0c;每到上下班时间&#xff0c;考勤数据库的CPU就飙升到100%&#xff0c;磁盘读写请求等待时间变长&#xff0c;最初无法确定是磁盘性能下降导致的CPU飙升&#xff0c;还是CPU飙升导致的磁盘性能下降&…

牛客网Verilog刷题——VL55

牛客网Verilog刷题——VL55 题目答案 题目 请用Verilog实现4位约翰逊计数器&#xff08;扭环形计数器&#xff09;&#xff0c;计数器的循环状态如下&#xff1a;   电路的接口如下图所示&#xff1a; 输入输出描述&#xff1a; 信号类型输入/输出位宽描述clkwireInput1系统…

C5.0决策树建立个人信用风险评估模型

通过构建自动化的信用评分模型&#xff0c;以在线方式进行即时的信贷审批能够为银行节约很多人工成本。本案例&#xff0c;我们将使用C5.0决策树算法建立一个简单的个人信用风险评估模型。 导入类库 读取数据 #创建编码所用的数据字典 col_dicts{} #要编码的属性集 cols [che…