Skip to content

VARGPT-v1.1: Improve Visual Autoregressive Large Unified Model via Iterative Instruction Tuning and Reinforcement Learning

License

Notifications You must be signed in to change notification settings

VARGPT-family/VARGPT-v1.1

Repository files navigation


VARGPT-v1.1: Improve Visual Autoregressive Large Unified Model via Iterative Instruction Tuning and Reinforcement Learning

Xianwei Zhuang1*  Yuxin Xie1*  Yufan Deng1*  Dongchao Yang2 
Liming Liang1  Jinghan Ru1  Yuguo Yin1  Yuexian Zou 1 

1 Peking University, 2 The Chinese University of Hong Kong

ArXiv ArXiv Webpage Model Model Dataset License WeChat badge

VARGPTv1-1.mp4

News

  • [2025-04-7] The technical report is released at https://arxiv.org/pdf/2504.02949.

  • [2025-04-2] We release the more powerful unified model of VARGPT-v1.1 (7B+2B) at VARGPT-v1.1 and the editing model datasets at VARGPT-v1.1-edit. 🔥🔥🔥🔥🔥🔥

  • [2025-04-1] We release the training (SFT and RL), inference and evaluation code of VARGPT-v1.1 and VARGPT at VARGPT-family-training for multimodal understanding and generation including image captioning, visual question answering (VQA), text-to-image generation and visual editing. 🔥🔥🔥🔥🔥🔥

What is the new about VARGPT-v1.1?

Compared with VARGPT, VARGPT-v1.1 has achieved comprehensive capability improvement. VARGPT-v1.1 integrates: (1) a novel training strategy combining iterative visual instruction tuning with reinforcement learning through Direct Preference Optimization (DPO), (2) an expanded training corpus containing 8.3M visual-generative instruction pairs, (3) an upgraded language backbone using Qwen2, (4) enhanced image generation resolution, and (5) emergent image editing capabilities without architectural modifications.


TODO

  • Release the inference code.
  • Release the code for evaluation.
  • Release the model checkpoint.
  • Supporting stronger visual generation capabilities.
  • Release the training datasets.
  • Release the training code.
  • Release the technical report.

Hugging Face models and annotations

The VARGPT-v1.1 checkpoints can be found on Hugging Face:

The VARGPT-v1.1-edit checkpoints for visual editing can be found on Hugging Face:

The VARGPT checkpoints can be found on Hugging Face:

The instruction for training data can be found on Hugging Face:

Getting Started

First, set up the environment:

pip3 install -r requirements.txt

If there are difficulties in compilation flash attention, you can directly use wheels for installation by flash_attn-2.7.3+cu12torch2.1cxx11abiFALSE-cp39-cp39-linux_x86_64 from https://github.com/Dao-AILab/flash-attention/releases.

The description of code structure:

VARGPTv1.1_code/
├── inference_v1_1           # Inference code for understanding and generation of VARGPT-v1.1.
├── patching_utils      # Patching utils for supporting VARGPT.
├── README.md
├── requirements.txt    # Requirements for inference code.
├── understand_eval     # Evaluation code for understanding.
├── VARGPT-family-training     # Training and Inference code for VARGPT and VARGPT-v1.1 (including SFT and RL).
└── vargpt_qwen_v1_1        # Complete model architecture code of VARGPT-v1.1.

Multimodal Understanding

Inference demo for Multimodal Understanding. You can execute the following shell:

python3 inference_v1_1/understanding_vargpt_v1_1.py

Or executing the following code:

# Or execute the following code
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, AutoTokenizer
from vargpt_qwen_v1_1.modeling_vargpt_qwen2_vl import VARGPTQwen2VLForConditionalGeneration
from vargpt_qwen_v1_1.prepare_vargpt_v1_1 import prepare_vargpt_qwen2vl_v1_1 
from vargpt_qwen_v1_1.processing_vargpt_qwen2_vl import VARGPTQwen2VLProcessor
from patching_utils.patching import patching

model_id = "VARGPT-family/VARGPT-v1.1"

prepare_vargpt_qwen2vl_v1_1(model_id)

model = VARGPTQwen2VLForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float32, 
    low_cpu_mem_usage=True, 
).to(0)

patching(model)

tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = VARGPTQwen2VLProcessor.from_pretrained(model_id)

# Define a chat history and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image") 
conversation = [
    {
      "role": "user",
      "content": [
          {"type": "text", "text": "Please explain the meme in detail."},
          {"type": "image"},
        ],
    },
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
image_file = "./assets/llava_bench_demo.png"
print(prompt)

raw_image = Image.open(image_file)
inputs = processor(images=[raw_image], text=prompt, return_tensors='pt').to(0, torch.float32)

output = model.generate(
    **inputs, 
    max_new_tokens=2048, 
    do_sample=False)

print(processor.decode(output[0], skip_special_tokens=True))

Multimodal Generation

Inference demo for Text-to-Image Generation. You can execute the following shell:

python3 inference_v1_1/generation_vargpt_v1_1.py

Or executing the following code:

import requests
from PIL import Image
import torch
from transformers import AutoProcessor, AutoTokenizer
from vargpt_qwen_v1_1.modeling_vargpt_qwen2_vl import VARGPTQwen2VLForConditionalGeneration
from vargpt_qwen_v1_1.prepare_vargpt_v1_1 import prepare_vargpt_qwen2vl_v1_1 
from vargpt_qwen_v1_1.processing_vargpt_qwen2_vl import VARGPTQwen2VLProcessor
from patching_utils.patching import patching
model_id = "VARGPT-family/VARGPT-v1.1"

prepare_vargpt_qwen2vl_v1_1(model_id)

model = VARGPTQwen2VLForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float32,     
    low_cpu_mem_usage=True, 
).to(0)

