pytorch lightning save checkpoint every epoch code example

Example: pytorch lightning save checkpoint every epoch

class CheckpointEveryEpoch(pl.Callback):
    def __init__(self, start_epoc, save_path,):
        self.start_epoc = start_epoc
        self.file_path = save_path

    def on_epoch_end(self, trainer: pl.Trainer, _):
        """ Check if we should save a checkpoint after every train epoch """
        epoch = trainer.current_epoch
        if epoch >= self.start_epoc:
            ckpt_path = f"{self.save_path}_e{epoch}.ckpt"
            trainer.save_checkpoint(ckpt_path)
            
            
trainer = Trainer(callbacks=[CheckpointEveryEpoch(2, args.save_path)]
)  # after 2 epoch start to saving ckpts

Tags:

Misc Example