基于Python的人工智能应用案例系列(7):纽约市出租车费预测(分类)

news2024/9/23 16:23:27

        在本篇文章中,我们将使用Kaggle提供的纽约市出租车费数据,构建一个基于深度学习的分类模型,预测出租车费是否超过10美元。我们将结合数据探索、特征工程、深度学习模型构建与优化,并对模型进行测试和评估。

1. 数据加载与预处理

        我们从数据集中加载纽约市出租车费数据,数据包括了乘客的上车和下车的经纬度、乘车时间、乘客人数等字段。我们需要基于这些输入来预测出租车费用的分类:低于10美元或等于/高于10美元。

import pandas as pd

# 加载数据集
df = pd.read_csv('../data/NYCTaxiFares.csv')
df.head()

数据概览

数据集共有120,000条记录,包含了以下字段:

  • pickup_datetime:上车时间
  • pickup_latitude:上车地点纬度
  • pickup_longitude:上车地点经度
  • dropoff_latitude:下车地点纬度
  • dropoff_longitude:下车地点经度
  • passenger_count:乘客人数
  • fare_class:目标变量,表示费用类别,0表示低于10美元,1表示等于或高于10美元
# 查看分类的分布
df['fare_class'].value_counts()

数据分类说明

  • Class 0: 费用低于10美元
  • Class 1: 费用等于或高于10美元

        约有2/3的数据属于Class 0,1/3的数据属于Class 1。这个数据不算严重的类别不平衡,但我们在建模时仍需注意平衡性。

2. 特征工程

        在开始建模之前,我们需要对原始数据进行一些处理,提取出有用的特征。例如,我们可以通过上车和下车的经纬度来计算出行距离,同时提取时间相关的特征(例如小时、星期几等)。

计算两点之间的距离

        我们可以使用haversine公式来计算两组经纬度之间的距离,该公式适用于球面距离计算,地球可以近似看作一个球体。

import numpy as np

def haversine_distance(df, lat1, long1, lat2, long2):
    """
    计算两组经纬度之间的球面距离
    """
    r = 6371  # 地球平均半径,单位为千米
    phi1 = np.radians(df[lat1])
    phi2 = np.radians(df[lat2])
    delta_phi = np.radians(df[lat2] - df[lat1])
    delta_lambda = np.radians(df[long2] - df[long1])
    a = np.sin(delta_phi / 2)**2 + np.cos(phi1) * np.cos(phi2) * np.sin(delta_lambda / 2)**2
    c = 2 * np.arctan2(np.sqrt(a), np.sqrt(1 - a))
    return r * c  # 返回距离,单位为千米

# 应用公式,生成距离特征
df['dist_km'] = haversine_distance(df, 'pickup_latitude', 'pickup_longitude', 'dropoff_latitude', 'dropoff_longitude')

提取时间特征

        通过将字符串形式的pickup_datetime转换为时间对象,我们可以提取出时间相关的特征,如小时、上午/下午以及星期几。

df['EDTdate'] = pd.to_datetime(df['pickup_datetime'].str[:19]) - pd.Timedelta(hours=4)  # 转换为美国东部时间
df['Hour'] = df['EDTdate'].dt.hour
df['AMorPM'] = np.where(df['Hour'] < 12, 'am', 'pm')
df['Weekday'] = df['EDTdate'].dt.strftime("%a")

        提取时间信息后,我们的数据框将包含以下特征:

  • dist_km: 乘车的行驶距离
  • Hour: 上车时间的小时
  • AMorPM: 上午或下午
  • Weekday: 星期几

3. 数据预处理

        在进行建模前,我们还需要对分类特征进行编码,将字符串类别转换为数字表示。我们使用pandas中的category数据类型来转换这些列。

for col in ['Hour', 'AMorPM', 'Weekday']:
    df[col] = df[col].astype('category')

