神经网络案例实战

news2025/1/12 4:00:12

🔎我们通过一个案例详细使用PyTorch实战 ,案例背景:你创办了一家手机公司,不知道如何估算手机产品的价格。为了解决这个问题,收集了多家公司的手机销售数据:这些数据维度可以包括RAM、存储容量、屏幕尺寸、摄像头像素等。

在这个问题中,我们不需要预测实际价格,而是一个价格范围,它的范围使用 0、1、2、3 来表示,所以该问题也是一个分类问题。

🔎思路:

  1. 数据预处理:对收集到的数据进行清洗和预处理,确保数据的质量和一致性。这包括处理缺失值、异常值和重复数据等。

  2. 特征工程:从原始数据中提取有用的特征,以便用于建模。这可以包括对连续型特征进行归一化或标准化,对分类特征进行编码等。

  3. 模型选择:选择一个适合的机器学习算法来建立模型,这里我们使用神经网络模型。

  4. 模型训练:将收集到的数据划分为训练集和测试集。使用训练集来训练模型,通过调整模型的参数来最小化预测误差。

  5. 模型评估:使用测试集来评估模型的性能。常用的评估指标包括均方误差(MSE)、均方根误差(RMSE)和平均绝对误差(MAE)等。

  6. 价格预测:使用训练好的模型来预测新手机的价格。输入新手机的功能特征,模型将输出预测的价格。

构建数据集 

💬数据共有 2000 条, 其中 1600 条数据作为训练集, 400 条数据用作测试集。 我们使用 sklearn 的数据集划分工作来完成。并使用 PyTorch 的 TensorDataset 来将数据集构建为 Dataset 对象,方便后期构造数据集加载对象。

def create_dataset():

    data = pd.read_csv('predict.csv')

    # 特征值和目标值
    x, y = data.iloc[:, :-1], data.iloc[:, -1]
    x = x.astype(np.float32)
    y = y.astype(np.int64)

    # 数据集划分
    x_train, x_valid, y_train, y_valid = train_test_split(x, y, train_size=0.8, random_state=88, stratify=y)

    # 构建数据集
    train_dataset = TensorDataset(torch.tensor(x_train.values), torch.tensor(y_train.values))
    valid_dataset = TensorDataset(torch.tensor(x_valid.values), torch.tensor(y_valid.values))

    return train_dataset, valid_dataset, x_train.shape[1], len(np.unique(y))


train_dataset, valid_dataset, input_dim, class_num = create_dataset()

💬其中 train_test_split 方法中的stratify=y参数的作用是在划分训练集和验证集时,保持类别的比例相同。这样可以确保在训练集和验证集中各类别的比例与原始数据集中的比例相同,有助于提高模型的泛化能力,防止出现一份中某个类别只有几个。

构建分类网络模型

# 构建网络模型
class PhonePriceModel(nn.Module):

    def __init__(self, input_dim, output_dim):
        super(PhonePriceModel, self).__init__()

        self.linear1 = nn.Linear(input_dim, 128) # 输入层到隐藏层的线性变换
        self.linear2 = nn.Linear(128, 256)  # 隐藏层到输出层的线性变换
        self.linear3 = nn.Linear(256, output_dim)  # 输出层到最终输出的线性变换

    def _activation(self, x):
        return torch.sigmoid(x)

    def forward(self, x):

        x = self._activation(self.linear1(x))
        # 通过第一层线性变换和激活函数
        x = self._activation(self.linear2(x))
        # 通过第二层线性变换和激活函数
        output = self.linear3(x)
        # 通过第三层线性变换得到最终输出

        return output
  • self.linear1self.linear2之间的线性变换将输入维度从input_dim映射到128个神经元,然后再将128个神经元映射到256个神经元。
  • 我们通过self._activation方法对输入数据进行激活函数处理。具体来说,它使用Sigmoid激活函数对输入数据进行非线性变换,将线性变换后的输出映射到0和1之间。
  • 第三层: 输入为维度为 256, 输出维度为: 4

编写训练函数 

def train():

    # 固定随机数种子
    torch.manual_seed(66)

    
    model = PhonePriceModel(input_dim, class_num)
    # 损失函数
    criterion = nn.CrossEntropyLoss()
    # 优化方法
    optimizer = optim.SGD(model.parameters(), lr=1e-3)
    
    num_epoch = 50

    for epoch_idx in range(num_epoch):

        # 初始化数据加载器
        dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)
        # 训练时间
        start = time.time()
        # 计算损失
        total_loss = 0.0
        total_num = 1
        # 准确率
        correct = 0

        for x, y in dataloader:

            output = model(x)
            # 计算损失
            loss = criterion(output, y)
            # 梯度清零
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 参数更新
            optimizer.step()

            total_num += len(y)
            total_loss += loss.item() * len(y)

        print('epoch: %4s loss: %.2f, time: %.2fs' %
              (epoch_idx + 1, total_loss / total_num, time.time() - start))

    # 模型保存
    torch.save(model.state_dict(), 'price-model.bin')
  • 💬要在PyTorch中查看随机数种子,可以使用torch.random.initial_seed()函数。这个函数会返回当前设置的随机数种子。 

