搭建全连接网络进行分类(糖尿病为例)

news2025/1/22 13:08:09

拿来练手,大神请绕道。

1.网上的代码大多都写在一个函数里,但是其实很多好论文都是把网络,数据训练等分开写的。

2.分开写就是有一个需要注意的事情,就是要import 要用到的文件中的模型或者变量等。

3.全连接的回归也写了,有空再上传吧。

4.一般都是先写data或者model

import torch
import torch.nn as nn
import torch.nn.functional as F
#nn.func这个里面很多功能其实nn里就有,可以不导入,而且后面新的版本的torch也取消了cc.functional里面的部分函数

#定义网络,需要定义两部分,一部分就是初始化,另一部分就是数据流
class FCNet(nn.Module):
    def __init__(self):
        super(FCNet,self).__init__()
        self.fc1 = nn.Linear(8,16)
        #初始的这个8,要和你的数据的特征数一样才行,后面的数可以随意设置,但是不要太多,容易过拟合
        # self.fc2 = nn.Linear(50,20)
        self.fc3 = nn.Linear(16,2)#二分类,输出2,其实1也可以的
        #最后的就是分类数,因为用的sigmod和交叉熵损失,就不用额外加softmax了,多分类要用softmax
        self.sig = nn.Sigmoid()
        # self.drop = nn.Dropout(0.3)
        #可以把用到的放在这里,也可以用nn.Sequential()放在一起,这样后面的话就可以直接用这个,不用写那么多了
        
        
    def forward(self,x):
        x = self.sig(self.fc1(x))
        # x = self.sig(self.fc2(x))
        x = self.sig(self.fc3(x))
        return x
        #就是x要怎么在网络中走,要写一遍
    
#可以自己输出测试一下看看网络是不是自己想的那样,在真的调用的时候再屏蔽掉
# net= FCNet()
# print(net)

首先看看数据是是啥样,outcome就是有没有糖尿病

其实可以手动把csv分成train和test

import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
#导入pands是为了读数据,当然使用numpy也可以读得,sklearn是为了把训练数据分为训练和验证集


data = pd.read_csv('./train.csv')
#就是把对应的数据哪出来,x代表的是feature上的data,y代表的是label,因为pd可以读到最上面的标签,所以从第2行(i=1)开始读就行
x = data.iloc[1:,:-1]
y = data.iloc[1:,[-1]]
#可以输出看看数据对不对,x中不应该包含labels
# print(x)
# print(y)
#test_size就是划分的比例,后面的是种子,意思是每次运行这个函数时候,0.8就是那些,0.2也还是每次一样,如果想要不一样,只要每次运行这个函数时候换个值就行
x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=0)
#print(x_train,y_test)
# print(x_test,y_test)
#给数据进行归一化,可以用很多方法,我用最简单的归一到-1到1
x_train = x_train.apply(lambda x: (x - x.mean()) / (x.std()))
x_test = x_test.apply(lambda x: (x - x.mean()) / (x.std()))

#写dataset可以用两种方法,第一种就是 每一个数据自己单独处理,第二个就是要自己重写dataset类
#1.
# 可以使用分别的处理,把数据(首先转换为tensor,或者把dataframe.valus拿出来才能转换为tensor)转换为tensor并且数据类型转换为float32,如果测试没有真值,需要单独转换
# x_train = torch.tensor(np.array(x_train),dtype=torch.float32)
# y_train = torch.tensor(np.array(y_train),dtype=torch.float32)
# x_test = torch.tensor(np.array(x_test),dtype=torch.float32)
# y_test = torch.tensor(np.array(x_test),dtype=torch.float32)
# train_dataset = torch.utils.data.TensorDataset(x_train,y_train)
# test_dataset = torch.utils.data.TensorDataset(x_test,y_test)

#2.也可以直接重写dataset

class dataset(Dataset):
    def __init__(self, x, y):
        #把值拿出来或者变为np类型才能转换为tensor
        # self.data = torch.tensor(x.values,dtype=torch.float32)
        # self.labels = torch.tensor(y.values,dtype=torch.float32)
        self.data = torch.tensor(np.array(x),dtype=torch.float32)
        self.labels = torch.tensor(np.array(y),dtype=torch.float32)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        return self.data[idx],self.labels[idx]
        #应该返回的是list类型,不是字典也不是set

