如何将 PReLU 合并到量化模型中?

我正在尝试量化使用PReLU. 替换PReLU为ReLU是不可能的,因为它会极大地影响网络性能,以至于无法使用。


据我所知,PReLU在量化方面,Pytorch 不支持。所以我尝试手动重写这个模块并实现乘法和加法torch.FloatFunctional()来绕过这个限制。


这是我到目前为止提出的:


class PReLU_Quantized(nn.Module):

    def __init__(self, prelu_object):

        super().__init__()

        self.weight = prelu_object.weight

        self.quantized_op = nn.quantized.FloatFunctional()

        self.quant = torch.quantization.QuantStub()

        self.dequant = torch.quantization.DeQuantStub()


    def forward(self, inputs):

        # inputs = torch.max(0, inputs) + self.weight * torch.min(0, inputs)    

        self.weight = self.quant(self.weight)

        weight_min_res = self.quantized_op.mul(self.weight, torch.min(inputs)[0])

        inputs = self.quantized_op.add(torch.max(inputs)[0], weight_min_res).unsqueeze(0)

        self.weight = self.dequant(self.weight)

        return inputs

和更换:


class model(nn.Module):

     def __init__(self)

         super().__init__()

         .... 

        self.prelu = PReLU()

        self.prelu_q = PReLU_Quantized(self.prelu)

         ....

基本上,我读取现有 prelu 模块的学习参数,并在新模块中自己运行计算。从某种意义上说,该模块似乎在工作,它并没有使整个应用程序失败。


但是,为了评估我的实现是否真的正确并产生与原始模块相同的结果,我尝试对其进行测试。

这是普通模型(即非量化模型)的对应物:

由于某种原因,实际与我的实现之间的误差PReLU非常大!


以下是不同层的示例差异:


diff : 1.1562038660049438

diff : 0.02868632599711418

diff : 0.3653906583786011

diff : 1.6100226640701294

diff : 0.8999372720718384

diff : 0.03773299604654312

diff : -0.5090572834014893

diff : 0.1654307246208191

diff : 1.161868691444397

diff : 0.026089997962117195

diff : 0.4205571115016937

diff : 1.5337920188903809

diff : 0.8799554705619812

diff : 0.03827812895178795

diff : -0.40296515822410583

diff : 0.15618863701820374

并且在正向传播中 diff 是这样计算的:


def forward(self, x):

    residual = x

    out = self.bn0(x)

    out = self.conv1(out)

    out = self.bn1(out)


    out = self.prelu(out)

    out2 = self.prelu2(out)

    print(f'diff : {( out - out2).mean().item()}')


    out = self.conv2(out)

...

我在这里错过了什么?


慕哥9229398
浏览 346回答 1
1回答

偶然的你

我想到了!我一开始就犯了一个大错误。我需要计算PReLU(x)=max(0,x)+a∗min(0,x)或者 不是实际的!或者!这没有任何意义!这是普通模型(即未量化)的最终解决方案!:torch.mintorch.maxclass PReLU_2(nn.Module):    def __init__(self, prelu_object):        super().__init__()        self.prelu_weight = prelu_object.weight        self.weight = self.prelu_weight    def forward(self, inputs):        pos = torch.relu(inputs)        neg = -self.weight * torch.relu(-inputs)        inputs = pos + neg        return inputs这是量化版本:class PReLU_Quantized(nn.Module):    def __init__(self, prelu_object):        super().__init__()        self.prelu_weight = prelu_object.weight        self.weight = self.prelu_weight        self.quantized_op = nn.quantized.FloatFunctional()        self.quant = torch.quantization.QuantStub()        self.dequant = torch.quantization.DeQuantStub()    def forward(self, inputs):        # inputs = max(0, inputs) + alpha * min(0, inputs)         self.weight = self.quant(self.weight)        weight_min_res = self.quantized_op.mul(-self.weight, torch.relu(-inputs))        inputs = self.quantized_op.add(torch.relu(inputs), weight_min_res)        inputs = self.dequant(inputs)        self.weight = self.dequant(self.weight)        return inputs旁注:我在计算差异时也有错别字:    out = self.prelu(out)    out2 = self.prelu2(out)    print(f'diff : {( out - out2).mean().item()}')    out = self.conv2(out)需要是    out1 = self.prelu(out)    out2 = self.prelu2(out)    print(f'diff : {( out1 - out2).mean().item()}')    out = self.conv2(out1)更新:如果您在量化方面遇到问题,您可以尝试这个版本:import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.nn.quantized as nnqfrom torch.quantization import fuse_modulesclass QPReLU(nn.Module):    def __init__(self, num_parameters=1, init: float = 0.25):        super(QPReLU, self).__init__()        self.num_parameters = num_parameters        self.weight = nn.Parameter(torch.Tensor(num_parameters).fill_(init))        self.relu1 = nn.ReLU()        self.relu2 = nn.ReLU()        self.f_mul_neg_one1 = nnq.FloatFunctional()        self.f_mul_neg_one2 = nnq.FloatFunctional()        self.f_mul_alpha = nnq.FloatFunctional()        self.f_add = nnq.FloatFunctional()        self.quant = torch.quantization.QuantStub()        self.dequant = torch.quantization.DeQuantStub()        self.quant2 = torch.quantization.QuantStub()        self.quant3 = torch.quantization.QuantStub()        # self.dequant2 = torch.quantization.QuantStub()        self.neg_one = torch.Tensor([-1.0])                def forward(self, x):        x = self.quant(x)                # PReLU, with modules only        x1 = self.relu1(x)                neg_one_q = self.quant2(self.neg_one)        weight_q = self.quant3(self.weight)        x2 = self.f_mul_alpha.mul(            weight_q, self.f_mul_neg_one2.mul(                self.relu2(                    self.f_mul_neg_one1.mul(x, neg_one_q),                ),            neg_one_q)        )                x = self.f_add.add(x1, x2)        x = self.dequant(x)        return x    m1 = nn.PReLU()m2 = QPReLU()# check correctness in fpfor i in range(10):    data = torch.randn(2, 2) * 1000    assert torch.allclose(m1(data), m2(data))# toy modelclass M(nn.Module):    def __init__(self):        super(M, self).__init__()        self.prelu = QPReLU()            def forward(self, x):        x = self.prelu(x)        return x    # quantize itm = M()m.qconfig = torch.quantization.default_qconfigtorch.quantization.prepare(m, inplace=True)# calibratem(torch.randn(4, 4))# converttorch.quantization.convert(m, inplace=True)# run some data throughres = m(torch.randn(4, 4))print(res)并确保阅读此处的注释
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python