28

大规模图训练调优指南

 3 years ago
source link: https://mp.weixin.qq.com/s?__biz=MzIwMTc4ODE0Mw%3D%3D&%3Bmid=2247513211&%3Bidx=2&%3Bsn=66701f0b53e4661d8f9e87038f9b4386
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

EnAnQz.gif!mobile

©PaperWeekly 原创 · 作者|桑运鑫 

学校|上海交通大学硕士生 

研究方向|图神经网络应用

最近对一个大规模的图训练嵌入,发现相关的中文资料还是很欠缺的,把自己踩的一些坑记下来。本文主要针对 DGL [1]  和 PyTorch  [2]  两个框架。

jiuMzu2.png!mobile

训练大规模图

对于大规模图不能像小图一样把整张图扔进去训练,需要对大图进行采样,即通过 Neighborhood Sampling 方法每次采样一部分输出节点,然后把更新它们所需的所有节点作为输入节点,通过这样的方式做 mini-batch 迭代训练。 具体的用法可以参考官方文档中的 Chapter 6: Stochastic Training on Large Graphs [3]

但是 GATNE-T [4] 中有一种更有趣的做法,即只把 DGL 作为一个辅助计算流的工具,提供 Neighborhood Sampling Message Passing 等过程,把 Node Embedding Edge Embedding 等存储在图之外,做一个单独的 Embedding 矩阵。 每次从 dgl 中获取节点的 id 之后再去 Embedding 矩阵中去取对应的 embedding 进行优化,以此可以更方便的做一些优化。

IbQFVj6.png!mobile

缩小图规模

从图的 Message Passing 过程可以看出,基本上所有的图神经网络的计算都只能传播连通图的信息,所以可以先用 connected_componets [5] 查一下自己的图是否是连通图。如果分为多个连通子图的话,可以分别进行训练或者选择一个大小适中的 component 训练。

如果图还是很大的话,也可以对图整体做一次 Neighborhood Sampling 采样一个子图进行训练。

ZJfe2i7.png!mobile

减小内存占用

对于大规模数据而言,如何在内存中存下它也是一件让人伤脑筋的事情。这时候采用什么样的数据结构存储就很关键了。首先是不要使用原生的 list ,使用 np.ndarray 或者 torch.tensor 。尤其注意不要显式的使用 set 存储大规模数据(可以使用 set 去重,但不要存储它)。

ZRfqQvz.png!mobile

注意:四种数据结构消耗的内存之间的差别(比例关系)会随着数据规模变大而变大。

其次就是在 PyTorch 中,设置 DataLoader num_workers 大于 0 时会出现内存泄露的问题,内存占用会随着 epoch 不断增大。查阅资料有两个解决方法:

  • 根据Num_workers in DataLoader will increase memory usage? [6] ,设置 num_workers 小于实际 cpu 数,亲测无效;

  • 根据 CPU memory gradually leaks when num_workers > 0 in the DataLoader [7] ,将原始 list 转为 np.ndarray 或者 torch.tensor ,可以解决。原因是: There is no way of storing arbitrary python objects (even simple lists) in shared memory in Python without triggering copy-on-write behaviour due to the addition of refcounts, everytime something reads from these objects.

JFbIRjv.png!mobile

减小显存消耗

对于大规模图嵌入而言, Embedding 矩阵会非常大。在反向传播中如果对整个矩阵做优化的话很可能会爆显存。可以参考 pinsage [8]  的代码,设置 Embedding 矩阵的 sparse = True ,使用 SparseAdam  [9] 进行优化。 SparseAdam 是一种为 sparse 向量设计的优化方法,它只优化矩阵中参与计算的元素,可以大大减少 backward 过程中的显存消耗。

此外,如果显存仍然不够的话,可以考虑将 Embedding 矩阵放到 CPU 上。使用两个优化器分别进行优化。

ZJVFFrb.png!mobile

加快训练速度

对于大规模数据而言,训练同样要花很长时间。加快训练速度也很关键。加快训练速度方面主要在两个方面:加快数据预处理和提高 GPU 利用率。

在加快数据预处理中,大部分数据集样本之间都是独立的,可以并行处理,所以当数据规模很大的时候,一定要加大数据预处理的并行度。

不要使用 for 循环逐条处理,可以使用 multiprocess [10] 库开多进程并行处理。但是要注意适当设置 processes ,否则会出现错误 OSError: [Errno 24] Too many open files

此外,数据处理好之后最好保存为 pickle 格式的文件,下次使用可以直接加载,不要再花时间处理一遍。

