扬帆大鱼
如果你想有一个自定义的预处理层,实际上你不需要使用PreprocessingLayer. 您可以简单地继承Layer以最简单的预处理层Rescaling为例,它在tf.keras.layers.experimental.preprocessing.Rescaling命名空间下。但是,如果您检查实际实现,它只是子Layer类 class Source Code Link Here但有@keras_export('keras.layers.experimental.preprocessing.Rescaling')@keras_export('keras.layers.experimental.preprocessing.Rescaling')class Rescaling(Layer): """Multiply inputs by `scale` and adds `offset`. For instance: 1. To rescale an input in the `[0, 255]` range to be in the `[0, 1]` range, you would pass `scale=1./255`. 2. To rescale an input in the `[0, 255]` range to be in the `[-1, 1]` range, you would pass `scale=1./127.5, offset=-1`. The rescaling is applied both during training and inference. Input shape: Arbitrary. Output shape: Same as input. Arguments: scale: Float, the scale to apply to the inputs. offset: Float, the offset to apply to the inputs. name: A string, the name of the layer. """ def __init__(self, scale, offset=0., name=None, **kwargs): self.scale = scale self.offset = offset super(Rescaling, self).__init__(name=name, **kwargs) def call(self, inputs): dtype = self._compute_dtype scale = math_ops.cast(self.scale, dtype) offset = math_ops.cast(self.offset, dtype) return math_ops.cast(inputs, dtype) * scale + offset def compute_output_shape(self, input_shape): return input_shape def get_config(self): config = { 'scale': self.scale, 'offset': self.offset, } base_config = super(Rescaling, self).get_config() return dict(list(base_config.items()) + list(config.items()))所以它证明Rescaling预处理只是另一个普通层。主要部分是def call(self, inputs)函数。您可以创建任何复杂的逻辑来预处理您的逻辑inputs然后返回。可以在此处找到有关自定义层的更简单的文档简而言之,您可以按层进行预处理,可以通过 Lambda 进行简单操作,也可以通过子类化 Layer 来实现您的目标。
弑天下
我认为最好和更干净的解决方案是使用一个简单的 Lambda 层,您可以在其中包装预处理函数这是一个虚拟的工作示例import numpy as npfrom tensorflow.keras.layers import *from tensorflow.keras.models import *X = np.random.randint(0,256, (200,32,32,3))y = np.random.randint(0,3, 200)inp = Input((32,32,3))x = Lambda(lambda x: x/255)(inp)x = Conv2D(8, 3, activation='relu')(x)x = Flatten()(x)out = Dense(3, activation='softmax')(x)m = Model(inp, out)m.compile(loss='sparse_categorical_crossentropy',optimizer='adam',metrics=['accuracy'])history = m.fit(X, y, epochs=10)