猿问

如何使用 tf.keras.utils.Sequence API 扩充训练集?

TensorFlow 文档有以下示例,可以说明当训练集太大而无法放入内存时,如何创建批量生成器以将训练集批量提供给模型:


from skimage.io import imread

from skimage.transform import resize

import tensorflow as tf

import numpy as np

import math


# Here, `x_set` is list of path to the images

# and `y_set` are the associated classes.


class CIFAR10Sequence(tf.keras.utils.Sequence):


    def __init__(self, x_set, y_set, batch_size):

        self.x, self.y = x_set, y_set

        self.batch_size = batch_size


    def __len__(self):

        return math.ceil(len(self.x) / self.batch_size)


    def __getitem__(self, idx):

        batch_x = self.x[idx * self.batch_size:(idx + 1) *

        self.batch_size]

        batch_y = self.y[idx * self.batch_size:(idx + 1) *

        self.batch_size]


        return np.array([

            resize(imread(file_name), (200, 200))

               for file_name in batch_x]), np.array(batch_y)

我的目的是通过将每个图像旋转 3 倍 90° 来进一步增加训练集的多样性。在训练过程的每个 Epoch 中,模型将首先输入“0° 训练集”,然后分别输入 90°、180° 和 270° 旋转集。


如何修改前面的代码以在CIFAR10Sequence()数据生成器中执行此操作?


请不要使用tf.keras.preprocessing.image.ImageDataGenerator(),以免答案失去对其他类型不同性质的类似问题的普遍性。


注意:这个想法是在模型被输入时“实时”创建新数据,而不是(提前)创建并在磁盘上存储一个比稍后使用的原始训练集更大的新的增强训练集(也在批次)在模型的训练过程中。


喵喔喔
浏览 116回答 1
1回答

米脂

使用自定义Callback并挂钩到on_epoch_end. 每个纪元结束后更改数据迭代器对象的角度。示例(内联记录)from skimage.io import imreadfrom skimage.transform import resize, rotateimport numpy as npimport tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras import layersfrom keras.utils import Sequence&nbsp;from keras.models import Sequentialfrom keras.layers import Conv2D, Activation, Flatten, Dense# Model architecture&nbsp; (dummy)model = Sequential()model.add(Conv2D(32, (3, 3), input_shape=(15, 15, 4)))model.add(Activation('relu'))model.add(Flatten())model.add(Dense(1))model.add(Activation('sigmoid'))model.compile(loss='binary_crossentropy',&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; optimizer='rmsprop',&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; metrics=['accuracy'])# Data iterator&nbsp;class CIFAR10Sequence(Sequence):&nbsp; &nbsp; def __init__(self, filenames, labels, batch_size):&nbsp; &nbsp; &nbsp; &nbsp; self.filenames, self.labels = filenames, labels&nbsp; &nbsp; &nbsp; &nbsp; self.batch_size = batch_size&nbsp; &nbsp; &nbsp; &nbsp; self.angles = [0,90,180,270]&nbsp; &nbsp; &nbsp; &nbsp; self.current_angle_idx = 0&nbsp; &nbsp; # Method to loop throught the available angles&nbsp; &nbsp; def change_angle(self):&nbsp; &nbsp; &nbsp; self.current_angle_idx += 1&nbsp; &nbsp; &nbsp; if self.current_angle_idx >= len(self.angles):&nbsp; &nbsp; &nbsp; &nbsp; self.current_angle_idx = 0&nbsp;&nbsp;&nbsp; &nbsp; def __len__(self):&nbsp; &nbsp; &nbsp; &nbsp; return int(np.ceil(len(self.filenames) / float(self.batch_size)))&nbsp; &nbsp; # read, resize and rotate the image and return a batch of images&nbsp; &nbsp; def __getitem__(self, idx):&nbsp; &nbsp; &nbsp; &nbsp; angle = self.angles[self.current_angle_idx]&nbsp; &nbsp; &nbsp; &nbsp; print (f"Rotating Angle: {angle}")&nbsp; &nbsp; &nbsp; &nbsp; batch_x = self.filenames[idx * self.batch_size:(idx + 1) * self.batch_size]&nbsp; &nbsp; &nbsp; &nbsp; batch_y = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size]&nbsp; &nbsp; &nbsp; &nbsp; return np.array([&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; rotate(resize(imread(filename), (15, 15)), angle)&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;for filename in batch_x]), np.array(batch_y)# Custom call back to hook into on epoch endclass CustomCallback(keras.callbacks.Callback):&nbsp; &nbsp; def __init__(self, sequence):&nbsp; &nbsp; &nbsp; self.sequence = sequence&nbsp; &nbsp; # after end of each epoch change the rotation for next epoch&nbsp; &nbsp; def on_epoch_end(self, epoch, logs=None):&nbsp; &nbsp; &nbsp; self.sequence.change_angle()&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;# Create data readersequence = CIFAR10Sequence(["f1.PNG"]*10, [0, 1]*5, 8)# fit the model and hook in the custom call backmodel.fit(sequence, epochs=10, callbacks=[CustomCallback(sequence)])输出:Rotating Angle: 0Epoch 1/10Rotating Angle: 0Rotating Angle: 02/2 [==============================] - 2s 755ms/step - loss: 1.0153 - accuracy: 0.5000Epoch 2/10Rotating Angle: 90Rotating Angle: 902/2 [==============================] - 0s 190ms/step - loss: 0.6975 - accuracy: 0.5000Epoch 3/10Rotating Angle: 180Rotating Angle: 1802/2 [==============================] - 2s 772ms/step - loss: 0.6931 - accuracy: 0.5000Epoch 4/10Rotating Angle: 270Rotating Angle: 2702/2 [==============================] - 0s 197ms/step - loss: 0.6931 - accuracy: 0.5000Epoch 5/10Rotating Angle: 0Rotating Angle: 02/2 [==============================] - 0s 189ms/step - loss: 0.6931 - accuracy: 0.5000Epoch 6/10Rotating Angle: 90Rotating Angle: 902/2 [==============================] - 2s 757ms/step - loss: 0.6932 - accuracy: 0.5000Epoch 7/10Rotating Angle: 180Rotating Angle: 1802/2 [==============================] - 2s 757ms/step - loss: 0.6931 - accuracy: 0.5000Epoch 8/10Rotating Angle: 270Rotating Angle: 2702/2 [==============================] - 2s 761ms/step - loss: 0.6932 - accuracy: 0.5000Epoch 9/10Rotating Angle: 0Rotating Angle: 02/2 [==============================] - 1s 744ms/step - loss: 0.6932 - accuracy: 0.5000Epoch 10/10Rotating Angle: 90Rotating Angle: 902/2 [==============================] - 0s 192ms/step - loss: 0.6931 - accuracy: 0.5000<tensorflow.python.keras.callbacks.History at 0x7fcbdf8bcdd8>
随时随地看视频慕课网APP

相关分类

Python
我要回答