机器学习——6.模型训练案例: 预测儿童神经缺陷分类TD/ADHD

news2025/1/26 15:36:05

案例目的

有一份EXCEL标注数据,如下,训练出合适的模型来预测儿童神经缺陷分类。

参考文章:机器学习——5.案例: 乳腺癌预测-CSDN博客

代码逻辑步骤

  1. 读取数据
  2. 训练集与测试集拆分
  3. 数据标准化
  4. 数据转化为Pytorch张量
  5. label维度转换
  6. 定义模型
  7. 定义损失计算函数
  8. 定义优化器
  9. 定义梯度下降函数
  10. 模型训练(正向传播、计算损失、反向传播、梯度清空)
  11. 模型测试
  12. 精度计算

代码实现

import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


df = pd.read_excel('/Users/guojun/Desktop/Learning/machine_learning/Preprocess_Without_WDE_Channels_Data.xlsx')

X = df[df.columns[0:8]].values
mapping = {"TD":0,"ADHD":1}
Y = df["Class"].replace(mapping)

# 数据集拆分
X_train,X_test,Y_train,Y_test = train_test_split(X,Y,test_size=0.2,random_state=5)
Y_train = Y_train.to_numpy()
Y_test = Y_test.to_numpy()

# 数据标准化
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.fit_transform(X_test)


# 转化为张量
X_train = torch.from_numpy(X_train.astype(np.float32))
X_test = torch.from_numpy(X_test.astype(np.float32))
Y_train = torch.from_numpy(Y_train.astype(np.float32))
Y_test = torch.from_numpy(Y_test.astype(np.float32))

# 真值转为为二维数据
Y_train = Y_train.view(Y_train.shape[0],-1)
Y_test = Y_test.view(Y_test.shape[0],-1)

# 定义模型
class Model(torch.nn.Module):
    def __init__(self,n_input_features):
        super(Model,self).__init__()
        self.linear = torch.nn.Linear(n_input_features,1)
        
    def forward(self,x):
        return torch.sigmoid(self.linear(x))

model = Model(X_train.shape[1])
# 定义损失函数
loss = torch.nn.BCELoss()
# 定义优化器
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)

# 梯度下降函数
def gradient_descent():
    # 预测Y值
    pre_y = model(X_train)
    # 计算损失
    l = loss(pre_y,Y_train)
    # 反向传播
    l.backward()
    # 梯度更新
    optimizer.step()
    # 梯度清空
    optimizer.zero_grad()
    return l,list(model.parameters())

# 模型训练
for i in range(10000):
    l,p = gradient_descent()
    print(l,p)

# 模型测试
mapping = {0:"TD",1:"ADHD"}
index = np.random.randint(0,X_test.shape[0])
pre_y = model(X_test[index])
pre_y = mapping[int(pre_y.round().item())]
gt_y = mapping[int(Y_test[index].item())]
print(pre_y,gt_y)


# 计算模型准确率
pres_y = model(X_test).round()
result = np.where(pres_y==Y_test,1,0)
ac = np.sum(result)/result.size
print(ac)

 即使调整参数后,损失在0.68左右就不会再下降了。

最终的准确率只有54%-60%,我会在后面的笔记中使用深度神经网络来重新训练,提升模型精度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1662190.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

手写一个SPI FLASH 读写擦除控制器(未完)

文章目录 flash读写数据的特点1. 扇擦除SE(Sector Erase)1.1 flash_se 模块设计1.1.1 信号连接示意图:1.1.2 SE状态机1.1.3 波形图设计:1.1.4 代码 2. 页写PP(Page Program)2.1 flash_pp模块设计2.1.1 信号连接示意图:…

基于STM32F401RET6智能锁项目(使用库函数点灯、按键)

点灯硬件原理图 1、首先,我们查看一下原理图,找到相对应的GPIO口 LED_R低电平导通,LED4亮,所以LED_R的GPIO口需要配置一个低电平才能亮; LED_G低电平导通,LED3亮,所以LED_R的GPIO口需要配置一…

[ES] ElasticSearch节点加入集群失败经历分析主节点选举、ES网络配置 [publish_address不是当前机器ip]

背景 三台CentOS 7.6.1虚拟机, 每台虚拟机上启动一个ElasticSearch 7.17.3(下面简称ES)实例 即每台虚拟机上一个ES进程(每台虚拟机上一个ES节点) 情况是: 之前集群是搭建成功的, 但是今天有一个节点一…

Dual Aggregation Transformer for Image Super-Resolution论文总结

题目:Dual Aggregation Transformer(双聚合Transformer) for Image Super-Resolution(图像超分辨) 论文(ICCV):Chen_Dual_Aggregation_Transformer_for_Image_Super-Resolution_ICCV…

「TypeScript」TypeScript入门练手题

前言 TypeScript 越来越火&#xff0c;现在很多前端团队都使用它&#xff0c;因此咱们前端码农要想胜任以后的前端工作&#xff0c;就要更加熟悉它。 入门练手题 interface A {x: number;y: number; }type T Partial<A>;const a: T { x: 0, y: 0 }; const b: T { …

Java集合框架之LinkedHashSet详解

哈喽&#xff0c;各位小伙伴们&#xff0c;你们好呀&#xff0c;我是喵手。运营社区&#xff1a;C站/掘金/腾讯云&#xff1b;欢迎大家常来逛逛 今天我要给大家分享一些自己日常学习到的一些知识点&#xff0c;并以文字的形式跟大家一起交流&#xff0c;互相学习&#xff0c;一…

