TensorFlow 中的回调函数
回调函数是 TensorFlow 训练之中非常重要的一部分,我们在之前的学习之中或多或少地用到了回调函数。比如在之前的过拟合一节之中,我们就曾经用到了早停回调。那么这节课我们就来学习以下 TensorFlow 之中的回调函数。
1. 什么是回调函数
简单来说,回调函数就是在训练到一定阶段的时候而执行的函数,我们最常采用的策略是每个Epoch结束之后执行一次回调函数。
回调函数的绝大多数 API 集中在 tf.keras.callbacks 之中,也就是说这是 Keras 之中的一个 API 。由于之前已经学习过早停回调,这节课我们来学习一下其他的几个常用的回调:
- 模型保存回调:tf.keras.callbacks.ModelCheckpoint;
- 学习率回调;tf.keras.callbacks.LearningRateScheduler;
- 自定义回调:tf.keras.callbacks.CallBack。
对于回调的使用方法,也是非常简单的,假设以下的数组之中定义了我们所需要的全部回调函数:
callbacks = [......]
那么我们在使用回调的时候,之中只需要在训练函数中指定回调即可:
model.fit(..., ..., callbacks=callbacks)
对于要介绍的回调,我们会首先给出介绍,然后再在统一的代码之中示例使用。
2. 模型保存回调
模型保存的回调函数为:
tf.keras.callbacks.ModelCheckpoint(
path, monitor='val_loss', verbose=0, save_best_only=False,
save_weights_only=False, save_freq='epoch')
这里只列出来了我们常用的参数,对于其中的每个参数,它们的作用如下:
- path: 保存模型的路径;
- monitor: 用哪个指标来评价模型的好坏,默认是验证集上的损失;
- verbose: 输出日志的等级,只能为 0 或 1;
- save_best_only: 是否只保存最好的模型,模型的好坏由 monitor 指定;
- save_weights_only: 是否只保存权重,默认 False ,也就是保存整个模型;
- save_freq: 保存的频率,可以为 ‘Epoch’ 或者一个整数,默认为每个 Epoch 保存一次模型;若是一个整数N,则是每训练 N 个 Batch 保存一次模型。
3. 学习率回调
学习率回调函数为:
tf.keras.callbacks.LearningRateScheduler(
schedule, verbose=0
)
其中 verbose 参数仍然是日志输出的等级,默认为 0 ;而 schedule 则是一个函数,用来定义一个学习率的变化。其中 schedule 函数的一个示例如下所示:
def my_schedule(epoch, lr):
if epoch < 20:
return lr
else:
return lr * 0.1
该学习率回调是在 20 个 Epoch 之前学习率保持不变,而在 20 个 Epoch 之后,每个 Epoch 学习率变为原来的 0.1 。
可以看出,该 schedule 函数由严格的形式,其中第一个参数为训练的 Epoch ,第二个参数为当前的学习率。
4. 自定义回调
我们在使用回调的过程之中难免会遇到要自定义回调的情况,这时我们便需要编写类来继承 tf.keras.callbacks.CallBack 类,从而实现我们的自定义回调。
在自定义回调的过程之中,你可以覆写不同的函数,从而可以实现在不同的时间来运行我们自定义的函数,这些函数包括:
- on_train_begin(self, logs=None): 在训练开始时调用;
- on_test_begin(self, logs=None): 在测试开始时调用;
- on_predict_begin(self, logs=None): 在预测开始时调用;
- on_train_end(self, logs=None) 在训练结束时调用;
- on_test_end(self, logs=None) 在测试结束时调用;
- on_predict_end(self, logs=None) 在预测结束时调用;
- on_train_batch_begin(self, batch, logs=None) 在训练期间的每个批次之前调用;
- on_test_batch_begin(self, batch, logs=None) 在测试期间的每个批次之前调用;
- on_predict_batch_begin(self, batch, logs=None) 在预测期间的每个批次之前调用;
- on_train_batch_end(self, batch, logs=None) 在训练期间的每个批次之后调用;
- on_test_batch_end(self, batch, logs=None) 在测试期间的每个批次之后调用;
- on_predict_batch_end(self, batch, logs=None) 在预测期间的每个批次之后调用;
- on_epoch_begin(self, epoch, logs=None) 在每次迭代训练开始时调用;
- on_epoch_end(self, epoch, logs=None) 在每次迭代训练结束时调用。
我们可以来使用其中两个简单的函数来做一个简单的示例:
class MyCallback(tf.keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
print("Start epoch {}.".format(epoch))
def on_train_begin(self, logs=None):
print("Starting training.")
这个样子,我们便可以在每次训练开始,以及每个 Epoch 开始之时进行输出日志。
5. 程序示例
在这里,我们将同时使用模型保存回调、学习率回调以及自定义回调来做一个简单的示例:
import tensorflow as tf
model = tf.keras.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
lr = 0.01
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
loss="mse"
)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
def my_schedule(epoch, lr):
print('Learning rate: ' + str(lr))
if epoch < 5:
return lr
else:
return lr * 0.1
lr_callback = tf.keras.callbacks.LearningRateScheduler(my_schedule)
save_model_callback = tf.keras.callbacks.ModelCheckpoint(
filepath='/model/', save_weights_only=True, verbose=1,
monitor='val_loss', mode='min', save_best_only=True)
class MyCallback(tf.keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs=None):
print("Start epoch {}.".format(epoch))
def on_train_begin(self, logs=None):
print("Starting training.")
model.fit(x_train, y_train,
batch_size=64, epochs=10,
validation_data=(x_test, y_test),
callbacks=[MyCallback(), lr_callback, save_model_callback],
)
在这里,我们按照之前学习的方法定义了三个回调函数,分别是模型保存回调、学习率回调、以及自定义回调。其中模型保存回调会在每次训练后保存模型、学习率回调会在第五个 Epoch 之后便每个 Epoch 变为原来的 0.1 ,而自定义回调会在训练开始之前、每个 Epoch 开始之前输出相应的信息。
于是我们可以得到输出:
Starting training.
Start epoch 0.
Learning rate: 0.009999999776482582
Epoch 1/10
931/938 [============================>.] - ETA: 0s - loss: 556.1402
Epoch 00001: val_loss improved from inf to 15.96259, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 552.3954 - val_loss: 15.9626
Start epoch 1.
Learning rate: 0.009999999776482582
Epoch 2/10
927/938 [============================>.] - ETA: 0s - loss: 12.4227
Epoch 00002: val_loss improved from 15.96259 to 10.01533, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 12.3927 - val_loss: 10.0153
Start epoch 2.
Learning rate: 0.009999999776482582
Epoch 3/10
914/938 [============================>.] - ETA: 0s - loss: 9.0919
Epoch 00003: val_loss improved from 10.01533 to 8.50834, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 9.0744 - val_loss: 8.5083
Start epoch 3.
Learning rate: 0.009999999776482582
Epoch 4/10
913/938 [============================>.] - ETA: 0s - loss: 8.3514
Epoch 00004: val_loss improved from 8.50834 to 8.26637, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.3450 - val_loss: 8.2664
Start epoch 4.
Learning rate: 0.009999999776482582
Epoch 5/10
920/938 [============================>.] - ETA: 0s - loss: 8.2481
Epoch 00005: val_loss improved from 8.26637 to 8.25048, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.2544 - val_loss: 8.2505
Start epoch 5.
Learning rate: 0.009999999776482582
Epoch 6/10
933/938 [============================>.] - ETA: 0s - loss: 8.2504
Epoch 00006: val_loss improved from 8.25048 to 8.25035, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.2502 - val_loss: 8.2504
Start epoch 6.
Learning rate: 0.0009999999310821295
Epoch 7/10
932/938 [============================>.] - ETA: 0s - loss: 8.2509
Epoch 00007: val_loss improved from 8.25035 to 8.25034, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.2501 - val_loss: 8.2503
Start epoch 7.
Learning rate: 9.99999901978299e-05
Epoch 8/10
916/938 [============================>.] - ETA: 0s - loss: 8.2600
Epoch 00008: val_loss improved from 8.25034 to 8.25034, saving model to /model/
938/938 [==============================] - 2s 2ms/step - loss: 8.2501 - val_loss: 8.2503
Start epoch 8.
Learning rate: 9.99999883788405e-06
Epoch 9/10
914/938 [============================>.] - ETA: 0s - loss: 8.2541
Epoch 00009: val_loss did not improve from 8.25034
938/938 [==============================] - 2s 2ms/step - loss: 8.2501 - val_loss: 8.2503
Start epoch 9.
Learning rate: 9.99999883788405e-07
Epoch 10/10
925/938 [============================>.] - ETA: 0s - loss: 8.2446
Epoch 00010: val_loss did not improve from 8.25034
938/938 [==============================] - 2s 2ms/step - loss: 8.2501 - val_loss: 8.2503
<tensorflow.python.keras.callbacks.History at 0x7eff7317f748>
可以看到,我们的三个回调函数都能正确地输出相应的信息,说明我们的回调函数已经成功生效。
6. 小结
在这节课之中,我们学习了什么是回调函数、模型保存回调、学习率回调以及如何自定义回调。同时我们又通过相应的示例演示了如何使用回调。