tf.custom_gradient仅接受一个Tensor x,如果此操作需要多个输入呢?
例如,定义需要输入x和label?的Softmax的梯度。
更新
感谢@AllenLavoie的建议,我使用Python列表作为输入。
def self_define_op_multiple_inputs():
@tf.custom_gradient
def loss_func(input_):
x = input_[0]
label = input_[2]
def grad(dy):
return [dy, dy]
return x - label, grad
x = tf.range(10, dtype=tf.float32)
y = tf.range(10, dtype=tf.int32)
loss = loss_func([x, y])
if __name__ == '__main__':
self_define_op_multiple_inputs()
看来它将把Python转换list为Tensor。上面的代码段将引发TypeError:
TypeError: Cannot convert a list containing a tensor of dtype <dtype: 'int32'> to <dtype: 'float32'> (Tensor is: <tf.Tensor 'range_1:0' shape=(10,) dtype=int32>)
如何解决?
相关分类