
如何使用 tf.keras.utils.Sequence API 扩充训练集?

TensorFlow 文档有以下示例,可以说明当训练集太大而无法放入内存时,如何创建批量生成器以将训练集批量提供给模型:

from skimage.io import imread

from skimage.transform import resize

import tensorflow as tf

import numpy as np

import math

# Here, `x_set` is list of path to the images

# and `y_set` are the associated classes.

class CIFAR10Sequence(tf.keras.utils.Sequence):

    def __init__(self, x_set, y_set, batch_size):

        self.x, self.y = x_set, y_set

        self.batch_size = batch_size

    def __len__(self):

        return math.ceil(len(self.x) / self.batch_size)

    def __getitem__(self, idx):

        batch_x = self.x[idx * self.batch_size:(idx + 1) *


        batch_y = self.y[idx * self.batch_size:(idx + 1) *


        return np.array([

            resize(imread(file_name), (200, 200))

               for file_name in batch_x]), np.array(batch_y)

我的目的是通过将每个图像旋转 3 倍 90° 来进一步增加训练集的多样性。在训练过程的每个 Epoch 中,模型将首先输入“0° 训练集”,然后分别输入 90°、180° 和 270° 旋转集。




浏览 116回答 1


使用自定义Callback并挂钩到on_epoch_end. 每个纪元结束后更改数据迭代器对象的角度。示例(内联记录)from skimage.io import imreadfrom skimage.transform import resize, rotateimport numpy as npimport tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras import layersfrom keras.utils import Sequence&nbsp;from keras.models import Sequentialfrom keras.layers import Conv2D, Activation, Flatten, Dense# Model architecture&nbsp; (dummy)model = Sequential()model.add(Conv2D(32, (3, 3), input_shape=(15, 15, 4)))model.add(Activation('relu'))model.add(Flatten())model.add(Dense(1))model.add(Activation('sigmoid'))model.compile(loss='binary_crossentropy',&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; optimizer='rmsprop',&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; metrics=['accuracy'])# Data iterator&nbsp;class CIFAR10Sequence(Sequence):&nbsp; &nbsp; def __init__(self, filenames, labels, batch_size):&nbsp; &nbsp; &nbsp; &nbsp; self.filenames, self.labels = filenames, labels&nbsp; &nbsp; &nbsp; &nbsp; self.batch_size = batch_size&nbsp; &nbsp; &nbsp; &nbsp; self.angles = [0,90,180,270]&nbsp; &nbsp; &nbsp; &nbsp; self.current_angle_idx = 0&nbsp; &nbsp; # Method to loop throught the available angles&nbsp; &nbsp; def change_angle(self):&nbsp; &nbsp; &nbsp; self.current_angle_idx += 1&nbsp; &nbsp; &nbsp; if self.current_angle_idx >= len(self.angles):&nbsp; &nbsp; &nbsp; &nbsp; self.current_angle_idx = 0&nbsp;&nbsp;&nbsp; &nbsp; def __len__(self):&nbsp; &nbsp; &nbsp; &nbsp; return int(np.ceil(len(self.filenames) / float(self.batch_size)))&nbsp; &nbsp; # read, resize and rotate the image and return a batch of images&nbsp; &nbsp; def __getitem__(self, idx):&nbsp; &nbsp; &nbsp; &nbsp; angle = self.angles[self.current_angle_idx]&nbsp; &nbsp; &nbsp; &nbsp; print (f"Rotating Angle: {angle}")&nbsp; &nbsp; &nbsp; &nbsp; batch_x = self.filenames[idx * self.batch_size:(idx + 1) * self.batch_size]&nbsp; &nbsp; &nbsp; &nbsp; batch_y = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size]&nbsp; &nbsp; &nbsp; &nbsp; return np.array([&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; rotate(resize(imread(filename), (15, 15)), angle)&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;for filename in batch_x]), np.array(batch_y)# Custom call back to hook into on epoch endclass CustomCallback(keras.callbacks.Callback):&nbsp; &nbsp; def __init__(self, sequence):&nbsp; &nbsp; &nbsp; self.sequence = sequence&nbsp; &nbsp; # after end of each epoch change the rotation for next epoch&nbsp; &nbsp; def on_epoch_end(self, epoch, logs=None):&nbsp; &nbsp; &nbsp; self.sequence.change_angle()&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;# Create data readersequence = CIFAR10Sequence(["f1.PNG"]*10, [0, 1]*5, 8)# fit the model and hook in the custom call backmodel.fit(sequence, epochs=10, callbacks=[CustomCallback(sequence)])输出:Rotating Angle: 0Epoch 1/10Rotating Angle: 0Rotating Angle: 02/2 [==============================] - 2s 755ms/step - loss: 1.0153 - accuracy: 0.5000Epoch 2/10Rotating Angle: 90Rotating Angle: 902/2 [==============================] - 0s 190ms/step - loss: 0.6975 - accuracy: 0.5000Epoch 3/10Rotating Angle: 180Rotating Angle: 1802/2 [==============================] - 2s 772ms/step - loss: 0.6931 - accuracy: 0.5000Epoch 4/10Rotating Angle: 270Rotating Angle: 2702/2 [==============================] - 0s 197ms/step - loss: 0.6931 - accuracy: 0.5000Epoch 5/10Rotating Angle: 0Rotating Angle: 02/2 [==============================] - 0s 189ms/step - loss: 0.6931 - accuracy: 0.5000Epoch 6/10Rotating Angle: 90Rotating Angle: 902/2 [==============================] - 2s 757ms/step - loss: 0.6932 - accuracy: 0.5000Epoch 7/10Rotating Angle: 180Rotating Angle: 1802/2 [==============================] - 2s 757ms/step - loss: 0.6931 - accuracy: 0.5000Epoch 8/10Rotating Angle: 270Rotating Angle: 2702/2 [==============================] - 2s 761ms/step - loss: 0.6932 - accuracy: 0.5000Epoch 9/10Rotating Angle: 0Rotating Angle: 02/2 [==============================] - 1s 744ms/step - loss: 0.6932 - accuracy: 0.5000Epoch 10/10Rotating Angle: 90Rotating Angle: 902/2 [==============================] - 0s 192ms/step - loss: 0.6931 - accuracy: 0.5000<tensorflow.python.keras.callbacks.History at 0x7fcbdf8bcdd8>

