【PyTorch】第六节:乳腺癌的预测(二分类问题)

news2025/4/27 5:06:56

作者🕵️‍♂️:让机器理解语言か

专栏🎇:PyTorch

描述🎨:PyTorch 是一个基于 Torch 的 Python 开源机器学习库。

寄语💓:🐾没有白走的路,每一步都算数!🐾 

介绍💬

        上一个实验我们讲解了线性问题的求解步骤,本实验我们以乳腺癌的预测为实例,详细的阐述如何利用 PyTorch 求解一个非线性问题

知识点

  • 数据集的标准化
  • 数据集的划分
  • Sigmoid 函数
  • 乳腺癌的预测

数据集的预处理

数据集的加载

        首先,让我们来加载数据集合。这里我们使用 pandas 对数据集合进行加载

import pandas as pd
df = pd.read_csv(  
    'https://labfile.oss.aliyuncs.com/courses/2534/breast_cancer.csv', index_col=False) # index_col 指定某个列为索引。 
df 
mean radiusmean texturemean perimetermean areamean smoothnessmean compactnessmean concavitymean concave pointsmean symmetrymean fractal dimension...worst textureworst perimeterworst areaworst smoothnessworst compactnessworst concavityworst concave pointsworst symmetryworst fractal dimensiontarget
017.9910.38122.801001.00.118400.277600.300100.147100.24190.07871...17.33184.602019.00.162200.665600.71190.26540.46010.118900
120.5717.77132.901326.00.084740.078640.086900.070170.18120.05667...23.41158.801956.00.123800.186600.24160.18600.27500.089020
219.6921.25130.001203.00.109600.159900.197400.127900.20690.05999...25.53152.501709.00.144400.424500.45040.24300.36130.087580
311.4220.3877.58386.10.142500.283900.241400.105200.25970.09744...26.5098.87567.70.209800.866300.68690.25750.66380.173000
420.2914.34135.101297.00.100300.132800.198000.104300.18090.05883...16.67152.201575.00.137400.205000.40000.16250.23640.076780
..................................................................
56421.5622.39142.001479.00.111000.115900.243900.138900.17260.05623...26.40166.102027.00.141000.211300.41070.22160.20600.071150
56520.1328.25131.201261.00.097800.103400.144000.097910.17520.05533...38.25155.001731.00.116600.192200.32150.16280.25720.066370
56616.6028.08108.30858.10.084550.102300.092510.053020.15900.05648...34.12126.701124.00.113900.309400.34030.14180.22180.078200
56720.6029.33140.101265.00.117800.277000.351400.152000.23970.07016...39.42184.601821.00.165000.868100.93870.26500.40870.124000
5687.7624.5447.92181.00.052630.043620.000000.000000.15870.05884...30.3759.16268.60.089960.064440.00000.00000.28710.070391

569 rows × 31 columns

         可以看到该数据集合一共有 569 条数据,每条数据有 30 个和乳腺癌相关的病变特征,最后一列是该患者是否患有乳腺癌的诊断结果。其中 0 表示没有患有乳腺癌,1 表示患有乳腺癌。

        我们可以利用 pandas 中的切片,先将上表中的特征标签分开

X = df[df.columns[0:-1]].values # 取出30个特征的值
# df.columns[0:-1]取出第0列到倒数第二列的指标名
# df[df.columns[0:-1]].values取出这些指标名所在列的值
y = df[df.columns[-1]].values   # 取出1个标签的值
X.shape, y.shape
# ((569, 30), (569,))

        可以看到共有 569 条数据,每条数据有 30 个特征和 1 个标签。

数据集的划分和标准化

        为了能够评价模型的好坏,这里我们利用 sklearn.model_selection 函数,将原数据按比例随机分为训练数据集和测试数据集,如下:

from sklearn.model_selection import train_test_split
# 按照 0.8 和 0.2 的比例随机划分数据集合
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=1234) 
# test_size=0.2:测试集为20%;random_state=1234:设置了固定值,每次分割结果一样,不会数据洗牌
X_train.shape, y_train.shape, X_test.shape, y_test.shape
# ((455, 30), (455,), (114, 30), (114,))

        为了加快模型的收敛速度,一般我们都需要对原始数据进行标准化处理,将所有的数据按照比例缩放到一定范围内。这里我们可以使用 sklearn.preprocessing 来对数据集合进行标准化。