评估函数

def test():

    # 加载模型
    model = PhonePriceModel(input_dim, class_num)
    model.load_state_dict(torch.load('price-model.bin'))

    # 构建加载器
    dataloader = DataLoader(valid_dataset, batch_size=8, shuffle=False)

    # 评估测试集
    correct = 0
    for x, y in dataloader:

        output = model(x)
        y_pred = torch.argmax(output, dim=1)
        correct += (y_pred == y).sum()

    print('Acc: %.5f' % (correct.item() / len(valid_dataset)))

网络性能调优

  1. 对输入数据进行标准化
  2. 调整优化方法
  3. 调整学习率
  4. 增加批量归一化层
  5. 增加网络层数、神经元个数
  6. 增加训练轮数

🔎我们可以自行将优化方法由 SGD 调整为 Adam , 学习率由 1e-3 调整为 1e-4 ,对数据数据进行标准化 ,增加网络深度 。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1654464.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

项目管理-项目资源管理2/2

项目管理:每天进步一点点~ 活到老,学到老 ヾ(◍∇◍)ノ゙ 何时学习都不晚,加油 资源管理:6个过程“硅谷火箭管控” ①规划资源管理: 写计划 ②估算活动资源:估算团队资源&…

2024最新手赚手机软件APP下载排行网站源码及应用商店源码

前言 这是一款简洁蓝色的手机软件下载应用排行、平台和最新发布网站,采用响应式织梦模板。主要包括主页、APP列表页、APP详情介绍页、新闻资讯列表、新闻详情页、关于我们等模块页面。 想入行手赚行业的朋友,这套源码非常适合你,简单的部署上…

JAVA栈相关习题3

1.将递归转化为循环 比如&#xff1a;逆序打印链表 // 递归方式void printList(Node head){if(null ! head){printList(head.next);System.out.print(head.val " ");}} // 循环方式void printList(Node head){if(nullhead){return;}Stack<Node> snew Stack<…

PotPlayer v1.7.22218 全格式影音播放器,无广绿色版!

软件介绍 PotPlayer是一款多功能且免费的媒体播放软件&#xff0c;兼容多种音频和视频格式。提供了丰富的功能性以及个性化设置&#xff0c;以迎合不同用户的需求。友好的用户界面&#xff0c;允许用户自定义皮肤和快捷键&#xff0c;提升了操作的便利性。 此外&#xff0c;Po…

0508_IO3

练习1&#xff1a; 1&#xff1a;使用 dup2 实现错误日志功能 使用 write 和 read 实现文件的拷贝功能&#xff0c;注意&#xff0c;代码中所有函数后面&#xff0c;紧跟perror输出错误信息&#xff0c;要求这些错误信息重定向到错误日志 err.txt 中去 1 #include <stdio.h…

Litedram仿真验证(四):AXI接口完成板级DDR3读写测试(FPGA-Artix7)

目录 日常唠嗑一、仿真中遗留的问题二、板级测试三、工程获取及交流 日常唠嗑 接上一篇Litedram仿真验证&#xff08;三&#xff09;&#xff1a;AXI接口完成仿真&#xff08;FPGA/Modelsim&#xff09;之后&#xff0c;本篇对仿真后的工程进行板级验证。 本次板级验证用到的开…

外企接受大龄程序员吗?

本人知乎账号同公众号&#xff1a;老胡聊Java&#xff0c;欢迎留言并咨询 亲身体会外企经历所见所闻&#xff0c;外企能接受大龄程序员。 1 大概是10年的时候&#xff0c;进一家知名外企&#xff0c;和我一起进的一位manager&#xff0c;后来听下来&#xff0c;年龄35&#xf…

Java 线程池之 ThreadPoolExecutor

Java线程池&#xff0c;特别是ThreadPoolExecutor&#xff0c;是构建高性能、可扩展应用程序的基石之一。它不仅关乎效率&#xff0c;还直接关系到资源管理与系统稳定性。想象一下&#xff0c;如果每来一个请求就创建一个新的线程&#xff0c;服务器怕是很快就要举白旗了。而Th…

Pytharm2020安装详细教程

