[译]高效的TensorFlow 2.0:应用最佳实践以及有什么变化

1,851

Tensorflow团队早早就放出了风声,Tensorflow 2.0就快来了,这是一个重要的里程碑版本,重点放在简单和易用性上。我对Tensorflow 2.0的到来充满期待,因此翻译了这篇Tensorflow团队发布的文档:Effective TensorFlow 2.0: Best Practices and What’s Changed。原文地址:https://medium.com/tensorflow/effective-tensorflow-2-0-best-practices-and-whats-changed-a0ca48767aff ,略有删减。点击阅读原文可以跳转到该文章,需要翻墙哦!

最近的一篇文章中,我们提到,TensorFlow 2.0经过重新设计,重点关注开发人员的工作效率、简单性和易用性。

要深入了解所改变的内容及应用最佳实践,请查看新的Effective TensorFlow 2.0指南(发布在GitHub上)。本文简要概述那份指南里的内容。如果您对这些主题感兴趣,请前往指南了解更多信息!

主要变化概述

TensorFlow 2.0中有许多变化可以提高用户的工作效率,包括删除冗余API、使API更加一致(统一的RNN统一的优化器),以及Python运行时更好地集成Eager执行

许多RFC(如果您对它们感到陌生,请查看它们!)已经解释了制定TensorFlow 2.0的变化和思考。本指南展现了在TensorFlow 2.0中开发应该是什么样的。前提假设您对TensorFlow 1.x有一定的了解。

API清理

许多API在TF 2.0中消失或改变位置,有些则被替换为等效的2.0版本 -- tf.summary、tf.keras.metrics和tf.keras.optimizers。自动替换为新方法的最简单方法是使用v2升级脚本

Eager执行

TensorFlow 1.X要求用户调用tf.* API手动将抽象语法树(图)拼接在一起。 然后,用户需要通过将一组输出张量和输入张量传递给 session.run() 函数调用来手动编译抽象语法树。相比之下,TensorFlow 2.0立即执行(就像Python通常做的那样),在tf 2.0中,图形和会话感觉更像实现细节。

减少全局变量

TensorFlow 1.X严重依赖于隐式全局命名空间。调用 tf.Variable() 时,它会被放入默认图形中,它会保留在那里,即使忘记了指向它的Python变量。然后,您可以恢复该 tf.Variable ,但前提是您知道它已创建的名称。如果变量的创建不由您掌控,这就很难做到。结果,各种机制激增,试图帮助用户再次找到他们的变量。

TensorFlow 2.0取消了所有这些机制(Variables 2.0 RFC),启用默认机制:跟踪变量! 如果您失去了对 tf.Variable 的追踪,就会被垃圾回收。

函数,而不是会话

session.run() 调用几乎就像一个函数调用:指定输入和要调用的函数,然后返回一组输出。在TensorFlow 2.0中,您可以使用 tf.function() 来修饰Python函数以将其标记为JIT编译,使得TensorFlow将其作为单个图运行(Functions 2.0 RFC)。

这种机制允许TensorFlow 2.0获得图形模式的所有好处:

  • 性能:可以优化函数(节点修剪、内核融合等)
  • 可移植性:函数可以导出/重新导入(SavedModel 2.0 RFC),允许用户重用和共享模块化TensorFlow函数。

凭借自由分发Python和TensorFlow代码的能力,您可以充分利用Python的表现力。但是,便携式TensorFlow在没有Python解释器上下文时执行 - 移动、C++和JS。为了帮助用户避免在添加@tf.function时重写代码,* AutoGraph *会将部分Python构造转换为他们的TensorFlow等价物。

TensorFlow 2.0约定建议

将代码重构为更小的函数

TensorFlow 1.X中的常见使用模式是“水槽”策略,其中所有可能的计算的合集被预先排列,然后通过 session.run() 评估选择的张量。在TensorFlow 2.0中,用户应将其代码重构为较小的函数,这些函数根据需要调用。通常,没有必要用 tf.function 来修饰这些较小的函数,仅使用 tf.function 来修饰高级计算 - 例如,训练的一个步骤或模型的正向传递。

使用Keras图层和模型来管理变量

Keras模型和图层提供方便的变量和 trainable_variables 属性,以递归方式收集所有关联变量,这样可以轻松地将变量本地管理到它们的使用位置。

Keras层/模型继承自tf.train.Checkpointable并与@ tf.function集成,这使得直接获得检查点或从Keras对象导出SavedModel成为可能。您不一定要使用Keras's.fit()API来进行这些集成。

组合tf.data.Datasets和@tf.function

迭代加载到内存的训练数据时,可以随意使用常规的Python迭代。否则,tf.data.Dataset是从磁盘传输训练数据的最佳方式。数据集是可迭代的(不是迭代器),在Eager模式下和其他Python迭代一样工作。您可以通过将代码包装在tf.function()中来充分利用数据集异步预取/流特性,它会将Python迭代替换为使用AutoGraph的等效图形操作。

@tf.function
def train(model, dataset, optimizer):
 for x, y in dataset:
  with tf.GradientTape() as tape:
   prediction = model(x)
   loss = loss_fn(prediction, y)
  gradients = tape.gradients(loss, model.trainable_variables)
  optimizer.apply_gradients(gradients, model.trainable_variables)

如果您使用 Keras.fit() API,则无需考虑数据集迭代。

model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)

利用AutoGraph和Python控制流程

AutoGraph提供了一种将依赖于数据的控制流转换为等价图形模式的方法,如 tf.condtf.while_loop

数据相关控制流通常出现在序列模型。tf.keras.layers.RNN 封装了RNN单元格,允许您静态或动态地展开循环。处于演示目的,您可以重新实现动态展开,如下所示:

class DynamicRNN(tf.keras.Model):

  def __init__(self, rnn_cell):
    super(DynamicRNN, self).__init__(self)
    self.cell = rnn_cell
 
  def call(self, input_data):
    # [batch, time, features] -> [time, batch, features]
    input_data = tf.transpose(input_data, [1, 0, 2])
    outputs = tf.TensorArray(tf.float32, input_data.shape[0])
    state = self.cell.zero_state(input_data.shape[1], dtype=tf.float32)
    for i in tf.range(input_data.shape[0]):
      output, state = self.cell(input_data[i], state)
      outputs = outputs.write(i, output)
    return tf.transpose(outputs.stack(), [1, 0, 2]), state

使用tf.metrics聚合数据并用tf.summary来记录日志

最后,一套完整的 tf.summary 符号即将推出。您可以使用以下命令访问 tf.summary 的2.0版本:

from tensorflow.python.ops import summary_ops_v2

下一步

本文提供了Effective TF 2.0指南的简要(如果您对这些主题感兴趣,请到那里了解更多!)要了解有关TensorFlow 2.0的更多信息,我们还推荐这些近期文章: