tensorflow找到到真实点的最小距离

我有一个 Bx3 张量,fooB= 批量大小的 3D 点。通过某种幻想,我得到了另一个张量,bar形状为 Bx6x3,其中每个 B 6x3 矩阵对应于foo. 该 6x3 矩阵由 6 个复值 3D 点组成。我想做的是,对于我的每个 B 点,从6 in 中找到最接近对应点 in的实值点,最终得到一个新的 Bx3 ,其中包含与 in点的最近点。barfoomin_barbarfoo


在numpy中,我可以使用屏蔽数组来完成这一壮举:


foo = np.array([

    [1,2,3],

    [4,5,6],

    [7,8,9]])

# here bar is only Bx2x3 for simplicity, but the solution generalizes

bar = np.array([

    [[2,3,4],[1+0.1j,2+0.1j,3+0.1j]],

    [[6,5,4],[4,5,7]],

    [[1j,1j,1j],[0,0,0]],

])


#mask complex elements of bar

bar_with_masked_imag = np.ma.array(bar)

candidates = bar_with_masked_imag.imag == 0

bar_with_masked_imag.mask = ~candidates


dists = np.sum(bar_with_masked_imag**2, axis=1)

mindists = np.argmin(dists, axis=1)

foo_indices = np.arange(foo.shape[0])

min_bar = np.array(

    bar_with_masked_imag[foo_indices,mindists,:], 

    dtype=float

)


print(min_bar)

#[[2. 3. 4.]

# [4. 5. 7.]

# [0. 0. 0.]]

但是,tensorflow 没有掩码数组等。我如何将其翻译成张量流?


繁花不似锦
浏览 89回答 1
1回答

幕布斯7119047

这是一种方法:import tensorflow as tfimport mathdef solution_tf(foo, bar):    foo = tf.convert_to_tensor(foo)    bar = tf.convert_to_tensor(bar)    # Get real and imaginary parts    bar_r = tf.cast(tf.real(bar), foo.dtype)    bar_i = tf.imag(bar)    # Mask of all real-valued points    m = tf.reduce_all(tf.equal(bar_i, 0), axis=-1)    # Distance to every corresponding point    d = tf.reduce_sum(tf.squared_difference(tf.expand_dims(foo, 1), bar_r), axis=-1)    # Replace distances of complex points with infinity    d2 = tf.where(m, d, tf.fill(tf.shape(d), tf.constant(math.inf, d.dtype)))    # Find smallest distances    idx = tf.argmin(d2, axis=1)    # Get points with smallest distances    b = tf.range(tf.shape(foo, out_type=idx.dtype)[0])    return tf.gather_nd(bar_r, tf.stack([b, idx], axis=1))# Testwith tf.Graph().as_default(), tf.Session() as sess:    foo = tf.constant([        [1,2,3],        [4,5,6],        [7,8,9]], dtype=tf.float32)    bar = tf.constant([        [[2,3,4],[1+0.1j,2+0.1j,3+0.1j]],        [[6,5,4],[4,5,7]],        [[1j,1j,1j],[0,0,0]]], dtype=tf.complex64)    sol_tf = solution_tf(foo, bar)    print(sess.run(sol_tf))    # [[2. 3. 4.]    #  [4. 5. 7.]    #  [0. 0. 0.]]
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python