Tensorflow:Logits 和标签必须具有相同的第一维

我是 TF 机器学习的新手。我有这个数据集,我生成并导出到 .csv 文件中。它在这里:tftest.csv。


“分布”列对应于一个独特的方程组,我试图将其压缩为 SageMath 中的一系列数字。'probs' 列对应于是否应该根据方程所在的行和列将给定的方程乘以给定的方程的单项式。以上仅用于概述,与我的实际问题无关。


无论如何,这是我的代码。我尽量用注释来解释它。


import csv

import numpy as np

import matplotlib.pyplot as plt

import tensorflow as tf

import tensorflow.keras as keras


distribution_train = []

probs_train = []

# x_train = []

# y_train = []


with open('tftest.csv') as csv_file:

    csv_reader = csv.reader(csv_file, delimiter=',')


    for row in csv_reader:

        distribution_train.append(row[0])

        probs_train.append(row[1])


'''

Get rid of the titles in the csv file

'''

distribution_train.pop(0)

probs_train.pop(0)


'''

For some reason everything in my csv file is stored as strings.

The below function is to convert it into floats so that TF can work with it.

'''

def num_converter_flatten(csv_list):

    f = []

    for j in range(len(csv_list)):

        append_this = []

        for i in csv_list[j]:

            if i == '1' or i == '2' or i == '3' or i == '4' or i == '5' or i == '6' or i == '7' or i == '8' or i =='9' or i =='0':

                append_this.append(float(i))

        f.append((append_this))


    return f


x_train = num_converter_flatten(distribution_train)

y_train = num_converter_flatten(probs_train)


x_train = tf.keras.utils.normalize(x_train, axis=1)

y_train = tf.keras.utils.normalize(y_train, axis=1)


model = tf.keras.models.Sequential()


model.add(tf.keras.layers.Flatten())


model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))

model.add(tf.keras.layers.Dense(128, activation=tf.nn.relu))


'''

I'm making the final layer 80 because I want TF to output the size of the

'probs' list in the csv file

'''


model.add(tf.keras.layers.Dense(80, activation=tf.nn.softmax))


model.compile(optimizer='adam',

              loss='sparse_categorical_crossentropy',

              metrics=['accuracy'])


model.fit(x_train, y_train, epochs=5)


我在网上搜索了这个错误,但我似乎无法理解它为什么会出现。任何人都可以帮助我了解我的代码有什么问题吗?如果还有任何问题,请发表评论,我会尽力回答。


一只萌萌小番薯
浏览 202回答 1
1回答
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python