我有以下简单的例子:
import tensorflow as tf
tensor1 = tf.constant(value = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
tensor2 = tf.constant(value = [20, 21, 22, 23])
print(tensor1.shape)
print(tensor2.shape)
dataset = tf.data.Dataset.from_tensor_slices((tensor1, tensor2))
print('Original dataset')
for i in dataset:
print(i)
dataset = dataset.repeat(3)
print('Repeated dataset')
for i in dataset:
print(i)
如果我然后将其批处理dataset为:
dataset = dataset.batch(3)
print('Batched dataset')
for i in dataset:
print(i)
正如预期的那样,我收到:
Batched dataset
(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([20, 21, 22], dtype=int32)>)
(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[10, 11, 12],
[ 1, 2, 3],
[ 4, 5, 6]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([23, 20, 21], dtype=int32)>)
(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[ 7, 8, 9],
[10, 11, 12],
[ 1, 2, 3]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([22, 23, 20], dtype=int32)>)
(<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]], dtype=int32)>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([21, 22, 23], dtype=int32)>)
批处理数据集采用连续的元素。
但是,当我先进行混音,然后进行批处理时:
dataset = dataset.shuffle(3)
print('Shuffled dataset')
for i in dataset:
print(i)
dataset = dataset.batch(3)
print('Batched dataset')
for i in dataset:
print(i)
我正在使用 Google Colab 和TensorFlow 2.x
.
我的问题是:为什么在批处理之前进行洗牌会导致batch
返回非连续元素?
感谢您的任何答复。
12345678_0001
相关分类