简介
如果你喜欢web可视化的方式训练深度学习模型,那么streamlit是一个不可错过的选择!
优点:
- 提供丰富的web组件支持
- 嵌入python中,简单易用
- 轻松构建一个web页面,按钮控制训练过程
本文使用streamlit进行web可视化渲染,并使用ndraw进行模型可视化,做到了:
- 训练过程可视化
- 模型输入输出shape一目了然
构建环境
首先安装必要的依赖,tensorflow、streamlit和ndraw为必须依赖,其他依赖根据自己的情况进行安装
pip install streamlit
pip install tensorflow
pip install ndraw
其他的环境自行安装,不过多赘述
然后引入模块:
import ndraw
import streamlit as st
import tensorflow as tf
import streamlit.components.v1 as components
编写代码
以mnist数据集为例
1.获取数据
书写数据加载方法,如果你的数据集没有改动的话,可以使用@st.cache装饰器,其作用是缓存数据,不用每次训练都重新加载数据
@st.cache(allow_output_mutation=True)
def get_data(is_onehot = False):
# 根据自己的训练数据进行加载
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train/255.0
x_test = x_test/255.0
if is_onehot:
y_train = tf.one_hot(y_train,10)
y_test = tf.one_hot(y_test,10)
return (x_train, y_train), (x_test, y_test)
2.构建模型
简单构建一个模型:如果是较为复杂模型,可以使用ndraw进行维度的查看
def build_model():
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
return model
3.构建逻辑
使用streamlit构建模型的逻辑:
- 首先设置一个web页面的标题
- 在左侧设置一个导航栏:开始和结束
- 点击开始的时候开始训练
- 添加一个模型扩展位置,点击的时候可以查看模型
if __name__ == '__main__':
#设置网页标题
st.title("训练xx模型")
#点击开始后进行数据加载和训练
if st.sidebar.button('开始'):
(x_train, y_train), (x_test, y_test) = get_data(is_onehot=True)
st.text("train size: {} {}".format(x_train.shape, y_train.shape))
st.text("test size: {} {}".format(x_test.shape, y_test.shape))
model = build_model()
#点击查看模型后可以可视化模型
with st.expander("查看模型"):
components.html(ndraw.render(model,init_x=200, flow=ndraw.VERTICAL), height=1000, scrolling=True)
model.compile(loss="categorical_crossentropy", optimizer=tf.keras.optimizers.Adam(lr=0.001),metrics=["accuracy"])
model.fit(x_train, y_train, batch_size=128, validation_data=(x_test, y_test), epochs=10, verbose=1,callbacks=[TrainCallback(x_test, y_test)])
st.success('训练结束')
if st.sidebar.button('停止'):
st.stop()
4.自定义指标可视化
tf提供了丰富的自定义功能,包括模型自定义,指标自定义,loss自定义、训练过程自定义等等,此处自定义一个训练过程自定义的Callback,主要用于在训练过程中获取相关的loss和acc进行绘图
class TrainCallback(tf.keras.callbacks.Callback):
def __init__(self, test_x, test_y):
super(TrainCallback, self).__init__()
self.test_x = test_x
self.test_y = test_y
def on_train_begin(self, logs=None):
st.header("训练汇总")
self.summary_line = st.area_chart()
st.subheader("训练进度")
self.process_text = st.text("0/{}".format(self.params['epochs']))
self.process_bar = st.progress(0)
st.subheader('loss曲线')
self.loss_line = st.line_chart()
st.subheader('accuracy曲线')
self.acc_line = st.line_chart()
def on_epoch_end(self, epoch, logs=None):
self.loss_line.add_rows({'train_loss': [logs['loss']], 'val_loss': [logs['val_loss']]})
self.acc_line.add_rows({'train_acc': [logs['accuracy']], 'val_accuracy': [logs['val_accuracy']]})
self.process_bar.progress(epoch / self.params['epochs'])
self.process_text.empty()
self.process_text.text("{}/{}".format(epoch, self.params['epochs']))
def on_batch_end(self, epoch, logs=None):
if epoch % 10 == 0 or epoch == self.params['epochs']:
self.summary_line.add_rows({'loss': [logs['loss']], 'accuracy': [logs['accuracy']]})
展示
总结
以上就是整个训练过程,不同的模型只需要更改一下加载数据和构建模型的函数即可,其他内容不变或者根据自己的需求添加
完整外码可见 visualneu