BATCH_SIZE = 64


#验证集一般不用shuffle
train_dataset = dataset(x_train,y_train)
test_dataset = dataset(x_test,y_test)
# print(train_dataset)
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
test_lodaer = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False)
# print(train_loader)

然后就可以写train或者test了,其实test和train一样

from Model import FCNet
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import data
#导入要调用的net和data,也可以from data import xxx 这样可以直接用xxx,现在的这个需要用data.xxx

#看自己的设备,最好用gpu来跑
if (torch.cuda.is_available()):
    my_device = torch.device('cuda')
else:
    my_device = torch.device('cpu')
    
    
print(my_device)
#实例化一个net,并且放到gpu上,需要放到gpu上的有inputs,labels,net,loss
net = FCNet().to(my_device)
# print(net)
#定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
#一开始是不需要weight_decay(也就是l2正则化),可以等出现过拟合在用,也可以先用上
optimizer = optim.Adam(net.parameters(),lr=0.001,weight_decay=0.01)

epochs = 600
#定义train,因为一边训练一边验证,所有就把两个loader都放进去了,不过写法很多,也可以不放dataloader,放epoches也可以
def train(dataloader,valloader):
    losses = []
    acces = []
    losses_val = []
    for epoch in range(epochs):
        loss_batch = 0
        for i,data in enumerate(dataloader):
            #需要注意的,这里的inputs和labels和之前定义的dataset相关,需要是list类型才可以
            inputs,labels = data
            #print(data)可以打印出来查看一下
            inputs,labels = inputs.to(my_device),labels.to(my_device)
            optimizer.zero_grad()#每次要梯度清零
            outputs = net(inputs)
            #print(outputs)
            #model的最后一层是sigmod
            #labels的格式需要注意,因为现在是[[1],[0],[1],[1]..]这样得格式,无法放到交叉熵了,需要时[0,1,1,1...]这样得格式才行
            loss = criterion(outputs,labels.squeeze(1).long()).to(my_device)
            #print(labels.squeeze(1).long())
            loss.backward()
            optimizer.step()
            loss_batch += loss.item()
            length = i
        #验证的时候不用反向传播和梯度下降这些
        net.eval()
        count = 0
        right = 0
        loss_batch_val =0
        with torch.no_grad():
            for j,data2 in enumerate(valloader):
                val_inputs,val_labels = data2
                val_inputs,val_labels = val_inputs.to(my_device),val_labels.squeeze(1).long().to(my_device)
                val_outputs = net(val_inputs)
                loss_val = criterion(val_outputs,val_labels)
                #因为net的最后一层是2,所以输出的是2维的【0.6,0.4】这种,但是这个可以直接放到交叉熵中
                #——中放的是概率,pred中放的是预测的类别,算损失还是要用outputs,但是算准确率就是用pred和真实labels相比了
                _,pred = torch.max(val_outputs,1)
                #print(pred)
                right = (pred == val_labels).sum().item()
                count = len(val_labels)
                acc = right/count
                loss_batch_val += loss_val.item()
                length2 = j
            
        if epoch % 10 == 9:
            print('train_epoch:',epoch+1,'train_loss:',loss_batch/length,'val_loss:',loss_batch_val/length2,'acc:',acc)
            losses.append(loss_batch/length)
            acces.append(acc)
            losses_val.append(loss_batch_val/length2)
    #可以画一些曲线,输出一些值
    plt.plot(range(60),losses,color ='blue',label ='train_loss')
    plt.plot(range(60),acces, color ='red',label ='val_acc')
    plt.plot(range(60),losses_val,color ='yellow',label ='val_loss')
    plt.legend()
    plt.show()
    torch.save(net.state_dict(),'./weights_epoch1000.pth')
    #保存参数
    
train(data.train_loader,data.test_lodaer)

最后看一下结果,最后的准确率在85%左右,还可以,毕竟数据不多,也是简单的全连接。

在这个结果之前出现了很多问题,比如波动很大,损失先降后升等问题,找个有问题的图

