streamlit+ndraw进行可视化训练深度学习模型

news2025/1/13 13:47:39

简介

如果你喜欢web可视化的方式训练深度学习模型,那么streamlit是一个不可错过的选择!

优点:

  1. 提供丰富的web组件支持
  2. 嵌入python中,简单易用
  3. 轻松构建一个web页面,按钮控制训练过程

本文使用streamlit进行web可视化渲染,并使用ndraw进行模型可视化,做到了:

  1. 训练过程可视化
  2. 模型输入输出shape一目了然

构建环境

首先安装必要的依赖,tensorflow、streamlit和ndraw为必须依赖,其他依赖根据自己的情况进行安装

pip install streamlit
pip install tensorflow
pip install ndraw

其他的环境自行安装,不过多赘述

然后引入模块:

import ndraw
import streamlit as st
import tensorflow as tf
import streamlit.components.v1 as components

编写代码

以mnist数据集为例

1.获取数据

书写数据加载方法,如果你的数据集没有改动的话,可以使用@st.cache装饰器,其作用是缓存数据,不用每次训练都重新加载数据

@st.cache(allow_output_mutation=True)
def get_data(is_onehot = False):
    # 根据自己的训练数据进行加载
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train = x_train/255.0
    x_test = x_test/255.0
    if is_onehot:
        y_train = tf.one_hot(y_train,10)
        y_test = tf.one_hot(y_test,10)
    return (x_train, y_train), (x_test, y_test)

2.构建模型

简单构建一个模型:如果是较为复杂模型,可以使用ndraw进行维度的查看

def build_model():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    return model

3.构建逻辑

使用streamlit构建模型的逻辑:

  1. 首先设置一个web页面的标题
  2. 在左侧设置一个导航栏:开始和结束
  3. 点击开始的时候开始训练
  4. 添加一个模型扩展位置,点击的时候可以查看模型
if __name__ == '__main__':
    #设置网页标题
    st.title("训练xx模型")
    #点击开始后进行数据加载和训练
    if st.sidebar.button('开始'):
        (x_train, y_train), (x_test, y_test) = get_data(is_onehot=True)

        st.text("train size: {} {}".format(x_train.shape, y_train.shape))
        st.text("test size: {} {}".format(x_test.shape, y_test.shape))

        model = build_model()
        #点击查看模型后可以可视化模型
        with st.expander("查看模型"):
            components.html(ndraw.render(model,init_x=200, flow=ndraw.VERTICAL), height=1000, scrolling=True)
        model.compile(loss="categorical_crossentropy", optimizer=tf.keras.optimizers.Adam(lr=0.001),metrics=["accuracy"])
        model.fit(x_train, y_train, batch_size=128, validation_data=(x_test, y_test), epochs=10, verbose=1,callbacks=[TrainCallback(x_test, y_test)])
        st.success('训练结束')

    if st.sidebar.button('停止'):
        st.stop()


4.自定义指标可视化

tf提供了丰富的自定义功能,包括模型自定义,指标自定义,loss自定义、训练过程自定义等等,此处自定义一个训练过程自定义的Callback,主要用于在训练过程中获取相关的loss和acc进行绘图

class TrainCallback(tf.keras.callbacks.Callback):
    def __init__(self, test_x, test_y):
        super(TrainCallback, self).__init__()
        self.test_x = test_x
        self.test_y = test_y

    def on_train_begin(self, logs=None):
        st.header("训练汇总")
        self.summary_line = st.area_chart()

        st.subheader("训练进度")
        self.process_text = st.text("0/{}".format(self.params['epochs']))
        self.process_bar = st.progress(0)

        st.subheader('loss曲线')
        self.loss_line = st.line_chart()

        st.subheader('accuracy曲线')
        self.acc_line = st.line_chart()

    def on_epoch_end(self, epoch, logs=None):
        self.loss_line.add_rows({'train_loss': [logs['loss']], 'val_loss': [logs['val_loss']]})
        self.acc_line.add_rows({'train_acc': [logs['accuracy']], 'val_accuracy': [logs['val_accuracy']]})
        self.process_bar.progress(epoch / self.params['epochs'])
        self.process_text.empty()
        self.process_text.text("{}/{}".format(epoch, self.params['epochs']))

    def on_batch_end(self, epoch, logs=None):
        if epoch % 10 == 0 or epoch == self.params['epochs']:
            self.summary_line.add_rows({'loss': [logs['loss']], 'accuracy': [logs['accuracy']]})

展示

在这里插入图片描述
在这里插入图片描述

总结

以上就是整个训练过程,不同的模型只需要更改一下加载数据和构建模型的函数即可,其他内容不变或者根据自己的需求添加

完整外码可见 visualneu

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/27708.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

会议管理系统SSM记录(一)

目录: (1)环境搭建 (2)整合MyBatis (1)环境搭建 添加:package 配置成web的结构: pom先加入springmvc的依赖就可以实现spring和springmvc的整合 pom.xml中加入依赖&am…

接口的定义与实现

声明类的关键字是class,声明接口的关键字是interface 1.介绍 普通类:只有具体实现 抽象类:具体实现和规范(抽象方法)都有 接口:只有规范 |自己无法写方法,专业的约束 接口就是规范,…

MATLAB | 全网唯一 MATLAB双向弦图(有向弦图)绘制

先赞后看,养成习惯~~ 先赞后看,养成习惯~~ 先赞后看,养成习惯~~ 绘制效果 下面这款弦图我已经出了很久了,也陆陆续续增添了新的功能和修了一些bug: 甚至还用它做出了一些复刻,分成两组的弦图有了后就有很多…

【仿真建模】AnyLogic入门基础教程 第一课

文章目录一、AnyLogic介绍二、设置2.1 设置中文三、新建项目四、行人库介绍五、创建新行人六、切换3D视角七、增加墙八、行人密度图一、AnyLogic介绍 二、设置 2.1 设置中文 三、新建项目 四、行人库介绍 点击面板,选择第三个图标,就是行人库 行人库分…

【第五部分 | JS WebAPI】4:八千字详解 “事件·高级篇”

目录 | 概述 | 注册事件的两种方式 | 删除事件的两种方式 | 事件对象【重要】 事件对象简介和声明 e.target 和 this 的区别 [ 事件对象 的常用属性方法 ] | Dom事件流 什么是Dom事件流? 阻止默认行为 阻止事件冒泡 利用事件冒泡进行事件委托 | 常用的鼠…

1、Git相关操作

目录 一、远程库的拉取 二、远程库创建分支 声明:需要有一定的GIt基础,如果不懂可以自行查看个人学习的Git笔记或者可以通过其他途径学习Git 一、远程库的拉取 步骤: 先创建一个空的文件夹在创建的文件夹中使用git init 命令来初始化本地…

频域中的后门攻击论文笔记

文章一:Rethinking the Backdoor Attacks’ Triggers: A Frequency Perspective 文章贡献: 在频域上对现有的 backdoor trigger 进行分析,发现常见 trigger 存在 high-frequency artifacts 的问题。对这些 artifacts 进行了详细的分析展示了…

什么是中间件

一、什么是中间件 中间件(Middleware)是处于操作系统和应用程序之间的软件,也有人认为它应该属于操作系统中的一部分。人们在使用中间件时,往往是一组中间件集成在一起,构成一个平台(包括开发平台和运行平…

企业内训app源码,在线培训小程序,随时随地想学就学

近年来,在线学习逐渐被广泛应用于人才培养领域。公司要想长远发展,内部培训必不可少。公司的发展离不开公司整体员工的进步,而人员管理往往是公司管理中最重要也最难的一个环节。许多公司开始通过企业内训app源码开发来优化公司人员管理方式、…

基于PHP+MySQL学生信息管理系统的设计与实现

我国是一个高等教育逐渐普及的国度,相应的每年也有上百万的大学生入校,如此庞大的学生数量如何进行更加科学的管理是教育工作者一直关心的一个问题,为了能够实现高校对学生信息管理的科学化,信息化,我们开发了本基于PH…

C++ 手动实现双向链表(作业版)

双向链表&#xff0c;并实现增删查改等功能 首先定义节点类&#xff0c;类成员包含当前节点的值&#xff0c; 指向下一个节点的指针和指向上一个节点的指针 //节点定义 template <typename T> class Node { public:Node<T>* prior;T value;Node<T>* next;N…

减少乘法次数的优化算法(Gauss、Strassen、Winograd)

目录 Gauss算法 Strassen算法 Winograd算法 Winograd 1D Winograd 2D 在硬件设计中&#xff0c;乘法无论是在逻辑资源的使用上还是组合逻辑的延时上都要比加法高很多。从硬件方面考虑&#xff0c;我们都更倾向于将乘法转换成移位和加法&#xff0c;譬如乘以8&#xff0c;可…

stm32项目平衡车详解(stm32F407)

stm32项目 stm32项目介绍值平衡车 本文章学习借鉴于创客学院团队&#xff0c;以表感谢。教学视频 文章目录stm32项目前言一、平衡小车平衡小车的功能介绍平衡小车功能开发需求平衡小车整体框架小车环境数据采集进程1. 平衡小车姿态信息介绍2. 平衡小车项目工程框架搭建3. Mpu6…

【面试题】原型和原型链

1. 如何用class实现继承 // 父类 class People{constructor(name){this.name name}eat(){console.log(${this.name} eat something)} }// 子类 class Student extends People{constructor(name, number){super(name)this.number number}sayHi(){console.log(姓名&#xff1a…

自动化脚本如何切换环境?Pytest这些功能你必须要掌握

文章目录一、前言二、安装三、使用第1种:使用方式是终端添加–base-url这个命令第2种:使用方式是在pytest.ini配置文件种去配置base_url,然后自动读取url的数据&#xff0c;这样就不用添加–base-url这个命令行参数了&#xff1a;第3种:pytest有个hooks函数&#xff0c;可以自定…

最优二叉搜索树问题(Java)

最优二叉搜索树问题&#xff08;Java&#xff09; 文章目录最优二叉搜索树问题&#xff08;Java&#xff09;1、前置介绍2、算法设计思路2.1 最优二叉搜索树的结构2.2 一个递归算法2.3 计算最优二叉搜索树的期望搜索代价3、代码实现4、复杂度分析5、参考资料1、前置介绍 设S{x…

R语言探索 BRFSS 数据和预测

加载包 library(ggplot2) library(dplyr) library(Hmisc) library(corrplot) 加载数据 load("brfss2013.RData") 第1部分&#xff1a;关于数据 行为风险因素监测系统&#xff08;BRFSS&#xff09;是美国的年度电话调查。BRFSS旨在识别成年人口的风险因素并报告…

docker启动出现Error response from daemon: Cannot restart container的报错

1、发现问题 突然发现启动(重启)容器的时候报这个错 Error response from daemon: Cannot restart container 容器id: driver failed programming external connectivity on endpoint 容器名 (容器id): (iptables failed: iptables --wait -t nat -A DOCKER -p tcp -d 0/0 --…

图像超分辨率:优化最近邻插值Super-Resolution by Predicting Offsets

文章目录3. Super-Resolution by Predicting Offsets3.1. 这篇论文用于处理栅格化图像的超分&#xff0c;不知道这样翻译对不对&#xff0c;3.2. 作者认为栅格图像的边缘比较规则&#xff0c;可以训练一个offset map移动栅格图像的 边缘点&#xff08;背景和前景像素 移动 和交…

能率携手梦想改造家,打造适老化住宅新典范

家装改造类节目《梦想改造家》第九季温情回归&#xff0c;日本一级建筑设计师本间贵史携手知名燃热品牌能率&#xff0c;与节目组一起关注民生&#xff0c;走进由一家五口组成的“足不出户的家”&#xff0c;共启老宅改造计划&#xff0c;倾情助力普通家庭拥抱生活与梦想&#…