Callbacks

به طور کلی callback به معنی فراخوانی یک تابع در یک رویداد (event) مشخص هست. از callback ها میتوان در مراحل زیر استفاده کرد:

  • Model.fit()
  • Model.evaluate()
  • Model.predict()
از callback ها برای بررسی مدل در مراحل مختلف و تغییر رفتار آن در بعضی از مراحل استفاده می شود. لیست رویداد هایی که تابع می تواند در آن فراخوانی شود در زیر آمده است:
  • on_train_begin
  • on_train_end
  • on_epoch_begin
  • on_epoch_end
  • on_test_begin
  • on_test_end
  • on_predict_begin
  • on_predict_end
  • on_train_batch_begin
  • on_train_batch_end
  • on_test_batch_begin
  • on_test_batch_end
  • on_predict_batch_begin
  • on_predict_batch_end
کراس علاوه بر پیاده سازی callback های کاربردی که ممکن است بیشتر از آن استفاده کنیم، امکان ساخت callback دلخواه خود را هم فراهم کرده است. قبل از اینکه نحوه ساخت یک callback را ببینیم، چند مورد از callback های کاربردی کراس را بررسی کنیم.
  • ModelCheckpoint

از keras.callbacks.ModelCheckpoint() برای ذخیره کردن مدل در فاصله های مشخص، هنگام model.fit() استفاده می شود. از مدل ذخیره شده می توان برای بازیابی مدل و ادامه فرایند train استفاده کرد. منظور از فاصله های مشخص این است که:
  • بهترین مدل را ذخیره کند یا بعد از اتمام هر epoch آن را ذخیره کند.
  • بر اساس کدام کمیت، بهترین مدل را تشخصی دهد که آن را ذخیره کند.
  • مدل را بعد از اتمام epoch آن را ذخیره کند یا بعد از اتمام چند batch مشخص.
  • آیا فقط وزن ها ذخیره شوند یا کل مدل ذخیر شود.

keras.callbacks.ModelCheckpoint(    filepath,    monitor="val_loss",    verbose=0,    save_best_only=False,    save_weights_only=False,    mode="auto",    save_freq="epoch",    initial_value_threshold=None)

از filepath برای آدرس ذخیره کردن فایل و نام گذاری آن و فرمت ذخیره کردن استفاده می شود. monitor و save_best_only برای ذخیره کردن بهترین مدل بر اساس کمیت monitor استفاده می شود. از save_freq برای ذخیره کردن مدل بعد از اتمام هر epoch یا تعداد مشخص batch استفاده می شود. مثال:

model_checkpoint_callback = keras.callbacks.ModelCheckpoint(    filepath='/tmp/ckpt/checkpoint.h5',    monitor='val_accuracy',    save_best_only=True)

 
  • LearningRateScheduler

از keras.callbacks.LearningRateScheduler() برای تغییر دادن learning rate استفاده می شود.

keras.callbacks.LearningRateScheduler(schedule, verbose=0)

schedule یک تابع با ورودی epoch و learning_rate هست که با توجه به مقدار epoch و learning_rate کنونی، مقدار learning rate را تغییر می دهد. مثال:

callback = keras.callbacks.LearningRateScheduler(scheduler) history = model.fit(...,epochs=15, callbacks=[callback], verbose=0)

 
  • ProgbarLogger

هدف از keras.callbacks.ProgbarLogger() در زیر آمده است. این callback یک ورودی دارد به نام count_mode که مقدار آن می تواند “steps” یا “samples” باشد. تفاوت آن در زیر آمده است.  
  • CSVLogger

از keras.callbacks.CSVLogger() برای دخیره کردن نتایج هر epoch در یک فایل csv استفاده می شود.

keras.callbacks.CSVLogger(filename, separator=",", append=False)

filename برای آدرس دهی و نام گذاری فایل csv استفاده می شود.append برای overwrite کردن یا ادامه دادن فایل csv استفاده می شود. مثال:

csv_logger = CSVLogger('training.log')model.fit(X_train, Y_train, callbacks=[csv_logger])

یکی دیگر از callback های کاربردی، TensorBoard هست. بررسی آن کمی طولانی می شود، در صورت علاقه، می توان از منابع معرفی شده درباره آن مطالعه کنید.  
  • پیاده سازی callback دلخواه

class CustomCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        keys = list(logs.keys())
        print("Starting training; got log keys: {}".format(keys))

    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))

    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_test_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start testing; got log keys: {}".format(keys))

    def on_test_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop testing; got log keys: {}".format(keys))

    def on_predict_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start predicting; got log keys: {}".format(keys))

    def on_predict_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop predicting; got log keys: {}".format(keys))

    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))
اگر با مفهوم Inheritance در پایتون آشنایی ندارد، پیشنهاد می شود که آن را مطالعه کنید.

منابع:

https://keras.io/api/callbacks/ https://www.aparat.com/v/Tdbck https://www.aparat.com/v/r05IW