jax环境依赖安装过程
原仓库:https://github.com/google-research/l2p?tab=readme-ov-file
参考:flax与jax、jaxlib的安装顺序和版本匹配
笔者环境为cuda12.1:
pip install "jax[cuda12]" flax
依赖环境如下:
flax-0.9.0
、jax-0.4.33
、jaxlib-0.4.33
然后,由于新版flax移除了optim包,所以需要参考该博客修改代码:AttributeError: module ‘flax‘ has no attribute ‘optim‘
数据集下载
- tensorflow-dataset 内网下载 指定目录
- tensorflow-dataset All attempts to get a Google authentication bearer token failed
总结为三步:安装稍低版本的tfds==2.1.0,设置内网下载使用代理,下载到指定目录。
flax.optim兼容性
l2p中使用了旧版的flax,这些optim API在新版都被遗弃了,如果要兼容,请参考Replacing flax.optim with optax。
笔者觉得太麻烦,就放弃了,转而安装旧版本的cuda和jax。笔者的机器全是cuda11和cuda12的,幸好找到云上平台有个镜像能支持cuda10.
https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 所有jax发行版和支持的cuda。
if config.get("optim_wd_ignore"):
# Allow zero weight decay for certain parameters listed in optim_wd_ignore
igns = config.optim_wd_ignore
p = flax.optim.ModelParamTraversal(
lambda path, _: not any([i in path for i in igns]))
p_nowd = flax.optim.ModelParamTraversal(
lambda path, _: any([i in path for i in igns]))
p_opt = flax.optim.Adam(weight_decay=config.weight_decay)
p_nowd_opt = flax.optim.Adam(weight_decay=0)
opt_def = flax.optim.MultiOptimizer((p, p_opt), (p_nowd, p_nowd_opt))
else:
opt_def = flax.optim.Adam(weight_decay=config.weight_decay)
踩坑
踩坑笔记:
- AttributeError: module ‘tensorflow._api.v2.compat.v2.internal‘ has no attribute ‘register_load_c
- jax安装踩坑(1) ImportError: cannot import name ‘linear_util‘ from ‘jax‘
- jax环境安装笔记
- jax安装踩坑(2) ModuleNotFoundError: No module named ‘keras.src.engine‘
- AttributeError: module ‘flax‘ has no attribute ‘optim‘