使用子类模型时,model.summary() 无法打印输出形状

这是创建keras模型的两种方法,但是output shapes两种方法的汇总结果不同。显然,前者打印的信息更多,更容易检查网络的正确性。


import tensorflow as tf

from tensorflow.keras import Input, layers, Model


class subclass(Model):

    def __init__(self):

        super(subclass, self).__init__()

        self.conv = layers.Conv2D(28, 3, strides=1)


    def call(self, x):

        return self.conv(x)



def func_api():

    x = Input(shape=(24, 24, 3))

    y = layers.Conv2D(28, 3, strides=1)(x)

    return Model(inputs=[x], outputs=[y])


if __name__ == '__main__':

    func = func_api()

    func.summary()


    sub = subclass()

    sub.build(input_shape=(None, 24, 24, 3))

    sub.summary()

输出:


_________________________________________________________________

Layer (type)                 Output Shape              Param #   

=================================================================

input_1 (InputLayer)         (None, 24, 24, 3)         0         

_________________________________________________________________

conv2d (Conv2D)              (None, 22, 22, 28)        784       

=================================================================

Total params: 784

Trainable params: 784

Non-trainable params: 0

_________________________________________________________________

_________________________________________________________________

Layer (type)                 Output Shape              Param #   

=================================================================

conv2d_1 (Conv2D)            multiple                  784       

=================================================================

Total params: 784

Trainable params: 784

Non-trainable params: 0

_________________________________________________________________

那么,我应该如何使用子类方法来获取output shape摘要()?


哔哔one
浏览 836回答 3
3回答

小怪兽爱吃肉

我已经用这个方法解决了这个问题,不知道有没有更简单的方法。class subclass(Model):    def __init__(self):        ...    def call(self, x):        ...    def model(self):        x = Input(shape=(24, 24, 3))        return Model(inputs=[x], outputs=self.call(x))if __name__ == '__main__':    sub = subclass()    sub.model().summary()

Qyouu

我想关键点是_init_graph_network类中的方法Network,它是Model. _init_graph_network如果在调用方法时指定inputs和outputs参数,将被调用__init__。所以会有两种可能的方法:手动调用_init_graph_network方法构建模型图。用输入层和输出重新初始化。并且这两种方法都需要输入层和输出(从 需要self.call)。现在调用summary将给出确切的输出形状。但是它会显示Input层,它不是子类模型的一部分。from tensorflow import kerasfrom tensorflow.keras import layers as klayersclass MLP(keras.Model):    def __init__(self, input_shape=(32), **kwargs):        super(MLP, self).__init__(**kwargs)        # Add input layer        self.input_layer = klayers.Input(input_shape)        self.dense_1 = klayers.Dense(64, activation='relu')        self.dense_2 = klayers.Dense(10)        # Get output layer with `call` method        self.out = self.call(self.input_layer)        # Reinitial        super(MLP, self).__init__(            inputs=self.input_layer,            outputs=self.out,            **kwargs)    def build(self):        # Initialize the graph        self._is_graph_network = True        self._init_graph_network(            inputs=self.input_layer,            outputs=self.out        )    def call(self, inputs):        x = self.dense_1(inputs)        return self.dense_2(x)if __name__ == '__main__':    mlp = MLP(16)    mlp.summary()输出将是:Model: "mlp_1"_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================input_1 (InputLayer)         [(None, 16)]              0         _________________________________________________________________dense (Dense)                (None, 64)                1088      _________________________________________________________________dense_1 (Dense)              (None, 10)                650       =================================================================Total params: 1,738Trainable params: 1,738Non-trainable params: 0_________________________________________________________________

侃侃无极

我解决问题的方式与Elazar提到的非常相似。覆盖类中的函数 summary() subclass。然后您可以在使用模型子类化时直接调用 summary() :class subclass(Model):    def __init__(self):        ...    def call(self, x):        ...    def summary(self):        x = Input(shape=(24, 24, 3))        model = Model(inputs=[x], outputs=self.call(x))        return model.summary()if __name__ == '__main__':    sub = subclass()    sub.summary()
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python