5

Tensorflow上手1: Print与py_func

 2 years ago
source link: https://yaoyaowd.medium.com/tensorflow%E4%B8%8A%E6%89%8B1-print%E4%B8%8Epy-func-d78571a6a87f
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Tensorflow上手1: Print与py_func

YOUR INSTAGRAM #DOGS AND #CATS ARE TRAINING FACEBOOK’S AI

这个Blog终于又要开始更新了,这几个月每天都在研究Tensorflow,虽然之前工作中也用它训练了不少模型,可是大部分模型都非常简单.最近接触了更多的深度学习模型后,对Tensorflow有了一些新的体会,于是打算写一系列相关文章,记录自己的学习心得.目前计划有简单的函数介绍,关于Keras的使用感触,分布式训练与优化,以及CUDA编程和Tensorflow的自定义函数.

今天就先从简单的调试和功能扩展开始.

深度学习模型的调试在我看来分为两部分,一部分是网络结构的调试,测试结构实现是否正确;另一部分是对于结果的调试,不断提升训练效果.目前我最关注的还是代码本身是否实现正确,每天用到最多的函数就是tf.Print.

善用tf.Print

最常在产品中进行调试的办法就是打印log,而tf.Print就是Tensorflow中最好的加log办法.在使用Tensorflow的过程中,我们可以把需要输出的Tensor做一个自己指向自己的Print操作,从而将Tensor打印到屏幕上,实时监控训练状态,比如:

node = tf.Print(node, [some debug nodes])

利用tf.Print自带的三个参数,基本上可以满足所有的调试日常需要.Message会在屏幕上输出帮助调试的信息.First_n以及summarize控制在屏幕上的调试信息数量,方便我们更好的阅读和理解程序.

扩展双刃剑tf.py_func

第一次看到tf.py_func是在研究如何自己实现PASCAL VOC 2007的衡量函数MAP的时候,我一直觉得Tensorflow对于自定义metrics并不友好,metrics的aggregation做的并不容易扩展.但是看到Tensorflow给出的官方代码采用该tf.py_func去装饰一个参数为numpy array的函数,瞬间茅塞顿开.

这里是该函数的代码地址: https://github.com/tensorflow/models/blob/master/research/object_detection/utils/object_detection_evaluation.py#L437

在tf.py_func包裹的函数里,我们可以很轻松的对数组进行排序,计算各种想要的数据,比如average precision, recall等等.同时我们还可以将传入的numpy输出到屏幕上进行调试,大大扩展了程序的灵活性.

Tensorflow另一种增加灵活性的方法就是定义custom ops,比起该方法,tf.py_func虽然灵活但是也有很多不足.首先tf.py_func的运行属于Python,程序效率上有一定的影响,并且被包装的计算节点必须与调用它的Python程序运行在同一物理设备上,在分布式环境下并不实用(但可以运行在evaluator节点).其次,被包装的函数不能被序列化,所以不利于存储和恢复模型.最后,也因为tf.py_func是脱离Graph存在的,所以不能定义可以训练的参数和进行网络优化.

总的来说,tf.py_func是一个又灵活又有局限性的函数,但是我个人觉得它非常适合用来实现评价函数.


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK