我有一个张量jdes,其是(?, 100)和常数的矩阵jt_six,其具有的形状(6,100)。我试图得到jdes和 的每一行的余弦接近度jt_six的结果,结果应该有 shape (?, 6)。我看到该dot()层能够计算余弦接近度设置,normalize=True但是使用我拥有的代码,我得到的结果形状(6,1)中没有批量大小。任何人都可以帮助我吗?
def dot_similarity(jdes):
jdes = K.l2_normalize(jdes, axis=-1) # (?, 100)
jt_six = K.l2_normalize(K.variable(jt_six), axis=-1) # (6, 100)
return dot([jt_six, jdes], axes=-1, normalize=True) # (6, 1), need (?, 6)
result = Lambda(dot_similarity)(jdes)
POPMUISE
相关分类