如何从 Keras 模型中获取数据以进行可视化?

我正在使用 Tensorflow 1.12,它将 Keras 与 Python 3.6.x 集成在一起


我希望使用 Keras 来简化模型构建,但也希望使用中间层上的数据进行特征图和内核的可视化,以更好地理解机器学习的工作原理(尽管这确实不那么明显)


我正在使用 mnist 数据库和一个非常基本的 Keras 模型来尝试做我想做的事情。


这是代码


import tensorflow as tf

from tensorflow.keras import layers

from tensorflow import keras


print(tf.VERSION)

print(tf.keras.__version__)


tf.keras.backend.clear_session()


mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train_shaped = np.expand_dims(x_train, axis=3) / 255.0

x_test_shaped = np.expand_dims(x_test, axis=3) / 255.0


def create_model():


  model = tf.keras.models.Sequential([

    keras.layers.Conv2D(32, kernel_size=(4, 4),strides=(1,1),activation='relu', input_shape=(28,28,1)),

    keras.layers.Dropout(0.5),

    keras.layers.MaxPooling2D(pool_size=(2,2), strides=(2,2)),

    keras.layers.Conv2D(24, kernel_size=(8, 8),strides=(1,1)),

    keras.layers.Flatten(),

    keras.layers.Dropout(0.5),

    keras.layers.Dense(128, activation=tf.nn.relu),

    keras.layers.Dense(10, activation=tf.nn.softmax)

  ])


  model.compile(optimizer=tf.keras.optimizers.Adam(), 

            loss=tf.keras.losses.sparse_categorical_crossentropy,

            metrics=['accuracy'])


  return model

以上设置了数据集和模型接下来我为 Tensorflow 定义我的会话并进行训练。


这一切都很好,但现在我想获取我的数据,例如,第一层作为理想的 numpy 数组,我可以在上面进行可视化。


我model.layers[0].output给了我Tensor的(?,25,25,32)预期,现在我尝试做一个eval()和thenafter一个.numpy()方法来获取我的结果。


错误信息是


You must feed a value for placeholder tensor 'conv2d_6_input' with dtype float and shape [?,28,28,1]

我正在寻求有关如何将我的数据(32 个 25x25 像素的特征图)作为 numpy 数组进行可视化的帮助。


sess = tf.Session(graph=tf.get_default_graph())

tf.keras.backend.set_session(sess)


with sess.as_default():

   model = create_model()

   model.summary()


   model.fit(x_train_shaped[:10000], y_train[:10000], epochs=2, 

   batch_size=64, validation_split=.2,)


   model.layers[0].output

   print(model.layers[0].output.shape)

   my_array = model.layers[0].output

   my_array.eval()


tf.keras.backend.clear_session()

sess.close()


呼啦一阵风
浏览 171回答 1
1回答
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python