【课程总结】Day6(下):机器学习项目实战–成人收入预测

news2025/1/15 18:21:39

机器学习项目实战:成人收入预测

项目目的

基于个人收入数据(包括教育程度、年龄、性别等)的数据集,通过机器学习算法,预测一个人的年收入是否超过5万美金。

数据集

  • 地址:http://idatascience.cn/dataset-detail?table_id=100368

  • 数据集字段

    字段名称字段类型字段说明
    age数值型年龄
    workclass字符型工作类型
    fnlwgt字符型序号
    education字符型教育程度
    education-num数值型受教育时长
    marital-status字符型婚姻状况
    occupation字符型职业
    relationship字符型关系
    race字符型种族
    sex字符型性别
    capital-gain数值型资本收益
    capital-loss数值型资本损失
    hours-per-week数值型每周工作小时数
    native-country字符型原籍
    salary字符型收入
  • 数据集样例:
    file

解决思路

分析输入/输出

通过分析,本次项目我们要解决的问题:给定一个人的相关信息(包括年龄、教育程度、受教育时长等),预测其收入是否超过5万。

该问题如果是预测其收入是多少,那么就属于线性回归问题;但数据集中在收入标签是>=5万或<5万,所以这应该是一个分类问题。

分析相关输入/输出如下:

  • 输入:一个人的信息(包括年龄、教育程度、受教育时长等)
  • 输出:0为<=50K,1为>50K
构建数据集
分析数据类型

首先,我们先分析一下如何将数据集向量化,通过以下代码来查看数据集每列的内容:

import csv
import numpy as np
def read_file(file_path, skip_header=True):
  """
  读取CSV文件的内容。
  参数:
      file_path (str): CSV文件的路径。
      skip_header (bool): 是否跳过表头数据,默认为True。
  返回:
      list: 包含CSV文件内容的列表。
  """
  print(f'读取原始数据集文件: {file_path}')
  with open(file_path, 'r', encoding='utf-8') as f:
      if skip_header:
          # 跳过表头数据
          f.readline()
      reader = csv.reader(f)
      return [row for row in reader]  # 读取csv文件的内容

    
def print_unique_columns(data):
    """
    打印表格数据中每一列去重后的数据。
    参数:
        data (list): 包含表格数据的列表。
    """
    if not data:
        print("没有读取到任何数据。")
        return

    # 获取每一列的数据
    columns = zip(*data)

    # 打印每一列去重后的数据
    for i, column in enumerate(columns):
        unique_values = set(column)
        unique_count = len(unique_values)
        print(f"第{i}列去重后的数据(最多前20个):")
        
        # 如果数据大于20个,则只打印前20个
        if unique_count > 20:
            for value in list(unique_values)[:20]:
                print(value)
            print(f"共 {unique_count} 个不同的值")
        else:
            for value in unique_values:
                print(value)
            print(f"共 {unique_count} 个不同的值")
        print()
        
data = read_file('./成人收入预测数据集.csv')
print_unique_columns(data)

运行结果:

对上述每一列内容进行梳理,梳理结果如下:

列的序号(第几列)原始数据表头名称表头名称中文解释对应列取值举例取值个数数据类型数据处理办法
0age年龄69, 62, 37, 80, 3473连续量保留
1workclass工作类型Without-pay, State-gov, Federal-gov9离散量保留
2fnlwgt疑似邮编编号226092, 209770, 3184021648连续量舍弃
3education教育程度Masters, Bachelors, Some-college16离散量保留
4education-num教育年限3, 4, 16, 2, 1116连续量保留
5marital-status婚姻状况Divorced, Separated, Never-married7离散量保留
6occupation职业类型Exec-managerial, Adm-clerical, Handlers-cleaners15离散量保留
7relationship家庭关系Wife, Not-in-family, Other-relative6离散量保留
8race种族Amer-Indian-Eskimo, White, Black5离散量保留
9sex性别Female, Male2离散量保留
10capital-gain投资利得3781, 2062, 401119连续量保留
11capital-loss投资损失1816, 1876, 176292连续量保留
12hours-per-week每周工作时长37, 80, 3494连续量保留
13native-country出生国家Japan, England, Guatemala42离散量保留

