章节索引 :

在 Keras 中保存与加载模型

无论是在学习还是在工作的过程之中,我们都会遇到保存数据的情形。

在我们之前的学习之中,我们所训练到的模型都没有经过保存,也就是所我们得到的模型的结构和参数都是存在于内存之中的,当我们关闭程序的时候这些模型和参数都会消失;如果我们想要使用该模型的话就需要再次训练模型。

这显然是不可取的,因此我们要学会如何保存模型与加载模型。

1. 定义模型结构

由于我们这节课的重点在模型的保存,而不是网络的结构,因此我们使用之前的网络结构: fashion_mnist 分类的网络结构。

具体的网络代码为:


import tensorflow as tf

# 使用内置的数据集合来加载数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

# 预处理图片数据,使其归一化
x_train, x_test = x_train / 255.0, x_test / 255.0

def get_model():
    # 定义网络结构
    model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
    ])
  return model

modle = get_model()

# 编译模型
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

在这里我们不仅仅定义了网络的基本结构,同时也载入了基本的图片数据,从而便于后面的训练以及模型保存等操作。

2. 在训练结束后保存模型参数、加载模型参数

我们可以在训练之前直接保存模型参数,但是因为这样的参数是未经过训练的,因此没有太有价值的意义,因此我们在保存模型之前要先训练模型

我们可以通过以下代码来训练模型:

# 训练模型
model.fit(x_train, y_train, epochs=5)

训练的过程之中我们可以得到如下的输出:

Epoch 1/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4870 - accuracy: 0.8288
Epoch 2/5
1875/1875 [==============================] - 5s 2ms/step - loss: 0.3616 - accuracy: 0.8679
Epoch 3/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3256 - accuracy: 0.8795
Epoch 4/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3006 - accuracy: 0.8883
Epoch 5/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.2867 - accuracy: 0.8931

2.1 保存模型参数

在训练结束之后我们可以手动进行模型参数的保存:

model.save_weights('./checkpoints/ckpt')

通过这样的操作,我们便可以将我们模型的参数保存至当前目录的 “checkpoints” 文件夹下面,并且名命为 ckpt 。

我们可以查看该文件夹下面的文件,可以看到文件夹下面包括三个文件:

79  checkpoint
1.2K  checkpoints.index
2.4M  checkpoints.data-00000-of-00001

这三个文件之中保存的就是我们的模型的参数。

2.2 加载模型参数

如果我们需要加载我们的模型,我们只需要经过以下两步即可:

  • 定义网络结构;
  • 按照保存路径来载入参数。

具体代码如下:

# 创建模型结构
model = get_model()

# 加载参数
model.load_weights('./checkpoints/ckpt')
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',
            metrics=['accuracy'])
# 评估模型
model.evaluate(x_test,  y_test, verbose=2)

我们可以看到模型的输出为:

313/313 - 0s - loss: 0.3448 - accuracy: 0.8758
[0.34482061862945557, 0.8758000135421753]

说明我们的模型参数已经成功加载。

3. 使用回调保存模型参数

前面我们知道了如何在模型训练结束后保存模型,那么如何让模型在训练的过程中自动保存模型呢?

那便就需要用到 TensorFlow 的**“回调函数”**这个功能,这个功能允许我们定义一系列的事件,并让其在训练的过程之中执行。

在这个例子之中,我们可以让它在每个 Epoch 结束的时候保存模型参数。

于是我们首先定义了模型保存的回调函数,然后我们又在在 fit 函数之中使用 callbacks 参数来将其传入。

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath='./checkpoints2/ckpt', save_weights_only=True)

model.fit(x_train,
          y_train,  
          epochs=10,
          validation_data=(x_test, y_test),
          callbacks=[cp_callback])

我们可以看到,在每个 Epoch 结束后,模型都会进行模型参数的保存

