Pytorch:有没有类似torch.argmax的函数,真的可以保持原始数据的维度?

例如,代码是


input = torch.randn(3, 10)

result = torch.argmax(input, dim=0, keepdim=True)

input 是


tensor([[ 1.5742,  0.8183, -2.3005, -1.1650, -0.2451],

       [ 1.0553,  0.6021, -0.4938, -1.5379, -1.2054],

       [-0.1728,  0.8372, -1.9181, -0.9110,  0.2422]])

并且result是


tensor([[ 0,  2,  1,  2,  2]])

但是,我想要这样的结果


tensor([[ 1,  0,  0,  0,  0],

        [ 0,  0,  1,  0,  0],

        [ 0,  1,  0,  1,  1]])


HUX布斯
浏览 499回答 2
2回答
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python