通过分析上述每列内容可知:

  • 数据集中,age、education-num、capital-gain、capital-loss、hours-per-week等字段的内容一般都是数字,可以视为连续量。这类数据直接使用即可,不需要做向量化处理(因为是数字,机器能够处理)。
  • 数据集中,workclass、education、marital-status、occupation、relationship、race、sex、native-country等字段内容一般都是表示状态,属于离散量;由于其内容不是数字,为了让机器能够处理,我们后续需要进行向量化处理。
  • 数据集中,fnlwgt疑似是连续量,但是查看取值个数(21648)与总样本个数(32561个)不一致,所以确定该列不是类似身份证号的唯一编号数据(如果身份证可以舍弃,因为身份证跟收入没啥关系);该列可能是类似邮编的数据,虽然邮编可能与收入有一定关系,但是因为其取值数量比较大(21648个),这会导致特征数据过去庞大,所以选择舍弃。
离散量的编码

通过与chatGPT交流,我们了解到,离散量的编码常见方式有One-Hot编码标签编码序数编码Target编码哈希编码实体嵌入

  1. One-Hot编码:
    • 编码方式:将每个离散特征转换为多个二进制特征。
    • 优点:能够很好地表示分类特征,不会引入大小关系。
    • 缺点:会增加特征维度。
    • 适用场景:适用于没有大小关系的分类特征,如性别、产品类别等。
  2. 标签编码(Label Encoding):
    • 编码方式:将每个离散特征的取值映射为一个整数值。
    • 优点:简单直观。
    • 缺点:可能会让算法认为特征之间存在大小关系。
    • 适用场景:适用于没有明确大小关系的分类特征,但要注意可能带来的影响。
  3. 序数编码(Ordinal Encoding):
    • 编码方式:将每个离散特征的取值映射为一个有序的整数值。
    • 优点:能够保留特征之间的大小关系。
    • 缺点:对于没有明确大小关系的离散特征可能不太合适。
    • 适用场景:适用于有明确大小关系的有序离散特征,如学历、星级等。
  4. Target编码:
    • 编码方式:将每个离散特征的取值映射为目标变量的平均值或中位数。
    • 优点:能够捕捉特征取值与目标变量之间的关系,通常能提高模型性能。
    • 缺点:需要提前知道目标变量,不适用于无监督学习。
    • 适用场景:适用于有监督学习问题中,当目标变量与离散特征存在相关性时。
  5. 哈希编码:
    • 编码方式:将每个离散特征的取值通过哈希函数映射为一个整数值。
    • 优点:在处理高基数特征时较为有效,能减少内存占用。
    • 缺点:可能会产生冲突,导致信息丢失。
    • 适用场景:适用于高基数离散特征,且内存受限的情况下。
  6. 实体嵌入(Entity Embedding):
    • 编码方式:将每个离散特征的取值映射为一个低维的稠密向量。
    • 优点:能够学习特征之间的潜在关系,通常能提高模型性能。
    • 缺点:需要额外的训练过程来学习嵌入向量。
    • 适用场景:适用于复杂的机器学习问题,如自然语言处理、推荐系统等,能够更好地捕捉特征之间的潜在关系。

分析上述编码内容,其中

  • Target编码哈希编码实体嵌入暂未学习到,本次暂不做考虑。

  • 标签编码序数编码由于其存在潜在的数字大小对比,让算法认为特征之间存在大小关系,所以并不适用于上述的婚姻状况、职业类型、国家等。

所以,综上所述本次练习使用One-Hot编码,其编码示意图如下:

基于上面的思想,我们通过与GPT沟通了解到,sklearn的库函数中有OneHotEncoder,使用方法如下:

import numpy as np
from sklearn.preprocessing import OneHotEncoder

