VARGPT-v1.1: Improve Visual Autoregressive Large Unified Model via Iterative Instruction Tuning and Reinforcement Learning
Xianwei Zhuang1*
Yuxin Xie1*
Yufan Deng1*
Dongchao Yang2
Liming Liang1
Jinghan Ru1
Yuguo Yin1
Yuexian Zou 1
1 Peking University, 2 The Chinese University of Hong Kong
VARGPTv1-1.mp4
-
[2025-04-7] The technical report is released at https://arxiv.org/pdf/2504.02949.
-
[2025-04-2] We release the more powerful unified model of VARGPT-v1.1 (7B+2B) at VARGPT-v1.1 and the editing model datasets at VARGPT-v1.1-edit. 🔥🔥🔥🔥🔥🔥
-
[2025-04-1] We release the training (SFT and RL), inference and evaluation code of VARGPT-v1.1 and VARGPT at VARGPT-family-training for multimodal understanding and generation including image captioning, visual question answering (VQA), text-to-image generation and visual editing. 🔥🔥🔥🔥🔥🔥
- Release the inference code.
- Release the code for evaluation.
- Release the model checkpoint.
- Supporting stronger visual generation capabilities.
- Release the training datasets.
- Release the training code.
- Release the technical report.
The VARGPT-v1.1 checkpoints can be found on Hugging Face:
The VARGPT-v1.1-edit checkpoints for visual editing can be found on Hugging Face:
The VARGPT checkpoints can be found on Hugging Face:
The instruction for training data can be found on Hugging Face:
First, set up the environment:
pip3 install -r requirements.txt
If there are difficulties in compilation flash attention, you can directly use wheels for installation by flash_attn-2.7.3+cu12torch2.1cxx11abiFALSE-cp39-cp39-linux_x86_64 from https://github.com/Dao-AILab/flash-attention/releases.
The description of code structure:
VARGPTv1.1_code/
├── inference_v1_1 # Inference code for understanding and generation of VARGPT-v1.1.
├── patching_utils # Patching utils for supporting VARGPT.
├── README.md
├── requirements.txt # Requirements for inference code.
├── understand_eval # Evaluation code for understanding.
├── VARGPT-family-training # Training and Inference code for VARGPT and VARGPT-v1.1 (including SFT and RL).
└── vargpt_qwen_v1_1 # Complete model architecture code of VARGPT-v1.1.
Inference demo for Multimodal Understanding. You can execute the following shell:
python3 inference_v1_1/understanding_vargpt_v1_1.py
Or executing the following code:
# Or execute the following code
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, AutoTokenizer
from vargpt_qwen_v1_1.modeling_vargpt_qwen2_vl import VARGPTQwen2VLForConditionalGeneration
from vargpt_qwen_v1_1.prepare_vargpt_v1_1 import prepare_vargpt_qwen2vl_v1_1
from vargpt_qwen_v1_1.processing_vargpt_qwen2_vl import VARGPTQwen2VLProcessor
from patching_utils.patching import patching
model_id = "VARGPT-family/VARGPT-v1.1"
prepare_vargpt_qwen2vl_v1_1(model_id)
model = VARGPTQwen2VLForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
).to(0)
patching(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = VARGPTQwen2VLProcessor.from_pretrained(model_id)
# Define a chat history and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image")
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please explain the meme in detail."},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
image_file = "./assets/llava_bench_demo.png"
print(prompt)
raw_image = Image.open(image_file)
inputs = processor(images=[raw_image], text=prompt, return_tensors='pt').to(0, torch.float32)
output = model.generate(
**inputs,
max_new_tokens=2048,
do_sample=False)
print(processor.decode(output[0], skip_special_tokens=True))
Inference demo for Text-to-Image Generation. You can execute the following shell:
python3 inference_v1_1/generation_vargpt_v1_1.py
Or executing the following code:
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, AutoTokenizer
from vargpt_qwen_v1_1.modeling_vargpt_qwen2_vl import VARGPTQwen2VLForConditionalGeneration
from vargpt_qwen_v1_1.prepare_vargpt_v1_1 import prepare_vargpt_qwen2vl_v1_1
from vargpt_qwen_v1_1.processing_vargpt_qwen2_vl import VARGPTQwen2VLProcessor
from patching_utils.patching import patching
model_id = "VARGPT-family/VARGPT-v1.1"
prepare_vargpt_qwen2vl_v1_1(model_id)
model = VARGPTQwen2VLForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
).to(0)
patching(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = VARGPTQwen2VLProcessor.from_pretrained(model_id)
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "Can you depict a scene of A power metalalbum cover featuring a fantasy-style illustration witha white falcon."},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
print(prompt)
inputs = processor(text=prompt, return_tensors='pt').to(0, torch.float32)
model._IMAGE_GEN_PATH = "output.png"
output = model.generate(
**inputs,
max_new_tokens=4096,
do_sample=False)
print(processor.decode(output[0][:-1], skip_special_tokens=True))
Following lmms-eval you can install the package by cloning the repository and running the following command:
cd understand_eval
pip install -e .
To evaluate the performance of the model in multimodel understanding, you can execute the following shell:
python3 -m accelerate.commands.launch \
--num_processes=8 \
--main_process_port=39535 \
-m lmms_eval \
--model vargpt_qwen2vl_v1_1 \
--model_args pretrained="path/to/VARGPT_v1-1" \
--tasks mmmu \
--batch_size 1 \
--log_samples \
--log_samples_suffix llava-hf_mmmu \
--output_path ./logs/
For the GenEval
benchmark, you can prepare the runtime environment according to GenEval and use prompts from Infinity , then perform sampling using our batch inference code provided in VARGPT-family-training
.
For the DPG-Bench
benchmark, prepare the runtime environment and prompts according to ELLA , and execute sampling using our batch inference code provided in VARGPT-family-training
.
The following content provides detailed instructions for preparing the training data for VARGPT. The data preparation process involves downloading and processing various datasets for different stages of training.
-
stage1-pt
: Contains 8.3M pre-training instruction fine-tuning dataset for VARGPT-v1.1. -
stage2-sft
: Includes datasets for the second stage of VARGPT instruction fine-tuning:stage2-sft/llava_v1_5_mix665k
: Derived entirely from LLaVA-1.5 training data.stage2-sft/llava_onevision_508k
: Sampled from the LLaVA-onevision Dataset.stage2-sft/50k_generation
: Sampled from our 8.3M dataset.
-
stage3-sft
: Contains datasets for the third stage of VARGPT-v1.1 instruction fine-tuning from JourneyDB and laion-coco-aesthetic
Our training code is implemented through LLaMA-Factory, so you can follow LLaMA factory to achieve flexible configuration.
You can follow the training environment with the following command:
cd VARGPT-family-training
pip install -e ".[metrics]" torch==2.1.0
You can achieve batch evaluation and training of the model through the corresponding script in run_scripts
:
- You can perform training SFT using the provided demo dataset and examples:
cd VARGPT-family-training
bash run_scripts/run_vargpt_qwen2_1_1_sft.sh
- You can perform batch evaluations, including evaluations of image generation and image editing:
cd VARGPT-family-training
bash run_scripts/run_eval_vargpt_v1_1.sh
bash run_scripts/run_eval_vargpt_v1_1_edit.sh
To cite the paper and model, please use the below:
@misc{zhuang2025vargptunifiedunderstandinggeneration,
title={VARGPT: Unified Understanding and Generation in a Visual Autoregressive Multimodal Large Language Model},
author={Xianwei Zhuang and Yuxin Xie and Yufan Deng and Liming Liang and Jinghan Ru and Yuguo Yin and Yuexian Zou},
year={2025},
eprint={2501.12327},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2501.12327},
}
@misc{zhuang2025vargptv11improvevisualautoregressive,
title={VARGPT-v1.1: Improve Visual Autoregressive Large Unified Model via Iterative Instruction Tuning and Reinforcement Learning},
author={Xianwei Zhuang and Yuxin Xie and Yufan Deng and Dongchao Yang and Liming Liang and Jinghan Ru and Yuguo Yin and Yuexian Zou},
year={2025},
eprint={2504.02949},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2504.02949},
}
This work is heavily based on LLaVA-1.5, VAR, Qwen2-VL, LLaVA-NeXT, lmms-eval, Show-o, LLaMA-Factory, Infinity, CLIP, transformers-hf. Thanks to all the authors for their great work.