我想tf.estimator.DNNClassifier在互联网被阻止的 Kaggle 笔记本环境中训练模型。因此,我无法使用 Tensorboard 来监控进度。所以相反,我想在标准输出中记录进度(类似于我们fit在 Keras 模型上调用方法时),但我无法让它工作。
到目前为止,我所尝试的是将日志记录级别设置为INFO并将tf.estimator.RunConfig实例传递给估算器。RunConfig有一个log_step_count_steps默认值 = 100的属性,这似乎与我正在寻找的内容有关,但它不起作用。这是代码的一部分:
import logging;
logging.getLogger().setLevel(logging.INFO)
tf.logging.set_verbosity(tf.logging.INFO)
config = tf.estimator.RunConfig()
classifier = tf.estimator.DNNClassifier(
feature_columns = feature_columns,
hidden_units = [128, 64],
n_classes = 2,
config = config
)
classifier.train(input_fn=train_input_fn)
我使用 Tensorflow 版本1.11.0-rc1。
绝地无双
相关分类