猿问

tf.custom_gradient具有多个输入

tf.custom_gradient仅接受一个Tensor x,如果此操作需要多个输入呢?


例如,定义需要输入x和label?的Softmax的梯度。


更新

感谢@AllenLavoie的建议,我使用Python列表作为输入。


def self_define_op_multiple_inputs():

    @tf.custom_gradient

    def loss_func(input_):

        x = input_[0]

        label = input_[2]


        def grad(dy):

            return [dy, dy]


        return x - label, grad


    x = tf.range(10, dtype=tf.float32)

    y = tf.range(10, dtype=tf.int32)


    loss = loss_func([x, y])



if __name__ == '__main__':

    self_define_op_multiple_inputs()

看来它将把Python转换list为Tensor。上面的代码段将引发TypeError: 

TypeError: Cannot convert a list containing a tensor of dtype <dtype: 'int32'> to <dtype: 'float32'> (Tensor is: <tf.Tensor 'range_1:0' shape=(10,) dtype=int32>)


如何解决?


HUWWW
浏览 141回答 2
2回答
随时随地看视频慕课网APP

相关分类

Python
我要回答