我想逐渐增加 Keras 模型中用于计算损失的系数。变量值基于当前纪元。但是,当我想设置该值时,出现以下错误:
float object has no attribute dtype
我的代码:
def warm_up(epoch, logs):
new_value= tf.keras.backend.variable(np.array(1.0, dtype=np.float32), dtype=tf.float32)
tf.keras.backend.set_value(model.variable1, new_value)
callback = tf.keras.callbacks.LambdaCallback(on_epoch_begin=warm_up)
model.fit(..., callbacks = [callback])
如何在训练期间更改自定义 Keras 模型中的变量?我使用的是 Tensorflow 2.2。
追溯:
\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py in _method_wrapper(self, *args, **kwargs)
64 def _method_wrapper(self, *args, **kwargs):
65 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
---> 66 return method(self, *args, **kwargs)
67
68 # Running inside `run_distribute_coordinator` already.
~\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
836 for epoch, iterator in data_handler.enumerate_epochs():
837 self.reset_metrics()
--> 838 callbacks.on_epoch_begin(epoch)
839 with data_handler.catch_stop_iteration():
840 for step in data_handler.steps():
~\Anaconda3\lib\site-packages\tensorflow\python\keras\callbacks.py in on_epoch_begin(self, epoch, logs)
347 logs = self._process_logs(logs)
348 for callback in self.callbacks:
--> 349 callback.on_epoch_begin(epoch, logs)
350 self._reset_batch_timing()
351
c:\Users\..\training.py in warm_up(epoch, logs)
379 def warm_up(epoch, logs):
380 test = tf.keras.backend.variable(np.array(1.0, dtype=np.float32), dtype=tf.float32)
--> 381 tf.keras.backend.set_value(model.variable1, test)
382
383
FFIVE
幕布斯6054654
相关分类