Skip to content

Latest commit

 

History

History
187 lines (126 loc) · 7.8 KB

README.md

File metadata and controls

187 lines (126 loc) · 7.8 KB

🌏 iVideoGPT: Interactive VideoGPTs are Scalable World Models (NeurIPS 2024)

[Project Page] [Paper] [Models] [Poster] [Slides] [Blog (In Chinese)]

This repo provides official code and checkpoints for iVideoGPT, a generic and efficient world model architecture that has been pre-trained on millions of human and robotic manipulation trajectories.

architecture

🔥 News

  • 🚩 2024.11.01: NeurIPS 2024 camera-ready version is released on arXiv.
  • 🚩 2024.09.26: iVideoGPT has been accepted by NeurIPS 2024, congrats!
  • 🚩 2024.08.31: Training code is released.
  • 🚩 2024.05.31: Project website with video samples is released.
  • 🚩 2024.05.30: Model pre-trained on Open X-Embodiment and inference code are released.
  • 🚩 2024.05.27: Our paper is released on arXiv.

🛠️ Installation

conda create -n ivideogpt python==3.9
conda activate ivideogpt
pip install -r requirements.txt

To evaluate the FVD metric, download the pretrained I3D model into pretrained_models/i3d/i3d_torchscript.pt.

🤗 Models

At the moment we provide the following pre-trained models:

Model Resolution Action-conditioned Goal-conditioned Tokenizer Size Transformer Size
ivideogpt-oxe-64-act-free 64x64 No No 114M 138M
ivideogpt-oxe-64-act-free-medium 64x64 No No 114M 436M
ivideogpt-oxe-64-goal-cond 64x64 No Yes 114M 138M
ivideogpt-oxe-256-act-free 256x256 No No 310M 138M

If no network connection to Hugging Face, you can manually download from Tsinghua Cloud.

Notes:

  • Due to the heterogeneity of action spaces, we currently do not have an action-conditioned prediction model on OXE.
  • Pre-trained models at 256x256 resolution may not perform best due to insufficient training, but can serve as a good starting point for downstream fine-tuning.
More models on downstream tasks
Model Resolution Action-conditioned Goal-conditioned Tokenizer Size Transformer Size
ivideogpt-bair-64-act-free 64x64 No No 114M 138M
ivideogpt-bair-64-act-cond 64x64 Yes No 114M 138M
ivideogpt-robonet-64-act-cond 64x64 Yes No 114M 138M
  • We are sorry that the checkpoints for RoboNet at 256x256 resolution were deleted by mistake during a disk cleanup, we will retrain and release them as soon as possible!

📦 Data Preparation

Open X-Embodiment: Download datasets from Open X-Embodiment and extract single episodes as .npz files:

python datasets/oxe_data_converter.py --dataset_name {dataset name, e.g. bridge} --input_path {path to downloaded OXE} --output_path {path to stored npz}

To replicate our pre-training on OXE, you need to extract all datasets listed under OXE_SELECT in ivideogpt/data/dataset_mixes.py.

See instructions at datasets on preprocessing more datasets.

🚀 Inference Examples

For action-free video prediction on Open X-Embodiment, run:

python inference/predict.py --pretrained_model_name_or_path "thuml/ivideogpt-oxe-64-act-free" --input_path inference/samples/fractal_sample.npz --dataset_name fractal20220817_data

See more examples at inference.

🌟 Pre-training

To pre-train iVideoGPT, adjust the arguments in the command below as needed and run:

bash ./scripts/pretrain/ivideogpt-oxe-64-act-free.sh

See more scripts for pre-trained models at scripts/pretrain.

🎇 Fine-tuning Video Prediction

Finetuning Tokenizer

After preparing the BAIR dataset, run the following:

accelerate launch train_tokenizer.py \
    --exp_name bair_tokenizer_ft --output_dir log_vqgan --seed 0 --mixed_precision bf16 \
    --model_type ctx_vqgan \
    --train_batch_size 16 --gradient_accumulation_steps 1 --disc_start 1000005 \
    --oxe_data_mixes_type bair --resolution 64 --dataloader_num_workers 16 \
    --rand_select --video_stepsize 1 --segment_horizon 16 --segment_length 8 --context_length 1 \
    --pretrained_model_name_or_path pretrained_models/ivideogpt-oxe-64-act-free/tokenizer \
    --max_train_steps 200005

Finetuning Transformer

For action-conditioned video prediction, run the following:

accelerate launch train_gpt.py \
    --exp_name bair_llama_ft --output_dir log_trm --seed 0 --mixed_precision bf16 \
    --vqgan_type ctx_vqgan \
    --pretrained_model_name_or_path {log directory of finetuned tokenizer}/unwrapped_model \
    --config_name configs/llama/config.json --load_internal_llm --action_conditioned --action_dim 4 \
    --pretrained_transformer_path pretrained_models/ivideogpt-oxe-64-act-free/transformer \
    --per_device_train_batch_size 16 --gradient_accumulation_steps 1 \
    --learning_rate 1e-4 --lr_scheduler_type cosine \
    --oxe_data_mixes_type bair --resolution 64 --dataloader_num_workers 16 \
    --video_stepsize 1 --segment_length 16 --context_length 1 \
    --use_eval_dataset --use_fvd --use_frame_metrics \
    --weight_decay 0.01 --llama_attn_drop 0.1 --embed_no_wd \
    --max_train_steps 100005

For action-free video prediction, remove --load_internal_llm --action_conditioned.

See more scripts at scripts/finetune.

Evaluation

To evaluate the checkpoints only, run:

bash ./scripts/evaluation/bair-64-act-cond.sh

See more scripts for released checkpoints at scripts/evaluation.

🤖 Visual Control

Visual Model-based RL

Install the Metaworld version we used:

pip install git+https://github.com/Farama-Foundation/Metaworld.git@83ac03ca3207c0060112bfc101393ca794ebf1bd

Modify paths in mbrl/cfgs/mbpo_config.yaml to your own paths (currently only support absolute paths).

Run model-based RL with iVideoGPT:

python mbrl/train_metaworld_mbpo.py task=plate_slide num_train_frames=100002 demo=true

Visual Planning

See vp for detailed instructions.

🎥 Showcases

showcase

📜 Citation

If you find this project useful, please cite our paper as:

@inproceedings{wu2024ivideogpt,
    title={iVideoGPT: Interactive VideoGPTs are Scalable World Models}, 
    author={Jialong Wu and Shaofeng Yin and Ningya Feng and Xu He and Dong Li and Jianye Hao and Mingsheng Long},
    booktitle={Advances in Neural Information Processing Systems},
    year={2024},
}

🤝 Contact

If you have any question, please contact [email protected].

💡 Acknowledgement

Our codebase is based on huggingface/diffusers and facebookresearch/drqv2.