# 示例数据
data = np.array([
    ['A', 'X', 'P'],
    ['B', 'Y', 'Q'],
    ['A', 'X', 'R'],
    ['B', 'Z', 'Q']
])

# 创建 OneHotEncoder 实例
encoder = OneHotEncoder(sparse_output=False)

# 对数据进行one-hot编码
encoded_data = encoder.fit_transform(data)

print("原始数据:")
print(data)
print("\n编码后的数据:")
print(encoded_data)

执行结果如下:

分离特征和标签
def split_data(data):
    """
    将数据分割为标签和数据。
    参数:
        data (list): 数据行的列表,第一个元素是标签。
    返回:
        numpy.ndarray: 标签数组。
        numpy.ndarray: 连接元素后的数据数组。
    """

    # 去除每个元素的前后空格
    data = [[col.strip() for col in row] for row in data]

    # 分离数据和标签
    n_label = np.array([row[-1] for row in data])
    n_data = np.array([row[:-1] for row in data])

    return n_label, n_data


csv_data = read_file('./成人收入预测数据集.csv')
label, data = split_data(csv_data)
label, data

运行结果:

处理特征列

这部分的处理较为麻烦,整体思路是这样:

1、实现一个函数,传入三个参数:分别是离散量列的序号、连续量列序号和丢弃列序号

2、函数根据传入的列序号,分别进行如下处理:

  • 如果是连续量列,取出对应的列,不用做处理;

  • 如果是离散量列,取出对应的列,使用OneHotEncoder进行编码

  • 如果是丢弃列,则在矩阵中删除对应的列

    最后,在去除丢弃列之后,将连续量列和离散量列按照列方向堆叠为一个新的矩阵

import numpy as np
from sklearn.preprocessing import OneHotEncoder, StandardScaler

def vectorize_data_with_sklearn(data, onehot_cols, continuous_cols, exclude_cols=None):
    """
    使用scikit-learn将给定的NumPy数组中的数据进行one-hot编码和标准化处理。

    参数:
    data (np.ndarray): 输入的NumPy数组
    onehot_cols (list): 需要进行one-hot编码的列索引
    continuous_cols (list): 不需要one-hot编码的连续量列索引
    exclude_cols (list, optional): 需要排除的列索引

    返回:
    np.ndarray: 经过one-hot编码和标准化的向量化数据
    """
    # 排除不需要处理的列
    if exclude_cols:
        data = np.delete(data, exclude_cols, axis=1)
        onehot_cols = [col - sum(col > exc for exc in exclude_cols) for col in onehot_cols if col not in exclude_cols]
        # 解释过程:
        # 1. 遍历 onehot_cols 中的每个索引 col
        # 2. 检查 col 是否在 exclude_cols 中
        #    - 如果在,计算 col 在 exclude_cols 中的位置 sum(col > exc for exc in exclude_cols)
        #    - 并从 col 中减去这个值,更新 col 的索引
        #    - 例如: col = 1, 在 exclude_cols 中的位置为 0, 则更新后 col = 1 - 0 = 1
        #    - ⭐ 这样可以确保 onehot_cols 中的索引能正确对应到数据的列
        # 3. 如果 col 不在 exclude_cols 中,则保留原始索引

        continuous_cols = [col - sum(col > exc for exc in exclude_cols) for col in continuous_cols if col not in exclude_cols]
        # 解释过程:
        # 1. 遍历 continuous_cols 中的每个索引 col
        # 2. 检查 col 是否在 exclude_cols 中
        #    - 如果在,计算 col 在 exclude_cols 中的位置 sum(col > exc for exc in exclude_cols)
        #    - 并从 col 中减去这个值,更新 col 的索引
        #    - 例如: col = 2, 在 exclude_cols 中的位置为 0, 则更新后 col = 2 - 0 = 2
        #    - ⭐ 这样可以确保 continuous_cols 中的索引能正确对应到数据的列
        # 3. 如果 col 不在 exclude_cols 中,则保留原始索引

    else:
        onehot_cols = onehot_cols[:]
        continuous_cols = continuous_cols[:]

    # 对离散量列进行one-hot编码
    onehot_encoder = OneHotEncoder(sparse_output=False)
    one_hot_data = onehot_encoder.fit_transform(data[:, onehot_cols])

    # 对连续量列进行标准化
    scaler = StandardScaler()
    continuous_data = scaler.fit_transform(data[:, continuous_cols])

    # 将one-hot编码结果和标准化后的连续量列拼接起来
    final_data = np.hstack((one_hot_data, continuous_data))

    return final_data
  
  
