这是WGAN-GP的损失函数
gen_sample = model.generator(input_gen)
disc_real = model.discriminator(real_image, reuse=False)
disc_fake = model.discriminator(gen_sample, reuse=True)
disc_concat = tf.concat([disc_real, disc_fake], axis=0)
# Gradient penalty
alpha = tf.random_uniform(
shape=[BATCH_SIZE, 1, 1, 1],
minval=0.,
maxval=1.)
differences = gen_sample - real_image
interpolates = real_image + (alpha * differences)
gradients = tf.gradients(model.discriminator(interpolates, reuse=True), [interpolates])[0] # why [0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
gradient_penalty = tf.reduce_mean((slopes-1.)**2)
d_loss_real = tf.reduce_mean(disc_real)
d_loss_fake = tf.reduce_mean(disc_fake)
disc_loss = -(d_loss_real - d_loss_fake) + LAMBDA * gradient_penalty
gen_loss = - d_loss_fake
发电机损耗震荡,值这么大。我的问题是:发电机损耗是正常的还是异常的?
喵喔喔
相关分类