猿问

Keras有没有办法立即停止训练?

我正在为我的tf.keras训练编写自定义提前停止回调。为此,我可以self.model.stop_training = True在其中一个回调函数中设置变量,例如on_epoch_end(). 然而,Keras 仅在当前时期完成时才停止训练,即使我在一个时期的训练中设置了这个变量,例如在on_batch_end().

因此我的问题是:Keras 有没有办法立即停止训练,即使是在当前时代的进展中?


MYYA
浏览 163回答 2
2回答

人到中年有点甜

在 kerasEarlyStopping中,当受监控的数量停止改善时,您会停止。从您的问题来看,您不清楚要停止的条件是什么。如果您只想监视一个值,EarlyStopping但只想在一批后停止,如果该值没有提高,您可以重写EarlyStopping类并实现逻辑 inon_batch_end而不是on_epoch_end:class EarlyBatchStopping(Callback):    def __init__(self,                 monitor='val_loss',                 min_delta=0,                 patience=0,                 verbose=0,                 mode='auto',                 baseline=None,                 restore_best_weights=False):        super(EarlyStopping, self).__init__()        self.monitor = monitor        self.baseline = baseline        self.patience = patience        self.verbose = verbose        self.min_delta = min_delta        self.wait = 0        self.stopped_epoch = 0        self.restore_best_weights = restore_best_weights        self.best_weights = None        if mode not in ['auto', 'min', 'max']:            warnings.warn('EarlyStopping mode %s is unknown, '                          'fallback to auto mode.' % mode,                          RuntimeWarning)            mode = 'auto'        if mode == 'min':            self.monitor_op = np.less        elif mode == 'max':            self.monitor_op = np.greater        else:            if 'acc' in self.monitor:                self.monitor_op = np.greater            else:                self.monitor_op = np.less        if self.monitor_op == np.greater:            self.min_delta *= 1        else:            self.min_delta *= -1    def on_train_begin(self, logs=None):        # Allow instances to be re-used        self.wait = 0        self.stopped_epoch = 0        if self.baseline is not None:            self.best = self.baseline        else:            self.best = np.Inf if self.monitor_op == np.less else -np.Inf    def on_batch_end(self, epoch, logs=None):        current = self.get_monitor_value(logs)        if current is None:            return        if self.monitor_op(current - self.min_delta, self.best):            self.best = current            self.wait = 0            if self.restore_best_weights:                self.best_weights = self.model.get_weights()        else:            self.wait += 1            if self.wait >= self.patience:                self.stopped_epoch = epoch                self.model.stop_training = True                if self.restore_best_weights:                    if self.verbose > 0:                        print('Restoring model weights from the end of '                              'the best epoch')                    self.model.set_weights(self.best_weights)    def on_train_end(self, logs=None):        if self.stopped_epoch > 0 and self.verbose > 0:            print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))    def get_monitor_value(self, logs):        monitor_value = logs.get(self.monitor)        if monitor_value is None:            warnings.warn(                'Early stopping conditioned on metric `%s` '                'which is not available. Available metrics are: %s' %                (self.monitor, ','.join(list(logs.keys()))), RuntimeWarning            )        return monitor_value如果您有其他逻辑,则可以根据您的逻辑使用on_batch_end和设置,但我认为您明白了。self.model.stop_training = True

红颜莎娜

您可以使用model.stop_training参数来停止训练。例如,如果我们想在第 2 轮第 3 批次停止训练,那么您可以执行如下操作。import kerasfrom keras.models import Sequentialfrom keras.layers import Densefrom keras.optimizers import SGDimport numpy as npimport pandas as pdclass My_Callback(keras.callbacks.Callback):&nbsp; &nbsp; def on_epoch_begin(self, epoch, logs={}):&nbsp; &nbsp; &nbsp; self.epoch = epoch&nbsp; &nbsp; def on_batch_end(self, batch, logs={}):&nbsp; &nbsp; &nbsp; &nbsp; if self.epoch == 1 and batch == 3:&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; print (f"\nStopping at Epoch {self.epoch}, Batch {batch}")&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; self.model.stop_training = TrueX_train = np.random.random((100, 3))y_train = pd.get_dummies(np.argmax(X_train[:, :3], axis=1)).valuesclf = Sequential()clf.add(Dense(9, activation='relu', input_dim=3))clf.add(Dense(3, activation='softmax'))clf.compile(loss='categorical_crossentropy', optimizer=SGD())clf.fit(X_train, y_train, epochs=10, batch_size=16, callbacks=[My_Callback()])输出:Epoch 1/10100/100 [==============================] - 0s 337us/step - loss: 1.0860Epoch 2/10&nbsp;16/100 [===>..........................] - ETA: 0s - loss: 1.0830Stopping at Epoch 1, Batch 3<keras.callbacks.callbacks.History at 0x7ff2e3eeee10>
随时随地看视频慕课网APP

相关分类

Python
我要回答