|
||
---|---|---|
.asset | ||
configs | ||
groundingdino_new | ||
knowledge | ||
maskrcnn_benchmark | ||
odinw | ||
tools | ||
utils | ||
.gitignore | ||
DATA.md | ||
DEBUG.md | ||
LICENSE | ||
README.md | ||
init.sh | ||
requirements.txt | ||
setup.py | ||
setup_glip.py |
README.md
MQ-Det: Multi-modal Queried Object Detection in the Wild (NeurIPS2023)
Official PyTorch implementation of "MQ-Det: Multi-modal Queried Object Detection in the Wild": the first multi-modal queried open-set object detector.
Citation
If you find our work useful in your research, please consider citing:
@article{mqdet,
title={Multi-modal queried object detection in the wild},
author={Xu, Yifan and Zhang, Mengdan and Fu, Chaoyou and Chen, Peixian and Yang, Xiaoshan and Li, Ke and Xu, Changsheng},
journal={Advances in Neural Information Processing Systems},
year={2023}
}
Multi-modal Queried Object Detection
We introduce MQ-Det, an efficient architecture and pre-training strategy design to utilize both textual description with open-set generalization and visual exemplars with rich description granularity as category queries, namely, Multi-modal Queried object Detection, for real-world detection with both open-vocabulary categories and various granularity.