Pytharm2020版提取链接链接&#xff1a; https://pan.baidu.com/s/1eDvwYmUJ4l7kIBXewtN4EA?pwd1111 提取码&#xff1a;1111 演示版本为2019版&#xff0c;链接包为2020版pytharm。 1.双击exe文件页面会提示更改选项&#xff0c;点击“是”。 2.点击下一步next 自…

52岁前宝丽金小花懒理旧爱郭晋安离婚,大晒美腿甜蜜放闪

TVB三届视帝郭晋安与欧倩怡早前在社交平台共同宣布离婚&#xff0c;并透露二人已分居两年&#xff0c;18年夫妻情画上句号&#xff0c;惊爆全城。郭晋安曾受访指&#xff0c;遇上欧倩怡前只有两段深刻的感情&#xff0c;一段是初恋&#xff0c;另一段则是刘小慧。 旧爱刘小慧懒…

事业单位向媒体投稿发文章上级领导交给了我投稿方法

作为一名事业单位的普通职员,负责信息宣传工作,我见证了从传统投稿方式到智能化转型的全过程,这段旅程既是一次挑战,也是一次宝贵的成长。回想起初涉此领域的日子,那些通过邮箱投稿的时光,至今仍然历历在目,其中的酸甜苦辣,构成了我职业生涯中一段难忘的经历。 邮箱投稿:费时费…

C++之大数运算

溪云初起日沉阁 山雨欲来风满楼 契子✨ 我们知道数据类型皆有范围&#xff0c;一旦超出了这个范围就会造成溢出问题 今天说说我们常见的数据类型范围&#xff1a; 我们平时写代码也会遇到数据类型范围溢出问题&#xff1a; 比如 ~ 我们之前写的学生管理系统在用 int类型 填写…

Leetcode—933. 最近的请求次数【简单】

2024每日刷题&#xff08;128&#xff09; Leetcode—933. 最近的请求次数 实现代码 class RecentCounter { public:RecentCounter() {}int ping(int t) {q.push(t);while(t - 3000 > q.front()) {q.pop();}return q.size();} private:queue<int> q; };/*** Your Re…

Kafka和Spark Streaming的组合使用(Spark 3.5.1)

一、安装Kafka 1.执行以下命令完成Kafka的安装&#xff1a; cd ~ //默认压缩包放在根目录 sudo tar -zxf kafka_2.11-2.3.1.tgz -C /usr/local cd /usr/local sudo mv kafka_2.11-2.3.1 kafka-2.3.1 sudo chown -R qiangzi ./kafka-2.3.1 二、启动Kafaka 1.首先需要启动K…

Java内存是怎样分配的

Java内存是怎样分配的 一、 1. 有些编程语言编写的程序会直接向操作系统请求内存&#xff0c;而 Java 语言为保证其平台无关性&#xff0c;并不允许程序直接向操作系统发出请求&#xff0c;而是在准备执行程序时由Java虚拟机&#xff08;JVM&#xff09;向操作系统请求一定的…

大模型方向好书推荐

我们已经加速进入了大模型的时代。以ChatGPT为首的一些超强模型服务&#xff0c;背后是百亿或千亿参数的基础模型&#xff0c;它们学到了丰富的世界知识&#xff0c;领悟了“与人类打交道”的门路&#xff0c;甚至开始连接和使用外部工具、成为“万物接口”。 新的时代有新的机…

源代码加密的重要性

在数字化时代&#xff0c;企业面临的最大挑战之一是如何保护其核心数据不被泄露。企业源代码防泄密是指企业采取措施保护其软件或应用程序源代码不被未授权的人员获取、泄露或盗用的一种安全措施。源代码是软件的核心组成部分&#xff0c;其中包含了程序员编写的具体指令和算法…

C++反射之检测struct或class是否实现指定函数

目录 1.引言 2.检测结构体或类的静态函数 3.检测结构体或类的成员函数 3.1.方法1 3.2.方法2 1.引言 诸如Java, C#这些语言是设计的时候就有反射支持的。c没有原生的反射支持。并且&#xff0c;c提供给我们的运行时类型信息非常少&#xff0c;只是通过typeinfo提供了有限的…

Julia 语言环境安装与使用

1、Julia 语言环境安装 安装教程&#xff1a;https://www.runoob.com/julia/julia-environment.html Julia 安装包下载地址为&#xff1a;https://julialang.org/downloads/。 安装步骤&#xff1a;注意&#xff08;勾选 Add Julia To PATH 自动将 Julia 添加到环境变量&…

【Linux 性能详解】CPU性能分析工具篇

目录 uptime mpstat 实时监控 查看特定CPU核心 pidstart 监控指定进程 组合多个监控类型 监控线程资源 按用户过滤进程 vmstart 用途 基本用法 输出字段 perf execsnoop dstat 通俗解释 技术层面解释 使用示例 总结 uptime uptime 是一个在 Linux 和 Unix…