Skip to content

--resume issues #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Laughing-q opened this issue Jul 4, 2020 · 28 comments
Closed

--resume issues #292

Laughing-q opened this issue Jul 4, 2020 · 28 comments
Assignees
Labels
bug Something isn't working

Comments

@Laughing-q
Copy link
Member

when i resume my training, learning rate will be reset(0.01) at the first epoch.
I find it seems that initializing the scheduler will reset the learning rate.
so i print the optimizer when i load model from my last.pt and after initialize scheduler.
code:
optimizer = optim.Adam(pg0, lr=hyp['lr0']) if opt.adam else
optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
.
.
.
# load optimizer
if ckpt['optimizer'] is not None:
optimizer.load_state_dict(ckpt['optimizer'])
print('before:', optimizer)
best_fitness = ckpt['best_fitness']
.
.
.
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
scheduler.last_epoch = start_epoch - 1 # do not move
print('after scheduler :', optimizer)

result:
before:
SGD (
Parameter Group 0
dampening: 0
initial_lr: 0.01
lr: 0.009996052735444863
momentum: 0.937
nesterov: True
weight_decay: 0

Parameter Group 1
dampening: 0
initial_lr: 0.01
lr: 0.009996052735444863
momentum: 0.937
nesterov: True
weight_decay: 0.0005

Parameter Group 2
dampening: 0
initial_lr: 0.01
lr: 0.009996052735444863
momentum: 0.937
nesterov: True
weight_decay: 0
)
after scheduler :
SDG(
Parameter Group 0
dampening: 0
initial_lr: 0.01
lr: 0.01
momentum: 0.937
nesterov: True
weight_decay: 0

Parameter Group 1
dampening: 0
initial_lr: 0.01
lr: 0.01
momentum: 0.937
nesterov: True
weight_decay: 0.0005

Parameter Group 2
dampening: 0
initial_lr: 0.01
lr: 0.01
momentum: 0.937
nesterov: True
weight_decay: 0
)

Epoch gpu_mem GIoU obj cls total targets img_size
5/299 4.15G 0.03826 0.01146 0.02194 0.07166 15 640: 12%|█▏ | 31/267 [00:19<02:04, 1.90it/s]

then I put initializing scheduler code on top of loading optimizer,it seems to be right.
code:
optimizer = optim.Adam(pg0, lr=hyp['lr0']) if opt.adam else
optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
optimizer.add_param_group({'params': pg1, 'weight_decay': hyp['weight_decay']}) # add pg1 with weight_decay
optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1 # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
.
.
.
# load optimizer
if ckpt['optimizer'] is not None:
optimizer.load_state_dict(ckpt['optimizer'])
print(optimizer)
best_fitness = ckpt['best_fitness']
.
.
.
scheduler.last_epoch = start_epoch - 1 # do not move
print(optimizer)

results:
before:
SGD (
Parameter Group 0
dampening: 0
initial_lr: 0.01
lr: 0.009996052735444863
momentum: 0.937
nesterov: True
weight_decay: 0

Parameter Group 1
dampening: 0
initial_lr: 0.01
lr: 0.009996052735444863
momentum: 0.937
nesterov: True
weight_decay: 0.0005

Parameter Group 2
dampening: 0
initial_lr: 0.01
lr: 0.009996052735444863
momentum: 0.937
nesterov: True
weight_decay: 0
)
after scheduler :
SGD (
Parameter Group 0
dampening: 0
initial_lr: 0.01
lr: 0.009996052735444863
momentum: 0.937
nesterov: True
weight_decay: 0

Parameter Group 1
dampening: 0
initial_lr: 0.01
lr: 0.009996052735444863
momentum: 0.937
nesterov: True
weight_decay: 0.0005

Parameter Group 2
dampening: 0
initial_lr: 0.01
lr: 0.009996052735444863
momentum: 0.937
nesterov: True
weight_decay: 0
)

