Tensorflow Keras:评估时如何在自定义层中设置断点(调试)?

我只想在自定义层中进行一些数值验证。


假设我们有一个非常简单的自定义层:


class test_layer(keras.layers.Layer):

    def __init__(self, **kwargs):

        super(test_layer, self).__init__(**kwargs)


    def build(self, input_shape):

        self.w = K.variable(1.)

        self._trainable_weights.append(self.w)

        super(test_layer, self).build(input_shape)


    def call(self, x, **kwargs):

        m = x * x            # Set break point here

        n = self.w * K.sqrt(x)

        return m + n

和主程序:


import tensorflow as tf

import keras

import keras.backend as K


input = keras.layers.Input((100,1))

y = test_layer()(input)


model = keras.Model(input,y)

model.predict(np.ones((100,1)))

如果我在该行上设置了断点调试,则m = x * x执行时程序将在此处暂停y = test_layer()(input),这是因为生成了图形,因此call()调用了该方法。


但是当我使用model.predict()它来赋予它真正的价值,并且想在图层内部查看它是否工作正常时,它并不会停在那一行m = x * x


我的问题是:


是call()计算图形正在兴建时,方法只叫什么名字?(提供实际价值时不会调用它吗?)


给实数输入时,如何调试(或在何处插入断点)以查看变量的值?


鸿蒙传说
浏览 383回答 2
2回答

holdtom

是的。该call()方法仅用于构建计算图。至于调试。我更喜欢使用TFDBG,这是针对tensorflow的推荐调试工具,尽管它不提供断点功能。对于Keras,您可以将以下行添加到脚本中以使用TFDBGimport tf.keras.backend as Kfrom tensorflow.python import debug as tf_debugsess = K.get_session()sess = tf_debug.LocalCLIDebugWrapperSession(sess)K.set_session(sess)
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python