csv_data = read_file('./成人收入预测数据集.csv')
label, data = split_data(csv_data)

# 对数据进行切分
onehot_cols = [1, 3, 5, 6, 7, 8, 9, 13]
continuous_cols = [0, 2, 4, 10, 11, 12]

# 排除列增加最后一列
exclude_cols = [2, 13]
vectorized_data = vectorize_data_with_sklearn(data, onehot_cols, continuous_cols, exclude_cols)
vectorized_data

运行结果:

处理标签列

因为标签列的内容只有两种情况:‘<=50K’ 和’>50K’,所以只需要将这一列中’<=50K’替换为0,'>50K’替换为1即可。

import numpy as np
from sklearn.preprocessing import LabelBinarizer

def binarize_labels(labels):
    """
    将标签二值化。
    参数:
        labels (numpy.ndarray): 原始标签数组。
    返回:
        numpy.ndarray: 二值化后的标签数组。
    """
    lb = LabelBinarizer()
    binarized_labels = lb.fit_transform(labels)
    return binarized_labels

vectorized_label = binarize_labels(label)
vectorized_label
试验算法

为了能够将整个流程跑通,我们仍然选择决策树算法跑通流程。


from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

def decision_tree(label, data, test_size=0.2):
    """
    决策树模型的训练和评估。
    参数:
        train_data_file_path (str): 向量化数据的文件路径。
        test_size (float): 测试集的比例,默认为0.2。
    """
    print('开始加载训练数据...')

    # 训练集和测试集切分
    X_train, X_test, y_train, y_test = train_test_split(data, label, test_size=test_size)

    print('开始训练决策树模型...')
    # 数据预测
    clf = DecisionTreeClassifier()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)

    # 评估
    print('开始决策树预测...')
    accuracy = np.mean(y_pred == y_test)
    print(f'预测准确率:{accuracy}')

# 读取文件
listdata = read_file('./成人收入预测数据集.csv')
# 对数据进行切分
label, data = split_data(listdata)

# 对数据进行切分
onehot_cols = [1, 3, 5, 6, 7, 8, 9, 13]
continuous_cols = [0, 2, 4, 10, 11, 12]

# 排除列增加最后一列
exclude_cols = [2, 13]    
vectorized_data = vectorize_data_with_sklearn(data, onehot_cols, continuous_cols, exclude_cols)
vectorized_label = binarize_labels(label)

decision_tree(vectorized_label, vectorized_data)

运行结果:

工程优化

通过以上的工作,整体构建数据集→训练模型→预测模型流程已经跑通,接下来进行代码重构优化。

1、将整体代码使用面向对象封装为类实现

2、在模型预测部分加入KNN、贝叶斯、线性回归、随机森林、SVC向量机的方式

3、将预测结果使用matplotlib绘制出来

4、给关键代码处增加带有时间戳的日志

以上工作就交给GPT来完成了,最后重构的代码请见Github仓库

遴选算法

通过运行上述工程优化后的代码,执行结果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1813081.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【UML用户指南】-14-对高级结构建模-实例

目录 1、实例的组成结构 1.1、类型 1.2、名称 1.3、操作 1.4、状态 1.5、其他特征 1.5.1、主动对象 1.5.2、链 1.5.3、静态属性 1.6、标准元素 实例是抽象的具体表现&#xff0c;可以对它施加一组操作&#xff0c;而且它可能有一组状态&#xff0c;来存储操作的结果。…