from sklearn.preprocessing import StandardScaler
sc = StandardScaler()  # 对特征进行标准化,标签不要标准化,因为标签只有 0 和 1
X_train = sc.fit_transform(X_train)  # fit_transform对训练集进行标准化
X_test = sc.transform(X_test)        # transform对测试集进行标准化
X_train

 最后,为了将数据放入 PyTorch 定义的模型之中,我们必须将所有的数据转为 张量类型:

import torch
import numpy as np
# 将 NumPy 类型的变量转为 Tensor
X_train = torch.from_numpy(X_train.astype(np.float32))
X_test = torch.from_numpy(X_test.astype(np.float32))
y_train = torch.from_numpy(y_train.astype(np.float32))
y_test = torch.from_numpy(y_test.astype(np.float32))

# 将标签也转为 2 维,否则放入模型之中训练时,可能出错(因为标签原本只有一列,每一行就一个数,现在要变为二维(标签的第二维为1),和x匹配)
y_train = y_train.view(y_train.shape[0], 1)
y_test = y_test.view(y_test.shape[0], 1)
X_train.size(), y_train.size()
# (torch.Size([455, 30]), torch.Size([455, 1]))

乳腺癌的预测

模型的定义

        在处理完数据后,接下来,我们就需要建立相应的模型,用于乳腺癌的预测了。

        线性函数是一条没有上界和下界的直线,即线性函数预测出来的值可以很大如 112321442,也可以很小如 -1231242412。而本实验的数据标签只有 0(患病) 或 1(不患病),因此用线性函数来拟合乳腺癌的数据点是不合理的。

        我们需要找到输出始终为 0-1 之间的函数模型。如果拥有这样的函数模型,那么将任意 x 放入该模型中,都会输出一个 0-1 之间的值。这个值我们可以看做是患有乳腺癌的概率。如果这个概率值小于 0.5 则表示没有患乳腺癌。如果这个概率值大于 0.5 则表示患有乳腺癌。

        逻辑回归函数 Sigmoid 就是这样一种函数,该函数又叫做激活函数,公式如下:

\sigma = \frac{1}{1+e^{-z}}

该函数的几何形式如下所示:

        从图中我们可以看出,该函数就是一个上下界分别为 1 和 0 的有界非线性函数。我们可以让通过了线性函数的输出再通过一次上面的激活函数,进而得到 0-1 之间的结果。

        综上,乳腺癌的预测模型如下:

import torch.nn as nn
# 我们的模型是一个线性函数+激活函数的非线性模型
# modle(x) = sigmoid(w*x+b)


class Model(nn.Module):
    def __init__(self, n_input_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1) # 参数:输入神经元的个数和输出神经元的个数

    def forward(self, x):
        # torch 中已经定义了 sigmoid 函数模型
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred


# 获得样本量和特征数
n_samples, n_features = X.shape
# 模型的初始化
model = Model(n_features)
model

        至此,我们就得到了一个乳腺癌的初始模型。由于最后通过了一层逻辑回归函数,无论输入的值为多少,模型的输出都必定属于 0-1 之间。

损失函数和优化器

        接下来的步骤和上个实验中的步骤类似。

        首先,让我们来定义一下损失函数,由于我们的标签只有 0 和 1,因此这里使用二元交叉熵损失nn.BCELoss()计算真实值和预测值之间的距离了。该损失函数的公式如下:

L = -\sum_{i=1}^N y^i log \hat{y}^i + (1-y^i)log(1-\hat{y}^i)

        当然,我们不必手写上面的损失函数, 直接使用 nn.BCELoss() 即可:

# 损失和优化器的定义
# 迭代次数
num_epochs = 100
# 学习率
learning_rate = 0.01
# 二元交叉熵损失
criterion = nn.BCELoss()
# SGD 优化器
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
criterion, optimizer

  

模型的训练

定义完损失函数和优化器后,接下来的模型训练步骤就是固定的了,如下:

  • 通过模型的正向传播,输出预测结果
  • 通过预测结果和真实标签计算损失
  • 通过后向传播,获得梯度
  • 通过梯度更新模型的权重
  • 进行梯度的清空
  • 循环上面操作,直到损失较小为止。

让我们用代码完成上面的步骤:

for epoch in range(num_epochs):
    y_pred = model(X_train)
    loss = criterion(y_pred, y_train)
    # 后向传播、梯度更新、梯度清空
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if (epoch+1) % 10 == 0:
        print(f'epoch: {epoch+1}, loss = {loss.item():.4f}')
print("模型训练完毕!!")

  

        综上,我们训练好了一个乳腺癌的预测模型。我们可以尝试对任意一条数据进行预测:

index = np.random.randint(0, len(X_test)) # 生成随机整数
y_predicted = model(X_test[index]) # 向模型中传入x,模型返回y
# 小于 0.5 则输出 0 ,大于0.5 则输出 1
y_predicted_cls = y_predicted.round()   # round()  四舍五入取整

# 将结果转为 numpy类型
real = y_test[index].detach().numpy()[0]      # 真实值
# real = y_test[index].detach():通过.detach() “分离”得到的的变量会和原来的变量共用同样的数据,y_test[index]发生了变化,原来的张量也会发生变化,而且新分离得到的张量real是不可求导的。
predict = y_predicted_cls.detach().numpy()[0] # 预测值
print("第 {} 条测试数据的真实结果为 {} ,预测结果为 {} "
      .format(index, real, predict))
# 第 82 条测试数据的真实结果为 1.0 ,预测结果为 1.0 

        由于模型准确率不是 100%,因此,上面的预测结果和真实结果也可能会不相同。但是,你多运行几次上面代码,必定会出现预测结果和真实结果相同的情况。

        那么,我们训练出来的模型准确率到底是多少呢?

with torch.no_grad():
    y_predicted = model(X_test)
    y_predicted_cls = y_predicted.round()
    acc = y_predicted_cls.eq(y_test).sum().numpy() / float(y_test.shape[0]) 
    #  y_predicted_cls.eq(y_test).sum() # 测试值与预测值相同的样本总数
    print(f'accuracy: {acc.item():.4f}')
# accuracy: 0.9123

        我们利用测试数据,计算出了整个模型的预测准确率大概在 90% 左右,证明我们的模型可以很好地进行乳腺癌的诊断预测。

