在 TensorFlow 2.x 中看似不连续地打乱后的批处理元素

我有以下简单的例子:


import tensorflow as tf


tensor1 = tf.constant(value = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])

tensor2 = tf.constant(value = [20, 21, 22, 23])


print(tensor1.shape)

print(tensor2.shape)


dataset = tf.data.Dataset.from_tensor_slices((tensor1, tensor2))


print('Original dataset')

for i in dataset:

      print(i)


dataset = dataset.repeat(3)


print('Repeated dataset')

for i in dataset:

      print(i)

如果我然后将其批处理dataset为:


dataset = dataset.batch(3)


print('Batched dataset')

for i in dataset:

   print(i)

正如预期的那样,我收到:


Batched dataset

(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=

array([[1, 2, 3],

       [4, 5, 6],

       [7, 8, 9]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([20, 21, 22], dtype=int32)>)

(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=

array([[10, 11, 12],

       [ 1,  2,  3],

       [ 4,  5,  6]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([23, 20, 21], dtype=int32)>)

(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=

array([[ 7,  8,  9],

       [10, 11, 12],

       [ 1,  2,  3]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([22, 23, 20], dtype=int32)>)

(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=

array([[ 4,  5,  6],

       [ 7,  8,  9],

       [10, 11, 12]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([21, 22, 23], dtype=int32)>)

批处理数据集采用连续的元素。


但是,当我先进行混音,然后进行批处理时:


dataset = dataset.shuffle(3)


print('Shuffled dataset')

for i in dataset:

  print(i)


dataset = dataset.batch(3)


print('Batched dataset')

for i in dataset:

   print(i)

我正在使用 Google Colab 和TensorFlow 2.x.

我的问题是:为什么在批处理之前进行洗牌会导致batch返回非连续元素

感谢您的任何答复。


四季花海
浏览 118回答 1
1回答

12345678_0001

这就是洗牌的作用。你是这样开始的:[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]您已指定,buffer_size=3因此它会创建前 3 个元素的缓冲区:[[1, 2, 3], [4, 5, 6], [7, 8, 9]]您指定了batch_size=3,因此它将从此样本中随机选择一个元素,并将其替换为初始缓冲区之外的第一个元素。假设[1, 2, 3]已被选中,您的批次现在是:[[1, 2, 3]]现在你的缓冲区是:[[10, 11, 12], [4, 5, 6], [7, 8, 9]]对于 的第二个元素batch=3,它将从此缓冲区中随机选择。假设[7, 8, 9]已挑选,您的批次现在是:[[1, 2, 3], [7, 8, 9]]现在你的缓冲区是:[[10, 11, 12], [4, 5, 6]]没有什么新内容可以填充缓冲区,因此它将随机选择这些元素之一,例如[10, 11, 12]。您的批次现在是:[[1, 2, 3], [7, 8, 9], [10, 11, 12]]下一批将只是[4, 5, 6]因为默认情况下, batch(drop_remainder=False).
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python