我正在使用数据集 API 生成训练数据并将其分类为 NN 的批次。
这是我的代码的最小工作示例:
import tensorflow as tf
import numpy as np
import random
def my_generator():
while True:
x = np.random.rand(4, 20)
y = random.randint(0, 11)
label = tf.one_hot(y, depth=12)
yield x.reshape(4, 20, 1), label
def my_input_fn():
dataset = tf.data.Dataset.from_generator(lambda: my_generator(),
output_types=(tf.float64, tf.int32))
dataset = dataset.batch(32)
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
if __name__ == "__main__":
tf.enable_eager_execution()
model = tf.keras.Sequential([tf.keras.layers.Flatten(input_shape=(4, 20, 1)),
tf.keras.layers.Dense(128, activation=tf.nn.relu),
tf.keras.layers.Dense(12, activation=tf.nn.softmax)])
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
data_generator = my_input_fn()
model.fit(data_generator)
但batch_size
不是 的公认关键字fit_generator()
。
我对这些错误消息感到困惑,如果有人能对它们有所了解,或者指出我做错了什么,我将不胜感激。
FFIVE
相关分类