当我在 tensorflow 中实现 unpool 时 tf.scatter_add

我正在尝试使用 tf.scatter_add 在 tensorflow 中实现 unpool,但我遇到了一个奇怪的错误,这是我的代码:


import tensorflow as tf

import numpy as np

import random


tf.reset_default_graph()


mat = list(range(64))

random.shuffle(mat)

mat = np.array(mat)

mat = np.reshape(mat, [1,8,8,1])

M = tf.constant(mat, dtype=tf.float32)

pool1, argmax1 = tf.nn.max_pool_with_argmax(M, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')

pool2, argmax2 = tf.nn.max_pool_with_argmax(pool1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')

pool3, argmax3 = tf.nn.max_pool_with_argmax(pool2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')


输出:


[[ 0.  0.]

 [ 0. 63.]]

[[  0.   0.   0.   0.]

 [  0.   0.   0.   0.]

 [  0.   0. 126.   0.]

 [  0.   0.   0.   0.]]

[[  0.   0.   0.   0.   0.   0.   0.   0.]

 [  0.   0.   0.   0.   0.   0.   0.   0.]

 [  0.   0.   0.   0.   0.   0.   0.   0.]

 [  0.   0.   0.   0.   0.   0.   0.   0.]

 [  0.   0.   0.   0.   0.   0.   0.   0.]

 [  0.   0.   0.   0. 315.   0.   0.   0.]

 [  0.   0.   0.   0.   0.   0.   0.   0.]

 [  0.   0.   0.   0.   0.   0.   0.   0.]]

位置是对的,但价值是错误的。unpool2 是对的,unpool1 是期望值的两倍,unpool2 是期望值的五倍。我不知道怎么了,谁能告诉我如何修复这个错误?


catspeake
浏览 205回答 2
2回答
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python