# 检查转换后的数据类型
df.dtypes

        通过这种转换方式,我们将会为每个分类特征分配一个整数编码,用以表示不同的类别。

4. 模型构建

嵌入层处理

        我们将为分类特征构建嵌入层,通过将类别变量转换为嵌入向量,模型能够更好地学习到不同类别的相似性。

import torch
import torch.nn as nn

class TaxiFareModel(nn.Module):
    def __init__(self, emb_szs, n_cont, out_sz, layers, p=0.5):
        super().__init__()
        self.embeds = nn.ModuleList([nn.Embedding(ni, nf) for ni,nf in emb_szs])
        self.emb_drop = nn.Dropout(p)
        self.bn_cont = nn.BatchNorm1d(n_cont)
        
        layers_list = []
        n_emb = sum((nf for ni, nf in emb_szs))
        n_in = n_emb + n_cont
        
        for i in layers:
            layers_list.append(nn.Linear(n_in, i))
            layers_list.append(nn.ReLU(inplace=True))
            layers_list.append(nn.BatchNorm1d(i))
            layers_list.append(nn.Dropout(p))
            n_in = i
            
        layers_list.append(nn.Linear(layers[-1], out_sz))
        self.layers = nn.Sequential(*layers_list)
    
    def forward(self, x_cat, x_cont):
        embeddings = [e(x_cat[:, i]) for i, e in enumerate(self.embeds)]
        x = torch.cat(embeddings, 1)
        x = self.emb_drop(x)
        x_cont = self.bn_cont(x_cont)
        x = torch.cat([x, x_cont], 1)
        return self.layers(x)

模型训练

        我们使用交叉熵损失函数进行模型的训练,优化器选用Adam优化器。

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 模型训练
epochs = 300
losses = []

for epoch in range(epochs):
    y_pred = model(cat_train, con_train)
    loss = criterion(y_pred, y_train)
    losses.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

5. 模型测试与评估

        我们使用测试集进行模型评估,并计算模型的准确率。

with torch.no_grad():
    y_val = model(cat_test, con_test)
    loss = criterion(y_val, y_test)
    
print(f'测试集上的交叉熵损失: {loss.item()}')

        我们还可以打印前50个预测值及其对应的实际值,来直观地检查模型的表现。

rows = 50
correct = 0

for i in range(rows):
    pred_label = y_val[i].argmax().item()
    true_label = y_test[i].item()
    if pred_label == true_label:
        correct += 1
        
print(f'{correct} out of {rows} = {100 * correct / rows:.2f}% correct')

6. 模型保存与加载

        为了方便后续的推理和应用,我们可以将训练好的模型保存到文件中。

torch.save(model.state_dict(), 'TaxiFareClssModel.pt')

        在需要时,我们可以重新加载模型并进行推理。

model2 = TaxiFareModel(emb_szs, conts.shape[1], 2, [200, 100], p=0.4)
model2.load_state_dict(torch.load('TaxiFareClssModel.pt'))
model2.eval()  # 设置为评估模式

结语

        通过这个案例,我们学习了如何基于分类问题构建深度学习模型,预测纽约市出租车费是否超过10美元。我们通过特征工程生成了重要的特征,并利用嵌入层对分类变量进行了处理。最终,我们构建了一个深度神经网络,并对模型进行了训练和评估。

        这个案例展示了如何将深度学习应用于实际问题中,并且提供了未来可扩展的模型保存和加载方式,便于对新数据进行推理。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2157967.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

骨架行为识别-论文复现(论文复现)

骨架行为识别-论文复现&#xff08;论文复现&#xff09; 本文所涉及所有资源均在传知代码平台可获取 序言 骨架行为识别的定义 骨架行为识别是指通过分析人体骨架的运动轨迹和姿态&#xff0c;来识别和理解人体的行为动作。它是计算机视觉和模式识别领域的一个重要研究方向&a…

力扣上刷题之C语言实现-Days1

