我有两个张量,我必须迭代第一个以只取另一个张量内的元素。只有一个元素t2也在里面t1。这里有一个例子
t1 = tf.where(values > 0) # I get some indices example [6, 0], [3, 0]
t2 = tf.where(values2 > 0) # I get [4, 0], [3, 0]
t3 = .... # [3, 0]
我尝试使用运算符来评估和迭代它们,.eval()并检查它们是否t2正在t1使用 operator in,但不起作用。TensorFlow 有没有可以做到这一点的函数?
编辑
for index in xrange(max_indices):
indices = tf.where(tf.equal(values, (index + 1))).eval() # indices: [[1 0]\n [4 0]\n [9 0]]
cent_indices = tf.where(centers > 0).eval() # cent_indices: [[6 0]\n [9 0]]
indices_list.append(indices)
for cent in cent_indices:
if cent in indices:
centers_list.append(cent)
break
第一次迭代cent具有值[6 0]但它进入if条件。
回答
for index in xrange(max_indices):
indices = tf.where(tf.equal(values, (index + 1))).eval()
cent_indices = tf.where(centers > 0).eval()
indices_list.append(indices)
for cent in cent_indices:
# batch_item is an iterator from an outer loop
if values[batch_item, cent[0]].eval() == (index + 1):
centers_list.append(tf.constant(cent))
break
该解决方案与我的任务有关,但如果您正在寻找一维张量中的解决方案,我建议您查看 tf.sets.set_intersection
炎炎设计
相关分类