Tensorflow Estimator:缓存瓶颈

在遵循 tensorflow 图像分类教程时,首先它会缓存每个图像的瓶颈:

定义:cache_bottlenecks())

我已经使用 tensorflow 的Estimator. 这确实简化了所有代码。但是我想在这里缓存瓶颈功能。

这是我的model_fn. 我想缓存dense层的结果,这样我就可以对实际训练进行更改,而不必每次都计算瓶颈。

我怎样才能做到这一点?

def model_fn(features, labels, mode, params):

    is_training = mode == tf.estimator.ModeKeys.TRAIN


    num_classes = len(params['label_vocab'])


    module = hub.Module(params['module_spec'], trainable=is_training and params['train_module'])

    bottleneck_tensor = module(features['image'])


    with tf.name_scope('final_retrain_ops'):

        logits = tf.layers.dense(bottleneck_tensor, units=num_classes, trainable=is_training)  # save this?


    def train_op_fn(loss):

        optimizer = tf.train.AdamOptimizer()

        return optimizer.minimize(loss, global_step=tf.train.get_global_step())


    head = tf.contrib.estimator.multi_class_head(n_classes=num_classes, label_vocabulary=params['label_vocab'])


    return head.create_estimator_spec(

        features, mode, logits, labels, train_op_fn=train_op_fn

    )


有只小跳蛙
浏览 170回答 2
2回答

慕仙森

TF 无法在您编码时工作。你应该:从原始网络导出瓶颈到文件。使用瓶颈结果作为输入,使用另一个网络来训练您的数据。

守着一只汪

这样的事情应该工作(未经测试):# Serialize the data into two tfrecord filestf.enable_eager_execution()feature_extractor = ...features_file = tf.python_io.TFRecordWriter('features.tfrec')label_file = tf.python_io.TFRecordWriter('labels.tfrec')for images, labels in dataset:  features = feature_extractor(images)  features_file.write(tf.serialize_tensor(features))  label_file.write(tf.serialize_tensor(labels))# Parse the files and zip them togetherdef parse(type, shape):  _def parse(x):    result = tf.parse_tensor(x, out_type=shape)    result = tf.reshape(result, FEATURE_SHAPE)    return result  return parsefeatures_ds = tf.data.TFRecordDataset('features.tfrec')features_ds = features_ds.map(parse(tf.float32, FEATURE_SHAPE), num_parallel_calls=AUTOTUNE)labels_ds = tf.data.TFRecordDataset('labels.tfrec')labels_ds = labels_ds.map(parse(tf.float32, FEATURE_SHAPE), num_parallel_calls=AUTOTUNE)ds = tf.data.Dataset.zip(features_ds, labels_ds)ds = ds.unbatch().shuffle().repeat().batch().prefetch()...您也可以使用 来完成它Dataset.cache,但我不是 100% 确定细节。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python