一. 简介 本文记录一下力扣的逻辑题。主要是数组方面的&#xff0c;使用 C语言实现。 二. 涉及数组的 C语言逻辑题 1. 两数之和 给定一个整数数组 nums 和一个整数目标值 target&#xff0c;请你在该数组中找出 和为目标值 target的那 两个 整数&#xff0c;并返回它们的…

C++笔试强训15、16、17

文章目录 笔试强训15一、选择题1-5题6-10题 二、编程题题目一题目二 笔试强训16一、选择题1-5题6-10题 二、编程题题目一题目二 笔试强训17一、选择题1-5题6-10题 二、编程题题目一题目二 笔试强训15 一、选择题 1-5题 共有派生下&#xff0c;派生类的成员函数只能访问基类的…

大模型训练不难,三步即可实现

前言 初步认识了大模型长什么样了&#xff0c;接下来一起来看看如何训练出一个大模型。 训练方式&#xff0c;这里主要参考OpenAI发表的关于InstructGPT的相关训练步骤&#xff0c;主流的大模型训练基本形式大多也是类似的&#xff1a; 1、预训练&#xff08;Pretraining&a…

安卓13设置动态修改设置显示版本号 版本号增加信息显示 android13增加序列号

总纲 android13 rom 开发总纲说明 文章目录 1.前言2.问题分析3.代码分析4.代码修改5.编译6.彩蛋1.前言 设置 =》关于平板电脑 =》版本号 在这里显示了系统的一些信息,但是这里面的信息并不包含序列号之类的信息,我们修改下系统设置,在这里增加上相关的序列号。 2.问题分析…

C语言 使用scanf函数时出现错误代码C4996

文章目录 错误样式解决方法方法一&#xff1a;使用安全的函数替代方法二&#xff1a;禁用警告方法三&#xff1a;检查并修改编译器设置 错误样式 C4996 ‘scanf’: This function or variable may be unsafe. Consider using scanf_s instead. To disable deprecation, use _C…

《算法岗面试宝典》正式发布

大家好&#xff0c;历时半年完善&#xff0c;《算法岗面试宝典》 终于可以跟大家见面了。 最近 ChatGPT 爆火&#xff0c;推动了技术圈对大模型算法场景落地的热情&#xff0c;就业市场招聘人数越来越多&#xff0c;算法岗一跃成为竞争难度第一的岗位。 岗位方向 从细分方向…

K8s Calico替换为Cilium,以及安装Cilium过程

一、删除Calico kubectl delete daemonset calico-node -n kube-systemkubectl delete deployment calico-kube-controllers -n kube-system kubectl delete ds kube-flannel-ds -n kube-system kubectl delete cm calico-config -n kube-system kubectl delete secret calico…

YOLOv5训练COCO2017数据集

网上没找到适合新手小白的教程,看了些教程,但还是没法解决自己遇到的问题。记录下自己的过程,希望能提供点帮助。 默认已经部署好了yolov5。 安装部署yolov5可参考以下: ubuntu20.04配置YOLOV5(非虚拟机)_ubuntu系统实现yolov5没有显卡-CSDN博客 目录 一、数据集下载…

Java基础-零拷贝

文章目录 什么是零拷贝&#xff1f;传统IO执行过程零拷贝的意义零拷贝的主要实现方式实际应用场景零拷贝的优势零拷贝的局限性 Java 中的零拷贝实现FileChannel.transferTo()FileChannel.transferFrom() 相关知识点解释什么是DMA内核空间和用户空间什么是用户态、内核态什么是上…

2012年408考研真题-数据结构

8.【2012统考真题】求整数n(n≥0)的阶乘的算法如下&#xff0c;其时间复杂度是(&#xff09;。 int fact(int n){ if(n<1) return 1; return n*fact (n-1); } A. O(log2n) B. O(n) C. O(nlog2n) D. O(n^2) 解析&#xff1a; 观察代码&#xff0c;我们不…

