基类定义
pytorch损失类也是模块的派生,损失类的基类是_Loss,定义如下
class _Loss(Module): def __init__(self, size_average=None, reduce=None, reduction='elementwise_mean'): super(_Loss, self).__init__() if size_average is not None or reduce is not None: self.reduction = _Reduction.legacy_get_string(size_average, reduce) else: self.reduction = reduction
看这个类,有两点我们知道:
损失类是模块
不改变forward函数,但是具备执行功能
还有其他模块的性质
子类介绍
从_Loss派生的类有
名称 | 说明 | 公式 |
---|---|---|
_WeightedLoss | 这个类只是申请了一个权重空间,功能和_Loss一样 | |
L1Loss | X、Y可以是任意形状的输入,X与Y的 shape相同 | |
PoissonNLLLoss | 适合多目标分类 | |
KLDivLoss | 适用于连续分布的距离计算 | |
MSELoss | 均方差 | |
BCEWithLogitsLoss | 多目标不需要经过sigmoid | |
HingeEmbeddingLoss | Y中的元素只能为1或-1 适用于学习非线性embedding、半监督学习。用于计算两个输入是否相似 | |
MultiLabelMarginLoss | 适用于多目标分类 | |
SmoothL1Loss | ||
SoftMarginLoss | ||
CosineEmbeddingLoss | ||
MarginRankingLoss | ||
TripletMarginLoss |
从_WeightedLoss继续派生的函数有
名称 | 说明 | |
---|---|---|
NLLLoss | ||
BCELoss | ||
CrossEntropyLoss | ||
MultiLabelSoftMarginLoss | ||
MultiMarginLoss |
作者:readilen
链接:https://www.jianshu.com/p/592a680ba3df