我目前有一个如下所示的图层:
class MyLayer(tf.keras.layers.Layer):
def __init__(self):
super(MyLayer, self).__init__()
def call(self, img, params):
tf.foo(img)
tf.bar(img, params)
return img
call 方法img使用 shape(128, 128, 3)和paramsshape获取和输入(15)。
我必须更改什么才能使图层可以批量操作?img例如,输入将具有 shape(32, 128, 128, 3)并且params将具有 shape (32, 15)。
所以问题基本上是:我必须如何编辑图层,使其执行与现在相同的操作,但对于批处理中的每个图像?
蛊毒传说
相关分类