TensorFlow小白教程:Graph计算图教程

4,223 阅读3分钟

第一章 TensorFlow 基础概念

TensorFlow小白教程:深度学习的理解

TensorFlow小白教程:Tensor基础教程

TensorFlow小白教程:Session基础教程

TensorFlow小白教程:Graph计算图教程

什么是Graph

计算图Graph我们可以简单理解成一个电路板,我们在电路板定义好电路(定义计算和tensor),然后通过插头进行通电(通过Session进行计算),整个电路就开始运作了。

在TensorFlow中会自动维护一个默认的一个计算图,所以我们能够直接定义的tensor或者运算都会被转换为计算图上一个节点。

v1 = tf.constant(value=1,name='v1',shape=(1,2),dtype=tf.float32)
v2 = tf.constant(value=2,name='v2',shape=(1,2),dtype=tf.float32)
add = v1 + v2
with tf.Session() as sess:
    # 判断v1所在的graph是否是默认的graph
    print(v1.graph is tf.get_default_graph())
    print(add)
    # 输出 True
    # 输出 [[3. 3.]]

我们可以通过tf.get_default_graph()来获取当前节点所在的计算图。我们通过判断v1tensor所在的计算图和默认的计算图进行比较,发现v1的值处于默认的计算图上,由此也验证了:TensorFlow会自动维护一个默认的计算图,并将我们的节点添加到默认的计算图上。

我们可以看到默认的计算图上有三个节点,分别是v1v1节点,它们共同组成了add节点。

如何创建Graph

我们可以通过tf.Graph()新增计算图,并通过as_default()将变量和计算添加在当前的计算图中,最后通过Session的graph=计算图来计算指定的计算图。

# 新增计算图
new_graph = tf.Graph()
with new_graph.as_default():
    # 在新增的计算图中进行计算
    v1 = tf.constant(value=3, name='v1', shape=(1, 2), dtype=tf.float32)
    v2 = tf.constant(value=4, name='v2', shape=(1, 2), dtype=tf.float32)
    add = v1 + v2
#  通过graph=new_graph指定Session所在的计算图
with tf.Session(graph=new_graph) as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(add))
# 在默认计算图中进行计算
v1 = tf.constant(value=1,name='v1',shape=(1,2),dtype=tf.float32)
v2 = tf.constant(value=2,name='v2',shape=(1,2),dtype=tf.float32)
add = v1 + v2
# 通过graph=tf.get_default_graph()指定Session所在默认的计算图
with tf.Session(graph=tf.get_default_graph()) as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(add))

# 输出:[[7. 7.]]
# 输出:[[3. 3.]]

我们可以看出在不同的计算图中,它们之间的tensor和计算是相互隔离的。这就好比两个电路板,它们上面的电路是相互隔离的。

通过Graph整理资源

我们知道两个Graph上的tensor和计算是相互隔离的,在每一个计算图中,我们会有多个集合来管理不同类别的资源。下面是TensorFlow为我们自动管理了一些常用的集合。

集合名称 集合内容 使用场景
tf.GraphKeys.VARIABLES 所有变量 持久化 TensorFlow 模型
tf.GraphKeys.TRAINABLE_VARIABLES 可学习的变量(一般指神经网络中的参数) 模型训练、生成模型可视化内容
tf.GraphKeys.SUMMARIES 日志生成相关的张量 TensorFlow 计算可视化
tf.GraphKeys.QUEUE_RUNNERS 处理输入的 QueueRunner 输入处理
tf.GraphKeys.MOVING_AVERAGE_VARIABLES 所有计算了滑动平均值的变量 计算变量的滑动平均值

我们也可以通过tf.add_to_collection(key,value)方法去添加我们自定义的集合,通过tf.get_collection(key)去获取对应key下面的集合资源。

v1 = tf.constant(value=1,name='v1',shape=(1,2),dtype=tf.float32)
v2 = tf.constant(value=2,name='v2',shape=(1,2),dtype=tf.float32)
add = v1 + v2
# 添加自定义的集合
tf.add_to_collection('my_collection',v1)
tf.add_to_collection('my_collection',v2)
with tf.Session(graph=tf.get_default_graph()) as sess:
    # 获得对应的集合
    print(tf.get_collection('my_collection'))
# 输出:[<tf.Tensor 'v1:0' shape=(1, 2) dtype=float32>, <tf.Tensor 'v2:0' shape=(1, 2) dtype=float32>]

上面这段代码,我们自定义了一个名为my_collection的集合,并将v1v2通过add_to_collection方法添加到对应的集合中。并通过get_collection方法获取到了对应的集合。

在日后的开发过程中,我们会运用到集合来管理我们不同的类别的资源,以方便在神经网络中方便获取资源。

复盘

我们今天学习了Graph(计算图),我们定义的节点和计算都定义在这个计算图上,当我们通过Session执行对应的计算时,我们的计算图上的资源开始运转,计算得到最终的结果。

计算图也提供了集合,方便了我们在一个计算图中获取我们想要的不同类别的资源。