-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
checkpoint.py
90 lines (79 loc) · 3.41 KB
/
checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright (c) Open-MMLab. All rights reserved.
import os
from ..dist_utils import master_only
from .hook import HOOKS, Hook
@HOOKS.register_module()
class CheckpointHook(Hook):
"""Save checkpoints periodically.
Args:
interval (int): The saving period. If ``by_epoch=True``, interval
indicates epochs, otherwise it indicates iterations.
Default: -1, which means "never".
by_epoch (bool): Saving checkpoints by epoch or by iteration.
Default: True.
save_optimizer (bool): Whether to save optimizer state_dict in the
checkpoint. It is usually used for resuming experiments.
Default: True.
out_dir (str, optional): The directory to save checkpoints. If not
specified, ``runner.work_dir`` will be used by default.
max_keep_ckpts (int, optional): The maximum checkpoints to keep.
In some cases we want only the latest few checkpoints and would
like to delete old ones to save the disk space.
Default: -1, which means unlimited.
"""
def __init__(self,
interval=-1,
by_epoch=True,
save_optimizer=True,
out_dir=None,
max_keep_ckpts=-1,
**kwargs):
self.interval = interval
self.by_epoch = by_epoch
self.save_optimizer = save_optimizer
self.out_dir = out_dir
self.max_keep_ckpts = max_keep_ckpts
self.args = kwargs
@master_only
def after_train_epoch(self, runner):
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
return
runner.logger.info(f'Saving checkpoint at {runner.epoch + 1} epochs')
if not self.out_dir:
self.out_dir = runner.work_dir
runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
# remove other checkpoints
if self.max_keep_ckpts > 0:
filename_tmpl = self.args.get('filename_tmpl', 'epoch_{}.pth')
current_epoch = runner.epoch + 1
for epoch in range(current_epoch - self.max_keep_ckpts, 0, -1):
ckpt_path = os.path.join(self.out_dir,
filename_tmpl.format(epoch))
if os.path.exists(ckpt_path):
os.remove(ckpt_path)
else:
break
@master_only
def after_train_iter(self, runner):
if self.by_epoch or not self.every_n_iters(runner, self.interval):
return
runner.logger.info(
f'Saving checkpoint at {runner.iter + 1} iterations')
if not self.out_dir:
self.out_dir = runner.work_dir
runner.save_checkpoint(
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
# remove other checkpoints
if self.max_keep_ckpts > 0:
filename_tmpl = self.args.get('filename_tmpl', 'iter_{}.pth')
current_iter = runner.iter + 1
for _iter in range(
current_iter - self.max_keep_ckpts * self.interval, 0,
-self.interval):
ckpt_path = os.path.join(self.out_dir,
filename_tmpl.format(_iter))
if os.path.exists(ckpt_path):
os.remove(ckpt_path)
else:
break