如何使用 tensorflow_datasets (tfds) 实现和理解预处理和数据扩充?

我正在学习基于使用Oxford-IIIT Pets 的TF 2.0 教程的分割和数据增强。


对于预处理/数据增强,它们为特定管道提供了一组功能:


# Import dataset

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)


def normalize(input_image, input_mask):

  input_image = tf.cast(input_image, tf.float32) / 255.0

  input_mask -= 1

  return input_image, input_mask


@tf.function

def load_image_train(datapoint):

  input_image = tf.image.resize(datapoint['image'], (128, 128))

  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))


  if tf.random.uniform(()) > 0.5:

    input_image = tf.image.flip_left_right(input_image)

    input_mask = tf.image.flip_left_right(input_mask)


  input_image, input_mask = normalize(input_image, input_mask)


  return input_image, input_mask


TRAIN_LENGTH = info.splits['train'].num_examples

BATCH_SIZE = 64

BUFFER_SIZE = 1000

STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE


鉴于 tf 语法,这段代码给我带来了几个疑问。为了防止我只是做一个 ctrl C ctrl V 并真正理解 tensorflow 是如何工作的,我想问一些问题:


1) 在normalize函数中,tf.cast(input_image, tf.float32) / 255.0可以通过tf.image.convert_image_dtype(input_image, tf.float32)?


2) 在normalize函数中,可以在格式中更改我的 segmentation_mask 值tf.tensor而不更改为numpy?我想做的是只使用两个可能的掩码(0 和 1)而不是(0、1 和 2)。使用 numpy 我做了这样的事情:


segmentation_mask_numpy = segmentation_mask.numpy()

segmentation_mask_numpy[(segmentation_mask_numpy == 2) | (segmentation_mask_numpy == 3)] = 0

可以在没有 numpy 转换的情况下做到这一点吗?


3)在load_image_train函数中,他们说这个函数正在做数据增强,但是怎么做呢?在我看来,他们正在通过给定随机数的翻转来更改原始图像,而不是根据原始图像向数据集提供另一个图像。那么,功能目标是更改图像而不是向我的数据集添加保留原始图像的 aug_image?如果我是正确的,如何更改此函数以提供 aug_image 并将原始图像保留在数据集中?


4) 在其他问题中,例如How to apply data augmentation in TensorFlow 2.0 after tfds.load()和TensorFlow 2.0 Keras: How to write image summary for TensorBoard他们使用了很多.map()顺序调用或.map().map().cache().batch().repeat(). 我的问题是:有这个必要性吗?是否存在更简单的方法来做到这一点?我试图阅读 tf 文档,但没有成功。


5)您建议使用此处ImageDataGenerator介绍的 keras或这种 tf 方法更好?


偶然的你
浏览 167回答 1
1回答

有只小跳蛙

4 - 这些顺序调用的事情是,它们简化了我们操作数据集以应用转换的工作,并且他们还声称这是一种加载和处理数据的更具性能的方式。关于模块化/简单性,我猜它完成了它的工作,因为您可以轻松加载、将其传递给整个预处理管道、随机播放并使用几行代码迭代批量数据。train_dataset =tf.data.TFRecordDataset(filenames=train_records_paths).map(parsing_fn)train_dataset = train_dataset.shuffle(buffer_size=12000)train_dataset = train_dataset.batch(batch_size)train_dataset = train_dataset.repeat()# Create a test datasettest_dataset = tf.data.TFRecordDataset(filenames=test_records_paths).map(parsing_fn)test_dataset = test_dataset.batch(batch_size)test_dataset = test_dataset.repeat(1)# validation_steps = test_size / batch_size history = transferred_resnet50.fit(x=train_dataset,                        epochs=epochs,                        steps_per_epoch=steps_per_epoch,                                                validation_data=test_dataset,                        validation_steps=validation_steps)例如,为了加载我的数据集并为我的模型提供预处理数据,这就是我所要做的。3 - 他们定义了一个预处理函数,他们的数据集被映射到,这意味着每次请求样本时都会应用映射函数,就像在我的情况下,我使用解析函数来解析我的使用前 TFRecord 格式的数据:def parsing_fn(serialized):    features = \        {            'image': tf.io.FixedLenFeature([], tf.string),            'label': tf.io.FixedLenFeature([], tf.int64)                    }    # Parse the serialized data so we get a dict with our data.    parsed_example = tf.io.parse_single_example(serialized=serialized,                                             features=features)    # Get the image as raw bytes.    image_raw = parsed_example['image']    # Decode the raw bytes so it becomes a tensor with type.    image = tf.io.decode_jpeg(image_raw)        image = tf.image.resize(image,size=[224,224])        # Get the label associated with the image.    label = parsed_example['label']        # The image and label are now correct TensorFlow types.    return image, label(另一个例子) - 从上面的解析函数,我可以使用下面的代码来创建一个数据集,遍历我的测试集图像并绘制它们。records_path = DATA_DIR+'/'+'TFRecords'+'/test/'+'test_0.tfrecord'# Create a datasetdataset = tf.data.TFRecordDataset(filenames=records_path)# Parse the dataset using a parsing function parsed_dataset = dataset.map(parsing_fn)# Gets a sample from the iteratoriterator = tf.compat.v1.data.make_one_shot_iterator(parsed_dataset) for i in range(100):    image,label = iterator.get_next()    img_array = image.numpy()    img_array = img_array.astype(np.uint8)    plt.imshow(img_array)    plt.show()
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python