mirror of https://github.com/FoundationVision/GLEE
add video tasks eval scripts
parent
8fc3b2518b
commit
503ae924a5
|
@ -244,9 +244,7 @@ ${GLEE_ROOT}
|
|||
|
||||
### ODinW
|
||||
|
||||
We follow [GLIP](https://github.com/microsoft/GLIP) to prepare the ODinW 35 dataset, and run ```python3 download.py ``` to download it:
|
||||
|
||||
rganized as below.
|
||||
We follow [GLIP](https://github.com/microsoft/GLIP) to prepare the ODinW 35 dataset, and run ```python3 download.py ``` to download it and organized as below.
|
||||
|
||||
```
|
||||
${GLEE_ROOT}
|
||||
|
@ -260,14 +258,42 @@ ${GLEE_ROOT}
|
|||
|
||||
```
|
||||
|
||||
### TAO&BURST
|
||||
|
||||
TAO and BURST share the same video frames.
|
||||
|
||||
First, download the validation set zip files (2-TAO_VAL.zip, 2_AVA_HACS_VAL_e49d8f78098a8ffb3769617570a20903.zip) and unzip them from https://motchallenge.net/tao_download.php.
|
||||
|
||||
Then, download our preprocessed YTVIS format (COCO-like) annotation files from huggingface:
|
||||
|
||||
https://huggingface.co/spaces/Junfeng5/GLEE_demo/tree/main/annotations/TAO
|
||||
|
||||
And organize them as below:
|
||||
|
||||
```
|
||||
${GLEE_ROOT}
|
||||
-- datasets
|
||||
-- TAO
|
||||
--burst_annotations
|
||||
-- TAO_val_withlabel_ytvisformat.json
|
||||
-- val
|
||||
-- all_classes.json
|
||||
-- ...
|
||||
--TAO_annotations
|
||||
-- validation_ytvisfmt.json
|
||||
-- validation.json
|
||||
-- frames
|
||||
-- val
|
||||
-- ArgoVerse
|
||||
-- ava
|
||||
-- ...
|
||||
|
||||
```
|
||||
|
||||
###
|
||||
|
||||
## Updating...
|
||||
|
||||
### TAO
|
||||
|
||||
### BURST
|
||||
|
||||
### LV-VIS
|
||||
|
||||
### MOSE
|
||||
|
|
118
assets/TEST.md
118
assets/TEST.md
|
@ -51,7 +51,125 @@ python3 projects/GLEE/train_net.py --config-file projects/GLEE/configs/images/Li
|
|||
|
||||
# Video Tasks (Continuously Updated)
|
||||
|
||||
### Youtube-VIS, OVIS
|
||||
|
||||
1. Run the inference scripts:
|
||||
|
||||
```
|
||||
# YTVIS19 GLEE-Lite
|
||||
python3 projects/GLEE/train_net.py --config-file projects/GLEE/configs/video/Lite/ytvis19_base.yaml --eval-only --num-gpus 8 MODEL.WEIGHTS /path/to/GLEE_model_zoo/GLEE_Plus_joint.pth
|
||||
# YTVIS19 GLEE-Plus
|
||||
python3 projects/GLEE/train_net.py --config-file projects/GLEE/configs/video/Plus/ytvis19_Plus.yaml --eval-only --num-gpus 8 MODEL.WEIGHTS /path/to/GLEE_model_zoo/GLEE_Plus_joint.pth
|
||||
|
||||
# ovis GLEE-Lite
|
||||
python3 projects/GLEE/train_net.py --config-file projects/GLEE/configs/video/Lite/ovis_base.yaml --eval-only --num-gpus 8 MODEL.WEIGHTS /path/to/GLEE_model_zoo/GLEE_Lite_joint.pth
|
||||
# ovis GLEE-Plus
|
||||
python3 projects/GLEE/train_net.py --config-file projects/GLEE/configs/video/Plus/ovis_Plus.yaml --eval-only --num-gpus 8 MODEL.WEIGHTS /path/to/GLEE_model_zoo/GLEE_Plus_joint.pth
|
||||
```
|
||||
|
||||
2. Submit the results.zip to online servers.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
### TAO, BURST
|
||||
|
||||
#### 1. Data preparation
|
||||
|
||||
TAO and BURST share the same video frames.
|
||||
|
||||
First, download the validation set zip files (2-TAO_VAL.zip, 2_AVA_HACS_VAL_e49d8f78098a8ffb3769617570a20903.zip) and unzip them from https://motchallenge.net/tao_download.php.
|
||||
|
||||
Then, download our preprocessed YTVIS format (COCO-like) annotation files from huggingface:
|
||||
|
||||
https://huggingface.co/spaces/Junfeng5/GLEE_demo/tree/main/annotations/TAO
|
||||
|
||||
And organize them as below:
|
||||
|
||||
```
|
||||
${GLEE_ROOT}
|
||||
-- datasets
|
||||
-- TAO
|
||||
--burst_annotations
|
||||
-- TAO_val_withlabel_ytvisformat.json
|
||||
-- val
|
||||
-- all_classes.json
|
||||
-- ...
|
||||
--TAO_annotations
|
||||
-- validation_ytvisfmt.json
|
||||
-- validation.json
|
||||
-- frames
|
||||
-- val
|
||||
-- ArgoVerse
|
||||
-- ava
|
||||
-- ...
|
||||
|
||||
```
|
||||
|
||||
#### 2. TAO
|
||||
|
||||
1. Run the inference scripts:
|
||||
|
||||
```
|
||||
python3 projects/GLEE/train_net.py --config-file projects/GLEE/configs/video/Lite/TAO_Lite.yaml --eval-only --num-gpus 8 MODEL.WEIGHTS /path/to/GLEE_model_zoo/GLEE_Lite_joint.pth
|
||||
|
||||
python3 projects/GLEE/train_net.py --config-file projects/GLEE/configs/video/Plus/TAO_Plus.yaml --eval-only --num-gpus 8 MODEL.WEIGHTS /path/to/GLEE_model_zoo/GLEE_Plus_joint.pth
|
||||
```
|
||||
|
||||
|
||||
|
||||
2. For TAO, we use teta as our evaluate metric (for more details, please refer to https://github.com/SysCV/tet/blob/main/teta/README.md)
|
||||
|
||||
3. Install teta and run evaluation:
|
||||
|
||||
```
|
||||
git clone https://github.com/SysCV/tet.git
|
||||
cd tet/teta/
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
|
||||
# eval
|
||||
python3 scripts/run_tao.py --METRICS TETA --TRACKERS_TO_EVAL TETer --GT_FOLDER /path/to/${GLEE_ROOT}/datasets/TAO/TAO_annotations/validation.json --TRACKER_SUB_FOLDER /path/to/${GLEE_ROOT}/GLEE_TAO_Lite_640p/inference/results.json
|
||||
|
||||
```
|
||||
|
||||
#### 3. BURST
|
||||
|
||||
1. Run the inference scripts:
|
||||
|
||||
```
|
||||
python3 projects/GLEE/train_net.py --config-file projects/GLEE/configs/video/Lite/BURST_Lite.yaml --eval-only --num-gpus 8 MODEL.WEIGHTS /path/to/GLEE_model_zoo/GLEE_Lite_joint.pth
|
||||
|
||||
python3 projects/GLEE/train_net.py --config-file projects/GLEE/configs/video/Plus/BURST_Plus.yaml --eval-only --num-gpus 8 MODEL.WEIGHTS /path/to/GLEE_model_zoo/GLEE_Plus_joint.pth
|
||||
```
|
||||
|
||||
2. Download eval tools from https://github.com/Ali2500/BURST-benchmark and https://github.com/JonathonLuiten/TrackEval:
|
||||
|
||||
```
|
||||
mkdir burst_tools
|
||||
cd burst_tools
|
||||
git clone https://github.com/Ali2500/BURST-benchmark.git
|
||||
git clone https://github.com/JonathonLuiten/TrackEval.git
|
||||
wget https://huggingface.co/spaces/Junfeng5/GLEE_demo/resolve/main/annotations/convert_ytvis2tao.py
|
||||
```
|
||||
|
||||
|
||||
|
||||
3. Run eval codes:
|
||||
|
||||
```
|
||||
# first convert ytvis format results to TAO/BURST results
|
||||
python3 convert_ytvis2tao.py --results path/to/GLEE_BURST_Lite_720p/inference/results.json --refer /path/to/${GLEE_ROOT}/datasets/TAO/burst_annotations/val/all_classes.json
|
||||
|
||||
cd BURST-benchmark
|
||||
export TRACKEVAL_DIR=/path/to/burst_tools/TrackEval/
|
||||
python3 burstapi/eval/run.py --pred ../converted_tao_results.json --gt ../../burst_annotations/val/ --task class_guided
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
_BASE_: "../../images/Lite/base_clip_frozen_image_R50.yaml"
|
||||
DATASETS:
|
||||
TRAIN: ("BURST_video_train",)
|
||||
TEST: ("BURST_video_val",)
|
||||
SOLVER:
|
||||
IMS_PER_BATCH: 16
|
||||
BASE_LR: 0.0001
|
||||
STEPS: (6000, )
|
||||
MAX_ITER: 8000
|
||||
CHECKPOINT_PERIOD: 2000
|
||||
INPUT:
|
||||
SAMPLING_FRAME_NUM: 4
|
||||
SAMPLING_FRAME_RANGE: 12
|
||||
MIN_SIZE_TRAIN_SAMPLING: "choice_by_clip"
|
||||
RANDOM_FLIP: "flip_by_clip"
|
||||
MIN_SIZE_TRAIN: (320, 352, 392, 416, 448, 480, 512, 544, 576, 608, 640)
|
||||
MAX_SIZE_TRAIN: 768
|
||||
MIN_SIZE_TEST: 720
|
||||
CROP:
|
||||
ENABLED: True
|
||||
TYPE: "absolute_range"
|
||||
SIZE: (384, 600)
|
||||
FORMAT: "RGB"
|
||||
TEST:
|
||||
EVAL_PERIOD: 100000
|
||||
DATALOADER:
|
||||
FILTER_EMPTY_ANNOTATIONS: False
|
||||
NUM_WORKERS: 8
|
||||
OUTPUT_DIR: ./GLEE_BURST_Lite_720p
|
|
@ -0,0 +1,34 @@
|
|||
_BASE_: "../../images/Lite/base_clip_frozen_image_R50.yaml"
|
||||
MODEL:
|
||||
PSEUDO_VIDEO: False
|
||||
FREEZE_WHOLE: False
|
||||
TEXT:
|
||||
ARCH: clip_frozen
|
||||
DATASETS:
|
||||
TRAIN: ("lvvis_train", )
|
||||
TEST: ("lvvis_val", )
|
||||
SOLVER:
|
||||
IMS_PER_BATCH: 8
|
||||
BASE_LR: 0.0001
|
||||
STEPS: (6000, )
|
||||
MAX_ITER: 8000
|
||||
CHECKPOINT_PERIOD: 2000
|
||||
INPUT:
|
||||
SAMPLING_FRAME_NUM: 2
|
||||
SAMPLING_FRAME_RANGE: 5
|
||||
MIN_SIZE_TRAIN_SAMPLING: "choice_by_clip"
|
||||
RANDOM_FLIP: "flip_by_clip"
|
||||
MIN_SIZE_TRAIN: (320, 352, 392, 416, 448, 480, 512, 544, 576, 608, 640)
|
||||
MAX_SIZE_TRAIN: 1333
|
||||
MIN_SIZE_TEST: 480
|
||||
CROP:
|
||||
ENABLED: True
|
||||
TYPE: "absolute_range"
|
||||
SIZE: (384, 600)
|
||||
FORMAT: "RGB"
|
||||
TEST:
|
||||
EVAL_PERIOD: 100000
|
||||
DATALOADER:
|
||||
FILTER_EMPTY_ANNOTATIONS: False
|
||||
NUM_WORKERS: 8
|
||||
OUTPUT_DIR: ./GLEE_Lite_LVVIS
|
|
@ -0,0 +1,30 @@
|
|||
_BASE_: "../../images/Lite/base_clip_frozen_image_R50.yaml"
|
||||
DATASETS:
|
||||
TRAIN: ("BURST_video_train",)
|
||||
TEST: ("TAO_video_val",)
|
||||
SOLVER:
|
||||
IMS_PER_BATCH: 16
|
||||
BASE_LR: 0.0001
|
||||
STEPS: (6000, )
|
||||
MAX_ITER: 8000
|
||||
CHECKPOINT_PERIOD: 2000
|
||||
INPUT:
|
||||
SAMPLING_FRAME_NUM: 4
|
||||
SAMPLING_FRAME_RANGE: 12
|
||||
MIN_SIZE_TRAIN_SAMPLING: "choice_by_clip"
|
||||
RANDOM_FLIP: "flip_by_clip"
|
||||
MIN_SIZE_TRAIN: (320, 352, 392, 416, 448, 480, 512, 544, 576, 608, 640)
|
||||
MAX_SIZE_TRAIN: 768
|
||||
MIN_SIZE_TEST: 640
|
||||
CROP:
|
||||
ENABLED: True
|
||||
TYPE: "absolute_range"
|
||||
SIZE: (384, 600)
|
||||
FORMAT: "RGB"
|
||||
TEST:
|
||||
EVAL_PERIOD: 100000
|
||||
DATALOADER:
|
||||
FILTER_EMPTY_ANNOTATIONS: False
|
||||
NUM_WORKERS: 8
|
||||
OUTPUT_DIR: ./GLEE_TAO_Lite_640p
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
_BASE_: "../../images/Lite/base_clip_frozen_image_R50.yaml"
|
||||
MODEL:
|
||||
PSEUDO_VIDEO: False
|
||||
FREEZE_WHOLE: False
|
||||
TEXT:
|
||||
ARCH: clip_frozen
|
||||
DATASETS:
|
||||
TRAIN: ("ovis_train",)
|
||||
TEST: ("ovis_val",)
|
||||
SOLVER:
|
||||
IMS_PER_BATCH: 8
|
||||
BASE_LR: 0.0001
|
||||
STEPS: (12000, )
|
||||
MAX_ITER: 18000
|
||||
CHECKPOINT_PERIOD: 2000
|
||||
INPUT:
|
||||
SAMPLING_FRAME_NUM: 2
|
||||
SAMPLING_FRAME_RANGE: 10
|
||||
MIN_SIZE_TRAIN_SAMPLING: "choice_by_clip"
|
||||
RANDOM_FLIP: "flip_by_clip"
|
||||
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
|
||||
# MAX_SIZE_TRAIN: 1333
|
||||
MIN_SIZE_TEST: 720
|
||||
CROP:
|
||||
ENABLED: True
|
||||
TYPE: "absolute_range"
|
||||
SIZE: (384, 600)
|
||||
FORMAT: "RGB"
|
||||
TEST:
|
||||
EVAL_PERIOD: 100000
|
||||
DATALOADER:
|
||||
FILTER_EMPTY_ANNOTATIONS: False
|
||||
NUM_WORKERS: 8
|
||||
OUTPUT_DIR: ./GLEE_Lite_ovis
|
|
@ -0,0 +1,35 @@
|
|||
_BASE_: "../../images/Lite/base_clip_frozen_image_R50.yaml"
|
||||
MODEL:
|
||||
CROSS_TRACK: False
|
||||
PSEUDO_VIDEO: False
|
||||
FREEZE_WHOLE: False
|
||||
TEXT:
|
||||
ARCH: clip_frozen
|
||||
DATASETS:
|
||||
TRAIN: ("ytvis_2019_train", )
|
||||
TEST: ("ytvis_2019_val",)
|
||||
SOLVER:
|
||||
IMS_PER_BATCH: 8
|
||||
BASE_LR: 0.0001
|
||||
STEPS: (6000, )
|
||||
MAX_ITER: 8000
|
||||
CHECKPOINT_PERIOD: 2000
|
||||
INPUT:
|
||||
SAMPLING_FRAME_NUM: 2
|
||||
SAMPLING_FRAME_RANGE: 5
|
||||
MIN_SIZE_TRAIN_SAMPLING: "choice_by_clip"
|
||||
RANDOM_FLIP: "flip_by_clip"
|
||||
MIN_SIZE_TRAIN: (320, 352, 392, 416, 448, 480, 512, 544, 576, 608, 640)
|
||||
MAX_SIZE_TRAIN: 1333
|
||||
MIN_SIZE_TEST: 480
|
||||
CROP:
|
||||
ENABLED: True
|
||||
TYPE: "absolute_range"
|
||||
SIZE: (384, 600)
|
||||
FORMAT: "RGB"
|
||||
TEST:
|
||||
EVAL_PERIOD: 100000
|
||||
DATALOADER:
|
||||
FILTER_EMPTY_ANNOTATIONS: False
|
||||
NUM_WORKERS: 8
|
||||
OUTPUT_DIR: ./GLEE_Lite_ytvis19
|
|
@ -0,0 +1,41 @@
|
|||
_BASE_: "../../images/Lite/base_clip_frozen_image_R50.yaml"
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "D2SwinTransformer"
|
||||
SWIN:
|
||||
EMBED_DIM: 192
|
||||
DEPTHS: [2, 2, 18, 2]
|
||||
NUM_HEADS: [6, 12, 24, 48]
|
||||
WINDOW_SIZE: 12
|
||||
APE: False
|
||||
DROP_PATH_RATE: 0.3
|
||||
PATCH_NORM: True
|
||||
PRETRAIN_IMG_SIZE: 384
|
||||
DATASETS:
|
||||
TRAIN: ("BURST_video_train",)
|
||||
TEST: ("BURST_video_val",)
|
||||
SOLVER:
|
||||
IMS_PER_BATCH: 16
|
||||
BASE_LR: 0.0001
|
||||
STEPS: (6000, )
|
||||
MAX_ITER: 8000
|
||||
CHECKPOINT_PERIOD: 2000
|
||||
INPUT:
|
||||
SAMPLING_FRAME_NUM: 4
|
||||
SAMPLING_FRAME_RANGE: 12
|
||||
MIN_SIZE_TRAIN_SAMPLING: "choice_by_clip"
|
||||
RANDOM_FLIP: "flip_by_clip"
|
||||
MIN_SIZE_TRAIN: (320, 352, 392, 416, 448, 480, 512, 544, 576, 608, 640)
|
||||
MAX_SIZE_TRAIN: 768
|
||||
MIN_SIZE_TEST: 720
|
||||
CROP:
|
||||
ENABLED: True
|
||||
TYPE: "absolute_range"
|
||||
SIZE: (384, 600)
|
||||
FORMAT: "RGB"
|
||||
TEST:
|
||||
EVAL_PERIOD: 100000
|
||||
DATALOADER:
|
||||
FILTER_EMPTY_ANNOTATIONS: False
|
||||
NUM_WORKERS: 8
|
||||
OUTPUT_DIR: ./GLEE_BURST_Plus_720p
|
|
@ -0,0 +1,42 @@
|
|||
_BASE_: "../../images/Lite/base_clip_frozen_image_R50.yaml"
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "D2SwinTransformer"
|
||||
SWIN:
|
||||
EMBED_DIM: 192
|
||||
DEPTHS: [2, 2, 18, 2]
|
||||
NUM_HEADS: [6, 12, 24, 48]
|
||||
WINDOW_SIZE: 12
|
||||
APE: False
|
||||
DROP_PATH_RATE: 0.3
|
||||
PATCH_NORM: True
|
||||
PRETRAIN_IMG_SIZE: 384
|
||||
DATASETS:
|
||||
TRAIN: ("BURST_video_train",)
|
||||
TEST: ("TAO_video_val",)
|
||||
SOLVER:
|
||||
IMS_PER_BATCH: 16
|
||||
BASE_LR: 0.0001
|
||||
STEPS: (6000, )
|
||||
MAX_ITER: 8000
|
||||
CHECKPOINT_PERIOD: 2000
|
||||
INPUT:
|
||||
SAMPLING_FRAME_NUM: 4
|
||||
SAMPLING_FRAME_RANGE: 12
|
||||
MIN_SIZE_TRAIN_SAMPLING: "choice_by_clip"
|
||||
RANDOM_FLIP: "flip_by_clip"
|
||||
MIN_SIZE_TRAIN: (320, 352, 392, 416, 448, 480, 512, 544, 576, 608, 640)
|
||||
MAX_SIZE_TRAIN: 768
|
||||
MIN_SIZE_TEST: 640
|
||||
CROP:
|
||||
ENABLED: True
|
||||
TYPE: "absolute_range"
|
||||
SIZE: (384, 600)
|
||||
FORMAT: "RGB"
|
||||
TEST:
|
||||
EVAL_PERIOD: 100000
|
||||
DATALOADER:
|
||||
FILTER_EMPTY_ANNOTATIONS: False
|
||||
NUM_WORKERS: 8
|
||||
OUTPUT_DIR: ./GLEE_TAO_Plus_640p
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
_BASE_: "../../images/Lite/base_clip_frozen_image_R50.yaml"
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "D2SwinTransformer"
|
||||
SWIN:
|
||||
EMBED_DIM: 192
|
||||
DEPTHS: [2, 2, 18, 2]
|
||||
NUM_HEADS: [6, 12, 24, 48]
|
||||
WINDOW_SIZE: 12
|
||||
APE: False
|
||||
DROP_PATH_RATE: 0.3
|
||||
PATCH_NORM: True
|
||||
PRETRAIN_IMG_SIZE: 384
|
||||
PSEUDO_VIDEO: False
|
||||
FREEZE_WHOLE: False
|
||||
TEXT:
|
||||
ARCH: clip_frozen
|
||||
DATASETS:
|
||||
TRAIN: ("ovis_train",)
|
||||
TEST: ("ovis_val",)
|
||||
SOLVER:
|
||||
IMS_PER_BATCH: 8
|
||||
BASE_LR: 0.0001
|
||||
STEPS: (12000, )
|
||||
MAX_ITER: 18000
|
||||
CHECKPOINT_PERIOD: 2000
|
||||
INPUT:
|
||||
SAMPLING_FRAME_NUM: 2
|
||||
SAMPLING_FRAME_RANGE: 10
|
||||
MIN_SIZE_TRAIN_SAMPLING: "choice_by_clip"
|
||||
RANDOM_FLIP: "flip_by_clip"
|
||||
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
|
||||
# MAX_SIZE_TRAIN: 1333
|
||||
MIN_SIZE_TEST: 720
|
||||
CROP:
|
||||
ENABLED: True
|
||||
TYPE: "absolute_range"
|
||||
SIZE: (384, 600)
|
||||
FORMAT: "RGB"
|
||||
TEST:
|
||||
EVAL_PERIOD: 100000
|
||||
DATALOADER:
|
||||
FILTER_EMPTY_ANNOTATIONS: False
|
||||
NUM_WORKERS: 8
|
||||
OUTPUT_DIR: ./GLEE_Plus_ovis
|
|
@ -0,0 +1,46 @@
|
|||
_BASE_: "../../images/Lite/base_clip_frozen_image_R50.yaml"
|
||||
MODEL:
|
||||
BACKBONE:
|
||||
NAME: "D2SwinTransformer"
|
||||
SWIN:
|
||||
EMBED_DIM: 192
|
||||
DEPTHS: [2, 2, 18, 2]
|
||||
NUM_HEADS: [6, 12, 24, 48]
|
||||
WINDOW_SIZE: 12
|
||||
APE: False
|
||||
DROP_PATH_RATE: 0.3
|
||||
PATCH_NORM: True
|
||||
PRETRAIN_IMG_SIZE: 384
|
||||
CROSS_TRACK: False
|
||||
PSEUDO_VIDEO: False
|
||||
FREEZE_WHOLE: False
|
||||
TEXT:
|
||||
ARCH: clip_frozen
|
||||
DATASETS:
|
||||
TRAIN: ("ytvis_2019_train", )
|
||||
TEST: ("ytvis_2019_val",)
|
||||
SOLVER:
|
||||
IMS_PER_BATCH: 8
|
||||
BASE_LR: 0.0001
|
||||
STEPS: (6000, )
|
||||
MAX_ITER: 8000
|
||||
CHECKPOINT_PERIOD: 2000
|
||||
INPUT:
|
||||
SAMPLING_FRAME_NUM: 2
|
||||
SAMPLING_FRAME_RANGE: 5
|
||||
MIN_SIZE_TRAIN_SAMPLING: "choice_by_clip"
|
||||
RANDOM_FLIP: "flip_by_clip"
|
||||
MIN_SIZE_TRAIN: (320, 352, 392, 416, 448, 480, 512, 544, 576, 608, 640)
|
||||
MAX_SIZE_TRAIN: 1333
|
||||
MIN_SIZE_TEST: 480
|
||||
CROP:
|
||||
ENABLED: True
|
||||
TYPE: "absolute_range"
|
||||
SIZE: (384, 600)
|
||||
FORMAT: "RGB"
|
||||
TEST:
|
||||
EVAL_PERIOD: 100000
|
||||
DATALOADER:
|
||||
FILTER_EMPTY_ANNOTATIONS: False
|
||||
NUM_WORKERS: 8
|
||||
OUTPUT_DIR: ./GLEE_Plus_ytvis19
|
|
@ -736,15 +736,211 @@ class GLEE(nn.Module):
|
|||
scale_fct = scale_fct.to(out_bbox)
|
||||
boxes = boxes * scale_fct
|
||||
return boxes
|
||||
def match_from_embds(self, tgt_embds, cur_embds):
|
||||
cur_embds = cur_embds / cur_embds.norm(dim=1)[:, None]
|
||||
tgt_embds = tgt_embds / tgt_embds.norm(dim=1)[:, None]
|
||||
cos_sim = torch.mm(cur_embds, tgt_embds.transpose(0,1))
|
||||
|
||||
cost_embd = 1 - cos_sim
|
||||
|
||||
C = 1.0 * cost_embd
|
||||
C = C.cpu()
|
||||
|
||||
indices = linear_sum_assignment(C.transpose(0, 1)) # target x current
|
||||
indices = indices[1] # permutation that makes current aligns to target
|
||||
|
||||
return indices
|
||||
def MinVIS_inference(self, batched_inputs, task):
|
||||
video_len = len(batched_inputs[0]['file_names'])
|
||||
|
||||
|
||||
clip_length = 5 # self.batch_infer_len
|
||||
batch_name_list = self.dataset_name_dicts[task]
|
||||
|
||||
#split long video into clips to form a batch input
|
||||
# if video_len > clip_length:
|
||||
num_clips = math.ceil(video_len/clip_length)
|
||||
logits_list, boxes_list, embed_list, points_list, masks_list = [], [], [], [], []
|
||||
for c in range(num_clips):
|
||||
start_idx = c*clip_length
|
||||
end_idx = (c+1)*clip_length
|
||||
clip_inputs = [{'image':batched_inputs[0]['image'][start_idx:end_idx]}]
|
||||
clip_images = self.preprocess_video(clip_inputs)
|
||||
(clip_output,_),dist,loss = self.glee(clip_images, {}, task, batch_name_list = batch_name_list, is_train= False)
|
||||
logits_list.append(clip_output['pred_logits'])
|
||||
boxes_list.append(clip_output['pred_boxes'])
|
||||
embed_list.append(clip_output['pred_track_embed'])
|
||||
masks_list.append(clip_output['pred_masks'].cpu()) #.to(self.merge_device)
|
||||
outputs = {
|
||||
'pred_logits':torch.cat(logits_list,dim=0).detach(),
|
||||
'pred_track_embed':torch.cat(embed_list,dim=0).detach(),
|
||||
'pred_masks':torch.cat(masks_list,dim=0).detach(),
|
||||
'pred_boxes': torch.cat(boxes_list,dim=0).detach(),
|
||||
}
|
||||
|
||||
|
||||
# batch_name_list = self.dataset_name_dicts[task]
|
||||
pred_logits = list(torch.unbind(outputs['pred_logits']))
|
||||
pred_masks = list(torch.unbind(outputs['pred_masks'].cpu()))
|
||||
pred_embds = list(torch.unbind(outputs['pred_track_embed']))
|
||||
pred_boxes = list(torch.unbind(outputs['pred_boxes']))
|
||||
del outputs
|
||||
out_logits = []
|
||||
out_masks = []
|
||||
out_embds = []
|
||||
out_boxes = []
|
||||
out_logits.append(pred_logits[0])
|
||||
out_masks.append(pred_masks[0].cpu())
|
||||
out_embds.append(pred_embds[0])
|
||||
out_boxes.append(pred_boxes[0])
|
||||
|
||||
for i in range(1, len(pred_logits)):
|
||||
MA_embedding = torch.stack(out_embds[-3:]).mean(0)
|
||||
indices = self.match_from_embds(MA_embedding, pred_embds[i])
|
||||
out_logits.append(pred_logits[i][indices, :])
|
||||
out_masks.append(pred_masks[i][indices, :, :])
|
||||
out_embds.append(pred_embds[i][indices, :])
|
||||
out_boxes.append(pred_boxes[i][indices, :])
|
||||
|
||||
mask_cls_result = sum(out_logits)/len(out_logits)
|
||||
|
||||
out_logits = torch.stack(out_logits, dim=1) # q numc -> q t numc
|
||||
|
||||
mask_pred_result = torch.stack(out_masks, dim=1) # q h w -> q t h w
|
||||
mask_box_result = torch.stack(out_boxes, dim=1) # q 4 -> q t 4
|
||||
first_resize_size = (clip_images.tensor.shape[-2], clip_images.tensor.shape[-1])
|
||||
|
||||
input_per_image = batched_inputs[0]
|
||||
image_size = clip_images.image_sizes[0] # image size without padding after data augmentation
|
||||
|
||||
height = input_per_image.get("height", image_size[0]) # raw image size before data augmentation
|
||||
width = input_per_image.get("width", image_size[1])
|
||||
mask_box_result = self.box_postprocess(mask_box_result, height, width)
|
||||
|
||||
return self.minvis_inference_video(mask_cls_result, mask_pred_result, mask_box_result, image_size, height, width, first_resize_size, task, out_logits, batched_inputs)
|
||||
|
||||
|
||||
def minvis_inference_video(self, mask_cls, mask_pred, mask_box_result, img_size, output_height, output_width, first_resize_size, task, ori_logits, batched_inputs):
|
||||
if task != 'tao_video':
|
||||
if len(mask_cls) > 0:
|
||||
# keep top-k predictions
|
||||
scores = mask_cls.sigmoid() # [300, 40]
|
||||
num_class = self.num_class[task]
|
||||
labels = torch.arange(num_class, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
|
||||
scores_per_image, topk_indices = scores.flatten(0, 1).topk(30, sorted=False) # select 20
|
||||
labels_per_image = labels[topk_indices]
|
||||
topk_indices = topk_indices // num_class
|
||||
mask_pred = mask_pred[topk_indices.cpu()].cpu()
|
||||
mask_box_result = mask_box_result[topk_indices]
|
||||
pred_masks = F.interpolate(
|
||||
mask_pred, size=first_resize_size, mode="bilinear", align_corners=False
|
||||
)
|
||||
if self.is_lsj:
|
||||
resize_ratio = img_size[0]/max(output_height, output_width)
|
||||
crop_size = (int(output_height*resize_ratio), int(output_width*resize_ratio))
|
||||
else:
|
||||
crop_size = img_size
|
||||
# resize_ratio = image_size[0]/max(height, width)
|
||||
# crop_size = (int(height*resize_ratio), int(width*resize_ratio))
|
||||
pred_masks = pred_masks[:, :, : crop_size[0], : crop_size[1]]
|
||||
pred_masks = F.interpolate(
|
||||
pred_masks, size=(output_height, output_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
masks = pred_masks > 0.
|
||||
|
||||
out_scores = scores_per_image.tolist()
|
||||
out_labels = labels_per_image.tolist()
|
||||
out_masks = [m for m in masks.cpu()]
|
||||
|
||||
mask_box_result[:,:,2]=mask_box_result[:,:,2]-mask_box_result[:,:,0]
|
||||
mask_box_result[:,:,3]=mask_box_result[:,:,3]-mask_box_result[:,:,1]
|
||||
|
||||
# xyxy2 xywh
|
||||
mask_box_result = mask_box_result.cpu().long()
|
||||
out_boxes = [m for m in mask_box_result]
|
||||
else:
|
||||
out_scores = []
|
||||
out_labels = []
|
||||
out_masks = []
|
||||
out_boxes = []
|
||||
|
||||
video_output = {
|
||||
"image_size": (output_height, output_width),
|
||||
"pred_scores": out_scores,
|
||||
"pred_labels": out_labels,
|
||||
"pred_masks": out_masks,
|
||||
"pred_boxes":out_boxes,
|
||||
}
|
||||
else: # for TAO video teta metric
|
||||
scores = mask_cls.sigmoid() # [300, numcls]
|
||||
|
||||
topk_num = 50
|
||||
|
||||
num_class = self.num_class[task]
|
||||
|
||||
|
||||
scores_per_video, topk_indices = scores.max(-1)[0].topk(topk_num, sorted=False)
|
||||
labels_per_video = scores[topk_indices].max(-1)[1] # [select_num]
|
||||
|
||||
mask_pred = mask_pred[topk_indices.cpu()] #[select, len, H, W]
|
||||
mask_pred = mask_pred>0
|
||||
|
||||
mask_box_result = mask_box_result[topk_indices] #[slelct_num, len, 4]
|
||||
# xyxy2 xywh
|
||||
mask_box_result[:,:,2]=mask_box_result[:,:,2]-mask_box_result[:,:,0]
|
||||
mask_box_result[:,:,3]=mask_box_result[:,:,3]-mask_box_result[:,:,1]
|
||||
|
||||
ori_logits = ori_logits[topk_indices].sigmoid() #[slelct_num, len, num_class]
|
||||
|
||||
image_ids = batched_inputs[0]['image_ids']
|
||||
video_id = batched_inputs[0]['video_id']
|
||||
video_len = len(image_ids)
|
||||
track_ids = torch.arange(topk_num).to(scores_per_video) + topk_num*video_id
|
||||
|
||||
|
||||
video_results = []
|
||||
for i,image_id in enumerate(image_ids):
|
||||
|
||||
# frame_logits = ori_logits[:,i] # [topk_num,nun_cls]
|
||||
# scores_per_frame, labels_per_frames = frame_logits.max(-1)
|
||||
|
||||
frame_boxes = mask_box_result[:,i]
|
||||
frame_masks = mask_pred[:,i]
|
||||
mask_valid = frame_masks.flatten(1,2).sum(-1)>5
|
||||
|
||||
frame_boxes = frame_boxes[mask_valid]
|
||||
frame_scores = scores_per_video[mask_valid]
|
||||
frame_labels = labels_per_video[mask_valid]
|
||||
frame_trackids = track_ids[mask_valid]
|
||||
|
||||
# box nms
|
||||
boxes_before_nms = box_ops.box_cxcywh_to_xyxy(frame_boxes)
|
||||
keep_indices = ops.nms(boxes_before_nms,frame_scores,0.5)#.tolist()
|
||||
|
||||
frame_boxes = frame_boxes[keep_indices]
|
||||
frame_scores = frame_scores[keep_indices]
|
||||
frame_labels = frame_labels[keep_indices]
|
||||
frame_trackids = frame_trackids[keep_indices]
|
||||
|
||||
|
||||
for box,score,label,trackid in zip(frame_boxes,frame_scores,frame_labels,frame_trackids):
|
||||
video_results.append(
|
||||
{
|
||||
"image_id" : image_id,
|
||||
"category_id" : label.item(),
|
||||
"bbox" : box.tolist(),
|
||||
"score" : score.item(),
|
||||
"track_id": trackid.item(),
|
||||
"video_id": video_id
|
||||
}
|
||||
)
|
||||
video_output = video_results
|
||||
|
||||
return video_output
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue