带有 tf 数据集输入的 Tensorflow keras

我是 tensorflow keras 和数据集的新手。谁能帮我理解为什么下面的代码不起作用?


import tensorflow as tf

import tensorflow.keras as keras

import numpy as np

from tensorflow.python.data.ops import dataset_ops

from tensorflow.python.data.ops import iterator_ops

from tensorflow.python.keras.utils import multi_gpu_model

from tensorflow.python.keras import backend as K



data = np.random.random((1000,32))

labels = np.random.random((1000,10))

dataset = tf.data.Dataset.from_tensor_slices((data,labels))

print( dataset)

print( dataset.output_types)

print( dataset.output_shapes)

dataset.batch(10)

dataset.repeat(100)


inputs = keras.Input(shape=(32,))  # Returns a placeholder tensor


# A layer instance is callable on a tensor, and returns a tensor.

x = keras.layers.Dense(64, activation='relu')(inputs)

x = keras.layers.Dense(64, activation='relu')(x)

predictions = keras.layers.Dense(10, activation='softmax')(x)


# Instantiate the model given inputs and outputs.

model = keras.Model(inputs=inputs, outputs=predictions)


# The compile step specifies the training configuration.

model.compile(optimizer=tf.train.RMSPropOptimizer(0.001),

          loss='categorical_crossentropy',

          metrics=['accuracy'])


# Trains for 5 epochs

model.fit(dataset, epochs=5, steps_per_epoch=100)


斯蒂芬大帝
浏览 171回答 2
2回答

幕布斯7119047

关于您为什么收到错误的原始问题:Error when checking input: expected input_1 to have 2 dimensions, but got array with shape (32,)您的代码中断的原因是因为您没有将.batch()back应用于dataset变量,如下所示:dataset = dataset.batch(10)您只需调用dataset.batch().这会中断,因为没有batch()输出张量不会批量处理,即您得到的是 shape(32,)而不是(1,32).
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python