如何在openKylin中配置ssh服务并实现远程连接开放麒麟系统(1)

文章目录 前言1. 安装SSH服务2. 本地SSH连接测试3. openKylin安装Cpolar4. 配置 SSH公网地址5. 公网远程SSH连接6. 固定SSH公网地址7. SSH固定地址连接 前言 本文主要介绍如何在openKlyin系统中设置ssh连接&#xff0c;并结合cpolar内网穿透工具实现远程也可以ssh连接本地局域…

功能 接口测试,详解从抓包 +linux 日志 + 数据库的 bug 定位!

我在跟很多测试人员交流中发现&#xff0c;很大一部分测试工程师在进行功能和接口测试过程中&#xff0c;对于发现的bug很少去进行定位&#xff0c;只是将bug基于业务操作上如何出现的&#xff0c;进行描述&#xff1b;至于bug产生的原因&#xff0c;开发自己排查去吧。本文中&…

多语言文本 AI 纠错格式化 API 数据接口

多语言文本 AI 纠错格式化 API 数据接口 AI / 文本处理 AI 模型智能纠正 语法纠错 / 文本格式化。 1. 产品功能 支持多语言文本的语法纠错&#xff1b;自动识别并纠正拼写错误、语法错误和标点符号使用不当&#xff1b;优化文本格式&#xff0c;提高可读性&#xff1b;基于AI…

《李·斯莫林讲量子引力》:在不断运动的宇宙中探究离散的时空

可能是斯莫林的书读得并不多&#xff0c;感觉他讲故事的能力不如讲物理定律的能力。前半部分纯知识的可读性要好于后面讲述理论的创造过程的故事。如作者所说现代科学没有任何领域是单打独斗&#xff0c;而是不断探索&#xff0c;在团队中&#xff0c;前人和其他专业领域专家合…

vue使用PDF.JS踩的坑--部署到服务器上显示pdf.mjs viewer.mjs找不到资源

之前项目使用的pdf.js 是2.15.349版本&#xff0c;最近换了一个4.6.82的版本&#xff0c;在本地上浏览文件运行的好好的&#xff0c;但是发布到服务器&#xff08;IIS&#xff09;上打不开文件&#xff0c;控制台提示找不到pdf.mjs viewer.mjs。 之前使用的2.15.349pdf和viewer…

76、Python之函数式编程:柯里化都不懂,别说你会函数式编程

引言 很多时候&#xff0c;我们在定义函数处理比较复杂的业务逻辑时&#xff0c;首先是想着遵照“单一职能原则&#xff08;SRP&#xff09;”&#xff0c;尽量拆分为功能单一、足够精简的函数&#xff0c;以便保证代码的可读性和可扩展性。但是&#xff0c;有些逻辑就是没法拆…

2024年双十一有什么好物值得买呢?双十一必买好物清单

双十一买什么犒劳自己既不会浪费钱又可以增添生活的幸福感&#xff1f;以下就整理了五款更适合与秋冬独自生活相伴的好物&#xff0c;精致增加生活氛围感&#xff0c;热爱生活的同时更好的爱自己&#xff01;努力工作和生活当然也要更好的享受生活&#xff0c;给生活创造更多美…

Vue(14)——组合式API①

setup 特点&#xff1a;执行实际比beforeCreate还要早&#xff0c;并且获取不到this <script> export default{setup(){console.log(setup函数);},beforeCreate(){console.log(beforeCreate函数);} } </script> 在setup函数中提供的数据和方法&#xff0c;想要在…

数据结构和算法之树形结构(2)

文章出处&#xff1a;数据结构和算法之树形结构(2) 关注码农爱刷题&#xff0c;看更多技术文章&#xff01;&#xff01; 三、二叉查找树(接前篇) 二叉查找树&#xff0c;又称二叉搜索树或二叉排序树&#xff0c;是在普通二叉树基础上为了实现快速查找而设计出来的一种树形结…