Kaggle系列之预测泰坦尼克号人员的幸存与死亡(随机森林模型)

news2025/1/19 3:14:04

Kaggle是开发商和数据科学家提供举办机器学习竞赛、托管数据库、编写和分享代码的平台,本节是对于初次接触的伙伴们一个快速了解和参与比赛的例子,快速熟悉这个平台。当然提交预测结果需要注册,这个可能需要科学上网了。

我们选择一个预测的入门的例子:https://www.kaggle.com/competitions/titanic

目的是使用机器学习创建一个模型,预测哪些乘客在泰坦尼克号沉船事故中幸存了下来。

对于泰坦尼克号沉船事故大家都不陌生,很多都是通过电影了解到的,那由于船上没有足够的救生艇,导致2224名乘客和船员中有1502人死亡,也就是幸存者只有722人。虽然存活下来有一定的运气成分,但似乎有些群体的人比其他人更有可能存活下来。

那什么样的人更有可能生存下来呢?其中乘客的数据包括姓名、年龄、性别、社会经济阶层等,在机器学习领域,我们把这些叫做特征值,根据这些特征值来预测乘客是否是生还者。

数据集

先把数据集下载下来熟悉下,点击"Data",下载下来是三个csv文件,这里我把它们放入dataset目录:

训练集train.csv:12列,包含891位乘客的详细介绍,关系着能否幸存的因素

PassengerId:乘客编号、Pclass:票的类型(1、2、3舱位可表示经济实力)、Name:姓名、Sex:性别、Age:年龄、SibSp:兄弟姐妹与配偶在船上的人数、Parch:父母与小孩在船上的人数、Ticket:票号、Fare:票价、Cabin 船舱号码、Embarked 港口(C:瑟堡、Q:皇后镇、S:南安普顿)

测试集test.csv:11列,除了Survived这列,481位乘客的信息,预测的是Survived是否幸存

预测结果gender_submission.csv:展示了应该如何组织预测。它预测所有的女性乘客都活下来了,所有的男性乘客都死了。

熟悉了数据集之后我们就开始构建模型进行训练,看下官方的示例:https://www.kaggle.com/code/alexisbcook/titanic-tutorial

先读取数据集,对于这种行列数据表格式的形式,一般都会用到pandas模块,能够非常方便的读取。

import pandas as pd
train_data = pd.read_csv("dataset/train.csv")
test_data = pd.read_csv("dataset/test.csv")
print(train_data.head(6))
'''
   PassengerId  Survived  Pclass  ...     Fare Cabin  Embarked
0            1         0       3  ...   7.2500   NaN         S
1            2         1       1  ...  71.2833   C85         C
2            3         1       3  ...   7.9250   NaN         S
3            4         1       1  ...  53.1000  C123         S
4            5         0       3  ...   8.0500   NaN         S
5            6         0       3  ...   8.4583   NaN         Q

[6 rows x 12 columns]
'''

显示的是前6个样本,就是6位乘客的相关资料。

我们回到那个gender_submission.csv,猜测的是女性全部幸存,看下占比怎么样?

print(len(train_data.loc[train_data.Sex=='female']))#314
print(sum(train_data.loc[train_data.Sex=='female']["Survived"]))#233

第一个就是查看有多少女性,总共有314位女性,再看下女性生还者的数量,将Survived为1的累加之后结果是233,也就是说生还的女性有233个,那看下比例:

rate_women = sum(women)/len(women)
print(rate_women)

得出结果是0.7420382165605095,预测结果约等于74.2%的准确率,也就是说女性生还率很高,其Sex性别是一个权重很高的特征值,这个我们在电影中也知道,先让妇女和小孩上船。

当然这种属于简单的预测,还有其他的影响因素都没有考虑进来是吧。

训练模型

这里我们看下官方示例使用的模型,一种叫做random forest model随机森林模型。这个模型由几棵“树”组成(下图中有三棵树,但我们将构建100棵!),它们将分别考虑每位乘客的数据,并投票决定该乘客是否幸存。然后,随机森林模型做出一个民主的决定:得票最多的结果获胜,如图:

from sklearn.ensemble import RandomForestClassifier

如果没有安装这模块,就会报错:

Import "sklearn.ensemble" could not be resolved

Traceback (most recent call last):
File "test.py", line 17, in <module>
from sklearn.ensemble import RandomForestClassifier
ModuleNotFoundError: No module named 'sklearn'

安装:pip install scikit-learn -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com

