猿问

在 tensorflow 上使用自动编码器获取图像名称

我正在使用这个 tensorflow 图像搜索脚本: https ://www.kaggle.com/jonmarty/using-autoencoder-to-search-images


def search(image):

    hidden_states = [sess.run(hidden_state(X, mask, W, b),

                     feed_dict={X: im.reshape(1, pixels), mask:  

                     np.random.binomial(1, 1-corruption_level, (1, pixels))})

                     for im in image_set]


    query = sess.run(hidden_state(X, mask, W, b),

                     feed_dict={X: image.reshape(1,pixels), mask: np.random.binomial(1, 1-corruption_level, (1, pixels))})


    starting_state = int(np.random.random()*len(hidden_states)) #choose random starting state

    best_states = [imported_images[starting_state]]

    distance = euclidean_distance(query[0], hidden_states[starting_state][0]) #Calculate similarity between hidden states

    for i in range(len(hidden_states)):

        dist = euclidean_distance(query[0], hidden_states[i][0])

        if dist <= distance:

            distance = dist #as the method progresses, it gets better at identifying similiar images

            best_states.append(imported_images[i])

    if len(best_states)>0:

        return best_states

    else:

        return best_states[len(best_states)-101:]

我想知道是否有可能知道图像名称(例如:homer.jpg)。我迷路了,我不知道我应该在代码中添加什么来知道这一点。这是我打印结果的脚本部分:


 print(len(results))

slots = 0

plt.figure(figsize = (125,125))

for im in results[::-1]: #reads through results backwards (more similiar images first)

    plt.subplot(10, 10, slots+1) 

    plt.imshow(cv2.cvtColor(im, cv2.COLOR_BGR2RGB)); plt.axis('off')

    slots += 1

十分感谢!:)


侃侃无极
浏览 132回答 1
1回答

繁星淼淼

您需要修改搜索功能。具体来说,看这一行:best_states.append(imported_images[i])如果要在返回的图像和文件名之间进行映射,则需要记录并返回该索引,i. 考虑添加一个best_states_index变量并返回两者,或者简单地替换imported_images[i]并i使用它来访问文件名和图像数据。更明确地说:def search(image):&nbsp; &nbsp; hidden_states = [sess.run(hidden_state(X, mask, W, b),&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;feed_dict={X: im.reshape(1, pixels), mask:&nbsp;&nbsp;&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;np.random.binomial(1, 1-corruption_level, (1, pixels))})&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;for im in image_set]&nbsp; &nbsp; query = sess.run(hidden_state(X, mask, W, b),&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;feed_dict={X: image.reshape(1,pixels), mask: np.random.binomial(1, 1-corruption_level, (1, pixels))})&nbsp; &nbsp; starting_state = int(np.random.random()*len(hidden_states)) #choose random starting state&nbsp; &nbsp; best_states = [imported_images[starting_state]]&nbsp; &nbsp; best_states_index = [starting_state]&nbsp; &nbsp; distance = euclidean_distance(query[0], hidden_states[starting_state][0]) #Calculate similarity between hidden states&nbsp; &nbsp; for i in range(len(hidden_states)):&nbsp; &nbsp; &nbsp; &nbsp; dist = euclidean_distance(query[0], hidden_states[i][0])&nbsp; &nbsp; &nbsp; &nbsp; if dist <= distance:&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; distance = dist #as the method progresses, it gets better at identifying similiar images&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; best_states.append(imported_images[i])&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; best_states_index.append(i)&nbsp; &nbsp; if len(best_states)>0:&nbsp; &nbsp; &nbsp; &nbsp; return best_states, best_states_index&nbsp; &nbsp; else:&nbsp; &nbsp; &nbsp; &nbsp; return best_states[len(best_states)-101:], best_states_index[len(best_states)-101:]
随时随地看视频慕课网APP

相关分类

Python
我要回答