猿问

将额外输入传递给 Keras Deep Network 以计算自定义成本函数

我们有 1000x512x512x1 的输入图像 (input_x),1000x512x512x1 的权重图 (input_w)。事实上,每张图像都有自己的权重图,它是在网络运行之前生成的,因此我们必须将它们作为第二个输入传递。两者都被馈送到网络,尽管这些权重图只是为了乘以损失函数而不是真正的张量(它们不来自任何层并且在到达损失函数之前保持模型的输入)。首先,我们的模型有两个输入,只有一个输出:


 model = keras.models.Model(inputs=[input_x, input_w], outputs=final_output)

并且输入形状在网络开始时发生变化:


input_x = layers.Input(shape=(512,512,1))

input_w = layers.Input(shape=(512,512,1))

input_x 穿过网络层,但 input_w 仅用于 customLoss :


model.compile(optimizer=optimizer, loss=customLoss(input_w), metrics=[dice_coef, mean_iou])

由于 input_w 的附加参数,这是一个包装器:


def customLoss(input_w): 

  def loss_fcn(y_true, y_pred):

     bce = keras.losses.binary_crossentropy(y_true, y_pred)

     dice_term = K.exp(1 + dice_coef(y_true, y_pred, 1.0))

     return input_w * (bce - dice_term)

  return loss_fcn

在从数据集生成 X 和 W 之后,我们称拟合为 2 个输入,X 是 input_x(图像),W 是(权重图)。


history= my_model.fit([X,W],y,validation_split=0.1, epochs=5000,batch_size=8, callbacks=[best_check])

一切对我来说似乎都是正确的,但我收到了错误


Epoch 1/5000

---------------------------------------------------------------------------

InvalidArgumentError                      Traceback (most recent call last)

<ipython-input-11-b756059772c5> in <module>()

      6                               patience=6,

      7                               verbose=1, mode='auto')

----> 8 history= my_model.fit([X,W],y,validation_split=0.1, epochs=5000,batch_size=8, callbacks=[best_check])


/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)

   1037                                         initial_epoch=initial_epoch,

   1038                                         steps_per_epoch=steps_per_epoch,

-> 1039                                         validation_steps=validation_steps)

   1040 

   1041     def evaluate(self, x=None, y=None,


至尊宝的传说
浏览 151回答 1
1回答

www说

所有这些代码都是真实编写的,但是,由于该项目是在 google Colab 中开发的,因此发生了一些奇怪的错误,因此在多次重新连接页面后问题解决了!错误可能是由于断开连接!
随时随地看视频慕课网APP

相关分类

Python
我要回答