is this a problem?the code is in train.py, please check it.

@Laughing-q Laughing-q added the bug Something isn't working label Jul 4, 2020
@github-actions
Copy link
Contributor

github-actions bot commented Jul 4, 2020

Hello @Laughing-q, thank you for your interest in our work! Please visit our Custom Training Tutorial to get started, and see our Jupyter Notebook Open In Colab, Docker Image, and Google Cloud Quickstart Guide for example environments.

If this is a bug report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you.

If this is a custom model or data training question, please note that Ultralytics does not provide free personal support. As a leader in vision ML and AI, we do offer professional consulting, from simple expert advice up to delivery of fully customized, end-to-end production solutions for our clients, such as:

  • Cloud-based AI systems operating on hundreds of HD video streams in realtime.
  • Edge AI integrated into custom iOS and Android apps for realtime 30 FPS video inference.
  • Custom data training, hyperparameter evolution, and model exportation to any destination.

For more information please visit https://www.ultralytics.com.

@glenn-jocher
Copy link
Member

@Laughing-q thanks for looking into this! If you believe you have a solution which --resumes better, please submit a PR with your proposed changes.

--resume is not 100% mature yet, as the EMA is saved during checkpointing but not the normal model, so there would be additional steps to take to create a seamless resume, but fixing this LR bug you seem to have found wout be a huge step in the right direction!

Can you show before and after tensorboard plots for interupted coco128.yaml trainings for example using your fix?

@Laughing-q
Copy link
Member Author

@Laughing-q thanks for looking into this! If you believe you have a solution which --resumes better, please submit a PR with your proposed changes.

--resume is not 100% mature yet, as the EMA is saved during checkpointing but not the normal model, so there would be additional steps to take to create a seamless resume, but fixing this LR bug you seem to have found wout be a huge step in the right direction!

Can you show before and after tensorboard plots for interupted coco128.yaml trainings for example using your fix?

I actually found this problem when I was training my dataset, I resume at about 200/299 epoch, then I found that map50 dropped from 80 to 67. but I delete the results.......
So I retrained my dataset for the experiment,resume at 15/19 epoch. In order to eliminate the influence of burn-in, i set n_burn=-1.
here is the results, blue line is the result with my fix. The precision decreased and losses increased significantly at epoch 15 without my fix. Although the final results of my dataset seem to be little different, I think the blue line is what I expect.
results
I also plot LR changes. Without my fix, LR is clearly reset at epoch15.
lr

@glenn-jocher
Copy link
Member

@Laughing-q hey good job, that's a very good experiment! It looks good to me. Can you submit a PR with the fix please? Thanks!!

@Laughing-q
Copy link
Member Author

@Laughing-q hey good job, that's a very good experiment! It looks good to me. Can you submit a PR with the fix please? Thanks!!

of course

@glenn-jocher
Copy link
Member

I started training a few yolov5s variants just before the PR was merged. One of them the VMs that was training terminated early (for google cloud maintenance I assume), so I used the --resume command to see what it would look like, and this is the result.

This huge drop must be from the LR=0.01 change in the first epoch after resume as you discovered. Hopefully your PR fixed this. From now on in the future new runs will have your fix included, so if this happens again in the next few weeks I'll repeat the experiment to see the new results.

results

@glenn-jocher
Copy link
Member

@Laughing-q ok, I have some more information here. I --resumed 2 models that had quit unexpectedly. These two models should have the updated scheduler PR in place. Unfortunately I still see the same style of recovery.

One major aspect I noticed is that in both cases the train losses recover very quickly, within a few epochs, while the validation losses and performance metrics both take very long to recover. The one thing the val losses and performance metrics have in common is they use the EMA, whereas train losses are displayed for the base model.

