带有 Tensorflow Dataset API 的 Keras 自动编码器并记录到

我在 Keras 中有简单的自动编码器,我想使用日志记录到张量板(因此我需要传递验证数据),并使用 Tensorflow Dataset API 使用预取从 TFRecord 加载数据。我读了一些关于它的文章,但他们要么省略了验证管道,要么直接传递数据而不使用 feed dict 的事实要慢得多。


源代码是


import tensorflow as tf

from keras.losses import mean_squared_error

from keras.models import Sequential, Model

from keras.layers import Dense, Input, Flatten, Reshape, Convolution2D,     Convolution2DTranspose, Conv2D, Conv2DTranspose

from keras.optimizers import Adam

from keras import backend as K

from keras.callbacks import TensorBoard


def create_dataset(tf_record, batch_size):

    data = tf.data.TFRecordDataset(tf_record)

    data = data.map(TFReader._parse_example_encoded, num_parallel_calls=8)

    data = data.apply(tf.data.experimental.shuffle_and_repeat(buffer_size=100))

    data = data.batch(batch_size, drop_remainder=True)

    data = data.prefetch(4)

    return data



def main(_):

    batch_size = 8  # todo: check and try bigger

    data = create_dataset('../../datasets/anime/no-game-no-life-ep-2.tfrecord', batch_size)

    iterator = data.make_one_shot_iterator()


    K.set_image_data_format('channels_last')  # set format


    input_tensor = Input(tensor=iterator.get_next())

    out = Conv2D(8, (3, 3), activation='elu', border_mode='valid', batch_input_shape=(batch_size, 432, 768, 3))(input_tensor)

    out = Conv2D(16, (3, 3), activation='elu', border_mode='valid')(out)

    out = Conv2D(32, (3, 3), activation='elu', border_mode='valid', name='bottleneck')(out)

    out = Conv2DTranspose(32, (3, 3), activation='elu', padding='valid')(out)

    out = Conv2DTranspose(16, (3, 3), activation='elu', padding='valid')(out)

    out = Conv2DTranspose(8, (3, 3), activation='elu', padding='valid')(out)

    out = Conv2D(3, (3, 3), activation='elu', padding='same')(out)

    m = Model(inputs=input_tensor, outputs=out)

    m.compile(loss=mean_squared_error, optimizer=Adam(), target_tensors=iterator.get_next())

    print(m.summary())


繁星淼淼
浏览 164回答 1
1回答

慕村9548890

几个选项:您是否看过此链接https://github.com/keras-team/keras/issues/3358(juiceboxjoe 的解决方案)?编写一个 TensorboardWrapper,它从生成器加载验证数据并将其作为回调传递。如果您不关心验证,请从训练数据中加载一两个样本并将它们作为数组传递给 validation_data。如果不需要直方图,则设置 histogram_freq = 0。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python