搭建深度学习框架+nn.Module

news2024/12/1 13:11:23

一、搭建项目框架(YOLO框架的简约版本)

最终成品的项目框架如下图,最终实现的效果,通过自己配置的框架函数,对模型进行定义与参数调配。同时在参数配置的过程中可选择是否进行模型初始化参数的使用。适用于YOLO框架的初认识。了解此框架可更好的认识YOLO框架。
在这里插入图片描述

二、Net模型的搭建-nn_net.py

net框架引用nn.Module父类框架,进行框架的搭建。本框架在文件夹net中。

import torch.nn as nn

class Mynet_model(nn.Module):
    def __init__(self,input_size,output_size):
        super(Mynet_model,self).__init__()

        self.hiden1 = nn.Sequential(nn.Linear(input_size,128),nn.LeakyReLU())
        self.hiden2 = nn.Sequential(nn.Linear(128,256),nn.LeakyReLU())
        self.hiden3 = nn.Sequential(nn.Linear(256,512),nn.LeakyReLU())
        self.hiden4 = nn.Sequential(nn.Linear(512,256),nn.LeakyReLU())
        self.hiden5 = nn.Sequential(nn.Linear(256,128),nn.LeakyReLU())
        self.hiden6 = nn.Sequential(nn.Linear(128,64),nn.LeakyReLU())
        self.out = nn.Sequential(nn.Linear(64,output_size),nn.Softmax())
        self._init_weight()
        
    def forward(self,x):
        x = self.hiden1(x)
        x = self.hiden2(x)
        x = self.hiden3(x)
        x = self.hiden4(x)
        x = self.hiden5(x)
        x = self.hiden6(x)
        x = self.out(x)

        return x
    
    def _init_weight(self):
        # 对模型参数的初始化
        nn.init.kaiming_uniform_(self.hiden1[0].weight,nonlinearity='leaky_relu')
        nn.init.kaiming_uniform_(self.hiden2[0].weight,nonlinearity='leaky_relu')
        nn.init.kaiming_uniform_(self.hiden3[0].weight,nonlinearity='leaky_relu')
        nn.init.kaiming_uniform_(self.hiden4[0].weight,nonlinearity='leaky_relu')
        nn.init.kaiming_uniform_(self.hiden5[0].weight,nonlinearity='leaky_relu')
        nn.init.kaiming_uniform_(self.hiden6[0].weight,nonlinearity='leaky_relu')

三、优化器框架的搭建-optimizer.py

主要用于深度学习框架中,优化器的选择。本框架位于net文件夹中。

# 创建优化器,按照参数名称进行优化器匹配,只能在这个文件中进行模型的配置
import torch.optim as optim

def optimizer_parents(model,name,lr,weight_decay,betas,eps):
    if name == 'SGD':
        optimizer = optim.SGD(model.parameters())
    elif name == 'Adam':
        optimizer = optim.Adam(model.parameters(),lr=lr,weight_decay=weight_decay,betas=betas,eps=eps)
    else:
        # 显示报错信息
        raise NotImplementedError('optimizer {} not implemented'.format(name))  
    return optimizer

if __name__ == '__main__':
    from nn_net import Mynet_model
    model = Mynet_model(input_size=20,output_size=10)
    optimizer = optimizer_parents(model,'SGD')

四、数据加载器-dataloder.py

本框架主要为读取csv文件,将csv文件进行数据处理后保存返回TensorDatasets类型的数据,便于后续进行Dataloder数据切分与调用。用于批量读取数据。本框架位于文件夹data中。

# 创建数据集加载的工具
import torch ,os
from torch.utils.data import TensorDataset,Dataset,DataLoader
import pandas as pd 
from sklearn.preprocessing import StandardScaler
import numpy as np
from sklearn.model_selection import train_test_split
import parameter 

