-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Model checkpointed using torch.save() unable to be loaded using torch.load() #12042
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
PyTorch version:
Python version:
|
cc @ezyang |
Would it be possible to upload the checkpoint file somewhere, so we can look at it? (Or, if you can provide a script which generates the checkpoint file, that would work too.) |
Thanks! Looking into it. |
Also, looks like you're training imagenet; can you make the code available to repro, if possible? |
Which part in particular? Here's the model: https://github.com/diux-dev/imagenet18/blob/086675a6df3d468e89c651ae4c75f31e5b3f381d/training/resnet.py Here's how the code is launched: https://github.com/diux-dev/imagenet18/tree/59a8f25171fb8cede51db9187a32fc8f802384a0 |
An easier to use version of the code is here: https://github.com/stanford-futuredata/pytorch-distributed/blob/master/train.py You can reproduce using Also, to answer your previous question, you're right that this isn't a .tar file -- it was just named with a |
I can reproduce the failure on load. Still investigating. |
One data point: the file descriptor at the time of error is misaligned:
|
@deepakn94 Does this repro if you run it on only one node? Basically, I want the model to be as small as possible while still reproducing the error. |
The model serializes and deserializes fine when run on one node. |
Here's another datapoint: serialization and deserialization seems to work fine for 4 nodes when using PyTorch 0.4.0 |
OK, I'm reading the serialization code, and I think I see an incorrect use of the |
Please recompile PyTorch with the following patch, which will fix write-time corruption.
I am not 100% sure this will fix the problem, I need to audit the rest of the sites now. |
Okay, thanks. There's no way to salvage the existing checkpoints, right? |
If the patch above fixes the problem, no, they're irretrievably corrupted. |
Okay. I'll a little busy for the next day or two, but will check this patch over the weekend. |
Should I apply this patch to current master? Or to the old commit we were using? Also, seems like PyTorch 0.4.0 on 16 machines doesn't work. |
I authored this patch on master, but it should backport to older versions too. Perhaps it would be better to backport to the old commit to get a cleaner test. |
That unfortunately didn't work. I applied the patch to the old commit (
Uploaded checkpoint here: https://s3.amazonaws.com/distributed-pytorch-imagenet-runs/imagenet-16-new/run1/model_best.pth.tar |
I can't read the updated checkpoint (no permissions). I have a more complete patch which also fixes an underrun on reads, but it doesn't catch any more write side errors, so it must be a different bug. |
… cases. Previously, doRead/doWrite were functions that could return partial reads/writes, and we checked for this case inconsistently in the call sites of serialization.cpp. Now, these functions do NOT return the amount of bytes read/written, and instead handle the necessary checking loop themselves. Fixes pytorch#12042. Signed-off-by: Edward Z. Yang <ezyang@fb.com>
I updated the permissions on the checkpoint. |
Thanks. Confirmed that it still seems to be a write side bug. I guess I'll have to figure something else out... Any luck minimizing the repro? |
This shouldn't be closed, right? I'm working on a smaller repro -- I suspect that running this on any distributed setup causes this, but will confirm sometime over the weekend. |
the closing was an accident, sorry about that. reopened the issue |
… cases. Previously, doRead/doWrite were functions that could return partial reads/writes, and we checked for this case inconsistently in the call sites of serialization.cpp. Now, these functions do NOT return the amount of bytes read/written, and instead handle the necessary checking loop themselves. Fixes pytorch#12042. Signed-off-by: Edward Z. Yang <ezyang@fb.com>
So... I tried saving and loading on the small distributed setup we have in our test suite, and I got a very similar looking error: "EOFError: Ran out of input". I'll work on debugging this case. Branch I'm testing off of is https://github.com/ezyang/pytorch/tree/test/semicharmed-kind-of-life using EDIT: Never mind! I forgot to seek the file back to the beginning before reading it out again. |
If the problem only appears in the distributed setting... are you sure that all processes aren't writing to the same file at the same time? That would corrupt it for sure. |
The script seems to only write from |
Yup, I don't think that's the problem -- only the "master" worker should write the checkpoint. The bug seems to be non-deterministic, because I do have a single 4-machine run that succeeded (along with perhaps 20 failures). |
@deepakn94 I haven't tried to get the script to run for me, but another thing to try: when you get to the save point, save the model multiple times; like, 8 times. We can then compare them and see if they're all corrupted identically, or some of them are ok, etc. |
Links are of the form https://s3.amazonaws.com/distributed-pytorch-imagenet-runs/multi-checkpoint/model_best.0.pth.tar (replace 0 with numbers from 0 to 7) |
This is actually interesting; looks like some of the checkpoints are corrupted identically, but most are different (and one of the eight checkpoints is not corrupted).
|
I'm not going to get around to looking at this until the work week, but my plan is to do a binary comparison on the checkpoints and see where they diverge, and what kind of corruption is happening, and then check which particular part of the serialization code was writing out that part of the file. |
Sounds good. Let me know if there's anything else you need from my side! (is the easier-to-produce test case still useful?) |
Sigh, I think I found the problem. It seems like |
Aw man, that sounds like a good one for the docs. Very happy you figured it out :) |
Verified that this is indeed the case -- closing this. Thanks for all the help! |
@deepakn94 If you don't mind me asking, what change did you make to solve the problem? IIUC, you were writing to a network filesystem for the checkpoints; did you just make them stop writing to NFS? |
I added a |
Note that the launch utility |
I have similar error when I only load pretrained model. The problem does not occur if only one process is loading the model. |
I have created a PyTorch model checkpoint using
torch.save
; however, I'm unable to load this model usingtorch.load
. I run into the following error:The model was saved using code like this:
The model was trained across multiple
p3.16xlarge
instances.The text was updated successfully, but these errors were encountered: