Tensorflow:Logits 和 Label 的大小必须相同

我目前正在 Google/Udacity 的 Tensorflow 课程中尝试一个项目,使用获取的数据集如下:


_URL = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"


zip_file = tf.keras.utils.get_file(origin=_URL,

                                   fname="flower_photos.tgz",

                                   extract=True)


不幸的是,我遇到了以下错误:


InvalidArgumentError:  logits and labels must have the same first dimension, got logits shape [100,5] and labels shape [500]

     [[node sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits (defined at <ipython-input-43-02964d57939c>:8) ]] [Op:__inference_test_function_3591]

我看了其他帖子,但似乎还是有点难以弄清楚。我最初的想法是我可能使用了错误的损失函数。


这是遇到问题的代码:


image_gen = ImageDataGenerator(rescale = 1./255, horizontal_flip=True, zoom_range=0.5, rotation_range=45, width_shift_range=0.15, height_shift_range=0.15)


train_data_gen = image_gen.flow_from_directory(batch_size=BATCH_SIZE, directory = train_dir, shuffle=True, target_size=(IMG_SHAPE,IMG_SHAPE),class_mode='binary')


image_gen = ImageDataGenerator(rescale = 1./255)


val_data_gen = image_gen.flow_from_directory(batch_size=BATCH_SIZE, directory = val_dir, shuffle=True, target_size=(IMG_SHAPE,IMG_SHAPE))

批量大小为 100,输入维度为 150,150 摘要如下: 模型:“sequential_4”


层(类型)输出形状参数#

conv2d_12(Conv2D)(无、148、148、16)448


max_pooling2d_12(最大池化(无、74、74、16)0


conv2d_13(Conv2D)(无、72、72、32)4640


max_pooling2d_13(最大池化(无、36、36、32)0


conv2d_14(Conv2D)(无、34、34、64)18496


max_pooling2d_14(最大池化(无、17、17、64)0


dropout_4(辍学)(无、17、17、64)0


flatten_4(压平)(无,18496)0


密集_8(密集)(无,512)9470464


密集_9(密集)(无,5)2565

总参数:9,496,613 可训练参数:9,496,613 不可训练参数:0

对可能出什么问题有什么想法吗?

慕沐林林
浏览 285回答 2
2回答

皈依舞

注意生成器中的 class_mode'int':表示标签被编码为整数(例如对于稀疏分类交叉熵损失)。“分类”意味着标签被编码为分类向量(例如,对于 categorical_crossentropy 损失)。'binary' 意味着标签(只能有 2 个)被编码为值为 0 或 1 的 float32 标量(例如,对于 binary_crossentropy)。无(无标签)。看来你需要“int”而不是“binary”来用于训练和验证生成器

眼眸繁星

在生成器中,我将 class_mode 更新为“稀疏”,并且工作正常。train_data_gen&nbsp;=&nbsp;image_gen.flow_from_directory(train_dir,&nbsp;target_size&nbsp;=&nbsp;(IMG_SHAPE,&nbsp;IMG_SHAPE),&nbsp;batch_size&nbsp;=&nbsp;batch_size,&nbsp;class_mode&nbsp;=&nbsp;'sparse')
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python