Tensorflow 数据:将函数应用于批处理

我正在使用 tf.data 从大型文本语料库中迭代批处理。

我只想将函数应用于数据子集(或批处理子集),而不是一个一个元素。具体来说,我的数据迭代器产生 query, reply批处理。它们都是正对,所以我只想洗牌下一批的子集(在这种情况下,只有“回复”批次)以生成随机负数。

例如,输入:

query1 reply1

query2 reply2

query3 reply3

...

输出:

  • 正对:(query1 reply1与输入相同)

  • 负对:(query1 replyN回复随机洗牌)

当然,可以使用 python 对文本进行混洗,但我想使用 tf.data 使其高效,因为数据大小太大。


月关宝盒
浏览 135回答 1
1回答

千万里不及你

假设你有queries和replies作为两个张量。您需要的是我认为类似于下面的内容,您可以将其与原始批次连接起来。batch_size = 10def reply_shuffle(queries, replies):   shuffled_indices = tf.random_uniform(minval=0, maxval=batch_size+1, shape=[batch_size], dtype=tf.int32)   shuffled_replies = tf.gather_nd(replies, shuffled_indices)    return queries, shuffled_replies
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python