leetcode-04-[24]两两交换链表中的节点[19]删除链表的倒数第N个节点[160]相交链表[142]环形链表II

一、[24]两两交换链表中的节点 重点&#xff1a;暂存节点 class Solution {public ListNode swapPairs(ListNode head) {ListNode dummyHeadnew ListNode(-1);dummyHead.nexthead;ListNode predummyHead;//重点&#xff1a;存节点while(pre.next!null&&pre.next.next…

AI智能体的分级

技术的分级 人们往往通过对一个复杂的技术进行分级&#xff0c;明确性能、适用范围和价值&#xff0c;方便比较、选择和管理&#xff0c;提高使用效率&#xff0c;促进资源合理分配和技术改进和标准化。 比如&#xff0c;国际汽车工程师学会&#xff08;SAE&#xff09;定义了自…

CANOpen转PROFINET网关连接低压伺服系统

在现代工业自动化领域&#xff0c;随着技术的不断进步&#xff0c;各种总线通讯协议之间的转换和互操作性变得越来越重要。CANOpen和PROFINET作为两种广泛应用的通讯协议&#xff0c;各自具有独特的优势和应用场景。然而&#xff0c;在实际应用中&#xff0c;往往需要将CANOpen…

python使用wkhtmltopdf将html字符串保存pdf,解决出现方框的问题

出现的问题: 解决办法: <html> <head><meta charset="UTF-8"/> </head> <style> * {font-family: Arial,SimSun !important; } </style> </html>在html字符串前面加上上面代码,意思是设置字体编码和样式 html示例:…

vue2前置路由守卫中使用this.$store.state报错解决

1、问题描述&#xff1a;在前置路由守卫逻辑中&#xff0c;要更改vuex中的store的state状态&#xff0c;使用常规的this.$store.state报错 2、问题原因&#xff1a; 在vue2是vueRouter前置路由守卫中&#xff0c;this关键字并不会指向vue实例&#xff0c;因此不能使用this.$st…

如何优雅的实现Excel导入通用处理流程

目录 1.业务背景2.业务导入流程3.流程优化3.1 模板模式3.1.1 导入处理器接口ImportProcessor3.1.2 抽象父类 AbstractImportProcessor3.1.3 子类实现 ImportDemoProcessor 3.2 工厂模式3.2.1 标识子类的枚举ImportTypeEnum3.2.2 工厂类ProcessorHolder3.2.3 工厂类的调用 4. 特…

纹理贴图必须要输入顶点坐标或纹理坐标吗

最近知识星球的一位同学,面试时被问到:纹理贴图必须要输入顶点坐标或纹理坐标吗? 他一下子被这个问题问蒙了,虽然他知道正确答案是否定的,但是说不上来理由。 这个就引出了文本提到的全屏三角形,它不需要顶点缓冲区,而是利用顶点着色器直接生成所需的顶点坐标和纹理坐标…

【CTS】android CTS测试

android CTS测试 1.硬件准备2. 软件准备3. 下载 CTS3.1 cts3.2 解压 CTS 包&#xff1a; 4 配置adb fastboot5 检查 Java 版本6 安装aapt26.1 下载并安装 Android SDK6.2 找到 aapt2 工具6.3 配置环境变量 7. 准备测试设备8. 运行 CTS 测试8.1 启动 CTS&#xff1a; 9. 查看测试…

DDD架构和微服务初步实现

本次记录的是微服务的初步认识和DDD架构的初步实现和思路&#xff0c;在之前的发布里&#xff0c;对Javaweb进行了一次小总结&#xff0c;还有一些东西&#xff0c;不去详细理解说明了&#xff0c;下面开始我对微服务的理解。 什么是微服务&#xff1f; 在刚刚开始学习的时候…

【让AI写高考AI话题作文】看各大模型的回答