This makes sense, as when a new run is started, or if a run is --resumed, a new EMA is created, and it uses a sliding decay constant that starts from zero and trends towards its final value 0.9999 after a few thousand iterations. This means the resumed EMA's are losing their smoothness, they are turning basically into base models, and slowly becoming a smoothed EMA over many epochs.

I think there's an easy way to fix this then without having to include the base model in the checkpoints: just add a flag to the EMA init() function that tells it how many iterations the baseline model has trained for, so it knows what to set the decay constant to. Then training should align much more closely with training before --resume. This decay constant equation is on L199 here:

yolov5/utils/torch_utils.py

Lines 194 to 199 in bf6f415

def __init__(self, model, decay=0.9999, device=''):
# Create EMA
self.ema = deepcopy(model.module if is_parallel(model) else model) # FP32 EMA
self.ema.eval()
self.updates = 0 # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)

results

@Laughing-q
Copy link
Member Author

@glenn-jocher Actually my friend had this problem few days ago, so I opened this issue #310 .Then I did some experiment about it , I saved both base model and EMA. when I resumed, I loaded both base model and EMA(also update the self.updates), but my results seemed to be little different from loading EMA only. From your results, I think that perhaps my epoch was too small(about 25~35) to see the huge drop. I'm glad you found the problem, so is this problem fixed now?

@glenn-jocher
Copy link
Member

@Laughing-q no, the problem is not fixed, but I think smart init of the EMA as described above would help a lot.

And yes you are right, you will only see this later on in training, as the EMA decay function takes about 10-20k iterations to start reaching close to it's steady state 0.9999. On COCO this might be 5-10 epochs or more, so for smaller datasets (like perhaps your example) the problem would not show up. I'll look at #310.

@glenn-jocher glenn-jocher changed the title resume problem --resume issues Jul 9, 2020
@glenn-jocher glenn-jocher added the TODO High priority items label Jul 9, 2020
@glenn-jocher glenn-jocher self-assigned this Jul 9, 2020
@Laughing-q
Copy link
Member Author

Laughing-q commented Jul 9, 2020

@glenn-jocher so what you mean is, add a flag to record the iterations and do not save base model. In this case, maybe the EMA's smoothness will be kept, but the previous EMA model will be used for --resume training. I still think that using base model for --resume training is right if save both of EMA and base model will not cost too much. My point is that base model is base model and EMA is EMA, it's not reasonable to use EMA for training when --resume, cause we use base model for training before --resume. perhaps I think too much, there still need some experiment to do.

@glenn-jocher
Copy link
Member

@Laughing-q I don't really know. I'll try to update the code today with my idea, since its a very small change that may fix a lot. If it doesn't fix things though, then yes we'll have to update the code further to save everything.

@glenn-jocher
Copy link
Member

Here are updated results from the above 2 models. It seems things do eventually return to normal after 20 or so epochs, which is the right amount of time for the EMA decay to reach steady state.

Probably the correct way to think about the EMA is that in the early stages of training it is trailing the model. It is smoother, less noisy, and delayed in time following the model changes. In the later stages of training, it is not trailing anymore, it is more a mean or centered version of the model. As the model oscillates or moves about a local minima the EMA will be at rest near the local minima more or less is how I imagine it. To further complicate it, the decay rise I've added means that the EMA and the model behave very similarly in the very early stages of training, but this is not important really in the final result.

results

@Laughing-q
Copy link
Member Author

@glenn-jocher I see, thanks for your fix. I'll do some experiments with your fix, and I think it'll work. Thanks a lot!

@glenn-jocher
Copy link
Member

@Laughing-q ok great! You probably won't notice any effect for early epochs or small datasets, but for example if you train yolov5s on coco for 10 or 20 epochs, this might hopefully result in better resuming if my idea is right.

@glenn-jocher
Copy link
Member

@Laughing-q I can confirm that --resume is now apparently working near-seamlessly. I resumed a stopped run 266 here at COCO epoch 77, and the effects are hardly noticeable. It looks like our changes worked! I'm going to remove the TODO label from this issue, as I think it is now essentially completely resolved.

