一个简单的pytorch项目框架

news2025/1/13 3:08:05

框架的基本功能:

        1. 模型的定义、训练与测试

        2. 数据生成与数据迭代器

        3. 训练日志记录

        4. 训练过程实时监控

有了这个框架,后续所有复杂的AI项目都可以在此基础上拓展开发。 

项目基本结构:

四个文件:

sequence_mean_generate.py 用于数据的生成与迭代
mean_mlp_model.py 用于定义网络模型,还有模型的训练与测试函数
log.py 用于定义日志函数
mean_mlp_main.py 主函数

以下是四个文件的代码

mean_mlp_main.py

import torch
import model.mean_mlp_model as mean_mlp_model
from data import sequence_mean_generate as ds
import os
import shutil
from tools import log
import argparse


def creat_args():
    
    # 检查CUDA是否可用
    if torch.cuda.is_available():
        print("CUDA is available!")
        # 如果CUDA可用,打印使用的GPU设备
        print("Using GPU:", torch.cuda.get_device_name())
    else:
        print("CUDA is not available. Using CPU instead.")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if os.path.exists('./log/'):
        shutil.rmtree('./log/')
    model_log_trace = log.creat_log('./log/', 'model_parameters', '.log')
    
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument('--device', type=torch.device, help='cpu or gpu', default=device)
    arg_parser.add_argument('--model_log_trace', type=str, help='model trace log path and name',
                            default=model_log_trace)
    arg_parser.add_argument('--epochs', type=int, help='training epochs', default=200)
    arg_parser.add_argument('--training_num_samples', type=int, help='training num samples', default=10000)
    arg_parser.add_argument('--training_batch_size', type=int, help='training batch size', default=24)
    arg_parser.add_argument('--training_max_range', type=int, help='training max range', default=100)
    arg_parser.add_argument('--training_min_range', type=int, help='training min range', default=-100)
    arg_parser.add_argument('--test_num_samples', type=int, help='test num samples', default=100)
    arg_parser.add_argument('--test_batch_size', type=int, help='test batch size', default=1)
    arg_parser.add_argument('--test_max_range', type=int, help='test max range', default=10000)
    arg_parser.add_argument('--test_min_range', type=int, help='test min range', default=-10000)
    args = arg_parser.parse_args()
    
    return args


# 主程序
def main():
    
    args = creat_args()
    
    # 初始化模型、损失函数、优化器
    print('start')
    model = mean_mlp_model.MLPModel().to(args.device)

    # 生成数据集
    training_set = ds.create_dataloader(args, mode='training')
    test_set = ds.create_dataloader(args, mode='test')
    
    # 训练模型
    mean_mlp_model.train_model(model, training_set, test_set, args)
    
    # 测试模型
    mean_mlp_model.test_model(model, test_set, args)
    
    for name, param in model.named_parameters():
        args.model_log_trace.info('%s: %s', name, param)
        
    log.close_log(args.model_log_trace)
    

if __name__ == '__main__':
    main()

mean_mlp_model.py

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.tensorboard import SummaryWriter


# 定义MLP模型
class MLPModel(nn.Module):
	def __init__(self):
		input_size = 4
		mean_out = 1
		super(MLPModel, self).__init__()
		self.fc = nn.Linear(input_size, mean_out)
	
	def forward(self, x):
		x = self.fc(x)
		return x


# 测试模型
def test_model(model, test_set, args):
	model.eval()
	total = 0
	with torch.no_grad():
		for inputs, targets in test_set:
			inputs, targets = inputs.to(args.device), targets.to(args.device)
			outputs = model(inputs)
			targets_array = np.array(targets)
			outputs_array = np.array(outputs)
			mse = np.mean((targets_array - outputs_array) ** 2)
			total += mse
	test_set_mse = total / test_set.batch_size
	print(f'MSE on the test data: {test_set_mse:.2f}')
	return test_set_mse


