17

TensorFlow小白教程:Graph计算图教程

 4 years ago
source link: https://www.tuicool.com/articles/E7BbI3b
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

计算图Graph我们可以理解成一个电路板,我们在电路板定义好电路,然后通过插头进行通电,整个电路就开始运作了。

在TensorFlow中会自动维护一个默认的一个计算图,所以我们能够直接定义的tensor或者运算都直接转换为计算图上以节点。

v1 = tf.constant(value=1,name='v1',shape=(1,2),dtype=tf.float32)
v2 = tf.constant(value=2,name='v2',shape=(1,2),dtype=tf.float32)
add = v1 + v2
with tf.Session() as sess:
    # 判断v1所在的graph是否是默认的graph
    print(v1.graph is tf.get_default_graph())
    print(add)
    # 输出 True
    # 输出 [[3. 3.]]

我们可以通过 tf.get_default_graph() 来获取当前节点所在的计算图。我们通过判断 v1 tensor所在的计算图和默认的计算图进行比较,发现 v1 的值处于默认的计算图上,由此也验证了:TensorFlow会自动维护一个默认的计算图,并将我们的节点添加到默认的计算图上。

BvUN3yi.png!web

我们可以看到默认的计算图上有三个节点,分别是 v1v1 节点,它们共同组成了 add 节点。

如何创建Graph

我们可以通过 tf.Graph() 新增计算图,并通过 as_default() 将变量和计算添加在当前的计算图中,最后通过Session的 graph=计算图 来计算指定的计算图。

# 新增计算图
new_graph = tf.Graph()
with new_graph.as_default():
    # 在新增的计算图中进行计算
    v1 = tf.constant(value=3, name='v1', shape=(1, 2), dtype=tf.float32)
    v2 = tf.constant(value=4, name='v2', shape=(1, 2), dtype=tf.float32)
    add = v1 + v2
#  通过graph=new_graph指定Session所在的计算图
with tf.Session(graph=new_graph) as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(add))
# 在默认计算图中进行计算
v1 = tf.constant(value=1,name='v1',shape=(1,2),dtype=tf.float32)
v2 = tf.constant(value=2,name='v2',shape=(1,2),dtype=tf.float32)
add = v1 + v2
# 通过graph=tf.get_default_graph()指定Session所在默认的计算图
with tf.Session(graph=tf.get_default_graph()) as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(add))
# 输出:[[7. 7.]]
# 输出:[[3. 3.]]

我们可以看出在不同的计算图中,它们之间的tensor和计算是相互隔离的。这就好比两个电路板,它们上面的电路是相互隔离的。

通过Graph整理资源

我们知道两个Graph上的tensor和计算是相互隔离的,在每一个计算图中,我们会有多个集合来管理不同类别的资源。下面是TensorFlow为我们自动管理了一些常用的集合。

集合名称 集合内容 使用场景 tf.GraphKeys.VARIABLES 所有变量 持久化 TensorFlow 模型 tf.GraphKeys.TRAINABLE_VARIABLES 可学习的变量(一般指神经网络中的参数) 模型训练、生成模型可视化内容 tf.GraphKeys.SUMMARIES 日志生成相关的张量 TensorFlow 计算可视化 tf.GraphKeys.QUEUE_RUNNERS 处理输入的 QueueRunner 输入处理 tf.GraphKeys.MOVING_AVERAGE_VARIABLES 所有计算了滑动平均值的变量 计算变量的滑动平均值

我们也可以通过 tf.add_to_collection(key,value) 方法去添加我们自定义的集合,通过 tf.get_collection(key) 去获取对应key下面的集合资源。

v1 = tf.constant(value=1,name='v1',shape=(1,2),dtype=tf.float32)
v2 = tf.constant(value=2,name='v2',shape=(1,2),dtype=tf.float32)
add = v1 + v2
# 添加自定义的集合
tf.add_to_collection('my_collection',v1)
tf.add_to_collection('my_collection',v2)
with tf.Session(graph=tf.get_default_graph()) as sess:
    # 获得对应的集合
    print(tf.get_collection('my_collection'))
# 输出:[<tf.Tensor 'v1:0' shape=(1, 2) dtype=float32>, <tf.Tensor 'v2:0' shape=(1, 2) dtype=float32>]

上面这段代码,我们自定义了一个名为 my_collection 的集合,并将 v1v2 通过 add_to_collection 方法添加到对应的集合中。并通过 get_collection 方法获取到了对应的集合。

在日后的开发过程中,我们会运用到集合来管理我们不同的类别的资源,以方便在神经网络中方便获取资源。

复盘

我们今天学习了Graph(计算图),我们定义的节点和计算都定义在这个计算图上,当我们通过Session执行对应的计算时,我们的计算图上的资源开始运转,计算得到最终的结果。

计算图也提供了集合,方便了我们在一个计算图中获取我们想要的不同类别的资源。


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK