猿问

Keras 生成器和 fit_generator,如何构建生成器以避免“函数形状”错误

我正在为 Keras 构建一个生成器,以便能够加载我的数据集图像,因为它对我的 ram 来说有点大。


我像这样构建了生成器:


# import the necessary packages

import tensorflow

from tensorflow import keras

from keras.preprocessing.image import ImageDataGenerator

import matplotlib.pyplot as plt

from sklearn.preprocessing import OneHotEncoder

import numpy as np

import pandas as pd

from tqdm import tqdm


#loading

path_to_txt = "/content/test/leafsnap-dataset/leafsnap-dataset- 

images_improved.txt"

df = pd.read_csv(path_to_txt ,sep='\t')

arr = np.array(df)

#epochs and steps:

NUM_TRAIN_IMAGES = 0

NUM_EPOCHS = 30


def image_generator(arr, bs, mode="train", aug=None):

  while True:

    images = []

    labels = []

    for row in arr:

      if len(images) < bs:

        img = (cv2.resize(cv2.imread("/content/test/leafsnap-dataset/" + 

        row[0]),(224,224)))

        images.append(img)

        labels.append([row[2]])

        NUM_TRAIN_IMAGES += 1

      else:

        break



  if aug is not None:

    (images, labels) = next(aug.flow(np.array(images),labels, 

     batch_size=bs))


  obj = OneHotEncoder()

  values = obj.fit_transform(labels).toarray()


  yield (np.array(images), labels)

然后我从顺序模型中调用 fit_generator (cnn 一直工作,直到出现 OOM 错误)


#create the augmentation function:

 aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,

    width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,

    horizontal_flip=True, fill_mode="nearest")


#create the generator:

gen = image_generator(arr, bs = 32, mode = "train", aug = aug)


history = model.fit_generator(image_generator,

    steps_per_epoch = NUM_TRAIN_IMAGES,

    epochs = NUM_EPOCHS)

从这里,我收到此错误:


# Create generator from NumPy or EagerTensor Input.

--> 377   num_samples = int(nest.flatten(data)[0].shape[0])

378   if batch_size is None:

379     raise ValueError('You must specify `batch_size`')

AttributeError: 'function' object has no attribute 'shape'


繁星淼淼
浏览 194回答 1
1回答

慕森王

我在这里看到两个主要错误。首先,您的生成器函数的内存效率不高。因为您首先加载所有图像(while 循环)。您应该遍历图像文件并在循环内产生带有标签的图像的 np.array。其次,当您应该使用其返回的对象 - gen 时,您将生成器函数名称传递给 fit_generator。
随时随地看视频慕课网APP

相关分类

Python
我要回答