class Mydatasets():
    def __init__(self,path,pt_path,device):
        self.data = pd.read_csv(path)
        self.pt_path = pt_path
        self.device = device
        
    def datasets(self):
        data = np.array(self.data)
        x_ = data[:,:-1]
        y_ = data[:,-1]

        x_train,x_text,y_train,y_test = train_test_split(x_,y_,test_size=0.2,random_state=42,stratify=y_)

        # 数据标准化
        stander = StandardScaler()
        x_train = stander.fit_transform(x_train)
        x_text = stander.transform(x_text)

        # 标准化工具参数保存
        stander_dic = {
            "mean":stander.mean_,
            "std":stander.scale_
        }
        path = f"{self.pt_path}/stander.pth"
        if os.path.exists(path):
            pass
        else:
            torch.save(stander_dic,path)

        device = torch.device(self.device)

        # 将数据转换为tensor
        x_train = torch.tensor(x_train,dtype=torch.float32).to(device)
        x_text = torch.tensor(x_text,dtype=torch.float32).to(device)
        y_train = torch.tensor(y_train,dtype=torch.int64).to(device)
        y_test = torch.tensor(y_test,dtype=torch.int64).to(device)
        
        train = TensorDataset(x_train,y_train)
        test = TensorDataset(x_text,y_test)

        return train,test

五、文件路径创建框架-creat_path.py

本框架主要用于模型训练前,创建文件夹路径,用于保存本次训练过程中的相关训练参数以及模型。在每次训练过程中,在文件夹pt中生成相关exp文件,在exp文件内生成model_weight,优化器参数,其他参数等。

import os 

def creat_path():
    name = 'exp_'
    Root = "./Torch/mobile_pheno/pt"
    i = 1
    for root,dir,list in os.walk(Root):
        if root == Root:
            if not dir:
                path = root+'/'+f"{name}{i}"
                os.makedirs(path)
            else:
                i = len(dir)
                path = root+'/'+f"{name}{i+1}"
                if not os.path.exists(path):
                    os.makedirs(path)
    return path

def read_path():
    name = 'exp_'
    Root = "./Torch/mobile_pheno/pt"
    i = 1
    for root,dir,list in os.walk(Root):
        if root == Root:
            i = len(dir)
            path = root+'/'+f"{name}{i}"

    return path

if __name__ == '__main__':
    creat_path()

六、模型相关参数配置-parameter.py

# 模型的相关参数配置文件
import os 
import creat_path as cp

# 对训练文件模型中相关参数进行配置
def train_parameter():
    root_path = cp.creat_path()
    argument = {
        "device": "cpu",   # 运行的设备
        "optimizer_name": "Adam",  # 优化器的名称
        "learning_rate": 1e-4,   # 优化器相关参数
        "weight_decay": 0.01,
        "betas": (0.9, 0.999),
        "eps": 1e-8,
        "epochs": 50,  # 循环轮次
        "model_weight_file_path": './Torch/mobile_pheno/model/model_weight.yaml',   # 初始模型权重路径
        "root_path": root_path,  
        "optimizer_weight_file_path":'./Torch/mobile_pheno/model/optimizer_weight.yaml',  # 初始优化器权重路径
    }
    return argument

# 测试框架模型相关调优参数
def test_parameter():
    root_path = cp.read_path()
    argument = {
        "device": "cpu",
        "model_weight_file_path": os.path.join(root_path,"model_weight.yaml"),
        "root_path": root_path,
        "optimizer_weight_file_path":os.path.join(root_path,"optimizer_weight.yaml"),
        "other_weight_file_path":os.path.join(root_path,"other_weight.yaml")
    }
    return argument

七、训练框架-train.py

from data.dataloder import Mydatasets
from net.nn_net import Mynet_model
from net.optimizer import optimizer_parents
import torch 
import torch.nn as nn
import os 
import time 
import parameter 
from torch.utils.data import DataLoader