patching(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = VARGPTQwen2VLProcessor.from_pretrained(model_id)

conversation = [
    {
      "role": "user",
      "content": [
          {"type": "text", "text": "Can you depict a scene of A power metalalbum cover featuring a fantasy-style illustration witha white falcon."},
        ],
    },
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
print(prompt)

inputs = processor(text=prompt, return_tensors='pt').to(0, torch.float32)
model._IMAGE_GEN_PATH = "output.png"
output = model.generate(
    **inputs, 
    max_new_tokens=4096, 
    do_sample=False)

print(processor.decode(output[0][:-1], skip_special_tokens=True))

Performance Evaluation

Following lmms-eval you can install the package by cloning the repository and running the following command:

cd understand_eval
pip install -e .

To evaluate the performance of the model in multimodel understanding, you can execute the following shell:

python3 -m accelerate.commands.launch \
    --num_processes=8 \
    --main_process_port=39535 \
    -m lmms_eval \
    --model vargpt_qwen2vl_v1_1 \
    --model_args pretrained="path/to/VARGPT_v1-1" \
    --tasks mmmu \
    --batch_size 1 \
    --log_samples \
    --log_samples_suffix llava-hf_mmmu \
    --output_path ./logs/

For the GenEval benchmark, you can prepare the runtime environment according to GenEval and use prompts from Infinity , then perform sampling using our batch inference code provided in VARGPT-family-training.

For the DPG-Bench benchmark, prepare the runtime environment and prompts according to ELLA , and execute sampling using our batch inference code provided in VARGPT-family-training.

VARGPT-v1.1 Training Data Preparation

Instruction Fine-tuning Dataset Download

The following content provides detailed instructions for preparing the training data for VARGPT. The data preparation process involves downloading and processing various datasets for different stages of training.

Dataset Structure

  1. stage1-pt: Contains 8.3M pre-training instruction fine-tuning dataset for VARGPT-v1.1.

  2. stage2-sft: Includes datasets for the second stage of VARGPT instruction fine-tuning:

    • stage2-sft/llava_v1_5_mix665k: Derived entirely from LLaVA-1.5 training data.
    • stage2-sft/llava_onevision_508k: Sampled from the LLaVA-onevision Dataset.
    • stage2-sft/50k_generation: Sampled from our 8.3M dataset.
  3. stage3-sft: Contains datasets for the third stage of VARGPT-v1.1 instruction fine-tuning from JourneyDB and laion-coco-aesthetic

VARGPT-v1.1 Training

Our training code is implemented through LLaMA-Factory, so you can follow LLaMA factory to achieve flexible configuration.

Setting Up

You can follow the training environment with the following command:

cd VARGPT-family-training
pip install -e ".[metrics]" torch==2.1.0

Training

You can achieve batch evaluation and training of the model through the corresponding script in run_scripts:

  • You can perform training SFT using the provided demo dataset and examples:
cd VARGPT-family-training
bash run_scripts/run_vargpt_qwen2_1_1_sft.sh
  • You can perform batch evaluations, including evaluations of image generation and image editing:
cd VARGPT-family-training
bash run_scripts/run_eval_vargpt_v1_1.sh
bash run_scripts/run_eval_vargpt_v1_1_edit.sh

Citation

To cite the paper and model, please use the below:

@misc{zhuang2025vargptunifiedunderstandinggeneration,
      title={VARGPT: Unified Understanding and Generation in a Visual Autoregressive Multimodal Large Language Model}, 
      author={Xianwei Zhuang and Yuxin Xie and Yufan Deng and Liming Liang and Jinghan Ru and Yuguo Yin and Yuexian Zou},
      year={2025},
      eprint={2501.12327},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2501.12327}, 
}
@misc{zhuang2025vargptv11improvevisualautoregressive,
      title={VARGPT-v1.1: Improve Visual Autoregressive Large Unified Model via Iterative Instruction Tuning and Reinforcement Learning}, 
      author={Xianwei Zhuang and Yuxin Xie and Yufan Deng and Dongchao Yang and Liming Liang and Jinghan Ru and Yuguo Yin and Yuexian Zou},
      year={2025},
      eprint={2504.02949},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2504.02949}, 
}

Acknowledgments

This work is heavily based on LLaVA-1.5, VAR, Qwen2-VL, LLaVA-NeXT, lmms-eval, Show-o, LLaMA-Factory, Infinity, CLIP, transformers-hf. Thanks to all the authors for their great work.

Github Star History

Star History Chart

About

VARGPT-v1.1: Improve Visual Autoregressive Large Unified Model via Iterative Instruction Tuning and Reinforcement Learning

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published