训练并生成预测结果文件submission.csv

from sklearn.ensemble import RandomForestClassifier

y = train_data["Survived"]
features = ["Pclass", "Sex", "SibSp", "Parch"]
X = pd.get_dummies(train_data[features])
X_test = pd.get_dummies(test_data[features])
model = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=1)
model.fit(X, y)
predictions = model.predict(X_test)

output = pd.DataFrame({'PassengerId': test_data.PassengerId, 'Survived': predictions})
output.to_csv('submission.csv', index=False)
print("Your submission was successfully saved!")

其中特征值我们选用了features = ["Pclass", "Sex", "SibSp", "Parch"]这四个,比如Ticket这个票号其实对于幸存者来说基本是没什么影响,所以也不需要将一些不要的因素给添加进来。

get_dummies的用法

我们转到定义来看下,解释为将分类变量转换为虚拟/指示变量。这句话是什么意思呢,简单来说就是一些分类使用了非数字这些类型,我们将其转换成0与1这样的数字,且进行单独分类列出,类似以前介绍过的独热编码One-Hot

比如在示例中的Sex性别这栏,使用get_dummies处理之后,Sex中的female与male将分成两列,Sex_female和Sex_male,里面是0或1的值。单独拎出来看个例子就明白了,例子都来自官方源码示例:

s = pd.Series(list('abca'))
>>> pd.get_dummies(s)
   a  b  c
0  1  0  0
1  0  1  0
2  0  0  1
3  1  0  0

可以看到第一个a和第四个a为1,其余为0,b和c也是一样的,所在位置为1,其余位置为0

对于缺失数据NaN的处理:

s1 = ['a', 'b', np.nan]
>>> pd.get_dummies(s1)
   a  b
0  1  0
1  0  1
2  0  0
>>> pd.get_dummies(s1, dummy_na=True)
   a  b  NaN
0  1  0    0
1  0  1    0
2  0  0    1

一个是忽略一个是保留缺失

df = pd.DataFrame({'A': ['a', 'b', 'a'], 'B': ['b', 'a', 'c'], 'C': [1, 2, 3]})
>>> pd.get_dummies(df)
   C  A_a  A_b  B_a  B_b  B_c
0  1    1    0    0    1    0
1  2    0    1    1    0    0
2  3    1    0    0    0    1

这个列名前缀就是列名_值,跟前面的Sex_femal一样。

当然也可以自己指定列名前缀

pd.get_dummies(df, prefix=['col1', 'col2'])
>>> pd.get_dummies(df, prefix=['col1', 'col2'])
   C  col1_a  col1_b  col2_a  col2_b  col2_c
0  1       1       0       0       1       0
1  2       0       1       1       0       0
2  3       1       0       0       0       1

其他参数比如drop_first=True删除第一列,指定值类型dtype=float等,基本上就是上述用法为主,更多详情可以参看其定义

Submit Predictions

最后就是将你自己预测的结果提交即可,点击“Submit Predictions”,提交的文件格式:

包含418个条目和一个标题行的csv文件。如果您有额外的列(除了PassengerId和)或行,则提交将显示错误。

该文件应该有2列:

PassengerId(乘客编号,按任意顺序排序)

Survived(幸存为1,死亡为0)

提交直接的结果如下,0.77511,也就是说准确率在77.5%左右。

提高分数的延展

我们做下修改,让分数提高点看能不能做到,因为我知道年龄对于逃生也是一个很关键的因素,所以我将年龄也添加进来试下,看是什么样的效果,而且我把它放到仅次于性别之后:

features = ["Pclass", "Sex","Age", "SibSp", "Parch"]

当然这样直接运行会报错:

type_err, msg_dtype if msg_dtype is not None else X.dtype
ValueError: Input contains NaN, infinity or a value too large for dtype('float32').

因为在Age列里面有缺失的数据,也就是说有的人年龄是未知的,于是我将缺失的使用一个平均年龄来填充X=X.fillna(24)

完整代码如下:

from sklearn.ensemble import RandomForestClassifier

y = train_data["Survived"]
features = ["Pclass", "Sex","Age", "SibSp", "Parch"]
X = pd.get_dummies(train_data[features])
X=X.fillna(24)
X_test = pd.get_dummies(test_data[features])
X_test=X_test.fillna(24)
#print(train_data["Age"].sum()/len(train_data["Age"]))#23.79

model = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=1)
model.fit(X, y)
predictions = model.predict(X_test)

