猿问

Keras 出现内存分配错误并且运行速度极慢

我正在使用卷积神经网络进行字符识别。我有 9 层模型和 19990 个训练数据和 4470 个测试数据。但是当我在 Tensorflow 后端使用 keras 时。当我尝试训练模型时,它运行得非常慢,比如每分钟 100-200 个样本。我尝试在展平后添加批量标准化层,使用正则化,添加 dropout 层,使用 fit_generator 从磁盘批量加载数据,以便使用不同的批量大小使 ram 保持空闲(性能更差),但没有任何效果。因此,我尝试将网络大小减少到 4 层,并向初始层添加更多通道以增加并行计算,但现在我开始出现内存分配错误。它说某个地址的分配超过了 10% 并且我的整个系统都冻结了。我每次都必须重新启动我的笔记本电脑。我尝试回到具有 9 层的早期版本,但现在也给了我同样的错误,即使它更早工作(没有真正工作,但至少开始训练)。那么,这个问题的解决方案是什么?是硬件能力不足还是其他问题?我有 8gb ram 和 2gb gpu,但我不使用 gpu 进行训练。我有英特尔 i5 7gen 处理器。


我的型号代码:


model = Sequential()


#First conv layer

model.add(Conv2D(512,(3,3),padding="same",kernel_initializer="glorot_normal",data_format="channels_last",input_shape=(278,278,1),kernel_regularizer=l1(0.04),activity_regularizer=l2(0.05)))

model.add(LeakyReLU())

model.add(MaxPool2D(pool_size=(2,2),padding="same",data_format="channels_last"))

model.add(Dropout(0.2))


#Second conv layer

model.add(Conv2D(256,(4,4),padding="same",kernel_initializer="glorot_normal",data_format="channels_last",kernel_regularizer=l1(0.02),activity_regularizer=l1(0.04)))

model.add(LeakyReLU())

model.add(MaxPool2D(pool_size=(2,2),strides=2,padding="same",data_format="channels_last"))

model.add(Dropout(0.2))



#Third conv layer

model.add(Conv2D(64,(3,3),padding="same",kernel_initializer="glorot_normal",data_format="channels_last",bias_regularizer=l1_l2(l1=0.02,l2=0.02),activity_regularizer=l2(0.04)))

model.add(LeakyReLU())

model.add(MaxPool2D(pool_size=(2,2),padding="same",data_format="channels_last"))

我的数据加载方法:


def Generator(hdf5_file,batch_size):

X = HDF5Matrix(hdf5_file,"/Data/X")

Y = HDF5Matrix(hdf5_file,"/Data/Y")


size = X.end

idx = 0


while True:

    last_batch = idx+batch_size >size

    end = idx + batch_size if not last_batch else size

    yield X[idx:end],Y[idx:end]

    idx = end if not last_batch else 0


holdtom
浏览 247回答 2
2回答

小怪兽爱吃肉

我遇到了问题。我在模型中有太多参数。我试着减少没有。频道,它奏效了。我想到了它,因为即使对于小数据集,我也会收到错误。

喵喵时光机

我认为(至少)您的问题之一是您正在将整个数据集加载到 ram 中。您的数据集(训练和验证)似乎至少为 5 GB。在您的生成器中,您将它们全部加载。因此,在您的情况下,由于 8Gb ram,它似乎在训练期间遇到了问题。
随时随地看视频慕课网APP

相关分类

Python
我要回答