我有一个 Bx3 张量,fooB= 批量大小的 3D 点。通过某种幻想,我得到了另一个张量,bar形状为 Bx6x3,其中每个 B 6x3 矩阵对应于foo. 该 6x3 矩阵由 6 个复值 3D 点组成。我想做的是,对于我的每个 B 点,从6 in 中找到最接近对应点 in的实值点,最终得到一个新的 Bx3 ,其中包含与 in点的最近点。barfoomin_barbarfoo
在numpy中,我可以使用屏蔽数组来完成这一壮举:
foo = np.array([
[1,2,3],
[4,5,6],
[7,8,9]])
# here bar is only Bx2x3 for simplicity, but the solution generalizes
bar = np.array([
[[2,3,4],[1+0.1j,2+0.1j,3+0.1j]],
[[6,5,4],[4,5,7]],
[[1j,1j,1j],[0,0,0]],
])
#mask complex elements of bar
bar_with_masked_imag = np.ma.array(bar)
candidates = bar_with_masked_imag.imag == 0
bar_with_masked_imag.mask = ~candidates
dists = np.sum(bar_with_masked_imag**2, axis=1)
mindists = np.argmin(dists, axis=1)
foo_indices = np.arange(foo.shape[0])
min_bar = np.array(
bar_with_masked_imag[foo_indices,mindists,:],
dtype=float
)
print(min_bar)
#[[2. 3. 4.]
# [4. 5. 7.]
# [0. 0. 0.]]
但是,tensorflow 没有掩码数组等。我如何将其翻译成张量流?
幕布斯7119047
相关分类