呼唤远方
由于您是 TF2 的新手,建议您阅读本指南。本指南涵盖 TensorFlow 2.0 中两种广泛情况下的训练、评估和预测(推理)模型:使用内置 API 进行训练和验证时(例如 model.fit()、model.evaluate()、model.predict())。这在“使用内置训练和评估循环”部分中有所介绍。使用 eager execution 和 GradientTape 对象从头开始编写自定义循环时。这在“从头开始编写您自己的训练和评估循环”一节中有所介绍。下面是一个程序,我在其中计算每个纪元后的梯度并附加到列表中。在程序结束时,为了简单起见,我将转换list为array。代码 -如果我使用多层和更大过滤器尺寸的深度网络,这个程序会抛出 OOM Error 错误# Importing dependency%tensorflow_version 2.xfrom tensorflow import kerasfrom tensorflow.keras import backend as Kfrom tensorflow.keras import datasetsfrom tensorflow.keras.models import Sequentialfrom tensorflow.keras.layers import Dense, Activation, Dropout, Flatten, Conv2D, MaxPooling2Dfrom tensorflow.keras.layers import BatchNormalizationimport numpy as npimport tensorflow as tf# Import Data(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()# Build Modelmodel = Sequential()model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(32,32, 3)))model.add(MaxPooling2D((2, 2)))model.add(Conv2D(64, (3, 3), activation='relu'))model.add(MaxPooling2D((2, 2)))model.add(Conv2D(64, (3, 3), activation='relu'))model.add(Flatten())model.add(Dense(64, activation='relu'))model.add(Dense(10))# Model Summarymodel.summary()# Model Compile model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])# Define the Gradient Fucntionepoch_gradient = []loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)# Define the Gradient Function@tf.functiondef get_gradient_func(model): with tf.GradientTape() as tape: logits = model(train_images, training=True) loss = loss_fn(train_labels, logits) grad = tape.gradient(loss, model.trainable_weights) model.optimizer.apply_gradients(zip(grad, model.trainable_variables)) return grad# Define the Required Callback Functionclass GradientCalcCallback(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch, logs={}): grad = get_gradient_func(model) epoch_gradient.append(grad)epoch = 4print(train_images.shape, train_labels.shape)model.fit(train_images, train_labels, epochs=epoch, validation_data=(test_images, test_labels), callbacks=[GradientCalcCallback()])# (7) Convert to a 2 dimensiaonal array of (epoch, gradients) typegradient = np.asarray(epoch_gradient)print("Total number of epochs run:", epoch)输出 -Model: "sequential_5"_________________________________________________________________Layer (type) Output Shape Param # =================================================================conv2d_12 (Conv2D) (None, 30, 30, 32) 896 _________________________________________________________________max_pooling2d_8 (MaxPooling2 (None, 15, 15, 32) 0 _________________________________________________________________conv2d_13 (Conv2D) (None, 13, 13, 64) 18496 _________________________________________________________________max_pooling2d_9 (MaxPooling2 (None, 6, 6, 64) 0 _________________________________________________________________conv2d_14 (Conv2D) (None, 4, 4, 64) 36928 _________________________________________________________________flatten_4 (Flatten) (None, 1024) 0 _________________________________________________________________dense_11 (Dense) (None, 64) 65600 _________________________________________________________________dense_12 (Dense) (None, 10) 650 =================================================================Total params: 122,570Trainable params: 122,570Non-trainable params: 0_________________________________________________________________(50000, 32, 32, 3) (50000, 1)Epoch 1/41563/1563 [==============================] - 109s 70ms/step - loss: 1.7026 - accuracy: 0.4081 - val_loss: 1.4490 - val_accuracy: 0.4861Epoch 2/41563/1563 [==============================] - 145s 93ms/step - loss: 1.2657 - accuracy: 0.5506 - val_loss: 1.2076 - val_accuracy: 0.5752Epoch 3/41563/1563 [==============================] - 151s 96ms/step - loss: 1.1103 - accuracy: 0.6097 - val_loss: 1.1122 - val_accuracy: 0.6127Epoch 4/41563/1563 [==============================] - 152s 97ms/step - loss: 1.0075 - accuracy: 0.6475 - val_loss: 1.0508 - val_accuracy: 0.6371Total number of epochs run: 4希望这能回答您的问题。快乐学习。