# 训练模型
def train_model(model, training_set, test_set, args):
	
	criterion = nn.MSELoss()
	optimizer = optim.Adam(model.parameters(), lr=0.001)
	writer = SummaryWriter('./log')
	
	iteration = 0
	
	model.train()
	for epoch in range(args.epochs):
		for inputs, targets in training_set:
			iteration += 1
			inputs, targets = inputs.to(args.device), targets.to(args.device)
			optimizer.zero_grad()
			outputs = model(inputs)
			loss = criterion(outputs, targets)
			loss.backward()
			optimizer.step()
		for name, param in model.named_parameters():
			args.model_log_trace.debug('%s: %s', name, param)
		
		writer.add_scalars('Weight', {'fc(0,0)': model.fc.weight[0, 0], 'fc(0,1)': model.fc.weight[0, 1], 'fc(0,2)': model.fc.weight[0, 2], 'fc(0,3)': model.fc.weight[0, 3]}, iteration)
		args.model_log_trace.info(f'Epoch {epoch + 1}/{args.epochs}, Loss: {loss.item()}')

		# 测试模型
		test_loss = test_model(model, test_set, args)
		writer.add_scalars('Loss', {'train': loss.item(), 'test': test_loss}, epoch)
	writer.close()
	

sequence_mean_generate.py

import random
import torch
from torch.utils.data import DataLoader, TensorDataset


def generate_random_sequence(min_range, max_range):
    # 设置器噪声倍数范围
    n = 0.1

    # 生成一个包含四个随机整数的列表,范围在range_lower到range_upper之间
    random_numbers = [random.randint(min_range, max_range) for _ in range(4)]
    random_noise = [random.randint(min_range*0.01, max_range*0.01) for _ in range(4)]
    
    # 创建一个新列表来存储相加的结果
    summed_numbers = []
    
    # 遍历两个列表,将对应元素相加
    for num, noise in zip(random_numbers, random_noise):
        summed_numbers.append(num + noise)

    # 计算平均数
    average = [sum(random_numbers) / len(random_numbers)]

    # 返回列表和平均数
    return summed_numbers, average


# 生成数据集
def generate_dataset(min_range, max_range, num_samples, device):
    quadruples = []
    labels = []
    for _ in range(num_samples):
        quadruple, label = generate_random_sequence(min_range, max_range)
        quadruples.append(quadruple)
        labels.append(label)
    quadruples_tensor = torch.tensor(quadruples, dtype=torch.float32).to(device)
    labels_tensor = torch.tensor(labels, dtype=torch.float32).to(device)
    return quadruples_tensor, labels_tensor


# 创建DataLoader
def create_dataloader(args, mode):
    if mode == 'training':
        min_range = args.training_min_range
        max_range = args.training_max_range
        num_samples = args.training_num_samples
        batch_size = args.training_batch_size
    else:
        min_range = args.test_min_range
        max_range = args.test_max_range
        num_samples = args.test_num_samples
        batch_size = args.test_batch_size
        
    quadruples, labels = generate_dataset(min_range, max_range, num_samples, args.device)
    dataset = TensorDataset(quadruples, labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

log.py

import os
import shutil
import logging


def creat_log(log_path, logging_name, suf_name):
	
	if not os.path.exists(log_path):
		os.makedirs(log_path)
	log_full_path = log_path + logging_name + suf_name
	
	logger = logging.getLogger(logging_name)
	logger.setLevel(level=logging.DEBUG)
	
	handler = logging.FileHandler(log_full_path, encoding='UTF-8', mode='w')
	handler.setLevel(logging.INFO)
	formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
	handler.setFormatter(formatter)
	
	console = logging.StreamHandler()
	console.setLevel(logging.DEBUG)
	logger.addHandler(handler)
	logger.addHandler(console)
	return logger


# 关闭log
def close_log(log_trace):
			
	for handler in list(log_trace.handlers):
		log_trace.removeHandler(handler)
		

以下是训练演示:

 tensorboard 的实时监控演示:

关于tensorboard的设置:

第一步:正常安装tensorboard

pip install tensorboard  

第二步:添加系统的环境变量

第三步:添加监控代码

第四步:设置tensorboard的运行脚本

        在”1“中进行配置,并填写对应的选项,填写正确后点应用-》确定-》运行

 4. 运行mean_mlp_main.py主程序,生成数据,这时tensorboard创建的服务器就会调用这些数据生成监控面板上的监控结果

5. 再切换到tensorboard的运行窗口,如下图,点击输出结果中的网址,就可以访问监控结果,在页面按F5可以刷新监控画面 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2036834.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++初阶_2:引用

本节咱们来说说引用: C添加了“引用”,与指针成了两兄弟——这两兄弟对我们今后写C代码可谓各有特点,缺一不可。 何谓引用? 引用:就是取别名 不知诸位可有别名?这里不妨举一本耳熟能详的小说《水浒传》&…

Redis16-批处理优化

目录 Pipeline 集群下的批处理 Pipeline 单个命令的执行流程: N条命令的执行流程: N条命令批量执行: Redis提供了很多Mxxx这样的命令,可以实现批量插入数据,例如: msethmset 利用mset批量插入10万条数…

vivado报错:file ended before end of clause

最近在学习Xilinx FPGA时,遇到 Vivado 报错如下图所示: 刚开始,看到错误是在第1行代码中出现的,我的第一反应是该行代码写错了,然后搜了搜语法,发现没错。 分析报错信息发现,该错误应该是和文件…

VScode + PlatformIO 和 Keil 开发 STM32

以前经常使用 KEIL 写 STM32 的代码,自从使用 VScode 写 ESP32 后感觉 KEIL 的开发环境不美观不智能了,后面学习了 VScode 开发 STM32 。 使用过程中发现 串口重定向在 KEIL 中可以用,搬到 VScode 后不能用,不用勾选 Use Micro LI…

SpringMVC 运行流程

SpringMVC 运行流程 💖The Begin💖点点关注,收藏不迷路💖 SpringMVC的运行流程可概括为以下几个核心步骤: 流程图: #mermaid-svg-l1CeK9JwP5wRQjBL {font-family:"trebuchet ms",verdana,arial,…

医学图像分割新突破:6篇文献带你洞悉最前沿的医学AI技术|顶刊速递·24-08-14

小罗碎碎念 文献日推主题:人工智能在医学图像分割中的最新研究进展 今天这期文章信息量很大,并且不同的人看,获取的信息量也会差距很大。为了缩小这个差距,请在正式阅读前,记住小罗的一句话——模型学会了如何分割图像…

【Spark集群部署系列三】Spark StandAlone HA模式介绍和搭建以及使用

简介: Spark Standalone集群是Master-Slaves架构的集群模式,和大部分的Master-Slaves结构集群一样,存在着Master 单点故障(SPOF)的问题。 高可用HA 如何解决这个单点故障的问题,Spark提供了两种方案&#…

83.SAP ABAP从前台找字段所在表的两种方法整理笔记

目录 方法1:F1查看技术信息 F1 技术信息 方法2:ST05开启跟踪 Activate Trace Input and save data Deactivate Trace Display Trace 分析你想要的表 方法1:F1查看技术信息 从前台找一个屏幕字段所在表,一般通过按F1来查找…

Java Nacos与Gateway的使用

Java系列文章目录 IDEA使用指南 Java泛型总结(快速上手详解) Java Lambda表达式总结(快速上手详解) Java Optional容器总结(快速上手图解) Java 自定义注解笔记总结(油管) Jav…

【大数据】6:MapReduce YARN 初体验

目录 MapReduce & YARN 初体验 集群启停命令 一键启动脚本: 单进程启停 提交MapReduce任务到YARN执行 提交MapReduce程序至YARN运行 提交wordcount示例程序 提交求圆周率示例程序 拓展:蒙特卡罗算法求PI的基础原理 onte Carlo蒙特卡罗算法…

【MySQL 06】表的约束

文章目录 🌈 一、约束的概念🌈 二、空属性约束⭐ 1. 空值无法参与运算⭐ 2. 设置非空属性 🌈 三、默认值约束⭐ 1. 默认值使用案例⭐ 2. 同时设置 not null 和 default 🌈 四、列描述约束🌈 五、zerofill 补零约束&…

贷齐乐漏洞复现+php特性绕过WAF

目录 一、环境搭建 1.将贷齐乐源码放入phpstudy中的www目录下 2.在phpstudy上创建网站: 3.在本地数据库中创建数据库--ctf,并创建users表: 4.往表中插入数据: 5.查看users表: 6.测试能否访问到数据库 二、源码分析…

力扣热题100_链表_234_回文链表

文章目录 题目链接解题思路解题代码 题目链接 234. 回文链表 给你一个单链表的头节点 head ,请你判断该链表是否为 回文链表。如果是,返回 true ;否则,返回 false 。 示例 1: 输入:head [1,2,2,1] 输出…

搭建MoneyPrinterTurbo,利用AI大模型,一键生成高清短视频实战

搭建MoneyPrinterTurbo,利用AI大模型,一键生成高清短视频 1.MoneyPrinterTurbo简介 只需提供一个视频 主题 或 关键词 ,就可以全自动生成视频文案、视频素材、视频字幕、视频背景音乐,然后合成一个高清的短视频。 github地址&a…

【大模型从入门到精通19】开源库框架LangChain LangChain文档加载器1

目录 理解文档加载器非结构化数据加载器结构化数据加载器 使用文档加载器的实际指南设置和配置安装必要的包(注意:这些包可能已经在你的环境中安装好了)从 .env 文件加载环境变量从环境变量中设置 OpenAI API 密钥 在数据驱动的应用领域&…

企业如何组建安全稳定的跨国通信网络

当企业在海外设有分公司时,如何建立一个安全且稳定的跨国通信网络是一个关键问题。为了确保跨国通信的安全和稳定性,可以考虑以下几种方案。 首先,可以在分公司之间搭建虚拟专用网络。虚拟专用网络通过对传输数据进行加密,保护通信…

Java:jdk8以后开始接口新增的3种方法:default,private,static

文章目录 jdk8以后开始接口新增的方法默认方法:deafult私有方法private如何查看自己的jdk版本静态方法static 问题接口中不止有抽象方法为什么接口中的方法都是public为什么要增加这三种方法 jdk8以后开始接口新增的方法 默认方法:deafult 必须使用defa…

【CentOS 】DHCP 更改为静态 IP 地址并且遇到无法联网

文章目录 引言解决方式标题1. **编辑网络配置文件**:标题2. **确保配置文件包含以下内容**:特别注意 标题3. **重启网络服务**:标题4. **检查配置是否生效**:标题5. **测试网络连接**:标题6. **检查路由表**&#xff1…

思科默认路由配置2

#路由协议实现# #任务二默认路由配置2# #1配置计算机的IP地址、子网掩码和网关 #2配置Router-A的名称及其接口IP地址 Router(config)#hostname Router-A Router-A(config)#int g0/0 Router-A(config-if)#ip add 192.168.1.1 255.255.255.0 Router-A(config-if)#no shutdow…

全网最适合入门的面向对象编程教程:36 Python的内置数据类型-字典

全网最适合入门的面向对象编程教程:36 Python 的内置数据类型-字典 摘要: 字典是非常好用的容器,它可以用来直接将一个对象映射到另一个对象。一个拥有属性的空对象在某种程度上说就是一个字典,属性名映射到属性值。在内部&#…