在提高 GPU 利用率上,如果 GPU 利用率比较低主要是两个原因: batch_size 较小(表现为 GPU 利用率一直很低)和数据加载跟不上训练(表现为 GPU 利用率上升下降上升下降,如此循环)。解决方法也是两个:

  • 增大 batch_size ,一般来说 GPU 利用率和 batch_size 成正比;

  • 加快数据加载:设置 DataLoader pin_memory=True , 适当增大 num_workers (注意不要盲目增大,设置到使用的 CPU 利用率到 90% 左右就可以了,不然反而可能会因为开线程的消耗拖慢训练速度)。

如果有多块 GPU 的话,可以参照 graphsage [11] 的代码进行多 GPU 训练。一般来说在单机多卡的情况下都可以得到线性加速。但是使用 DistributedDataParallel 有几个需要注意的问题:

  • 最好参照 graphsage [11] 中的代码而不是使用官方教程中的 torch.multiprocessing.spawn 函数开辟多进程。因为使用这个函数会不停的打印 Using backend: pytorch ,暂时还不清楚是什么原因。
  • DataParallel 一样, DistributedDataParallel 会对原模型进行封装,如果需要获取原模型 model 的一些属性或函数的话,需要 model 替换为 model.module
  • 在使用 DistributedDataParallel 时,需要根据 GPU 的数量对 batch_size learning rate 进行调整。根据 Should we split batch_size according to ngpu_per_node when DistributedDataparallel [12] ,简单来说就是保持 batch_size learning rate 的乘积不变,因为我们多 GPU 训练一般不改 batch_size ,所以使用了多少 GPU 就要把 learning rate 扩大为原来的几倍。
  • 如何使 DistributedDataParallel 支持 Sparse Embedding ? 可以参考我在 PyTorch 论坛上的回答 DistributedDataParallel Sparse Embeddings [13] ,设置 torch.distributed.init_process_group 中的 backend=gloo 即可,现在版本(1.6 以及 nightly)的 PyTorch 在 nccl 中仍然不支持 Sparse Embedding 。关于这个问题的最新进展可以看这个 PR:Sparse allreduce for ProcessGroupNCCL [14]

最后,PyTorch 1.6 中提供了混合精度训练 amp [15] 的 API,可以方便的调用,通过将一部分操作从 torch.float32 (float) 变为 torch.float16 (half) 来加快训练速度。但是它并不支持 Sparse Embedding ,如果你的模型中包含 Sparse 向量的话,最好不要使用。

iQfeei.png!mobile

参考文献

iQfeei.png!mobile

[1] https://www.dgl.ai/

[2] https://pytorch.org/

[3]  https://docs.dgl.ai/guide/minibatch.html#chapter-6-stochastic-training-on-large-graphs

[4]  h ttps://github.com/dmlc/dgl/tree/master/examples/pytorch/GATNE-T

[5]  https://networkx.github.io/documentation/latest/reference/algorithms/generated/networkx.algorithms.components.connected_components.html#networkx.algorithms.components.connected_components

[6] https://discuss.pytorch.org/t/num-workers-in-dataloader-will-increase-memory-usage/28522

[7] https://github.com/pytorch/pytorch/issues/13246

[8]  https://github.com/dmlc/dgl/blob/master/examples/pytorch/pinsage/model_sparse.py

[9] https://pytorch.org/docs/stable/optim.html?highlight=sparsea#torch.optim.SparseAdam

[10]  https://docs.python.org/3/library/multiprocessing.html

[11]  https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_sampling_multi_gpu.py

[12] https://discuss.pytorch.org/t/should-we-split-batch-size-according-to-ngpu-per-node-when-distributeddataparallel/72769

[13]  https://discuss.pytorch.org/t/distributeddataparallel-sparse-embeddings/60410/2

[14] https://github.com/pytorch/pytorch/issues/22400

[15] https://pytorch.org/docs/stable/amp.html

更多阅读

aAniiyj.png!mobile

iIj2qyb.png!mobile

ZnIriy.png!mobile

2mUBJnB.gif!mobile

# 投 稿 通 道 #

让你的论文被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢? 答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是 最新论文解读 ,也可以是 学习心得技术干货 。我们的目的只有一个,让知识真正流动起来。

:memo:  来稿标准:

• 稿件确系个人 原创作品 ,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向) 

• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接 

• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志

:mailbox_with_mail:  投稿邮箱:

• 投稿邮箱: [email protected] 

• 所有文章配图,请单独在附件中发送 

• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通

:mag:

现在,在 「知乎」 也能找到我们了

进入知乎首页搜索 「PaperWeekly」

点击 「关注」 订阅我们的专栏吧

关于PaperWeekly

PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击 「交流群」 ,小助手将把你带入 PaperWeekly 的交流群里。

R7nmyuB.gif!mobile

feMfiqY.jpg!mobile


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK