猿问

取一个张量的元素,它也在另一个张量中

我有两个张量,我必须迭代第一个以只取另一个张量内的元素。只有一个元素t2也在里面t1。这里有一个例子


t1 = tf.where(values > 0) # I get some indices example [6, 0], [3, 0]

t2 = tf.where(values2 > 0) # I get [4, 0], [3, 0]


t3 = .... # [3, 0]

我尝试使用运算符来评估和迭代它们,.eval()并检查它们是否t2正在t1使用 operator in,但不起作用。TensorFlow 有没有可以做到这一点的函数?


编辑


for index in xrange(max_indices):

    indices = tf.where(tf.equal(values, (index + 1))).eval() # indices: [[1 0]\n [4 0]\n [9 0]]

    cent_indices = tf.where(centers > 0).eval() # cent_indices: [[6 0]\n [9 0]]

    indices_list.append(indices)

    for cent in cent_indices:

        if cent in indices:

           centers_list.append(cent)

           break

第一次迭代cent具有值[6 0]但它进入if条件。


回答


for index in xrange(max_indices):

    indices = tf.where(tf.equal(values, (index + 1))).eval()

    cent_indices = tf.where(centers > 0).eval()

    indices_list.append(indices)

    for cent in cent_indices:

        # batch_item is an iterator from an outer loop

        if values[batch_item, cent[0]].eval() == (index + 1):

           centers_list.append(tf.constant(cent))

           break

该解决方案与我的任务有关,但如果您正在寻找一维张量中的解决方案,我建议您查看 tf.sets.set_intersection


手掌心
浏览 226回答 1
1回答

炎炎设计

那是你想要的吗?我只使用了这两个测试用例。x = tf.constant([[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 1]])y = tf.constant([[1, 2, 3, 4, 3, 6], [1, 2, 3, 4, 5, 1]])# x = tf.constant([[1, 2], [4, 5], [7, 7]])# y = tf.constant([[7, 7], [3, 5]])def match(xiterations, yiterations, yvalues, xvalues ):    for i in range(xiterations):        for j in range(yiterations):            if (np.array_equal(yvalues[j], xvalues[i])):                print( yvalues[j])with tf.Session() as sess:    xindex = tf.where( x > 4 )    yindex = tf.where( y > 4 )    xvalues = xindex.eval()    yvalues = yindex.eval()    xiterations =  tf.shape(xvalues)[0].eval()    yiterations =  tf.shape(yvalues)[0].eval()    print(tf.shape(xvalues)[0].eval())    print(tf.shape(yvalues)[0].eval())    if tf.shape(xvalues)[0].eval() >= tf.shape(yvalues)[0].eval():        match( xiterations, yiterations, yvalues, xvalues)    else:        match( yiterations, xiterations, xvalues, yvalues)
随时随地看视频慕课网APP

相关分类

Python
我要回答