def train(file_path):
    # 参数初始化
    argument = parameter.train_parameter()
    root_path = argument['root_path']
    train,_ = Mydatasets(file_path,root_path,argument['device']).datasets()  
    input_size = train[:][0].shape[1]
    output_size = torch.unique(train[:][1]).shape[0]
    
    # 加载数据加载器
    train = DataLoader(train,batch_size=16,shuffle=True)
    
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device(argument['device'])
    model = Mynet_model(input_size,output_size).to(device)

    # 判定是否进行继承训练
    if os.path.exists(argument['model_weight_file_path']):
        model_weight = torch.load(argument['model_weight_file_path'])
        model.load_state_dict(model_weight["model_weight"])  

    loss_func = nn.CrossEntropyLoss()

    # 配置优化器
    # optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
    optimizer = optimizer_parents(model,argument['optimizer_name'],argument['learning_rate'],argument['weight_decay'],argument['betas'],argument['eps'])
    # 进行optimizer参数优化
    if os.path.exists(argument['optimizer_weight_file_path']):
        optimizer_weight = torch.load(argument['optimizer_weight_file_path'])
        optimizer.load_state_dict(optimizer_weight['optimizer_weight'])

    start_time = time.time()
    epochs=argument['epochs']
    for epoch in range(epochs):
        conut = 0
        loss = 0
        for i,(x,y) in enumerate(train):
            y_pred = model(x)
            loss_value = loss_func(y_pred,y)
            optimizer.zero_grad()
            loss_value.backward()
            optimizer.step()
            conut +=1
            loss += loss_value
        print(f"epoch:{epoch},loss:{loss/conut}")
    end_time = time.time()
    print(f"time:{end_time-start_time}")
    

    # 生成字典对模型参数进行保存
    model_weight_parents = {
        "model_weight":model.state_dict(),
    }
    model_weight_parents_path = os.path.join(root_path,'model_weight.yaml')
    # torch.save(model_weight_parents,'./Torch/mobile_pheno/model/model_weight.yaml')   # 初始化模型保存
    torch.save(model_weight_parents,model_weight_parents_path)  # 在当前训练文件中进行当前训练模型的保存

    optimizer_weight_parents = {
        "optimizer_weight":optimizer.state_dict(),
    }
    optimizer_weight_parents_path = os.path.join(root_path,'optimizer_weight.yaml')
    # torch.save(optimizer_weight_parents,'./Torch/mobile_pheno/model/optimizer_weight.yaml')   # 初始化模型保存
    torch.save(optimizer_weight_parents,optimizer_weight_parents_path)  # 在当前训练文件中进行当前训练模型的保存

    other_parents = {
        "epoch":epochs,
        "input_size" : input_size,
        "output_size" : output_size,
    }
    other_parents_path = os.path.join(root_path,'other_weight.yaml')
    # torch.save(other_parents,'./Torch/mobile_pheno/model/other_weight.yaml')   # 初始化模型保存
    torch.save(other_parents,other_parents_path)  # 在当前训练文件中进行当前训练模型的保存

if __name__ == '__main__':
    path = "./Torch/mobile_pheno/data/手机价格预测.csv"
    train(path)

八,测试框架-test.py

import torch 
import torch.nn as nn
from net.nn_net import Mynet_model
from data.dataloder import Mydatasets
import creat_path
import parameter
from torch.utils.data import DataLoader

def test(path):
    argument = parameter.test_parameter()
    root_path = argument['root_path']
    device = argument['device']
    _,test = Mydatasets(path,root_path,device).datasets()
    
    number = len(test)
    test = DataLoader(test,batch_size=8,shuffle=False)

    # 读取模型中的参数数据
    other_weight_file_path = torch.load(argument['other_weight_file_path'],map_location=device)
    input_size,output_size = other_weight_file_path['input_size'],other_weight_file_path['output_size']

    model_weight_file_path = torch.load(argument['model_weight_file_path'],map_location=device)
    model = Mynet_model(input_size,output_size).to(device)
    model.load_state_dict(model_weight_file_path['model_weight'])

    count = 0
    for x,y in test:
        y_pred = model(x)
        y_pred = torch.argmax(y_pred,dim=1)
        # print(f"预测结果:{y_pred},真实结果:{y}")
        count += (y_pred==y).sum()

    print(f"正确率:{count/number}")
    return count/number

if __name__ == '__main__':
    path = "./Torch/mobile_pheno/data/手机价格预测.csv"
    test(path)

九、推理框架

# 推理验证模型 
import torch,os
import torch.nn as nn
import pandas as pd 
from sklearn.preprocessing import StandardScaler
import parameter
from torch.utils.data import DataLoader,TensorDataset
from net.nn_net import Mynet_model
import numpy as np

def data_sets(path):
    argument = parameter.test_parameter()
    device = argument['device']

    data = pd.read_csv(path)
    data = np.array(data)[:,1:]
    print(data.dtype)
    print(data.shape)
    print('**************************')
    # 导入数据标准化处理工具
    stander_file_path = os.path.join(argument['root_path'],'stander.pth')
    satander_mean_std = (torch.load(stander_file_path))


    satander = StandardScaler()
    satander.mean_ = np.array(satander_mean_std['mean'])
    satander.scale_ = np.array(satander_mean_std['std'])
    data = satander.transform(data)

    data = torch.tensor(data,dtype=torch.float32).to(device)
    return data

def detect(path):
    argument = parameter.test_parameter()
    data = data_sets(path)

    device = argument['device']
    model_weight_file_path = torch.load(argument['model_weight_file_path'],map_location=device)
    other_weight_file_path = torch.load(argument['other_weight_file_path'],map_location=device)

    model = Mynet_model(other_weight_file_path['input_size'],other_weight_file_path['output_size']).to(device)
    model.load_state_dict(model_weight_file_path['model_weight'])

    y_pred = model(data)
    y_pred = torch.argmax(y_pred,dim=1)
    print(y_pred)


if __name__ == "__main__":
    path = './Torch/mobile_pheno/data/detect_test.csv'
    detect(path)

十、main函数-save_best_model.py

# 保存验证效果最好的模型数据参数
import train
import test
import torch

def bset_model():
    list = 0
    path = "./Torch/mobile_pheno/data/手机价格预测.csv"
    train.train(path)  
    result = test.test(path)  
    print(result)
    if result > list:
        model_parents =torch.load('./Torch/mobile_pheno/model/model_weight.yaml')
        torch.save(model_parents,'./Torch/mobile_pheno/model/best_model.yaml')
    
if __name__ == '__main__':
    bset_model()

十一、数据来源

本模型实验数据来源为https://tianchi.aliyun.com/dataset/157241,手机价格预测数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2251123.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

FFmpeg 简介与编译

1. ffmpeg 简介: FFmpeg是一套可以用来记录、转换数字音频、视频,并能将其转化为流的开源计算机程序。采用LGPL或GPL许可证。它提供了录制、转换以及流化音视频的完整解决方案。它包含了非常先进的音频/视频编解码库libavcodec,为了保证高可移…

打latex公式可以练到像手写一样快吗?

这里分享两个Python Latex工具latexify和handcalcs。 latexify生成LaTeX 数学公式 import math import latexify @latexify.with_latex #调用latexify的装饰器 def solve(a, b, c):return (-b + math.sqrt(b**2 - 4*a*c)) / (2*a)solve 更多例子.......

【Linux】磁盘 | 文件系统 | inode

🪐🪐🪐欢迎来到程序员餐厅💫💫💫 主厨:邪王真眼 主厨的主页:Chef‘s blog 所属专栏:青果大战linux 总有光环在陨落,总有新星在闪烁 模电好难啊&#xff…

AntFlow 0.20.0版发布,增加多数据源多租户支持,进一步助力企业信息化,SAAS化

传统老牌工作流引擎比如activiti,flowable或者camunda等虽然功能强大,也被企业广泛采用,然后也存着在诸如学习曲线陡峭,上手难度大,流程设计操作需要专业人员,普通人无从下手等问题。。。引入工作流引擎往往需要企业储…

Scrapy管道设置和数据保存

1.1 介绍部分: 文字提到常用的Web框架有Django和Flask,接下来将学习一个全球范围内流行的爬虫框架Scrapy。 1.2 内容部分: Scrapy的概念、作用和工作流程 Scrapy的入门使用 Scrapy构造并发送请求 Scrapy模拟登陆 Scrapy管道的使用 Scrapy中…

洛谷 B3626 跳跃机器人 C语言 记忆化搜索

题目: https://www.luogu.com.cn/problem/B3626 题目描述 地上有一排格子,共 n 个位置。机器猫站在第一个格子上,需要取第 n 个格子里的东西。 机器猫当然不愿意自己跑过去,所以机器猫从口袋里掏出了一个机器人!这…

docker快速部署gitlab

