在 Tensorflow/Keras 中为重复元素创建掩码

我正在尝试为人员重新识别任务编写一个自定义损失函数,该函数在多任务学习设置和对象检测中进行训练。过滤后的标签值的形状为(batch_size, num_boxes)。我想创建一个掩码,以便仅考虑在暗淡 1 中重复的值进行进一步计算。如何在 TF/Keras 后端执行此操作?

简短示例

Input labels = [[0,0,0,0,12,12,3,3,4], [0,0,10,10,10,12,3,3,4]]
Required output: [[0,0,0,0,1,1,1,1,0],[0,0,1,1,1,0,1,1,0]]

(基本上我只想过滤掉重复项并丢弃损失函数的唯一标识)。

我想可以使用 tf.unique 和 tf.scatter 的组合,但我不知道如何使用。


小唯快跑啊
浏览 114回答 1
1回答

森林海

这段代码的工作原理:x = tf.constant([[0,0,0,0,12,12,3,3,4], [0,0,10,10,10,12,3,3,4]])def mark_duplicates_1D(x):  y, idx, count = tf.unique_with_counts(x)  comp = tf.math.greater(count, 1)  comp = tf.cast(comp, tf.int32)  res = tf.gather(comp, idx)  mult = tf.math.not_equal(x, 0)  mult = tf.cast(mult, tf.int32)  res *= mult  return resres = tf.map_fn(fn=mark_duplicates_1D, elems=x)
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python