文章目录 命题chatGPT问题的消失&#xff0c;思考的萎缩 通义千问标题&#xff1a;在信息洪流中寻找智慧之光 文心一言探寻未知&#xff0c;拥抱无限的问题 命题 阅读下面的材料&#xff0c;根据要求写作。&#xff08;60分&#xff09; 随着互联网的普及、人工智能的应用&am…

快速锁定Bug!掌握Wireshark等抓包技术,提升测试效率

前言 相信做了测试一段时间的小伙伴都会开始意识到抓包对于测试的重要性&#xff0c;它涉及到功能测试、性能测试、自动化测试、安全测试和数据库测试等等。可以说我们要想做好测试就必须和抓包打交道&#xff0c;脱离抓包的测试是不合格的。人们都说黑客利用Wireshark等抓包工…

未来校园的新质生产力:南京江北新区浦口外国语学校校园网升级改造的启示

作者:南京江北新区浦口外国语学校 校长助理 杨美玲 导语:在南京江北新区(第十三个国家级新区),浦口外国语学校,这所拥有77605平方米宽阔校园、169个班级、7335名学生和511位专任教师的九年一贯制公办外语特色学校,正以前所未有的活力和智慧,迎接信息化时代的挑战。作为学校信息…

【JMeter接口测试工具】第二节.JMeter基本功能介绍(下)【进阶篇】

文章目录 前言八、Jmeter常用逻辑控制器 8.1 如果&#xff08;if&#xff09;控制器 8.2 循环控制器 8.3 ForEach控制器九、Jmeter关联 9.1 正则表达式提取器 9.2 xpath提取器 9.3 JSON提取器十、跨越线程组传值 10.1 高并发 10.2 高频…

1996-2023年各省农林牧渔总产值数据(无缺失)

1996-2023年各省农林牧渔总产值数据&#xff08;无缺失&#xff09; 1、 时间&#xff1a;1996-2023年 2、 来源&#xff1a;国家统计局、统计年鉴 3、 指标&#xff1a;农林牧渔总产值 4、 范围&#xff1a;31省 5、 缺失情况&#xff1a;无缺失 6、 指标解释&…

韩顺平0基础学java——第20天

p407-429 接口 一个类可以实现多个接口&#xff08;电脑上可以有很多插口&#xff09; class computer IB&#xff0c;IC{} 接口中的属性只能是final&#xff0c;并且是public static final 接口不能继承其他类&#xff0c;但是可以继承多个别的接口 interface ID extends I…

【PX4-AutoPilot教程-TIPS】离线安装Flight Review PX4日志分析工具

离线安装Flight Review PX4日志分析工具 安装方法 安装方法 使用Flight Review在线分析日志&#xff0c;有时会因为网络原因无法使用。 使用离线安装的方式使用Flight Review&#xff0c;可以在无需网络的情况下使用Flight Review网页。 安装环境依赖。 sudo apt-get insta…

Rust基础学习-标准库

栈和堆是我们Rust代码在运行时可以使用的内存部分。Rust是一种内存安全的编程语言。为了确保Rust是内存安全的&#xff0c;它引入了所有权、引用和借用等概念。要理解这些概念&#xff0c;我们必须首先了解如何在栈和堆中分配和释放内存。 栈 栈可以被看作一堆书。当我们添加更…

数据库错误[ERR] 1071 - Specified key was too long; max key length is 1000 bytes

环境&#xff1a;phpstudy的mysql8 索引长度问题&#xff1a; 试了很多解决办法&#xff0c;例如需改配置&#xff1a; set global innodb_large_prefixON; set global innodb_file_formatBARRACUDA; 试了还是有问题&#xff0c;直接启动不了了。因为mysql8取消了这个配置。…

Linux操作系统学习:day02

内容来自&#xff1a;Linux介绍 视频推荐&#xff1a;[Linux基础入门教程-linux命令-vim-gcc/g -动态库/静态库 -makefile-gdb调试]( day02 5、Linux目录结构 操作系统文件结构的开始&#xff0c;只有一个单独的顶级目录结构&#xff0c;叫做根目录。所有一切都从“根”开始…