猿问

在 Keras 中,如何使用 dot() 计算张量与常量矩阵的每一行之间的余弦接近度?

我有一个张量jdes,其是(?, 100)和常数的矩阵jt_six,其具有的形状(6,100)。我试图得到jdes和 的每一行的余弦接近度jt_six的结果,结果应该有 shape (?, 6)。我看到该dot()层能够计算余弦接近度设置,normalize=True但是使用我拥有的代码,我得到的结果形状(6,1)中没有批量大小。任何人都可以帮助我吗?


def dot_similarity(jdes):

    jdes = K.l2_normalize(jdes, axis=-1) # (?, 100)

    jt_six = K.l2_normalize(K.variable(jt_six), axis=-1) # (6, 100)

    return dot([jt_six, jdes], axes=-1, normalize=True) # (6, 1), need (?, 6)


result = Lambda(dot_similarity)(jdes)


炎炎设计
浏览 338回答 1
1回答

POPMUISE

可以K.dot()直接使用。因为您已经使用了K.l2_normalize,所以矩阵乘法的结果是余弦接近度。from keras.models import Modelimport keras.backend as Kfrom keras.layers import Lambda,Inputimport numpy as npN = 100def dot_similarity(jdes):    jdes = K.l2_normalize(jdes, axis=-1) # (?, 100)    # define it myself    jt_six = K.constant(np.random.uniform(0, 1, size=(6, N)))    jt_six = K.l2_normalize(K.variable(jt_six), axis=-1) # (6, 100)    return K.dot(jdes,K.transpose(jt_six))jdes = Input(shape=(N,))result = Lambda(dot_similarity)(jdes)model = Model(jdes,result)print(model.summary())_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================input_1 (InputLayer)         (None, 100)               0         _________________________________________________________________lambda_1 (Lambda)            (None, 6)                 0         =================================================================Total params: 0Trainable params: 0Non-trainable params: 0_________________________________________________________________
随时随地看视频慕课网APP

相关分类

Python
我要回答