Generator 只进行 12 次迭代 - 无论批量大小

我有以下数据生成器。它工作并返回预期数据。除了我将 epochs 或 batchsize 设置为等于什么之外,它只执行 12 次迭代然后给出错误(见下文)


我曾尝试更改纪元数和批量大小。


# initialize the number of epochs to train for and batch size

NUM_EPOCHS = 10 #100

BS = 32 #64 #32


NUM_TRAIN_IMAGES = len(train_uxo_scrap)

NUM_TEST_IMAGES = len(test_uxo_scrap)

def datagenerator(imgfns, imglabels, batchsize, mode="train", class_mode='binary'):

    cnt=0

    while True:

        images = []

        labels = []

        #cnt=0


        while len(images) < batchsize and cnt < len(imgfns):

            images.append(imgfns[cnt])

            labels.append(imglabels[cnt])

            cnt=cnt+1


        print(images)

        print(labels)

        print('********** cnt = ', cnt)

        yield images, labels

train_gen = datagenerator(train_uxo_scrap, train_uxo_scrap_labels, batchsize=BS, class_mode='binary')


valid_gen = datagenerator(test_uxo_scrap, test_uxo_scrap_labels, batchsize=BS, class_mode='binary')

# train the network

H = model.fit_generator(

    train_gen,

    steps_per_epoch=NUM_TRAIN_IMAGES // BS,

    validation_data=valid_gen,

    validation_steps=NUM_TEST_IMAGES // BS,

    epochs=NUM_EPOCHS)

我希望代码在每次迭代中通过 32 个样本经历 10 个时期。我每次迭代得到 32 个样本,但在第一个时期我只得到 12 个迭代,然后我得到以下错误。无论设置什么批次大小或纪元,都会发生这种情况。


---------------------------------------------------------------------------

IndexError                                Traceback (most recent call last)

<ipython-input-83-26f81894773d> in <module>()

      5     validation_data=valid_gen,

      6     validation_steps=NUM_TEST_IMAGES // BS,

----> 7     epochs=NUM_EPOCHS)


~\AppData\Local\Continuum\anaconda3\envs\dltf1\lib\site-packages\tensorflow\python\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)

   1424         use_multiprocessing=use_multiprocessing,

   1425         shuffle=shuffle,

-> 1426         initial_epoch=initial_epoch)

   1427 

   1428   def evaluate_generator(self,

PIPIONE
浏览 177回答 1
1回答

qq_笑_17

看看这是否有效:def datagenerator(imgfns, imglabels, batchsize, mode="train", class_mode='binary'):&nbsp; &nbsp; while True:&nbsp; &nbsp; &nbsp; &nbsp; start = 0&nbsp; &nbsp; &nbsp; &nbsp; end = batchsize&nbsp; &nbsp; &nbsp; &nbsp; while start&nbsp; < len(imgfns):&nbsp;&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; x = imgfns[start:end]&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; y = imglabels[start:end]&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; yield x, y&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; start += batchsize&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; end += batchsize假设imgfns, imglabels是 numpy 数组。
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python