uniapp、web网页跨站数据交互及通讯

来来来&#xff0c;说说你的创作灵感&#xff01;这就跟吃饭睡觉一样&#xff0c;饿了就找吃的&#xff0c;渴了就倒水张口灌。 最近一个多月实在是忙的没再更新日志&#xff0c;好多粉丝私信说之前的创作于他们而言非常有用&#xff01;受益菲浅&#xff0c;这里非常感谢粉丝…

分布式与集群的区别

先说区别&#xff1a; 分布式是并联工作的&#xff0c;集群是串联工作的。 分布式中的每一个节点都可以做集群。而集群并不一定就是分布式的。 集群举例&#xff1a;比如新浪网&#xff0c;访问的人很多&#xff0c;他可以做一个集群&#xff0c;前面放一个相应的服务器&…

MySQL变量的四则运算以及取模运算

1、定义多个变量在一条语句中&#xff0c;需要使用,作为分隔符 除法默认保留4位有效数字 2、浮点数运算&#xff1a; 除法默认保留4位有效数字

《这就是ChatGPT》读书笔记

书名&#xff1a;这就是ChatGPT 作者&#xff1a;[美] 斯蒂芬沃尔弗拉姆&#xff08;Stephen Wolfram&#xff09; ChatGPT在做什么&#xff1f; ChatGPT可以生成类似于人类书写的文本&#xff0c;它基本任务是弄清楚如何针对它得到的任何文本产生“合理的延续”。当ChatGPT写…

2024 年最新使用 ntwork 框架搭建企业微信机器人详细教程

NTWORK 概述 基于 PC 企业微信的 api 接口&#xff0c;支持收发文本、群、名片、图片、文件、视频、链接卡片等。 下载安装 ntwork pip install ntwork国内源安装 pip install -i https://pypi.tuna.tsinghua.edu.cn/simple ntwork企业微信版本下载 官方下载&#xff1a;h…

无列名注入

在进行sql注入时&#xff0c;一般都是使用 information_schema 库来获取表名与列名&#xff0c;因此有一种场景是传入参数时会将 information_schema 过滤 在这种情况下&#xff0c;由于 information_schema 无法使用&#xff0c;我们无法获取表名与列名。 表名获取方式 Inn…

使用chatglm3本地部署形成的api给上一篇得到的网页信息text_content做内容提取

使用chatglm3本地部署形成的api给上一篇得到的网页信息text_content做内容提取&#xff0c; chatglm3的api调用见&#xff1a;chatglm3的api调用_启动chatglm3的api服务报错-CSDN博客 import os from openai import OpenAIbase_url "http://localhost:5000/v1/" c…

书生作业:XTuner

作业链接&#xff1a; https://github.com/InternLM/Tutorial/blob/camp2/xtuner/homework.md xtuner: https://github.com/InternLM/xtuner 环境配置 首先&#xff0c;按照xtuner的指令依次完成conda环境安装&#xff0c;以及xtuner库的安装。 然后&#xff0c;我们开始尝试…

基于Vant UI的微信小程序开发(随时更新的写手)

基于Vant UI的微信小程序开发✨ &#xff08;一&#xff09;悬浮浮动1、效果图&#xff1a;只要无脑引用样式就可以了2、页面代码3、js代码4、样式代码 &#xff08;二&#xff09;底部跳转1、效果图&#xff1a;点击我要发布跳转到发布的页面2、js代码3、页面代码4、app.json代…

STM32CubeMX软件使用(超详细)

1、Cube启动页介绍 2、芯片选择页面介绍 3、输入自己的芯片型号&#xff0c;这里以STM32U575RIT6举例 4、芯片配置页码介绍 5、芯片外设配置栏详细说明 6、点击ClockConfiguration进行时钟树的配置&#xff0c;选择时钟树后可以选择自己想使用的时钟源&#xff0c;也可以直接输…

LeetCode题练习与总结:反转链表Ⅱ--92

一、题目描述 给你单链表的头指针 head 和两个整数 left 和 right &#xff0c;其中 left < right 。请你反转从位置 left 到位置 right 的链表节点&#xff0c;返回 反转后的链表 。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5], left 2, right 4 输出&#…

e 值的故事:从复利到自然增长的数学之旅

自然对数函数的底数 e&#xff08;也称为自然常数或欧拉数&#xff09;与 π 一样&#xff0c;是数学中最伟大的常数之一。它大约为 2.718281828&#xff0c;是一个无理数&#xff0c;意味着它的小数部分无限且不重复。 与 π 和 √2 这些由几何发现而来的常数不同&#xff0c…

【高阶数据结构】图 -- 详解

一、图的基本概念 图 是由顶点集合及顶点间的关系组成的一种数据结构&#xff1a;G (V&#xff0c; E)。其中&#xff1a; 顶点集合 V {x | x属于某个数据对象集} 是有穷非空集合&#xff1b; E {(x,y) | x,y属于V} 或者 E {<x, y> | x,y属于V && Path(x, y…

解决常见的Android问题

常见问题&#xff1a; 1、查杀&#xff1a; 查杀一般分为两个方向一种是内存不足的查杀&#xff0c;一种的是因为温度限频查杀&#xff0c;统称为内存查杀&#xff0c;两个问题的分析思路不同 1、内存不足查杀&#xff1a; 主要是因为当用户出现后台运行多个APP或者是相机等…