使用 tf.function 提升效率
在之前的入门介绍之中,我们曾经介绍过 TensorFlow1.x 采用的并不是 Eager execution 执行模型;而 TensorFlow2.x 默认采用的是 Eager execution 模式。
这种改变使得我们可以更加容易地学习,但是也会造成性能的损失,因此, TensorFlow 在 2.0 版本之后引入了 tf.function 。
在 TensorFlow1.x 之中,如果我们想要运行一个学习任务,那么我们需要首先创建一个 tf.Sesstion (),然后再调用 Session.run () 进行运行。
其实在 TensorFlow1.x 内部,当我们在 TensorFlow 之中进行工作的时候, TensorFlow 会帮助我们创建一个计算图 tf.graph ,然后通过 tf.Session 对计算图进行计算。
而在 TensorFlow2.x 之中,其默认采用的是 Eager execution 执行方式,在该执行方式之中,我们不再需要定义一个计算图来进行。
这样就产生了一些问题:
- 使用 tf.Sesstion () 的运行效率非常高,但是代码很难懂;
- 使用 Eager execution 方式的代码很简单,但是执行效率比较低。
有什么方法能够兼顾两者吗?
那就是 tf.function 。
tf.function 是一个函数标注修饰,也就是如下的形式:
@tf.function def my_function(): ...
代码块预览 复制
其实如你所见,这就是 tf.function 的全部用法。
我们只需要在我们要修饰的函数之前加上 tf.function 标注既可。
采用 tf.function,TensorFlow 会将该函数转变为计算图 tf.graph 的形式来进行运算,这会使得该函数在进行大量运算的时候会加速非常多。
是不是所有的函数都适合 tf.function 进行修饰呢?
答案是否定的,以下两种情况不适合使用 tf.function 进行修饰:
- 函数本身的计算非常简单,那么构建计算图本身的时间就会相对非常浪费;
- 当我们需要在函数之中定义 tf.Variable 的时候,因为 tf.function 可能会被调用多次,因此定义 tf.Variable 会产生重复定义的情况。
既然了解了 tf.function 的用法,那么我们便来测试一下 tf.function 的性能,我们采用一个简单的卷积神经网络来进行测试:
import tensorflow as tf import timeit def f1(layer, image): y = layer(image) return y @tf.function def f2(layer, image): y = layer(image) return y layer = tf.keras.layers.Conv2D(300, 3) image = tf.zeros([64, 32, 32, 3]) model = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)), tf.keras.layers.MaxPooling2D((2, 2)), tf.keras.layers.Conv2D(64, (3, 3), activation='relu'), tf.keras.layers.MaxPooling2D((2, 2)), tf.keras.layers.Conv2D(128, (3, 3), activation='relu'), tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(10, activation='softmax')]) print(timeit.timeit(lambda: f1(model, image), number=500)) print(timeit.timeit(lambda: f2(model, image), number=500))
代码块预览 复制
在这里,我们定义了两个相同的函数,其中一个使用了 tf.function 进行修饰,而另外一个没有。
在这里我们使用 lambda 函数来让函数重复执行 500 次,并且使用 timeit 来进行时间的统计,得到两个函数的执行时间,从而进行比较。
最终,我们可以得到结果:
17.20403664399987 12.07886406200032
代码块预览 复制
由此可以看出,我们的 tf.function 已经提升了一定的速度,但是提升的速度有限,目前大概提升了 25 % 的速度。这是因为我们的计算仍然还是太简单了,当我们计算非常大的时候,性能会有很大的提升。
在这节课之中,我们学习到了什么是 tf.function ,以及 tf.function 的基本原理,然后我们了解了 tf.function 的使用方法;最后我们通过一个简单的神经网络来进行了性能的测试,最终我们发现我们的 tf.function 确实能给我们性能带来很大的提升。