Epoch 6/10
1868/1875 [============================>.] - ETA: 0s - loss: 0.2704 - accuracy: 0.8985
Epoch 00006: saving model to ./checkpoints2/ckpt
INFO:tensorflow:Assets written to: ./checkpoints2/ckpt/assets

于是我们便可以使得模型能够自动地保存模型参数。

cp_callback 中的几个参数大家需要注意一下:

  • file_path: 与手动保存模型一样,定义了模型参数保存的路径;
  • save_weights_only: 是否只保存模型参数,一般而言只保存参数的文件会比全部保存的文件小很多,因此我们一般只是保存网络参数。

这样可以避免因为意外情况导致程序意外停止时,前面所有的训练都前功尽弃的情况。因为我们可以加载最近一次保存的模型继续训练。

如果想要加载模型,那么便和手动加载模型一样即可:

model.load_weights('./checkpoints2/ckpt')

4. 保存模型与保存参数

前面的保存都是只保存网络中的各种参数,而没有保存网络的模型。相比较而言而这主要有以下差别:

  • 保存参数的文件较小,而保存整个模型的文件较大
  • 加载参数速度较快,而加载整个模型较慢
  • 保存参数不包含网络结构,而保存整个模型则包含网络的结构

4.1 在训练结束后手动保存与加载整个模型

和之前的操作一样,只是我们需要换一下保存的API函数:

model.save('saved_model/model1')

当我们需要加载模型的时候,我们需要使用以下方法来加载模型:

model = tf.keras.models.load_model('saved_model/model1')

4.2 在回调之中保存整个模型

在回调之中保存整个模型比较简单,我们只需要将 save_weights_only 参数设置为 False 即可:

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath='./checkpoints3/ckpt', save_weights_only=False)

model.fit(x_train,
          y_train,  
          epochs=10,
          validation_data=(x_test, y_test),
          callbacks=[cp_callback])

5. 小结

这节课之中我们主要学习了如何进行模型的保存与加载,同时我们又深入了解了保存模型与保存参数的区别以及它们具体的实现方式。

图片描述

TensorFlow 简介、安装与快速入门
TensorFlow 简介 TensorFlow 安装 - CPU TensorFlow 安装 - GPU TensorFlow 快速入门示例
TensorFlow 模型的简洁表示-Keras
Keras 简介 使用 tf.keras 进行图片分类 使用 Keras 进行文本分类 使用 Keras 进行回归 在 Keras 中保存与加载模型 在 Keras 中进行模型的评估 Keras 中的Masking 与 Padding
TensorFlow 中的数据格式
TensorFlow 中的数据核心 使用 TensorFlow 加载 CSV 数据 使用 TensorFlow 加载 Numpy 数据 使用 TF 加载 DateFrame 数据 使用图像数据来训练模型 在 TensorFlow 之中使用文本数据 TF 之中的 Unicode 数据格式的处理
TensorFlow模型的高级表示-Estimat
使用预设的 Estimator 模型 将Keras模型转化为Estimator模型 Estimator实现BoostingTree模型
TensorFlow 高级技巧
过拟合问题 TensorFlow 中的回调函数 文本数据嵌入 在 TensorFlow 之中使用卷积神经网络 在 TensorFlow 之中使用循环神经网络 在 TensorFlow 之中使用注意力模型 在 TensorFlow 之中进行迁移学习 在 TensorFlow 之中进行数据增强 在 TensorFlow 之中进行图像分割 如何进行多 GPU 的分布式训练? 使用 tf.function 提升效率 使用 TF HUB 进行模型复用
TensorFlow高级技巧-自定义
使用 TensorFlow 进行微分操作 在 TensorFlow 之中自定义网络层与模型 在 TensorFlow 之中自定义训练
TF 框架中的可视化工具-TensorBoard
TensorBoard 的简介与快速上手 使用 TensorBoard 记录训练中的各项指标 在 TensorBoard 之中查看模型结构图 在 TensorBoard 之中记录图片数据