Method
MQ-Det incorporates vision queries into existing well-established language-queried-only detectors.
Features:
- A plug-and-play gated class-scalable perceiver module upon the frozen detector.
- A vision conditioned masked language prediction strategy.
- Compatible with most language-queried object detectors.
Preparation
Environment. Init the environment:
git clone https://github.com/YifanXu74/MQ-Det.git
cd MQ-Det
conda create -n mqdet python=3.9 -y
conda activate mqdet
bash init.sh
The implementation environment in the paper is python==3.9, torch==2.0.1, GCC==8.3.1, CUDA==11.7. Several potential errors and their solutions are presented in DEBUG.md
Data. Prepare Objects365
(for modulated pre-training), LVIS
(for evaluation), and ODinW
(for evaluation) benchmarks following DATA.md.
Initial Weight. MQ-Det is build upon frozen language-queried detectors. To conduct modulated pre-training, download corresponding pre-trained model weights first.
We apply MQ-Det on GLIP and GroundingDINO:
GLIP-T:
wget https://huggingface.co/GLIPModel/GLIP/resolve/main/glip_tiny_model_o365_goldg_cc_sbu.pth -O MODEL/glip_tiny_model_o365_goldg_cc_sbu.pth
GLIP-L:
wget https://huggingface.co/GLIPModel/GLIP/resolve/main/glip_large_model.pth -O MODEL/glip_large_model.pth
GroundingDINO-T:
wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth -O MODEL/groundingdino_swint_ogc.pth
If the links fail, please manually download corresponding weights from the following table or the github pages of GLIP/GroundingDINO.
GLIP-T | GLIP-L | GroundingDINO-T |
---|---|---|
weight | weight | weight |
The weight files should be placed as follows:
MODEL/
glip_tiny_model_o365_goldg_cc_sbu.pth
glip_large_model.pth
groundingdino_swint_ogc.pth
Model Zoo
The table reports the finetuning-free performance with 5 vision queries.
Model | LVIS MiniVal | LVIS Val v1.0 | ODinW-13 | ODinW-35 | Config | Weight |
---|---|---|---|---|---|---|
MQ-GLIP-T | 30.4 | 22.6 | 45.6 | 20.8 | config | weight |
MQ-GLIP-L | 43.4 | 34.7 | 54.1 | 23.9 | config | weight |
Vision Query Extraction
Take MQ-GLIP-T as an example.
If you wish to extract vision queries from custom dataset, specify the DATASETS.TRAIN
in the config file.
We provide some examples in our implementation in the following.
Objects365 for modulated pre-training:
python tools/extract_vision_query.py --config_file configs/pretrain/mq-glip-t.yaml --dataset objects365 --add_name tiny
This will generate a query bank file in MODEL/object365_query_5000_sel_tiny.pth
LVIS for downstream tasks:
python tools/extract_vision_query.py --config_file configs/pretrain/mq-glip-t.yaml --dataset lvis --num_vision_queries 5 --add_name tiny
This will generate a query bank file in MODEL/lvis_query_5_pool7_sel_tiny.pth
.
ODinW for downstream tasks:
python tools/extract_vision_query.py --config_file configs/pretrain/mq-glip-t.yaml --dataset odinw-13 --num_vision_queries 5 --add_name tiny
This will generate query bank files for each dataset in ODinW in MODEL/{dataset}_query_5_pool7_sel_tiny.pth
.
Some paramters corresponding to the query extraction:
DATASETS.FEW_SHOT
: if set k>0
, the dataset will be subsampled to k-shot for each category when initializing the dataset. This is completed before training. Not used during pre-training.
VISION_QUERY.MAX_QUERY_NUMBER
: the max number of vision queries for each category when extracting the query bank. Note that the query extraction is conducted before training and evaluation.
VISION_QUERY.NUM_QUERY_PER_CLASS
controls how many queries to provide for each category during one forward process in training and evaluation.
Usually, we set
VISION_QUERY.MAX_QUERY_NUMBER=5000
, VISION_QUERY.NUM_QUERY_PER_CLASS=5
, DATASETS.FEW_SHOT=0
during pre-training.
VISION_QUERY.MAX_QUERY_NUMBER=5
, VISION_QUERY.NUM_QUERY_PER_CLASS=5
, DATASETS.FEW_SHOT=5
during few-shot (5-shot) fine-tuning.
--num_vision_queries
denotes number of vision queries for each category, and can be an arbitrary number. This will set both VISION_QUERY.MAX_QUERY_NUMBER
and DATASETS.FEW_SHOT
to num_vision_queries
.
Note that here DATASETS.FEW_SHOT
is only for accelerating the extraction process.
--add_name
is only a mark for different models.
For training/evaluating with MQ-GLIP-T/MQ-GLIP-L/MQ-GroundingDINO, we set --add_name
to 'tiny'/'large'/'gd'.
Modulated Training
Take MQ-GLIP-T as an example.
python -m torch.distributed.launch --nproc_per_node=8 tools/train_net.py --config-file configs/pretrain/mq-glip-t.yaml --use-tensorboard OUTPUT_DIR 'OUTPUT/MQ-GLIP-TINY/'
To conduct pre-training, one should first extract vision queries before start training following the above instruction.
To pre-train on custom datasets, please specify DATASETS.TRAIN
and VISION_SUPPORT.SUPPORT_BANK_PATH
in the config file.
Finetuning-free Evaluation
Take MQ-GLIP-T as an example.
LVIS Evaluation
MiniVal:
python -m torch.distributed.launch --nproc_per_node=4 \
tools/test_grounding_net.py \
--config-file configs/pretrain/mq-glip-t.yaml \
--additional_model_config configs/vision_query_5shot/lvis_minival.yaml \
VISION_QUERY.QUERY_BANK_PATH MODEL/lvis_query_5_pool7_sel_tiny.pth \
MODEL.WEIGHT ${model_weight_path} \
TEST.IMS_PER_BATCH 4
Val 1.0:
python -m torch.distributed.launch --nproc_per_node=4 \
tools/test_grounding_net.py \
--config-file configs/pretrain/mq-glip-t.yaml \
--additional_model_config configs/vision_query_5shot/lvis_val.yaml \
VISION_QUERY.QUERY_BANK_PATH MODEL/lvis_query_5_pool7_sel_tiny.pth \
MODEL.WEIGHT ${model_weight_path} \
TEST.IMS_PER_BATCH 4
Please follow the above instruction to extract corresponding vision queries. Note that --nproc_per_node
must equal to TEST.IMS_PER_BATCH
.
ODinW / Custom Dataset Evaluation
python tools/eval_odinw.py --config_file configs/pretrain/mq-glip-t.yaml \
--opts 'MODEL.WEIGHT ${model_weight_path}' \
--setting finetuning-free \
--add_name tiny \
--log_path 'OUTPUT/odinw_log/'
The results are stored at OUTPUT/odinw_log/
.
If you wish to use custom vision queries or datasets, specify --task_config
and --custom_bank_path
. The task_config
should be like the ones in ODinW configs. The custom_bank_path
should be extracted following the instruction. For example,
python tools/eval_odinw.py --config_file configs/pretrain/mq-glip-t.yaml \
--opts 'MODEL.WEIGHT ${model_weight_path}' \
--setting finetuning-free \
--add_name tiny \
--log_path 'OUTPUT/custom_log/'
--task_config ${custom_config_path}
--custom_bank_path ${custom_bank_path}
Fine-Tuning
Take MQ-GLIP-T as an example.
python tools/eval_odinw.py --config_file configs/pretrain/mq-glip-t.yaml \
--opts 'MODEL.WEIGHT ${model_weight_path}' \
--setting 3-shot \
--add_name tiny \
--log_path 'OUTPUT/odinw_log/'
This command will first automatically extract the vision query bank from the (few-shot) training set. Then conduct fine-tuning.
If you wish to use custom vision queries, add 'VISION_QUERY.QUERY_BANK_PATH custom_bank_path'
to the --opts
argment, and also modify the dataset_configs
in the tools/eval_odinw.py
.
If set VISION_QUERY.QUERY_BANK_PATH
to ''
, the model will automatically extract the vision query bank from the (few-shot) training set before fine-tuning.
Single-Modal Evaluation
Here we provide introduction on utilizing single modal queries, such as visual exemplars or textual description.
Follow the command as in Finetuning-free Evaluation
. But set the following hyper-parameters.
To solely use vision queries, add hyper-parameters:
VISION_QUERY.MASK_DURING_INFERENCE True VISION_QUERY.TEXT_DROPOUT 1.0
To solely use language queries, add hyper-parameters:
VISION_QUERY.ENABLED FALSE
For example, to solely use vision queries,
python -m torch.distributed.launch --nproc_per_node=4 \
tools/test_grounding_net.py \
--config-file configs/pretrain/mq-glip-t.yaml \
--additional_model_config configs/vision_query_5shot/lvis_minival.yaml \
VISION_QUERY.QUERY_BANK_PATH MODEL/lvis_query_5_pool7_sel_tiny.pth \
MODEL.WEIGHT ${model_weight_path} \
TEST.IMS_PER_BATCH 4 \
VISION_QUERY.MASK_DURING_INFERENCE True VISION_QUERY.TEXT_DROPOUT 1.0
python tools/eval_odinw.py --config_file configs/pretrain/mq-glip-t.yaml \
--opts 'MODEL.WEIGHT ${model_weight_path} VISION_QUERY.MASK_DURING_INFERENCE True VISION_QUERY.TEXT_DROPOUT 1.0' \
--setting finetuning-free \
--add_name tiny \
--log_path 'OUTPUT/odinw_log/'