猿问

WGAN-GP 训练损失大

这是WGAN-GP的损失函数


gen_sample = model.generator(input_gen)

disc_real = model.discriminator(real_image, reuse=False)

disc_fake = model.discriminator(gen_sample, reuse=True)

disc_concat = tf.concat([disc_real, disc_fake], axis=0)

# Gradient penalty

alpha = tf.random_uniform(

    shape=[BATCH_SIZE, 1, 1, 1],

    minval=0.,

    maxval=1.)

differences = gen_sample - real_image

interpolates = real_image + (alpha * differences)

gradients = tf.gradients(model.discriminator(interpolates, reuse=True), [interpolates])[0]    # why [0]

slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))

gradient_penalty = tf.reduce_mean((slopes-1.)**2)


d_loss_real = tf.reduce_mean(disc_real)

d_loss_fake = tf.reduce_mean(disc_fake)


disc_loss = -(d_loss_real - d_loss_fake) + LAMBDA * gradient_penalty

gen_loss = - d_loss_fake


发电机损耗震荡,值这么大。我的问题是:发电机损耗是正常的还是异常的?


慕哥6287543
浏览 952回答 1
1回答

喵喔喔

需要注意的一件事是您的梯度惩罚计算是错误的。以下行:slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))实际上应该是:slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), reduction_indices=[1,2,3]))您在第一个轴上减少,但渐变基于 alpha 值显示的图像,因此您必须在轴上减少[1,2,3]。代码中的另一个错误是生成器损失是:gen_loss = d_loss_real - d_loss_fake对于梯度计算,这没有区别,因为生成器的参数仅包含在 d_loss_fake 中。然而,对于发电机损失的价值,这在世界上造成了很大的不同,这也是为什么会如此震荡的原因。归根结底,您应该查看您关心的实际性能指标,以确定 GAN 的质量,例如初始分数或 Fréchet 初始距离 (FID),因为鉴别器和生成器的损失仅具有轻微的描述性。
随时随地看视频慕课网APP

相关分类

Python
我要回答