实验总结

        本实验以乳腺癌的预测为例,引入了激活函数 sigmoid 的概念。建立了一个简单的非线性模型用于诊断患者是否患有乳腺癌。其实,本实验建立的一个线性函数+激活函数的模型就是一个简单的神经网络模型。全连接神经网络的实质其实就是无数个线性函数和非线性网络组成的集合。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/434212.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【SCI电气】考虑不同充电需求的电动汽车有序充电调度方法(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

【Python】【进阶篇】二十、Python爬虫实现Cookie模拟登录

目录 二十、Python爬虫实现Cookie模拟登录20.1 注册登录20.2 分析网页结构20.3 编写完整程序 二十、Python爬虫实现Cookie模拟登录 在使用爬虫采集数据的规程中,我们会遇到许多不同类型的网站,比如一些网站需要用户登录后才允许查看相关内容&#xff0c…

【获奖案例巡展】信创先锋之星——甘肃省住房和城乡建设厅住建数据大脑

为表彰使用大数据、人工智能等基础软件为企业、行业或世界做出杰出贡献和巨大创新的标杆项目,星环科技自2021年推出了“新科技 星力量” 星环科技科技实践案例评选活动,旨在为各行业提供更多的优秀产品案例,彰显技术改变世界的力量&#xff0…

推进数字化转型进程,AntDB数据库协同神州云动共促新发展

当今,数字化转型已成为企业发展的必由之路。随着新技术的不断发展和市场的日益竞争,越来越多的企业开始意识到数字化转型的重要性,在帮助企业数字化转型过程中,高效的客户关系管理软件和具有灵活性、可伸缩的数字底座,…

关于于pyreadline模块的py3k_compat.py的函数collections.Callable兼容性问题

上图警告的官网链接地址 上图警告信息是一个警告信息,提醒你在代码中使用了即将被弃用的函数或配置项,建议及时修改以避免在将来的版本中出现不兼容的情况。具体解释如下: 这段段警告信息来自于pyreadline模块的py3k_compat.py文件,提示你使用了collections模块中即将被弃用…

code=45, title=禁止登录, message=登录失败,建议升级最新版本后重试,或通过问题反馈与我们联系。

如果你是采用 java 开发的,你可以参考本文章,java 和 kotlin 都是可以相互转换的。 在解决之前,先说明环境: JDK版本:java version "17.0.3.1" 【Oracle JDK】 Kotlin版本:1.8.20 采取simbot核心包开发&am…

PLATO-2: Towards Building an Open-Domain Chatbot via Curriculum Learning论文学习

一、概述 Motivation:直接提升PLATO的size训练不work Methods: 通过curriculum learning技术来构建一个高质量的开放领域机器人第一阶段:coarse-gained generation model:再简单的one-to-one框架下学习粗力度的回复生成模型第二…

【Micropython】ESP8266驱动mpu6050读取数据

【Micropython】ESP8266驱动mpu6050读取数据 📌相关篇《【MicroPython ESP32】ssd1306驱动0.96“I2C屏幕mpu6050图形控制》 ✨本案例基于Thonny平台开发。✨ 🔖esp8266固件版本:MicroPython v1.19.1 on 2022-06-18 📍本篇需要使…

2023 年打破认知,这个开源 API 管理工具你应该知道

关于 API 管理工具,如今的市场已经把用户教育的差不多了,毫不夸张地说,如果我随机抽取一位幸运读者,他都能给我罗列出一二三四款大家耳熟能详的工具,但我今天还是要推荐这一款我上手后,亲测觉得不错的开源 …

深入浅出OpenGL三维渲染管线

1 前言 在计算机图形学中,渲染是根据模型描述在显示器上生成图像的过程。3D图形渲染管线输入根据图元顶点(如三角形、点、线和四边形)对3D模型的描述,并为显示器上的像素生成颜色值。 如下图所示的是3D图形渲染管线的流程。 3D图形渲染管线主要包含以…

JVM知识

类加载机制 虚拟机把class文件加载到内存,并对数据进行校验,转换解析和初始化,形成虚拟机可以直接使用的Java类型,即java.lang.class 装载(Load) ClassFile -> 字节流 ->类加载器 查找和导入class文件 1:通…

解决方案|以大数据为抓手,打造粮食安全智慧监管平台

食为政首,粮安天下,粮食问题一直深受总书记记挂,总书记多次提到:“中国十三亿多人口,吃饭主要靠自己,不能靠外面来解决。” 近年来粮食安全事件频发,中央纪委国家监委在全国开展粮食购销领域腐败…

程序员的那些事儿

作者主页:爱笑的男孩。 持续分享:机器学习、深度学习、python相关内容、日常BUG解决方法及Windows&Linux实践小技巧。 如发现文章有误,麻烦请指出,我会及时去纠正。有其他需要可以私信我或者发我邮箱:zhilong666foxmail.com 目…

编译后的hue 替换cdh默认版本hue步骤

基于hue源码编译的hue 进行替换cdh6.x.x默认带的hue版本,主要解决hue滚动条 拉的时候,一下就到末尾的bug,通过源码编译githu上的hue解决问题 一. 拷贝编译好的hue到cdh目录替换原来hue目录 三.启动hue报错 问题一:没有pip命令 pip list -bash: pip: command not found …

4.17、TCP三次握手

4.17、TCP三次握手 1.TCP三次握手2.TCP通信具体流程①三次握手②服务器客户端进行通信 1.TCP三次握手 TCP 是一种面向连接的单播协议,在发送数据前,通信双方必须在彼此间建立一条连接。所谓的“连接”,其实是客户端和服务器的内存里保存的一…

热点数据监测方法

在日常开发中,我们需要着重注意一种场景-热点数据。他可能是一种请求,每次请求的数据类型都是一样的;可能是同一个数据,比如页面上公用的类型数据;可能是同一个用户大量的请求。他们都有着同一个特点,瞬时爆…

Redis---哨兵服务

一、配置哨兵服务 1、哨兵服务介绍 监视 master 服务器,发现 master 宕机后,将 slave 服务器提升为 master 服务器 主配置文件:sentinel.conf 模板文件:redis-4.0.8/sentinel.conf 哨兵服务:类似于mha的管理节点&#…

ELK日志

思维导图 一、ELK介绍 ELK是Elasticsearch、Logstash、Kibana首字母大写缩写,后续加入了Beats(Beats是负责单一用途数据采集并推送给Logstash或Elasticsearch的轻量级产品),就更名为 Elastic Stack。 Elastic Stack技术栈的功能…

JVM-GC回收机制

目录 1.判定垃圾 1.引用计数 2.可达性分析 2.清理垃圾 1.标记清除 2.复制算法 3.标记整理 4.分代回收 上文讲述的Java运行时内存划分,对于程序计数器,虚拟机栈,本地方法栈来说,生命周期是和线程有关的,随着线程而生,随线程而灭,当方法结束或者线程结束时,它们的内存就自…

细讲const与引用的关系

目录 先了解语言层面的权限 进入正题引用与const权限关系 引用权限的概念 const引用返回值时错误情况 一:返回到临时空间时权限问题 二:临时空间到调用处保存问题 结论:如果不对子函数内部数据修改、那么在父函数变量ret的类型可以设置…