卷积训练报错[You must feed a value for placeholder tensor 'y' with dtype float and shape]

来源:3-7 TensorFlow结合mnist进行卷积模型训练(4)

qq_Q先生_0

2018-08-21 14:14

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'y' with dtype float and shape [?,10]

[[Node: y = Placeholder[dtype=DT_FLOAT, shape=[?,10], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]


Caused by op 'y', defined at:

  File "E:/pycharmWorkspace/flaskDemo/app/mnist/convolutional.py", line 19, in <module>

    y_ = tf.placeholder(tf.float32, [None, 10] , name='y');

19行代码 : 

mnist = input_data.read_data_sets('MNIST_data', one_hot=True);
#创建模型
with tf.variable_scope('convolutional'):
    x = tf.placeholder(tf.float32 , [None , 784] , name='x');
    keep_prob = tf.placeholder(tf.float32);
    y , variables = model.convolutional(x, keep_prob);

y_ = tf.placeholder(tf.float32, [None, 10] , name='y');#(19行)报错行
cross_entropy = -tf.reduce_sum(y_ * tf.log(y));
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy);
#预测信息
correct_prediction = tf.equal(tf.argmax(y, 1) , tf.argmax(y_, 1));
#计算准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction , tf.float32));

saver = tf.train.Saver(variables);

#训练
with tf.Session() as sess :
    merged_summary_op = tf.summary.merge_all();
    summary_writer = tf.summary.FileWriter('/temp/mnsit_log/1' , sess.graph);
    #把图像加进来
    summary_writer.add_graph(sess.graph);
    sess.run(tf.global_variables_initializer());

    #训练
    for i in range(20000) :
        batch = mnist.train.next_batch(50);
        #每隔100次
        if i% 100 == 0:
            train_accuracy = accuracy.eval(feed_dict={x:batch[0] , y:batch[1] , keep_prob : 1.0});
            print('step : %d <==> train_accuracy %g: ' %(i , train_accuracy))
        sess.run(train_step , feed_dict={x:batch[0] , y_:batch[1] , keep_prob : 0.5});

    print("卷积正确率 : " ,sess.run(accuracy , feed_dict={x:mnist.test.images , y_:mnist.test.labels, keep_prob : 1.0}) );

    path = saver.save(sess,
                      os.path.join(os.path.dirname(__file__) ,
                      'data' ,
                      'convolutional.ckpt'),write_meta_graph=False , write_state=False);

    print('saverd :',path);
    sess.close();


写回答 关注

1回答

  • 啊哈HL
    2018-09-21 16:16:31
    if i% 100 == 0:
                train_accuracy = accuracy.eval(feed_dict={x:batch[0] , y:batch[1] , keep_prob : 1.0});

    这个里面应该是y_

TensorFlow与Flask结合打造手写体数字识别

TensorFlow和flask结合识别自己的手写体数字

20432 学习 · 102 问题

查看课程

相似问题