20

谷歌大脑开源Trax代码库,你的深度学习进阶路径

 4 years ago
source link: https://www.jiqizhixin.com/articles/2020-02-27-8
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

感觉深度学习建模只不过调库与堆叠层级?你需要谷歌大脑维护的这条路径 Trax,从头实现深度学习模型。

buYbIv2.jpg!web

从最开始介绍卷积、循环神经网络原理,到后来展示各种最前沿的算法与论文,机器之心与读者共同探索着机器学习。我们会发现,现在读者对那些著名的深度学习模型已经非常熟悉了,经常也会推导或复现它们。

而对于最前沿的一些实现,包括 Transformer 或其它强化学习,我们通常都需要看原作者开源的代码,或者阅读大厂的复现。出于速度等方面的考虑,这些实现通常会显得比较「隐晦」,理解起来不是那么直白。这个时候,你就需要谷歌大脑维护的 Trax,它是 ML 开发者进阶高级 DL 模型的路径。

Trax 是一个开源项目,它的目的在于帮助我们挖掘并理解高一阶的深度学习模型。谷歌大脑表示,该项目希望 Trax 代码做到非常整洁与直观,并同时令 Reformer 这类高阶深度学习达到最好的效果。

项目地址:https://github.com/google/trax

jaiMVnj.jpg!web

什么是 Trax

简单来说,Trax 就是一个代码库,它有点类似于一个极简的深度学习框架。只不过 Trax 关注什么样的代码能让读者更好地理解模型,而不只是关注加速与优化。

Trax 代码及其组织方式希望让我们从头理解深度学习,而不只是简单地调库。整个项目从最基础的数学部分开始,然后向上依次构建层级运算、模型运算,以及有监督与强化学习训练任务。

因为是进阶深度学习高级建模,Trax 还囊括了最前沿的研究结果,例如在 ICLR 2020 上做演讲报告的 Reformer。如下展示的是该项目的代码文件结构:

MvmAraa.jpg!web

如果要从头理解并进阶深度学习,那么 Trax 代码主要可以分为以下 6 部分:

  • math/:最基本的数学运算,以及通过 JAX 和TensorFlow加速运算性能的方法,尤其是在 GPU/TPU 上;

  • layers/:搭建神经网络的所有层级构建块;

  • models/:包含所有基础模型,例如 MLP、ResNet 和 Transformer,还包含一些前沿 DL 模型;

  • optimizers/:包含深度学习所需要的最优化器;

  • supervised/:包含执行监督学习的各种有用模块,以及整体的训练工具;

  • rl/:包含谷歌大脑在强化学习上的一些研究工作;

每一个文件夹下都有对应的实现,例如在 Layers 中,所有神经网络层级都继承自最基础的 Layer 类,实现这个类花了 700 行代码。而后新的层级在继承它后只要实现以下两个方法就行:

3eMNzmn.jpg!web

通过 900 行代码(包括 Err 处理),基础的 Layer 类能完成其它所有处理,包括初始化与调用等。

使用 Trax

我们可以将 Trax 作为 Python 脚本库或者JupyterNotebook 的基础,也可以作为命令行工具执行。Trax 包含很多深度学习模型,并且绑定了大量深度学习数据集,包括 Tensor2Tensor 和TensorFlow采用的数据集。同时,如果我们在 CPU、GPU 或 TPU 上运行这些模型,也不需要改变。

如果读者想要了解如何快速将 Trax 作为一个库来使用,那么可以看看如下 Colab 上的入门示例。它介绍了如何生成样本数据,并连接到 Trax 中的 Transformer 模型。在训练或推断时,我们可以选择 GPU,也可以选择 8 核心的免费 TPU。

ieIrq22.jpg!web

入门简介地址:

https://colab.research.google.com/github/google/trax/blob/master/trax/intro.ipynb

如果要在命令行中使用 Trax,那么带上参数就可以了,例如模型类型、学习率等超参。谷歌大脑团队建议我们可以看看 gin-config,例如训练一个最简单的 MNIST 分类模型,可以看看 mlp_mnist.gin,然后如下运行就行了:

python -m trax.trainer --config_file=$PWD/trax/configs/mlp_mnist.gin

如果你觉得上面的训练太简单,也可以在 ImageNet64 上训练一下 Reformer:

python -m trax.trainer --config_file=$PWD/trax/configs/reformer_imagenet64.gin

最后,这个项目最重要的还是它的实现代码,我们并不是因为可以直接运行而使用它。相反,我们是因为它的代码直观简洁,能帮助我们一步步更深刻地理解模型而使用它。


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK