我正在尝试量化使用PReLU. 替换PReLU为ReLU是不可能的,因为它会极大地影响网络性能,以至于无法使用。
据我所知,PReLU在量化方面,Pytorch 不支持。所以我尝试手动重写这个模块并实现乘法和加法torch.FloatFunctional()来绕过这个限制。
这是我到目前为止提出的:
class PReLU_Quantized(nn.Module):
def __init__(self, prelu_object):
super().__init__()
self.weight = prelu_object.weight
self.quantized_op = nn.quantized.FloatFunctional()
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, inputs):
# inputs = torch.max(0, inputs) + self.weight * torch.min(0, inputs)
self.weight = self.quant(self.weight)
weight_min_res = self.quantized_op.mul(self.weight, torch.min(inputs)[0])
inputs = self.quantized_op.add(torch.max(inputs)[0], weight_min_res).unsqueeze(0)
self.weight = self.dequant(self.weight)
return inputs
和更换:
class model(nn.Module):
def __init__(self)
super().__init__()
....
self.prelu = PReLU()
self.prelu_q = PReLU_Quantized(self.prelu)
....
基本上,我读取现有 prelu 模块的学习参数,并在新模块中自己运行计算。从某种意义上说,该模块似乎在工作,它并没有使整个应用程序失败。
但是,为了评估我的实现是否真的正确并产生与原始模块相同的结果,我尝试对其进行测试。
这是普通模型(即非量化模型)的对应物:
由于某种原因,实际与我的实现之间的误差PReLU非常大!
以下是不同层的示例差异:
diff : 1.1562038660049438
diff : 0.02868632599711418
diff : 0.3653906583786011
diff : 1.6100226640701294
diff : 0.8999372720718384
diff : 0.03773299604654312
diff : -0.5090572834014893
diff : 0.1654307246208191
diff : 1.161868691444397
diff : 0.026089997962117195
diff : 0.4205571115016937
diff : 1.5337920188903809
diff : 0.8799554705619812
diff : 0.03827812895178795
diff : -0.40296515822410583
diff : 0.15618863701820374
并且在正向传播中 diff 是这样计算的:
def forward(self, x):
residual = x
out = self.bn0(x)
out = self.conv1(out)
out = self.bn1(out)
out = self.prelu(out)
out2 = self.prelu2(out)
print(f'diff : {( out - out2).mean().item()}')
out = self.conv2(out)
...
我在这里错过了什么?
偶然的你
相关分类