猿问

为什么 Keras model.fit() 将整个数据集用作批处理并耗尽内存?

我正在使用 tensorflow 构建一个非常简单的 Keras 模型。当我启动它时,它因 OOM 异常而失败,因为它试图分配一个与整个数据集大小成比例的张量。这里会发生什么?


相关形状:


数据集形状:[60000, 28, 28, 1]

Batch_size(自动):10,

step_per_epoch:6000

错误消息:分配形状为 [60000,256,28,28] 和类型为 float 的张量时出现 OOM

注意:我没有使用顺序模型,因为稍后我将需要非顺序层。


张量流:1.12.0;Keras:2.1.6-tf


最小工作示例:


from tensorflow.keras import layers

import tensorflow as tf

import tensorflow.keras as keras

import numpy as np



def build_mnist_model(input_img):

    conv1 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(input_img)

    conv2 = layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same')(conv1)

    return conv2



(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()


x_train = np.expand_dims(x_train.astype('float32') / 255., -1)

x_test  = np.expand_dims(x_test.astype('float32')  / 255., -1)

print(x_train.shape)

print(x_test.shape)


input_img = keras.Input(shape = (28, 28, 1))

autoencoder = keras.Model(input_img, build_mnist_model(input_img))

autoencoder.compile(loss='mean_squared_error', optimizer = tf.train.AdamOptimizer(0.001))



autoencoder.fit(x_train, x_train,

                epochs=50,

                steps_per_epoch=int(int(x_train.shape[0])/10),

                shuffle=True,

                verbose=1,

                validation_data=(x_test, x_test)

               )

当我将模型定义为 keras.Sequential() 时,问题就消失了。


皈依舞
浏览 355回答 3
3回答

桃花长相依

要分批训练,您应该使用 fit_generator 方法。为此,您需要先制作数据生成器。您需要通过 flow_from_directory 使用 ImageDataGenerator 跟随(例如)。这样 keras 将分批提供数据。您应该调整批量大小以确保 GPU 的内存足够。通常批量大小约为 32-64。一般来说,批量越大越好。Keras 文档:https ://keras.io/preprocessing/image/您可以在此处查看使用示例:https : //www.kaggle.com/vbookshelf/skin-lesion-analyzer-tensorflow-js-web-app

繁花如伊

对我来说同样的问题。我只是检查了一些例子,发现:dummy_x = tf.zeros((1, 224, 224, 1))model._set_inputs(dummy_x)如果此代码在 fit 之前,则不会发生 oom。

慕侠2389804

嗯,我想您忘记定义要在您的网络中输入的 batch_size 了!尝试类似的东西:autoencoder.fit(x_train, x_train,                epochs=50,                batch_size = 32,                steps_per_epoch=int(int(x_train.shape[0])/10),                shuffle=True,                verbose=1,                validation_data=(x_test, x_test)               )
随时随地看视频慕课网APP

相关分类

Python
我要回答