output = pd.DataFrame({'PassengerId': test_data.PassengerId, 'Survived': predictions})
output.to_csv('submission.csv', index=False)
print("Your submission was successfully saved!")

提交之后分数有提高,达到了0.7799

Kaggle房价预测的练习(K折交叉验证)

有兴趣的也可以去熟悉下,通过现有的数据集,来预测不同位置面积等房子的售价。

好了,Kaggle的使用就这么愉快的搞定了,多看大神的源码和参加比赛,对于能力的提高是很有帮助的。欢迎大家留言交流,有不正确的地方欢迎指正。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/197652.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【操作系统】4、设备管理

文章目录四、设备管理4.1 I/O设备基本概念4.2 I/O控制方式4.2.1 程序控制方式4.2.2 中断方式4.2.3 DMA控制方式4.2.4 通道控制方式4.3 缓冲技术4.4 假脱机技术四、设备管理 I/O控制方式&#xff1a;程序控制、中断、DMA、通道&#xff0c; 缓冲技术&#xff1b;假脱机技术(SPO…

大龄学长的浙大MBA提面优秀之路分享

作为今年上岸浙大MBA项目的一名中年老学长&#xff0c;想把自己在提面中取得优秀资格的经验做个梳理供大家参考&#xff0c;因为以我的经历来说&#xff0c;我认为浙大MBA提前批面试是非常有价值的&#xff0c;而且在提面过程中也发现了优秀资格其实遍布于各个年龄段和层级&…

2023-02-04 Elasticsearch环境安装

1 JDK-8的安装 查询资料自我安装即可&#xff0c;这里不做展示。 2 Elasticsearch 的安装 Elasticsearch目录结构: 配置文件&#xff1a; #节点名称&#xff0c;集群内要唯一 node.name: node-1001node.master: true node.data: true#ip 地址 network.host: localhost htt…

细讲TCP三次握手四次挥手(一)

计算机网络体系结构 在计算机网络的基本概念中&#xff0c;分层次的体系结构是最基本的。计算机网络体系结构的抽象概念较多&#xff0c;在学习时要多思考。这些概念对后面的学习很有帮助。 网络协议是什么&#xff1f; 在计算机网络要做到有条不紊地交换数据&#xff0c;就必…

lsof - list open file

lsof 指令全称 list open file&#xff0c;用官方的话说 Lsof revision 4.91 lists on its standard output file information about files opened by processes -i 平常工作中&#xff0c;用到最多的就是 -i 参数&#xff0c;后面跟端口号&#xff0c;可以查看和这个端口有关…

【嵌入式】MDK使用sct文件将代码段放入RAM中执行

sct文件即分散加载文件&#xff0c;是ARMCC编译器使用的链接脚本文件&#xff0c;等同于GCC编译器的ld链接脚本。MDK IDE使用的是ARMCC。 支持NorFlash中运行代码&#xff08;XIP&#xff09;的MCU例如STM32&#xff0c;一般将所有代码&#xff08;text段&#xff09;都放在FL…

[ 云计算 | AWS ] 亚马逊云科技核心服务之计算服务(Part1:AWS EC2 星巴克为什么横向排队)

(星爸爸网络上的一张图) 注意上图中的5个人&#xff0c;对没错这5个人。一般情况星巴克的人员配置大概是这样的&#xff1a; 1个经理&#xff0c;在办公室两个收银&#xff0c;在收银台&#xff08;本文关注的重点&#xff09;三个人做咖啡 当你去过星巴克买咖啡时&#xff0…

【NS2】tcl与c++互相调用/传参

在NS2&#xff0c;做实验的时候&#xff0c;为了能通过循环配合传值实验&#xff0c;一直找不到tcl传参给c的方法&#xff0c;网上的只po出一部分看不懂&#xff0c;只能通过源码自己研究。最后的解决办法就是&#xff0c;模仿源码的操作&#xff0c;以下通过tcl→ex→sat-irid…

Navicat Monitor 3.0 现已上市 | 欢迎下载试用

Navicat Monitor 3.0 现已上市Navicat Montior 3.0 现已发布&#xff01;一经发布&#xff0c;受到广大专业运维人员的关注与选择! 五大新亮点带给运维团队最为实用且有效地提升监控能力。其具备 PostgreSQL 服务器监控能力、支持优化慢查询、构建自定义指标、性能分析工具优化…

flutter问题