下面是一些总结:

1.跳跃很大,波动:增大batch_size,减小lr。

2.降低过拟合:

        a.降低模型的复杂程度,但是修改具体的神经元个数,因为这个网络本身就不大,所有没啥用,模型非常大没准会有用。

        b.batchsize增大,lr减小是有效的。

        c.输入数据进行归一化是有用的,归一化之后lr可以调大一点,收敛变快了。

        d.L2正则化是有用的,很有用。dropout应该也有用,但是模型本来就很小,我试了试没啥差别。而且有正则化之后可以加速收敛,lr可以稍微调大一点,较少的epoches也可以收敛了,而已acc也会更高一点,稳定一点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1054745.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ubuntu下源码编译方式安装opencv

基础条件 ubuntu 20.04 opencv 3.4.3 opencv 源码编译的安装步骤 第一步, 首先clone源码 git clone https://github.com/opencv/opencv.git第二步,依赖包,执行下面的命令 sudo apt-get install build-essential sudo apt-get install cmak…

3.物联网射频识别,(高频)RFID应用ISO14443-2协议,(校园卡)Mifare S50卡

一。ISO14443-2协议简介 1.ISO14443协议组成及部分缩略语 (1)14443协议组成(下面的协议简介会详细介绍) 14443-1 物理特性 14443-2 射频功率和信号接口 14443-3 初始化和防冲突 (分为Type A、Type B两种接口&…

c语言系统编程之多进程

程序与进程的区别? 程序是静态的未运行的二进制文件,存储在磁盘中 进程是已经运行的二进制文件,存储在内存中 进程的内存划分图有哪几部分? 堆(存储malloc和calloc出来的空间)、栈(局部变量…

字符串函数(一)

✨博客主页:小钱编程成长记 🎈博客专栏:进阶C语言 字符串函数(一) 0.前言1.求字符串长度的函数1.1 strlen(字符串长度) 2.长度不受限制的字符串函数2.1 strcpy(字符串拷贝&#xff0…

CTF-python爬虫学习笔记

学习链接 【Python爬虫】爆肝两个月!拜托三连了!这绝对是全B站最用心(没有之一)的Python爬虫公开课程,从入门到(不)入狱 ! 。知识 1.1 出现错误 复制红框中的内容去查找 1.2 打印…

七、2023.10.1.Linux(一).7

文章目录 1、 Linux中查看进程运行状态的指令、查看内存使用情况的指令、tar解压文件的参数。2、文件权限怎么修改?3、说说常用的Linux命令?4、说说如何以root权限运行某个程序?5、 说说软链接和硬链接的区别?6、说说静态库和动态…

字符串函数(二)—— 长度受限制的字符串函数

✨博客主页:小钱编程成长记 🎈博客专栏:进阶C语言 🎈相关博文:字符串函数(一) 字符串函数(二)—— 长度受限制的字符串函数 3.长度受限制的字符串函数3.1 strncpy&#x…

594.最长和谐子序列(滑动窗口)

目录 一、题目 二、代码 一、题目 力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 二、代码 class Solution { public:int findLHS(vector<int>& nums) {sort(nums.begin(), nums.end());int left 0, right 0;int MaxLength 0;while…

【网络安全-SQL注入】SQL注入----一篇文章教你access数据库SQL注入以及注入点利用。SQL注入【3】

前言&#xff1a; 本篇文章以凡诺企业网站管理系统为例&#xff0c;讲解了access数据库是如何进行SQL注入的&#xff0c;以及注入点如何利用&#xff0c;如何判断查询字段个数&#xff0c;如果用联合查询爆出数据库数据等&#xff1b; 之前有两篇文章详细介绍了MySQL数据库的…

Qt Creator 预览界面 快捷键

一般来说&#xff0c;我们运行Qt程序所花费的时间是比较长的&#xff0c;那有时我们只改变了界面&#xff0c;那么此时花费如此长的时间去运行程序来观察界面改动的效果是非常浪费时间的行为。 此时我们可以选择预览界面来观察界面改动后的效果&#xff1a;

九、GC收集日志

JVM由浅入深系列一、关于Java性能的误解二、Java性能概述三、了解JVM概述四、探索JVM架构五、垃圾收集基础六、HotSpot中的垃圾收集七、垃圾收集中级八、垃圾收集高级👋GC收集日志 ⚽️1. 认识GC收集日志 垃圾收集日志是一个重要的信息来源,对于与性能相关的一些悬而未决的…

基本的五大排序算法

目录&#xff1a; 一&#xff0c;直接插入算法 二&#xff0c;希尔排序算法 三&#xff0c;选择排序 四&#xff0c;堆排序 五&#xff0c;冒泡排序算法 简介&#xff1a; 排序算法目前是我们最常用的算法之一&#xff0c;据研究表明&#xff0c;目前排序占用计算机CPU的时…

1003 我要通过!

一.问题&#xff1a; “答案正确”是自动判题系统给出的最令人欢喜的回复。本题属于 PAT 的“答案正确”大派送 —— 只要读入的字符串满足下列条件&#xff0c;系统就输出“答案正确”&#xff0c;否则输出“答案错误”。得到“答案正确”的条件是&#xff1a; 字符串中必须仅…

数组和切⽚ - Go语言从入门到实战

数组和切⽚ - Go语言从入门到实战 数组的声明 package main import "fmt" func main() { var a [3]int //声明并初始化为默认零值 a[0] 1 fmt.Println("a:", a) // 输出: a: [1 0 0] b : [3]int{1, 2, 3} //声明同时初始化 fmt.Println("b:…

番外6:下载+安装+配置Linux

#########配置Linux---后续 step08: 点击编辑虚拟机设置&#xff0c;选择下载好的映像文件.iso进行挂载&#xff1b; step09: 点击编辑虚拟机选项&#xff0c;选择UEFI启动模式并点击确定&#xff1b; step10: 点击开启虚拟机&#xff0c;选择Install rhel &#xff1b; 备注&…

架构的未来:微前端与微服务的融合

文章目录 微服务架构简介微前端架构简介微前端与微服务的融合1. 共享服务2. 基于事件的通信3. 统一的身份和认证4. 交付管道的集成 示例&#xff1a;使用微服务和微前端的电子商务平台微服务架构微前端架构融合微服务和微前端 结论 &#x1f389;欢迎来到架构设计专栏~架构的未…

【Linux系统编程】僵尸进程与孤儿进程

文章目录 1. 僵尸进程2. 僵尸进程的危害3. 孤儿进程 1. 僵尸进程 上一篇文章进程的状态中最后我们提出了僵尸状态&#xff1a; 为了方便子进程退出后父进程或操作系统获取该进程的退出结果&#xff0c;Linux进程退出时&#xff0c;进程一般不会立即死亡&#xff0c;而是要维持…

【Spring底层原理】BeanFactory的实现

&#x1f40c;个人主页&#xff1a; &#x1f40c; 叶落闲庭 &#x1f4a8;我的专栏&#xff1a;&#x1f4a8; c语言 数据结构 javaEE 操作系统 Redis 石可破也&#xff0c;而不可夺坚&#xff1b;丹可磨也&#xff0c;而不可夺赤。 容器实现 一、BeanFactory实现的特点1.1 Be…

2023年中国半导体IP行业发展概况及趋势分析:半导体IP的市场空间广阔[图]

半导体指IP指芯片设计中预先没计、验证好的功能模块&#xff0c;处于半导体产业链最上游&#xff0c;为芯片设计厂商提供设计模块。半导体IP按交付方式可分为软核、硬核和固核&#xff1b;按产品类型可分为处理器IP、接口IP、其他物理IP及其他数字IP。 半导体IP分类 资料来源&…

K-Means(上):数据分析 | 数据挖掘 | 十大算法之一

⭐️⭐️⭐️⭐️⭐️欢迎来到我的博客⭐️⭐️⭐️⭐️⭐️ &#x1f434;作者&#xff1a;秋无之地 &#x1f434;简介&#xff1a;CSDN爬虫、后端、大数据领域创作者。目前从事python爬虫、后端和大数据等相关工作&#xff0c;主要擅长领域有&#xff1a;爬虫、后端、大数据…