文章目录 场景部署步骤默认账号密码效果 场景 新增了一台机器, 在初始化本地开发环境,docker快速部署gitlab 部署步骤 编写dockerfile version: 3.7services:gitlab:image: gitlab/gitlab-ce:latestcontainer_name: gitlabrestart: alwayshostname: gitlabenviron…

计算机视觉工程师紧张学习中!

在当今这个日新月异的科技时代,计算机视觉作为人工智能的重要分支,正以前所未有的速度改变着我们的生活和工作方式。为了紧跟时代步伐,提升自我技能,一群怀揣梦想与热情的计算机视觉设计开发工程师们聚集在了本次线下培训活动中。…

RabbitMq死信队列(详解)

死信队列的概念 死信(dead message)简单理解就是因为种种原因,无法被消费的信息,就是死信。 有死信,自然就有死信队列。当消息在⼀个队列中变成死信之后,它能被重新被发送到另⼀个交换器中,这个交换器就是DLX( Dead L…

30分钟学会正则表达式

正则表达式是对字符串操作的一种逻辑公式,就是用事先定义好的一些特定字符、及这些特定字符的组合,组成一个“规则字符串”,这个“规则字符串”用来表达对字符串的一种过滤逻辑。 作用 匹配 查看一个字符串是否符合正则表达式的语法 搜索 正…

IDEA无法创建java8、11项目创建出的pom.xml为空

主要是由于Spring3.X版本不支持JDK8,JDK11,最低支持JDK17 解决的话要不就换成JDK17以上的版本,但是不太现实 另外可以参考以下方式解决 修改spring初始化服务器地址为阿里云的 https://start.aliyun.com/

Unity类银河战士恶魔城学习总结(P149 Screen Fade淡入淡出菜单)

【Unity教程】从0编程制作类银河恶魔城游戏_哔哩哔哩_bilibili 教程源地址:https://www.udemy.com/course/2d-rpg-alexdev/ 本章节实现了进入游戏和死亡之后的淡入淡出动画效果 UI_FadeScreen.cs 1. Animator 组件的引用 (anim) 该脚本通过 Animator 控制 UI 元…

IDEA 解决Python项目import导入报错、引用不到的问题

使用Idea 23.1 专业版编写Python项目时,import 导入爆红,无法引入其他package的代码,现象如: 解决方案:Idea表头打开 File -> Project Settring 解决效果:

[NSSRound#12 Basic]ordinary forensics

解压出来两个文件,一个是镜像文件另一个不知道 先查看镜像文件 vol.py -f /home/kali/Desktop/forensics.raw imageinfo再查看进程,发现有个cmd的程序 vol.py -f /home/kali/Desktop/forensics.raw --profileWin7SP1x64 pslist进行查看,有…

uniapp中父组件数组更新后与页面渲染数组不一致实战记录

简单描述一下业务场景方便理解: 商品设置功能,支持添加多组商品(点击添加按钮进行增加).可以对任意商品进行删除(点击减少按钮对选中的商品设置进行删除). 问题: 正常添加操作后,对已添加的任意商品删除后,控制台打印数组正常.但是与页面显示不一致.已上图为例,选中尾…

【Git系列】利用 Bash 脚本获取 Git 最后一次非合并提交的提交人

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

hadoop环境配置-vm安装+麒麟ubantu

一.VM版本 选择16版本,15版本存在windows蓝屏的情况,也不用设置HV等相关设置 激活下载参考下述博客:https://blog.csdn.net/matrixlzp/article/details/140674802 提前在bois打开SVM设置,不设置无法打开新建的虚拟机 ubantu下载…

基于SpringBoot的电脑配件销售系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏:…

【linux学习指南】详解Linux进程信号保存

文章目录 📝保存信号🌠 信号其他相关常⻅概念🌉在内核中的表⽰ 🌠 sigset_t🌠信号集操作函数🌉sigprocmask🌉sigpending 🚩总结 📝保存信号 🌠 信号其他相关常…

[在线实验]-Redis Docker镜像的下载与部署

镜像下载 dockerredis镜像资源-CSDN文库 加载镜像 使用以下命令从redis.tar文件中加载Docker镜像 docker load --input redis.tar 创建映射目录 为了确保Redis的数据能够持久化,我们需要创建一个本地目录来存储这些数据 mkdir -p datasource/docker/redis 运…