问题一1.报错&#xff1a;Flutter ios/Flutter/Debug.xcconfig: unable to open file (in target "Runner" in project "Runner")2.解决&#xff1a;cd 项目目录flutter cleanflutter create --org solanddriver .运行Xcode问题二1.Cannot run with sound …

Java线程安全问题的原因和解决方案

1.什么是线程安全2.线程不安全的原因 及 解决措施2.1 多线程同时修改同一个变量2.2 修改操作不是原子性加锁操作关键字&#xff1a;synchronized2.3 抢占式执行,随机调度 (根本原因)2.4内存可见性问题volatile 关键字2.5指令重排序1.什么是线程安全 线程安全的确切定义是比较复…

Java——SSM项目(瑞吉外卖)笔记

阅读提醒&#xff1a;最重要的内容都是我手打的字&#xff0c;还有截图上的红字备注部分。 nginx是一个服务器&#xff0c;主要部署一些静态的资源&#xff0c;包括后面做tomcat的集群&#xff0c; 可以接收前端的请求&#xff0c;然后分发给各个tomcat 第一步搭建数据库&…

浏览器网页视频怎么快速下载到本地?

我们在浏览网页时&#xff0c;经常会遇到一些特别喜欢的视频文件&#xff0c;想要下载收藏却苦于不会操作怎恶魔办呢&#xff1f;这时候可以通过一些小插件快速达成下载&#xff0c;比如通过猫爪视频下载插件用户可以轻松的抓取任意网页的视频文件&#xff0c;并将其保存到本地…

Java 利用PriorityQueue进行无InvokerTransformer反序列化

java_PriorityQueue java.util.PriorityQueue 是一个优先队列&#xff08;Queue&#xff09;&#xff0c;节点之间按照优先级大小排序成一棵树。其中PriorityQueue有自己的readObject反序列化入口。 反序列化链为&#xff1a;PriorityQueue#readObject->heapify()->sif…

新网站沙盒期要多久(关于网站走出沙盒期的征兆)

做网站优化首先要明白搜索引擎抓取原理&#xff0c;不管是百度还是谷歌&#xff0c;新站上线总要进入沙盒&#xff0c;接受来自搜索引擎的审查&#xff0c;涉及网站结构、网站内容、网站外链等内容。对于新手朋友来说&#xff0c;难免着急&#xff0c;这段考察期究竟有多长&…

【Python获取相亲网站数据】马上都元宵节了,还在相亲,看看某相亲网站有没有那个有缘人。

前言 马上都元宵节了&#xff0c;还在相亲&#xff0c;看看某相亲网站有没有那个有缘人。今天我们来爬取某相亲网站获取我们想要的数据&#xff0c;比如说&#xff0c;对方的姓名&#xff0c;年龄&#xff0c;身高&#xff0c;体重等等。今天我们主要使用CSS选择的方法来匹配我…

IDEA插件开发入门.01

环境准备Idea插件SDK文档在线地址&#xff1a;https://plugins.jetbrains.com/docs/intellij/welcome.html安装IntelliJ IDEA&#xff0c;这里使用版本2020.1.3 X64IDEA中安装Plugin DevKit插件创建插件项目新建工程。File ->New -> Project选择工程类型&#xff0c;Inte…

无法应用转换程序。请检查指定的转换程序路径是否有效。例子:Adobe Acrobat DC (PDF编辑器)卸载不了或者无法重新安装

不知道大家有没遇到这种情况&#xff0c;Adobe Acrobat DC (PDF编辑器)卸载不了或者无法重新安装&#xff0c;显示&#xff1a;无法应用转换程序。请检查指定的转换程序路径是否有效。 今天小编句遇到了这种情况&#xff0c;卸不了&#xff0c;把文件夹直接删了还是无法重新安装…

Linux安装Mysql8.0

mysql官网 www.mysql.com 这里是新建了个虚拟机 有时候用 rpm -qa|grep mysql和 rpm -qa|grep mariadb检测不到已经安装了mysql或者mariadb 可以使用rpm -qa|grep -i mysql 自己对Linux学习阶段,因此新建虚拟机安装 卸载原来的mariadb rpm -e mariadb-libs rpm -e --node…

微信如何注册小号?一个手机号注册两个微信账号?图文教学

2023年2月3日微信正式开放注册“小号”的功能&#xff0c;也就是可以使用一个手机号来注册两个微信账号。微信作为很多一款国民级别的工具&#xff0c;早就成为了小伙伴日常生活中不可或缺的一部分了。能够注册微信小号自然很好&#xff0c;可是微信如何注册小号呢&#xff1f;…