Keras AttributeError: 'NoneType' 对象在 load_model

我正在处理课程作业,我必须在 keras 中保存和加载模型。我创建模型、训练模型和保存模型的代码是


def get_new_model(input_shape):

    """

    This function should build a Sequential model according to the above specification. Ensure the 

    weights are initialised by providing the input_shape argument in the first layer, given by the

    function argument.

    Your function should also compile the model with the Adam optimiser, sparse categorical cross

    entropy loss function, and a single accuracy metric.

    """

    

    model = Sequential([

        Conv2D(16, kernel_size=(3,3),activation='relu',padding='Same', name='conv_1', input_shape=input_shape),

        Conv2D(8, kernel_size=(3,3), activation='relu', padding='Same', name='conv_2'),

        MaxPooling2D(pool_size=(8,8), name='pool_1'),

        tf.keras.layers.Flatten(name='flatten'),

        Dense(32, activation='relu', name='dense_1'),

        Dense(10, activation='softmax', name='dense_2')

    ])

    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['acc'])

    return model


model = get_new_model(x_train[0].shape)



def get_checkpoint_every_epoch():

    """

    This function should return a ModelCheckpoint object that:

    - saves the weights only at the end of every epoch

    - saves into a directory called 'checkpoints_every_epoch' inside the current working directory

    - generates filenames in that directory like 'checkpoint_XXX' where

      XXX is the epoch number formatted to have three digits, e.g. 001, 002, 003, etc.

    """

    path = 'checkpoints_every_epoch/checkpoint_{epoch:02d}'

    checkpoint = ModelCheckpoint(filepath = path, save_weights_only=True, save_freq= 'epoch')

    return checkpoint

慕桂英3389331
浏览 76回答 1
1回答

跃然一笑

我得到了它。文件路径名中有错误。我花了很多时间来弄清楚。所以正确的功能是def get_model_last_epoch(model):    """    This function should create a new instance of the CNN you created earlier,    load on the weights from the last training epoch, and return this model.    """    model.load_weights(tf.train.latest_checkpoint('checkpoints_every_epoch'))    return model        def get_model_best_epoch(model):    """    This function should create a new instance of the CNN you created earlier, load     on the weights leading to the highest validation accuracy, and return this model.    """    #filepath = tf.train.latest_checkpoint('checkpoints_best_only')    model.load_weights(tf.train.latest_checkpoint('checkpoints_best_only'))    return model    它不会给出错误,因为文件名tf.train.latest_checkpoint是正确的
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python