results

@glenn-jocher glenn-jocher removed the TODO High priority items label Jul 12, 2020
@Laughing-q
Copy link
Member Author

@glenn-jocher Some results of my experiment also show that --resume is working near-seamlessly. thanks for your work!

@zcode86
Copy link

zcode86 commented Jul 14, 2020

@glenn-jocher Nice to see it, what affected it? EMA?

@glenn-jocher
Copy link
Member

@Laughing-q @iamastar88 strangely enough I had to eliminate the update. I resumed a larger 5l model and observed strange behavior after the resume. The 'seam' was fine, but the model performed much more poorly than it's peers after resuming, and even 30 epochs afterwards was underperforming by several mAP points.

The prior --resume functionality shown in #292 (comment) has a very large discontinuity, but eventually recovers very well, so that after about 20-30 COCO epochs the --resumed model is training identically to it's peers. Not sure why.

@Laughing-q
Copy link
Member Author

@glenn-jocher so is there still a huge drop on mAP? As far as I'm concerned, could you please save base model and EMA then --resume both of them for the same experiment? or there are some other issues we didn't find yet.

@glenn-jocher
Copy link
Member

@Laughing-q yes that's a third option. We'd have to test to see its effect of course, I just don't have time to do this myself.

@Laughing-q
Copy link
Member Author

@glenn-jocher ok, I know you are so busy, maybe I will do some experiments on this, but I don't have such a device to run a huge dataset such as coco. What I can only do is to run my own dataset, but my concern is that my dataset is too small to see the effect and difference. If my result shows some effects and differences, I'll report it on this issue.

@glenn-jocher
Copy link
Member

glenn-jocher commented Jul 15, 2020

@Laughing-q actually some people see mAP jumps on --resume, so for some it's working better than expected #409. It seems to be very dataset specific.

@Laughing-q
Copy link
Member Author

@Laughing-q @iamastar88 strangely enough I had to eliminate the update. I resumed a larger 5l model and observed strange behavior after the resume. The 'seam' was fine, but the model performed much more poorly than it's peers after resuming, and even 30 epochs afterwards was underperforming by several mAP points.

The prior --resume functionality shown in #292 (comment) has a very large discontinuity, but eventually recovers very well, so that after about 20-30 COCO epochs the --resumed model is training identically to it's peers. Not sure why.

@glenn-jocher I did my experiment, ran 250 epochs, then interrupted, and --resume run to 300 epochs. With your fix, Saving EMA only is almost the same as saving both EMA and base model. I think your fix is right and working! Have you ever encountered this issue without your fix? In the later stages of training, EMA should not be trailing anymore, so I think it should be caused by other reasons, not your fix. This issus is strange. sorry, I couldn't run the coco dataset on yolov5l with my device, so I couldn't do this experiment. But I think maybe there is some other issus that we haven't found out. If you find out, please tell me, thanks a lot.

@glenn-jocher
Copy link
Member

Hmm ok! Yeah I had thought I had a good fix setting the Ema updates value, I don’t know why it didn’t work for me.

@Laughing-q
Copy link
Member Author

@glenn-jocher I saw that #300 is missing in the newest code, is this fix covered by something? or the LR bug was solved in other ways?

@Laughing-q
Copy link
Member Author

@glenn-jocher initializing scheduler code which will reset LR is under the loading optimizer code now.

@glenn-jocher
Copy link
Member

@Laughing-q ah, thanks for catching this. I think I may be accepting too many PRs too quickly. Often times PRs are not rebased against the current master and the PR inadvertently reverts earlier changes. It's a huge problem with no easy solution, so I'm being much more careful with larger PRs now.

Ok, about this particular issue, can you submit a new PR please? Sorry for the double effort.

@Laughing-q
Copy link
Member Author

@glenn-jocher I've submitted a new PR, please check it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants