我想对批处理规范化层的变量添加条件操作。具体来说,在浮点数中训练,然后在微调二次训练阶段量化。为此,我想在变量上添加一个 tf.cond 操作(均值和变量的比例、移位和 exp 移动平均值)。
我用tf.layers.batch_normalization 我写的batchnorm层替换了它(见下文)。
这个函数工作得很好(即我用两个函数得到了相同的度量),我可以向变量添加任何管道(在batchnorm 操作之前)。问题是性能(运行时)急剧下降(即通过简单地用我自己的函数替换 layer.batchnorm 有一个 x2 因子,如下所述)。
def batchnorm(self, x, name, epsilon=0.001, decay=0.99):
epsilon = tf.to_float(epsilon)
decay = tf.to_float(decay)
with tf.variable_scope(name):
shape = x.get_shape().as_list()
channels_num = shape[3]
# scale factor
gamma = tf.get_variable("gamma", shape=[channels_num], initializer=tf.constant_initializer(1.0), trainable=True)
# shift value
beta = tf.get_variable("beta", shape=[channels_num], initializer=tf.constant_initializer(0.0), trainable=True)
moving_mean = tf.get_variable("moving_mean", channels_num, initializer=tf.constant_initializer(0.0), trainable=False)
moving_var = tf.get_variable("moving_var", channels_num, initializer=tf.constant_initializer(1.0), trainable=False)
batch_mean, batch_var = tf.nn.moments(x, axes=[0, 1, 2]) # per channel
update_mean = moving_mean.assign((decay * moving_mean) + ((1. - decay) * batch_mean))
update_var = moving_var.assign((decay * moving_var) + ((1. - decay) * batch_var))
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_var)
bn_mean = tf.cond(self.is_training, lambda: tf.identity(batch_mean), lambda: tf.identity(moving_mean))
bn_var = tf.cond(self.is_training, lambda: tf.identity(batch_var), lambda: tf.identity(moving_var))
with tf.variable_scope(name + "_batchnorm_op"):
inv = tf.math.rsqrt(bn_var + epsilon)
inv *= gamma
output = ((x*inv) - (bn_mean*inv)) + beta
return output
我将不胜感激以下任何问题的帮助:
关于如何提高我的解决方案的性能(减少运行时间)的任何想法?
是否可以在 batchnorm 操作之前将我自己的运算符添加到 layer.batchnorm 的变量管道中?
有没有其他解决方案可以解决同样的问题?
慕桂英546537
相关分类