使用张量流的句子相似度

我试图确定一个句子和其他句子之间的语义相似性,如下所示:


import tensorflow as tf

import tensorflow_hub as hub

import numpy as np

import os, sys

from sklearn.metrics.pairwise import cosine_similarity


# get cosine similairty matrix

def cos_sim(input_vectors):

    similarity = cosine_similarity(input_vectors)

    return similarity


# get topN similar sentences

def get_top_similar(sentence, sentence_list, similarity_matrix, topN):

    # find the index of sentence in list

    index = sentence_list.index(sentence)

    # get the corresponding row in similarity matrix

    similarity_row = np.array(similarity_matrix[index, :])

    # get the indices of top similar

    indices = similarity_row.argsort()[-topN:][::-1]

    return [sentence_list[i] for i in indices]



module_url = "https://tfhub.dev/google/universal-sentence-encoder/2" #@param ["https://tfhub.dev/google/universal-sentence-encoder/2", "https://tfhub.dev/google/universal-sentence-encoder-large/3"]


# Import the Universal Sentence Encoder's TF Hub module

embed = hub.Module(module_url)


# Reduce logging output.

tf.logging.set_verbosity(tf.logging.ERROR)


sentences_list = [

    # phone related

    'My phone is slow',

    'My phone is not good',

    'I need to change my phone. It does not work well',

    'How is your phone?',


    # age related

    'What is your age?',

    'How old are you?',

    'I am 10 years old',


    # weather related

    'It is raining today',

    'Would it be sunny tomorrow?',

    'The summers are here.'

]


with tf.Session() as session:


  session.run([tf.global_variables_initializer(), tf.tables_initializer()])

  sentences_embeddings = session.run(embed(sentences_list))


similarity_matrix = cos_sim(np.array(sentences_embeddings))


sentence = "It is raining today"

top_similar = get_top_similar(sentence, sentences_list, similarity_matrix, 3)


# printing the list using loop 

for x in range(len(top_similar)): 

    print(top_similar[x])

#view raw


梦里花落0921
浏览 108回答 1
1回答

慕雪6442864

问题的原因似乎是 TF2 不支持 hub 型号。这很简单,但是您是否尝试过禁用tensorflow版本2的行为?import tensorflow.compat.v1 as tftf.disable_v2_behavior()此命令将禁用 TensorFlow 2 行为,但仍然可能会出现一些与导入模块和图形相关的错误。然后尝试下面的命令。!pip install --upgrade tensorflow==1.15import tensorflow as tfprint(tf.__version__)
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python