多热标签编码

我是 Tensorflow 的新手。我有一个图像数据集,其中一张图像有多个标签。据我了解,我需要使用tf.losses.sigmoid_cross_entropy(). 我尝试应用于tf.one_hot标签,但是当我尝试将它们传递给损失函数时,我得到错误,形状不兼容。我怎样才能解决这个问题?


慕婉清6462132
浏览 235回答 1
1回答

HUWWW

你是对的tf.losses.sigmoid_cross_entropy。所有你需要做的就是 wrap tf.one_hotwithtf.reduce_max来减少这样的维度。tf.reduce_max(tf.one_hot(labels, num_classes, dtype=tf.int32), axis=0)那应该返回 shape 的张量(num_classes,),正是您的损失函数所需要的。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python