The callbacks in tf.keras
is a class, usually specified as a parameter when use model.fit
. It provides the extra operations at the starting or the ending of training, each epoch or each batch. These operations include record some log information, change learning rate, early termination of the training, etc.
Likewise, this callbacks parameter is also able to be specified for model.evaluate
or model.predict
, providing extra operations at the starting or the ending of the evaluation, prediction, or each batch. However this method is rarely used.
For the most cases, the pre-defined callbacks in the sub-module keras.callbacks
are sufficient. It is also possible to define child class inheriting keras.callbacks.Callbacks
to customize callbacks if necessary.
All the classes of callbacks are inheriting keras.callbacks.Callbacks
, which contains two attributes: params
and model
.
params
is a dictionary, which records training parameters (e.g. verbosity, batch size, number of epochs, etc.). model
is the reference to the current model.
What's more, there is an extra argument logs
in the certain methods of the callbacks classes, such as on_epoch_begin
, on_batch_end
. This parameter provides certain information of current epoch or batch and are able to save the computing results. These logs
variables are able to transfer among the functions with the same name in these callbacks classes.
BaseLogger
: it calcuates the mean metrics among all batches for each epoch. For those metrics with middle status in staeful_metrics
, it uses the final metrics without calculating mean value for all the batches, and the final mean metrics is added to the variable logs
. This callback is automatically applied to every Keras model and is applied first.
History
: a dictionary that records the metrics of each epoch calculated by BaseLogger
and is returned by model.fit
. This callback is automatically applied to every Keras model after BaseLogger
.
EarlyStopping
: this callback terminates the training if the monitoring metrics are not significantly increased after certain number of epoches.
TensorBoard
: this callback saves the visualized log of the Tensorboard. It supports visualization of metrics, graphs and parameters in the model.
ModelCheckpoint
: this callback saves model after each epoch.
ReduceLROnPlateau
: this callback reduce the learning rate with certain rate if the monitoring metrics are not significantly increased after certain number of epoches.
TerminateOnNaN
: terminate the training if loss is NaN.
LearningRateScheduler
: it controls the learning rate before each epoch with given function between the learning rate lr
and epoch.
CSVLogger
: save logs
of each epoch in CSV file.
ProgbarLogger
: print the logs
of each epoch into stardard I/O stream.
It is possible to write a simple callback through callbacks.LambdaCallback
, or write a complicated callback through inheriting base class callbacks.Callback
.
Don't hesitate to read the source code to know more details of the callbacks in tf.Keras
.
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers,models,losses,metrics,callbacks
import tensorflow.keras.backend as K
# Example of the simple callback using LambdaCallback
import json
json_log = open('../data/keras_log.json', mode='wt', buffering=1)
json_logging_callback = callbacks.LambdaCallback(
on_epoch_end=lambda epoch, logs: json_log.write(
json.dumps(dict(epoch = epoch,**logs)) + '\n'),
on_train_end=lambda logs: json_log.close()
)
# Example of the complicated callback through base class inheritance. This is the source code of LearningRateScheduler.
class LearningRateScheduler(callbacks.Callback):
def __init__(self, schedule, verbose=0):
super(LearningRateScheduler, self).__init__()
self.schedule = schedule
self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
try:
lr = float(K.get_value(self.model.optimizer.lr))
lr = self.schedule(epoch, lr)
except TypeError: # Support for old API for backward compatibility
lr = self.schedule(epoch)
if not isinstance(lr, (tf.Tensor, float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function '
'should be float.')
if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating:
raise ValueError('The dtype of Tensor should be float')
K.set_value(self.model.optimizer.lr, K.get_value(lr))
if self.verbose > 0:
print('\nEpoch %05d: LearningRateScheduler reducing learning '
'rate to %s.' % (epoch + 1, lr))
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
logs['lr'] = K.get_value(self.model.optimizer.lr)
Please leave comments in the WeChat official account "Python与算法之美" (Elegance of Python and Algorithms) if you want to communicate with the author about the content. The author will try best to reply given the limited time available.
You are also welcomed to join the group chat with the other readers through replying 加群 (join group) in the WeChat official account.
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。