Skip to content

Support Mixtral-8x7B #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed

Conversation

yanboliang
Copy link
Contributor

@yanboliang yanboliang commented Dec 30, 2023

This is based on #57. Please checkout https://github.com/yanboliang/gpt-fast/tree/mixtral-moe to try this.

Performance numbers (tokens/second):

|                  |   1 GPU |    2 GPU  |    8 GPU    |
|------------------|---------|-----------|-------------|
|baseline(bfloat16)|    OOM  |    78.75  |   203.69    |
|        int8      |   56.04 |    99.91  |   218.48    |

Note: Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology.

How to reproduce it:

export MODEL_REPO=mistralai/Mixtral-8x7B-v0.1
# Download model weights
python scripts/download.py --repo_id $MODEL_REPO
# Convert to gpt-fast supported format
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
# Generate int8 quantization model weights
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8
# Test tp=8
ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --compile_prefill --checkpoint_path checkpoints/$MODEL_REPO/model.pth
# Test single GPU + int8 model
python generate.py --compile --compile_prefill --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 30, 2023
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jan 1, 2024
…e_descent_tuning is enabled (#116582)

We found this perf optimization opportunity at pytorch-labs/gpt-fast#71. This would bring 5%+ perf gain for Mixtral 8x7B on gpt-fast.

Pull Request resolved: #116582
Approved by: https://github.com/lezcano
@raghukiran1224
Copy link

does the model compile without any graph breaks?

@yanboliang
Copy link
Contributor Author

@raghukiran1224 Yes, no graph break!

facebook-github-bot pushed a commit to pytorch/executorch that referenced this pull request Jan 6, 2024
Summary:
Pull Request resolved: #1533

Support MoE structure, where there can be multiple experts in the FFN layer.
The change in model.py is based on pytorch-labs/gpt-fast#71.
Note that it's a functional verification, with random weights.
It can be successful exported and lowered to ExecuTorch.

TODO: test the runtime side.

Reviewed By: larryliu0820

Differential Revision: D52543030

fbshipit-source-id: 5d4220f1e8ea9eb1e4be398fe2a47bfb0b89c975
@chauhang
Copy link

@yanboliang Great to see this PR, what is the work remaining for merging? It will help to also update the main Readme Benchmarks to include the model.

@yanboliang
Copy link
Contributor Author

@chauhang I think we need to figure out a structure of how to put this under gpt-fast, probably we need a separate folder. No other blockers, so I'll prioritize this work and hopefully we can merge it in a few days.

@yanboliang
Copy link
Contributor Author

closing this as it has been merged at #105

@yanboliang yanboliang closed this Feb 26, 2024
@guangy10
Copy link

guangy10 commented Mar 7, 2024

@yanboliang It doesn't seem like python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8 will quantize all weights. I end up getting mismatched dtype error when lowering this model to ExecuTorch. After looking into the model_int8.pth, I noticed that there are still weights in bfloat16. Is it expected?

@yanboliang
Copy link
Contributor Author

@guangy10 Yes, it's expected! We don't quantize gate networks to ensure accuracy as they are used to choose experts.

if isinstance(child, nn.Linear) and name != "gate":
setattr(module, name, WeightOnlyBit8Linear(child.in_features, child.out_features, target_dtype=target_dtype))

@yanboliang yanboliang deleted the mixtral-moe branch March 7, 2024 05:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants