需要安装:
TensorFlow 安装(包含cudatoolkit、cuDNN)
HDF5 和 h5py (如果你需要将 Keras模型保存到磁盘,则需要这些)
graphviz 和 pydot (用于绘制模型图的可视化工具)
Keras
一、更新驱动
先升级显卡驱动:https://zhuanlan.zhihu.com/p/147552901,确保后面不会因为显卡驱动版本低这个问题被卡住。
二、安装tensorflow
搭建tensorflow环境(keras最高支持到python3.6,若以后更高了再改成3.9、3.10之类的吧~)
conda create -n tensorflow-gpu python=3.6
进入到新环境中
activate tensorflow-gpu
mamba也是一个包管理器,设置环境时比较快,避免停在solving environment不动
conda install mamba
conda install mamba -c conda-forge
安装tensorflow-gpu可以一次性安装CUDA、cuDNN、tensorflow-gpu、tensorflow
mamba install tensorflow-gpu
没自信的话,可以查看一下,conda list
三、安装keras-gpu
mamba install keras-gpu
四、安装 h5py、graphviz 、 pydot
conda install xxxx 就好了
import keras去试一下啊!