diff --git a/configs/base.py b/configs/base.py index 1244a32c..4805d305 100644 --- a/configs/base.py +++ b/configs/base.py @@ -3,7 +3,7 @@ test_cfg = {} optimizer_config = dict() # grad_clip, coalesce, bucket_size_mb # yapf:disable log_config = dict( - interval=10, + interval=50, hooks=[ dict(type='TextLoggerHook'), # dict(type='TensorboardLoggerHook') diff --git a/configs/detection/dab_detr/coco_detection.py b/configs/detection/_base_/dataset/autoaug_coco_detection.py similarity index 95% rename from configs/detection/dab_detr/coco_detection.py rename to configs/detection/_base_/dataset/autoaug_coco_detection.py index 7efc45b7..e4597fa0 100644 --- a/configs/detection/dab_detr/coco_detection.py +++ b/configs/detection/_base_/dataset/autoaug_coco_detection.py @@ -119,3 +119,14 @@ val_dataset = dict( data = dict( imgs_per_gpu=2, workers_per_gpu=2, train=train_dataset, val=val_dataset) + +# evaluation +eval_config = dict(interval=1, gpu_collect=False) +eval_pipelines = [ + dict( + mode='test', + evaluators=[ + dict(type='CocoDetectionEvaluator', classes=CLASSES), + ], + ) +] diff --git a/configs/detection/dab_detr/dab_detr.py b/configs/detection/dab_detr/dab_detr.py index 5ba39d61..41c111a3 100644 --- a/configs/detection/dab_detr/dab_detr.py +++ b/configs/detection/dab_detr/dab_detr.py @@ -16,7 +16,6 @@ model = dict( transformer=dict( type='DABDetrTransformer', in_channels=2048, - num_queries=300, d_model=256, nhead=8, num_encoder_layers=6, @@ -27,8 +26,6 @@ model = dict( normalize_before=False, return_intermediate_dec=True, query_dim=4, - random_refpoints_xy=False, - num_patterns=0, keep_query_pos=False, query_scale_type='cond_elewise', modulate_hw_attn=True, @@ -40,16 +37,18 @@ model = dict( embed_dims=256, query_dim=4, iter_update=True, + num_queries=300, num_select=300, + random_refpoints_xy=False, + num_patterns=0, bbox_embed_diff_each_layer=False, - cost_dict={ - 'cost_class': 2, - 'cost_bbox': 5, - 'cost_giou': 2, - }, - weight_dict={ - 'loss_ce': 1, - 'loss_bbox': 5, - 'loss_giou': 2 - }, - )) + cost_dict=dict( + cost_class=2, + cost_bbox=5, + cost_giou=2, + ), + weight_dict=dict( + loss_ce=1, + loss_bbox=5, + loss_giou=2, + ))) diff --git a/configs/detection/dab_detr/dab_detr_r50_8x2_50e_coco.py b/configs/detection/dab_detr/dab_detr_r50_8x2_50e_coco.py index 7890fb4c..522c3a72 100644 --- a/configs/detection/dab_detr/dab_detr_r50_8x2_50e_coco.py +++ b/configs/detection/dab_detr/dab_detr_r50_8x2_50e_coco.py @@ -1,28 +1,8 @@ -_base_ = ['./dab_detr.py', './coco_detection.py', 'configs/base.py'] - -CLASSES = [ - 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', - 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', - 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', - 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', - 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', - 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', - 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', - 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', - 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', - 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', - 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', - 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', - 'hair drier', 'toothbrush' +_base_ = [ + './dab_detr.py', '../_base_/dataset/autoaug_coco_detection.py', + 'configs/base.py' ] -log_config = dict( - interval=50, - hooks=[ - dict(type='TextLoggerHook'), - # dict(type='TensorboardLoggerHook') - ]) - checkpoint_config = dict(interval=10) # optimizer paramwise_options = {'backbone': dict(lr_mult=0.1, weight_decay_mult=1.0)} @@ -37,16 +17,4 @@ lr_config = dict(policy='step', step=[40]) total_epochs = 50 -# evaluation -# eval_config = dict(initial=True, interval=1, gpu_collect=False) -eval_config = dict(interval=1, gpu_collect=False) -eval_pipelines = [ - dict( - mode='test', - evaluators=[ - dict(type='CocoDetectionEvaluator', classes=CLASSES), - ], - ) -] - find_unused_parameters = False diff --git a/configs/detection/dab_detr/dn_detr_r50_8x2_50e_coco.py b/configs/detection/dab_detr/dn_detr_r50_8x2_50e_coco.py new file mode 100644 index 00000000..29f96392 --- /dev/null +++ b/configs/detection/dab_detr/dn_detr_r50_8x2_50e_coco.py @@ -0,0 +1,7 @@ +_base_ = './dab_detr_r50_8x2_50e_coco.py' + +# model settings +model = dict( + head=dict( + dn_components=dict( + scalar=5, label_noise_scale=0.2, box_noise_scale=0.4))) diff --git a/configs/detection/dab_detr/dn_detr_r50_dc5_8x2_50e_coco.py b/configs/detection/dab_detr/dn_detr_r50_dc5_8x2_50e_coco.py new file mode 100644 index 00000000..20b56487 --- /dev/null +++ b/configs/detection/dab_detr/dn_detr_r50_dc5_8x2_50e_coco.py @@ -0,0 +1,4 @@ +_base_ = './dn_detr_r50_8x2_50e_coco.py' + +# model settings +model = dict(backbone=dict(strides=(1, 2, 2, 1), dilations=(1, 1, 1, 2))) diff --git a/configs/detection/detr/coco_detection.py b/configs/detection/detr/coco_detection.py deleted file mode 100644 index 7efc45b7..00000000 --- a/configs/detection/detr/coco_detection.py +++ /dev/null @@ -1,121 +0,0 @@ -CLASSES = [ - 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', - 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', - 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', - 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', - 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', - 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', - 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', - 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', - 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', - 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', - 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', - 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', - 'hair drier', 'toothbrush' -] - -# dataset settings -data_root = 'data/coco/' -img_norm_cfg = dict( - mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) - -train_pipeline = [ - dict(type='MMRandomFlip', flip_ratio=0.5), - dict( - type='MMAutoAugment', - policies=[[ - dict( - type='MMResize', - img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), - (608, 1333), (640, 1333), (672, 1333), (704, 1333), - (736, 1333), (768, 1333), (800, 1333)], - multiscale_mode='value', - keep_ratio=True) - ], - [ - dict( - type='MMResize', - img_scale=[(400, 1333), (500, 1333), (600, 1333)], - multiscale_mode='value', - keep_ratio=True), - dict( - type='MMRandomCrop', - crop_type='absolute_range', - crop_size=(384, 600), - allow_negative_crop=True), - dict( - type='MMResize', - img_scale=[(480, 1333), (512, 1333), (544, 1333), - (576, 1333), (608, 1333), (640, 1333), - (672, 1333), (704, 1333), (736, 1333), - (768, 1333), (800, 1333)], - multiscale_mode='value', - override=True, - keep_ratio=True) - ]]), - dict(type='MMNormalize', **img_norm_cfg), - dict(type='MMPad', size_divisor=1), - dict(type='DefaultFormatBundle'), - dict( - type='Collect', - keys=['img', 'gt_bboxes', 'gt_labels'], - meta_keys=('filename', 'ori_filename', 'ori_shape', 'ori_img_shape', - 'img_shape', 'pad_shape', 'scale_factor', 'flip', - 'flip_direction', 'img_norm_cfg')) -] -test_pipeline = [ - dict( - type='MMMultiScaleFlipAug', - img_scale=(1333, 800), - flip=False, - transforms=[ - dict(type='MMResize', keep_ratio=True), - dict(type='MMRandomFlip'), - dict(type='MMNormalize', **img_norm_cfg), - dict(type='MMPad', size_divisor=1), - dict(type='ImageToTensor', keys=['img']), - dict( - type='Collect', - keys=['img'], - meta_keys=('filename', 'ori_filename', 'ori_shape', - 'ori_img_shape', 'img_shape', 'pad_shape', - 'scale_factor', 'flip', 'flip_direction', - 'img_norm_cfg')) - ]) -] - -train_dataset = dict( - type='DetDataset', - data_source=dict( - type='DetSourceCoco', - ann_file=data_root + 'annotations/instances_train2017.json', - img_prefix=data_root + 'train2017/', - pipeline=[ - dict(type='LoadImageFromFile'), - dict(type='LoadAnnotations', with_bbox=True) - ], - classes=CLASSES, - test_mode=False, - filter_empty_gt=True, - iscrowd=False), - pipeline=train_pipeline) - -val_dataset = dict( - type='DetDataset', - imgs_per_gpu=1, - data_source=dict( - type='DetSourceCoco', - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', - pipeline=[ - dict(type='LoadImageFromFile'), - dict(type='LoadAnnotations', with_bbox=True) - ], - classes=CLASSES, - test_mode=True, - filter_empty_gt=False, - iscrowd=True), - pipeline=test_pipeline) - -data = dict( - imgs_per_gpu=2, workers_per_gpu=2, train=train_dataset, val=val_dataset) diff --git a/configs/detection/detr/detr.py b/configs/detection/detr/detr.py index ae8cfc81..c2129a8c 100644 --- a/configs/detection/detr/detr.py +++ b/configs/detection/detr/detr.py @@ -1,8 +1,7 @@ # model settings model = dict( type='Detection', - pretrained= - 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet50.pth', + pretrained=True, backbone=dict( type='ResNet', depth=50, @@ -32,14 +31,13 @@ model = dict( in_channels=2048, embed_dims=256, eos_coef=0.1, - cost_dict={ - 'cost_class': 1, - 'cost_bbox': 5, - 'cost_giou': 2, - }, - weight_dict={ - 'loss_ce': 1, - 'loss_bbox': 5, - 'loss_giou': 2 - }, - )) + cost_dict=dict( + cost_class=1, + cost_bbox=5, + cost_giou=2, + ), + weight_dict=dict( + loss_ce=1, + loss_bbox=5, + loss_giou=2, + ))) diff --git a/configs/detection/detr/detr_r50_8x2_150e_coco.py b/configs/detection/detr/detr_r50_8x2_150e_coco.py index 304cf514..0863a35c 100644 --- a/configs/detection/detr/detr_r50_8x2_150e_coco.py +++ b/configs/detection/detr/detr_r50_8x2_150e_coco.py @@ -1,28 +1,8 @@ -_base_ = ['./detr.py', './coco_detection.py', 'configs/base.py'] - -CLASSES = [ - 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', - 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', - 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', - 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', - 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', - 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', - 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', - 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', - 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', - 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', - 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', - 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', - 'hair drier', 'toothbrush' +_base_ = [ + './detr.py', '../_base_/dataset/autoaug_coco_detection.py', + 'configs/base.py' ] -log_config = dict( - interval=50, - hooks=[ - dict(type='TextLoggerHook'), - # dict(type='TensorboardLoggerHook') - ]) - checkpoint_config = dict(interval=10) # optimizer paramwise_options = {'backbone': dict(lr_mult=0.1, weight_decay_mult=1.0)} @@ -37,16 +17,4 @@ lr_config = dict(policy='step', step=[100]) total_epochs = 150 -# evaluation -# eval_config = dict(initial=True, interval=1, gpu_collect=False) -eval_config = dict(interval=1, gpu_collect=False) -eval_pipelines = [ - dict( - mode='test', - evaluators=[ - dict(type='CocoDetectionEvaluator', classes=CLASSES), - ], - ) -] - find_unused_parameters = False diff --git a/configs/detection/fcos/fcos_center-normbbox-centeronreg-giou_r50_caffe_fpn_gn-head_1x_coco.py b/configs/detection/fcos/fcos_center-normbbox-centeronreg-giou_r50_caffe_fpn_gn-head_1x_coco.py index 55939f60..feaf24be 100644 --- a/configs/detection/fcos/fcos_center-normbbox-centeronreg-giou_r50_caffe_fpn_gn-head_1x_coco.py +++ b/configs/detection/fcos/fcos_center-normbbox-centeronreg-giou_r50_caffe_fpn_gn-head_1x_coco.py @@ -43,8 +43,7 @@ lr_config = dict( total_epochs = 12 # evaluation -eval_config = dict(initial=True, interval=1, gpu_collect=False) -# eval_config = dict(interval=1, gpu_collect=False) +eval_config = dict(interval=1, gpu_collect=False) eval_pipelines = [ dict( mode='test', diff --git a/docs/source/model_zoo_det.md b/docs/source/model_zoo_det.md index f85362d8..3adf04be 100644 --- a/docs/source/model_zoo_det.md +++ b/docs/source/model_zoo_det.md @@ -24,9 +24,11 @@ Pretrained on COCO2017 dataset. | Algorithm | Config | Params
(backbone/total) | inference time(V100)
(ms/img) | mAPval
0.5:0.95 | APval
50 | Download | | ---------- | ------------------------------------------------------------ | ------------------------ | --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | | FCOS-r50 | [fcos-r50](https://github.com/alibaba/EasyCV/tree/master/configs/detection/fcos/fcos_center-normbbox-centeronreg-giou_r50_caffe_fpn_gn-head_1x_coco.py) | 23M/32M | 85.8ms | 38.58 | 57.18 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/fcos/epoch_12.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/fcos/20220621_121315.log.json) | + ## DETR | Algorithm | Config | Params
(backbone/total) | inference time(V100)
(ms/img) | bbox_mAPval
0.5:0.95 | APval
50 | Download | | ---------- | ------------------------------------------------------------ | ------------------------ | --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | | DETR-r50 | [detr-r50](https://github.com/alibaba/EasyCV/tree/master/configs/detection/detr/detr_r50_8x2_150e_coco.py) | 23M/41M | 48.5ms | 39.92 | 60.52 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/detr/epoch_150.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/detr/20220609_101243.log.json) | -| DAB-DETR-r50 | [dab-detr-r50](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dab_detr/dab_detr_r50_8x2_50e_coco.py) | 23M/43M | 58.5ms | 42.52 | 63.03 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dab_detr/epoch_50.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dab_detr/20220610_122811.log.json) | +| DAB-DETR-r50 | [dab-detr-r50](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dab_detr/dab_detr_r50_8x2_50e_coco.py) | 23M/43M | 58.5ms | 42.52 | 63.03 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dab_detr/dab_detr_epoch_50.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dab_detr/20220610_122811.log.json) | +| DN-DETR-r50 | [dab-detr-r50](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dab_detr/dn_detr_r50_8x2_50e_coco.py) | 23M/43M | 58.5ms | 44.39 | 64.66 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dn_detr/dn_detr_epoch_50.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dn_detr/20220713_105127.log.json) | diff --git a/easycv/models/classification/classification.py b/easycv/models/classification/classification.py index d532fa6c..ddcc9e31 100644 --- a/easycv/models/classification/classification.py +++ b/easycv/models/classification/classification.py @@ -128,8 +128,9 @@ class Classification(BaseModel): strict=False, logger=logger) else: - print_log('load model from init weights') - self.backbone.init_weights() + raise ValueError( + 'default_pretrained_model_path for {} not found'.format( + self.backbone.__class__.__name__)) else: print_log('load model from init weights') self.backbone.init_weights() diff --git a/easycv/models/detection/detectors/dab_detr/dab_detr_head.py b/easycv/models/detection/detectors/dab_detr/dab_detr_head.py index 6a3261a1..271a0506 100644 --- a/easycv/models/detection/detectors/dab_detr/dab_detr_head.py +++ b/easycv/models/detection/detectors/dab_detr/dab_detr_head.py @@ -5,17 +5,13 @@ import math import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from easycv.models.builder import HEADS, build_neck -from easycv.models.detection.utils import (HungarianMatcher, accuracy, +from easycv.models.detection.utils import (HungarianMatcher, SetCriterion, box_cxcywh_to_xyxy, - box_xyxy_to_cxcywh, - generalized_box_iou, - inverse_sigmoid) -from easycv.models.loss.focal_loss import py_sigmoid_focal_loss -from easycv.models.utils import (MLP, get_world_size, - is_dist_avail_and_initialized) + box_xyxy_to_cxcywh, inverse_sigmoid) +from easycv.models.utils import MLP +from .dn_components import dn_post_process, prepare_for_dn @HEADS.register_module() @@ -27,15 +23,17 @@ class DABDETRHead(nn.Module): num_classes (int): Number of categories excluding the background. """ - _version = 2 - def __init__(self, num_classes, embed_dims, query_dim=4, iter_update=True, + num_queries=300, num_select=300, + random_refpoints_xy=False, + num_patterns=0, bbox_embed_diff_each_layer=False, + dn_components=None, transformer=None, cost_dict={ 'cost_class': 1, @@ -57,7 +55,9 @@ class DABDETRHead(nn.Module): num_classes, matcher=self.matcher, weight_dict=weight_dict, - losses=['labels', 'boxes', 'cardinality']) + losses=['labels', 'boxes'], + loss_class_type='focal_loss', + dn_components=dn_components) self.postprocess = PostProcess(num_select=num_select) self.transformer = build_neck(transformer) @@ -71,8 +71,31 @@ class DABDETRHead(nn.Module): self.transformer.decoder.bbox_embed = self.bbox_embed self.num_classes = num_classes + self.num_queries = num_queries + self.embed_dims = embed_dims self.query_dim = query_dim self.bbox_embed_diff_each_layer = bbox_embed_diff_each_layer + self.dn_components = dn_components + + self.query_embed = nn.Embedding(num_queries, query_dim) + self.random_refpoints_xy = random_refpoints_xy + if random_refpoints_xy: + self.query_embed.weight.data[:, :2].uniform_(0, 1) + self.query_embed.weight.data[:, :2] = inverse_sigmoid( + self.query_embed.weight.data[:, :2]) + self.query_embed.weight.data[:, :2].requires_grad = False + + self.num_patterns = num_patterns + if not isinstance(num_patterns, int): + Warning('num_patterns should be int but {}'.format( + type(num_patterns))) + self.num_patterns = 0 + if self.num_patterns > 0: + self.patterns = nn.Embedding(self.num_patterns, embed_dims) + + if self.dn_components: + # leave one dim for indicator + self.label_enc = nn.Embedding(num_classes + 1, embed_dims - 1) def init_weights(self): self.transformer.init_weights() @@ -92,7 +115,45 @@ class DABDETRHead(nn.Module): nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) - def forward(self, feats, img_metas): + def prepare(self, feats, targets=None, mode='train'): + bs = feats[0].shape[0] + query_embed = self.query_embed.weight + if self.dn_components: + # default pipeline + self.dn_components['num_patterns'] = self.num_patterns + self.dn_components['targets'] = targets + # prepare for dn + tgt, query_embed, attn_mask, mask_dict = prepare_for_dn( + mode, self.dn_components, query_embed, bs, self.num_queries, + self.num_classes, self.embed_dims, self.label_enc) + if self.num_patterns > 0: + l = tgt.shape[0] + tgt[l - self.num_queries * self.num_patterns:] += \ + self.patterns.weight[:, None, None, :].repeat(1, self.num_queries, bs, 1).flatten(0, 1) + return query_embed, tgt, attn_mask, mask_dict + else: + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + if self.num_patterns == 0: + tgt = torch.zeros( + self.num_queries, + bs, + self.embed_dims, + device=query_embed.device) + else: + tgt = self.patterns.weight[:, None, None, :].repeat( + 1, self.num_queries, bs, + 1).flatten(0, 1) # n_q*n_pat, bs, d_model + query_embed = query_embed.repeat(self.num_patterns, 1, + 1) # n_q*n_pat, bs, d_model + return query_embed, tgt, None, None + + def forward(self, + feats, + img_metas, + query_embed=None, + tgt=None, + attn_mask=None, + mask_dict=None): """Forward function. Args: feats (tuple[Tensor]): Features from the upstream network, each is @@ -109,7 +170,9 @@ class DABDETRHead(nn.Module): normalized coordinate format (cx, cy, w, h) and shape \ [nb_dec, bs, num_query, 4]. """ - feats = self.transformer(feats, img_metas) + + feats = self.transformer( + feats, img_metas, query_embed, tgt, attn_mask=attn_mask) hs, reference = feats outputs_class = self.class_embed(hs) @@ -128,6 +191,10 @@ class DABDETRHead(nn.Module): outputs_coords.append(outputs_coord) outputs_coord = torch.stack(outputs_coords) + if mask_dict is not None: + # dn post process + outputs_class, outputs_coord = dn_post_process( + outputs_class, outputs_coord, mask_dict) out = { 'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1] @@ -164,27 +231,45 @@ class DABDETRHead(nn.Module): Returns: dict[str, Tensor]: A dictionary of loss components. """ - outputs = self.forward(x, img_metas) - + # prepare ground truth for i in range(len(img_metas)): img_h, img_w, _ = img_metas[i]['img_shape'] # DETR regress the relative position of boxes (cxcywh) in the image. # Thus the learning target should be normalized by the image size, also # the box format should be converted from defaultly x1y1x2y2 to cxcywh. - factor = outputs['pred_boxes'].new_tensor( - [img_w, img_h, img_w, img_h]).unsqueeze(0) + factor = gt_bboxes[i].new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0) gt_bboxes[i] = box_xyxy_to_cxcywh(gt_bboxes[i]) / factor targets = [] for gt_label, gt_bbox in zip(gt_labels, gt_bboxes): targets.append({'labels': gt_label, 'boxes': gt_bbox}) - losses = self.criterion(outputs, targets) + query_embed, tgt, attn_mask, mask_dict = self.prepare( + x, targets=targets, mode='train') + + outputs = self.forward( + x, + img_metas, + query_embed=query_embed, + tgt=tgt, + attn_mask=attn_mask, + mask_dict=mask_dict) + + losses = self.criterion(outputs, targets, mask_dict) return losses def forward_test(self, x, img_metas): - outputs = self.forward(x, img_metas) + query_embed, tgt, attn_mask, mask_dict = self.prepare(x, mode='test') + + outputs = self.forward( + x, + img_metas, + query_embed=query_embed, + tgt=tgt, + attn_mask=attn_mask, + mask_dict=mask_dict) ori_shape_list = [] for i in range(len(img_metas)): @@ -242,192 +327,3 @@ class PostProcess(nn.Module): } return results - - -class SetCriterion(nn.Module): - """ This class computes the loss for Conditional DETR. - The process happens in two steps: - 1) we compute hungarian assignment between ground truth boxes and the outputs of the model - 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) - """ - - def __init__(self, num_classes, matcher, weight_dict, losses): - """ Create the criterion. - Parameters: - num_classes: number of object categories, omitting the special no-object category - matcher: module able to compute a matching between targets and proposals - weight_dict: dict containing as key the names of the losses and as values their relative weight. - losses: list of all the losses to be applied. See get_loss for list of available losses. - """ - super().__init__() - self.num_classes = num_classes - self.matcher = matcher - self.weight_dict = weight_dict - self.losses = losses - - def loss_labels(self, outputs, targets, indices, num_boxes, log=True): - """Classification loss (Binary focal loss) - targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] - """ - assert 'pred_logits' in outputs - src_logits = outputs['pred_logits'] - - idx = self._get_src_permutation_idx(indices) - target_classes_o = torch.cat( - [t['labels'][J] for t, (_, J) in zip(targets, indices)]) - target_classes = torch.full( - src_logits.shape[:2], - self.num_classes, - dtype=torch.int64, - device=src_logits.device) - target_classes[idx] = target_classes_o - - target_classes_onehot = torch.zeros([ - src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1 - ], - dtype=src_logits.dtype, - layout=src_logits.layout, - device=src_logits.device) - target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) - - target_classes_onehot = target_classes_onehot[:, :, :-1] - loss_ce = py_sigmoid_focal_loss( - src_logits, - target_classes_onehot.long(), - alpha=0.25, - gamma=2, - reduction='none').mean(1).sum() / num_boxes - loss_ce = loss_ce * src_logits.shape[1] * self.weight_dict['loss_ce'] - losses = {'loss_ce': loss_ce} - - if log: - # TODO this should probably be a separate loss, not hacked in this one here - losses['class_error'] = 100 - accuracy(src_logits[idx], - target_classes_o)[0] - return losses - - @torch.no_grad() - def loss_cardinality(self, outputs, targets, indices, num_boxes): - """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes - This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients - """ - pred_logits = outputs['pred_logits'] - device = pred_logits.device - tgt_lengths = torch.as_tensor([len(v['labels']) for v in targets], - device=device) - # Count the number of predictions that are NOT "no-object" (which is the last class) - card_pred = (pred_logits.argmax(-1) != - pred_logits.shape[-1] - 1).sum(1) - card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) - losses = {'cardinality_error': card_err} - return losses - - def loss_boxes(self, outputs, targets, indices, num_boxes): - """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss - targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] - The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. - """ - assert 'pred_boxes' in outputs - idx = self._get_src_permutation_idx(indices) - src_boxes = outputs['pred_boxes'][idx] - target_boxes = torch.cat( - [t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) - - loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') - - losses = {} - losses['loss_bbox'] = loss_bbox.sum( - ) / num_boxes * self.weight_dict['loss_bbox'] - - loss_giou = 1 - torch.diag( - generalized_box_iou( - box_cxcywh_to_xyxy(src_boxes), - box_cxcywh_to_xyxy(target_boxes))) - losses['loss_giou'] = loss_giou.sum( - ) / num_boxes * self.weight_dict['loss_giou'] - - return losses - - def _get_src_permutation_idx(self, indices): - # permute predictions following indices - batch_idx = torch.cat( - [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) - src_idx = torch.cat([src for (src, _) in indices]) - return batch_idx, src_idx - - def _get_tgt_permutation_idx(self, indices): - # permute targets following indices - batch_idx = torch.cat( - [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) - tgt_idx = torch.cat([tgt for (_, tgt) in indices]) - return batch_idx, tgt_idx - - def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): - loss_map = { - 'labels': self.loss_labels, - 'cardinality': self.loss_cardinality, - 'boxes': self.loss_boxes, - } - assert loss in loss_map, f'do you really want to compute {loss} loss?' - return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) - - def forward(self, outputs, targets, return_indices=False): - """ This performs the loss computation. - Parameters: - outputs: dict of tensors, see the output specification of the model for the format - targets: list of dicts, such that len(targets) == batch_size. - The expected keys in each dict depends on the losses applied, see each loss' doc - - return_indices: used for vis. if True, the layer0-5 indices will be returned as well. - """ - - outputs_without_aux = { - k: v - for k, v in outputs.items() if k != 'aux_outputs' - } - - # Retrieve the matching between the outputs of the last layer and the targets - indices = self.matcher(outputs_without_aux, targets) - if return_indices: - indices0_copy = indices - indices_list = [] - - # Compute the average number of target boxes accross all nodes, for normalization purposes - num_boxes = sum(len(t['labels']) for t in targets) - num_boxes = torch.as_tensor([num_boxes], - dtype=torch.float, - device=next(iter(outputs.values())).device) - if is_dist_avail_and_initialized(): - torch.distributed.all_reduce(num_boxes) - num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() - - # Compute all the requested losses - losses = {} - for loss in self.losses: - losses.update( - self.get_loss(loss, outputs, targets, indices, num_boxes)) - - # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. - if 'aux_outputs' in outputs: - for i, aux_outputs in enumerate(outputs['aux_outputs']): - indices = self.matcher(aux_outputs, targets) - if return_indices: - indices_list.append(indices) - for loss in self.losses: - if loss == 'masks': - # Intermediate masks losses are too costly to compute, we ignore them. - continue - kwargs = {} - if loss == 'labels': - # Logging is enabled only for the last layer - kwargs = {'log': False} - l_dict = self.get_loss(loss, aux_outputs, targets, indices, - num_boxes, **kwargs) - l_dict = {k + f'_{i}': v for k, v in l_dict.items()} - losses.update(l_dict) - - if return_indices: - indices_list.append(indices0_copy) - return losses, indices_list - - return losses diff --git a/easycv/models/detection/detectors/dab_detr/dab_detr_transformer.py b/easycv/models/detection/detectors/dab_detr/dab_detr_transformer.py index 547d77d7..baa992fd 100644 --- a/easycv/models/detection/detectors/dab_detr/dab_detr_transformer.py +++ b/easycv/models/detection/detectors/dab_detr/dab_detr_transformer.py @@ -33,10 +33,7 @@ class DABDetrTransformer(nn.Module): def __init__(self, in_channels=1024, - num_queries=300, query_dim=4, - random_refpoints_xy=False, - num_patterns=0, d_model=512, nhead=8, num_encoder_layers=6, @@ -58,27 +55,11 @@ class DABDetrTransformer(nn.Module): ] self.input_proj = nn.Conv2d(in_channels, d_model, kernel_size=1) - self.query_embed = nn.Embedding(num_queries, query_dim) self.positional_encoding = PositionEmbeddingSineHW( d_model // 2, temperatureH=temperatureH, temperatureW=temperatureW, normalize=True) - self.random_refpoints_xy = random_refpoints_xy - if random_refpoints_xy: - self.query_embed.weight.data[:, :2].uniform_(0, 1) - self.query_embed.weight.data[:, :2] = inverse_sigmoid( - self.query_embed.weight.data[:, :2]) - self.query_embed.weight.data[:, :2].requires_grad = False - - self.num_queries = num_queries - self.num_patterns = num_patterns - if not isinstance(num_patterns, int): - Warning('num_patterns should be int but {}'.format( - type(num_patterns))) - self.num_patterns = 0 - if self.num_patterns > 0: - self.patterns = nn.Embedding(self.num_patterns, d_model) encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, @@ -116,14 +97,13 @@ class DABDetrTransformer(nn.Module): def init_weights(self): for p in self.named_parameters(): - if 'input_proj' in p[0] or 'query_embed' in p[ - 0] or 'positional_encoding' in p[0] or 'patterns' in p[ - 0] or 'bbox_embed' in p[0]: + if 'input_proj' in p[0] or 'positional_encoding' in p[ + 0] or 'bbox_embed' in p[0]: continue if p[1].dim() > 1: nn.init.xavier_uniform_(p[1]) - def forward(self, src, img_metas): + def forward(self, src, img_metas, query_embed, tgt, attn_mask=None): src = src[0] # construct binary masks which used for the transformer. @@ -143,30 +123,18 @@ class DABDetrTransformer(nn.Module): # position encoding pos_embed = self.positional_encoding(mask) # [bs, embed_dim, h, w] # outs_dec: [nb_dec, bs, num_query, embed_dim] - query_embed = self.query_embed.weight # flatten NxCxHxW to HWxNxC src = src.flatten(2).permute(2, 0, 1) pos_embed = pos_embed.flatten(2).permute(2, 0, 1) - query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) mask = mask.flatten(1) memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) - num_queries = query_embed.shape[0] - if self.num_patterns == 0: - tgt = torch.zeros( - num_queries, bs, self.d_model, device=query_embed.device) - else: - tgt = self.patterns.weight[:, None, None, :].repeat( - 1, self.num_queries, bs, - 1).flatten(0, 1) # n_q*n_pat, bs, d_model - query_embed = query_embed.repeat(self.num_patterns, 1, - 1) # n_q*n_pat, bs, d_model - hs, references = self.decoder( tgt, memory, + tgt_mask=attn_mask, memory_key_padding_mask=mask, pos=pos_embed, refpoints_unsigmoid=query_embed) diff --git a/easycv/models/detection/detectors/dab_detr/dn_components.py b/easycv/models/detection/detectors/dab_detr/dn_components.py new file mode 100644 index 00000000..2e6411ee --- /dev/null +++ b/easycv/models/detection/detectors/dab_detr/dn_components.py @@ -0,0 +1,167 @@ +# ------------------------------------------------------------------------ +# DN-DETR +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +import torch + +from easycv.models.detection.utils import inverse_sigmoid + + +def prepare_for_dn(mode, dn_args, embedweight, batch_size, num_queries, + num_classes, hidden_dim, label_enc): + """ + prepare for dn components in forward function + Args: + dn_args: (targets, args.scalar, args.label_noise_scale, args.box_noise_scale, args.num_patterns) from engine input + embedweight: positional queries as anchor + training: whether it is training or inference + num_queries: number of queries + num_classes: number of classes + hidden_dim: transformer hidden dimenstion + label_enc: label encoding embedding + + Returns: input_query_label, input_query_bbox, attn_mask, mask_dict + """ + if mode == 'train': + targets, scalar, label_noise_scale, box_noise_scale, num_patterns = dn_args[ + 'targets'], dn_args['scalar'], dn_args[ + 'label_noise_scale'], dn_args['box_noise_scale'], dn_args[ + 'num_patterns'] + else: + num_patterns = dn_args['num_patterns'] + + if num_patterns == 0: + num_patterns = 1 + indicator0 = torch.zeros([num_queries * num_patterns, 1]).cuda() + tgt = label_enc(torch.tensor(num_classes).cuda()).repeat( + num_queries * num_patterns, 1) + tgt = torch.cat([tgt, indicator0], dim=1) + refpoint_emb = embedweight.repeat(num_patterns, 1) + if mode == 'train': + known = [(torch.ones_like(t['labels'])).cuda() for t in targets] + know_idx = [torch.nonzero(t) for t in known] + known_num = [sum(k) for k in known] + # you can uncomment this to use fix number of dn queries + # if int(max(known_num))>0: + # scalar=scalar//int(max(known_num)) + + # can be modified to selectively denosie some label or boxes; also known label prediction + unmask_bbox = unmask_label = torch.cat(known) + labels = torch.cat([t['labels'] for t in targets]) + boxes = torch.cat([t['boxes'] for t in targets]) + batch_idx = torch.cat([ + torch.full_like(t['labels'].long(), i) + for i, t in enumerate(targets) + ]) + + known_indice = torch.nonzero(unmask_label + unmask_bbox) + known_indice = known_indice.view(-1) + + # add noise + known_indice = known_indice.repeat(scalar, 1).view(-1) + known_labels = labels.repeat(scalar, 1).view(-1) + known_bid = batch_idx.repeat(scalar, 1).view(-1) + known_bboxs = boxes.repeat(scalar, 1) + known_labels_expaned = known_labels.clone() + known_bbox_expand = known_bboxs.clone() + + # noise on the label + if label_noise_scale > 0: + p = torch.rand_like(known_labels_expaned.float()) + chosen_indice = torch.nonzero(p < (label_noise_scale)).view( + -1) # usually half of bbox noise + new_label = torch.randint_like( + chosen_indice, 0, num_classes) # randomly put a new one here + known_labels_expaned.scatter_(0, chosen_indice, new_label) + # noise on the box + if box_noise_scale > 0: + diff = torch.zeros_like(known_bbox_expand) + diff[:, :2] = known_bbox_expand[:, 2:] / 2 + diff[:, 2:] = known_bbox_expand[:, 2:] + known_bbox_expand += torch.mul( + (torch.rand_like(known_bbox_expand) * 2 - 1.0), + diff).cuda() * box_noise_scale + known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0) + + m = known_labels_expaned.long().to('cuda') + input_label_embed = label_enc(m) + # add dn part indicator + indicator1 = torch.ones([input_label_embed.shape[0], 1]).cuda() + input_label_embed = torch.cat([input_label_embed, indicator1], dim=1) + input_bbox_embed = inverse_sigmoid(known_bbox_expand) + single_pad = int(max(known_num)) + pad_size = int(single_pad * scalar) + padding_label = torch.zeros(pad_size, hidden_dim).cuda() + padding_bbox = torch.zeros(pad_size, 4).cuda() + input_query_label = torch.cat([padding_label, tgt], + dim=0).repeat(batch_size, 1, 1) + input_query_bbox = torch.cat([padding_bbox, refpoint_emb], + dim=0).repeat(batch_size, 1, 1) + + # map in order + map_known_indice = torch.tensor([]).to('cuda') + if len(known_num): + map_known_indice = torch.cat([ + torch.tensor(range(num)) for num in known_num + ]) # [1,2, 1,2,3] + map_known_indice = torch.cat([ + map_known_indice + single_pad * i for i in range(scalar) + ]).long() + if len(known_bid): + input_query_label[(known_bid.long(), + map_known_indice)] = input_label_embed + input_query_bbox[(known_bid.long(), + map_known_indice)] = input_bbox_embed + + tgt_size = pad_size + num_queries * num_patterns + attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0 + # match query cannot see the reconstruct + attn_mask[pad_size:, :pad_size] = True + # reconstruct cannot see each other + for i in range(scalar): + if i == 0: + attn_mask[single_pad * i:single_pad * (i + 1), + single_pad * (i + 1):pad_size] = True + if i == scalar - 1: + attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * + i] = True + else: + attn_mask[single_pad * i:single_pad * (i + 1), + single_pad * (i + 1):pad_size] = True + attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * + i] = True + mask_dict = { + 'known_indice': torch.as_tensor(known_indice).long(), + 'batch_idx': torch.as_tensor(batch_idx).long(), + 'map_known_indice': torch.as_tensor(map_known_indice).long(), + 'known_lbs_bboxes': (known_labels, known_bboxs), + 'know_idx': know_idx, + 'pad_size': pad_size + } + else: # no dn for inference + input_query_label = tgt.repeat(batch_size, 1, 1) + input_query_bbox = refpoint_emb.repeat(batch_size, 1, 1) + attn_mask = None + mask_dict = None + + input_query_label = input_query_label.transpose(0, 1) + input_query_bbox = input_query_bbox.transpose(0, 1) + + return input_query_label, input_query_bbox, attn_mask, mask_dict + + +def dn_post_process(outputs_class, outputs_coord, mask_dict): + """ + post process of dn after output from the transformer + put the dn part in the mask_dict + """ + if mask_dict and mask_dict['pad_size'] > 0: + output_known_class = outputs_class[:, :, :mask_dict['pad_size'], :] + output_known_coord = outputs_coord[:, :, :mask_dict['pad_size'], :] + outputs_class = outputs_class[:, :, mask_dict['pad_size']:, :] + outputs_coord = outputs_coord[:, :, mask_dict['pad_size']:, :] + mask_dict['output_known_lbs_bboxes'] = (output_known_class, + output_known_coord) + return outputs_class, outputs_coord diff --git a/easycv/models/detection/detectors/detection.py b/easycv/models/detection/detectors/detection.py index cdfc0cdf..fe91fbf8 100644 --- a/easycv/models/detection/detectors/detection.py +++ b/easycv/models/detection/detectors/detection.py @@ -44,8 +44,9 @@ class Detection(BaseModel): strict=False, logger=logger) else: - print_log('load model from init weights') - self.backbone.init_weights() + raise ValueError( + 'default_pretrained_model_path for {} not found'.format( + self.backbone.__class__.__name__)) else: print_log('load model from init weights') self.backbone.init_weights() diff --git a/easycv/models/detection/detectors/detr/detr_head.py b/easycv/models/detection/detectors/detr/detr_head.py index b355ac26..402d88f3 100644 --- a/easycv/models/detection/detectors/detr/detr_head.py +++ b/easycv/models/detection/detectors/detr/detr_head.py @@ -6,12 +6,10 @@ import torch.nn as nn import torch.nn.functional as F from easycv.models.builder import HEADS, build_neck -from easycv.models.detection.utils import (HungarianMatcher, accuracy, +from easycv.models.detection.utils import (HungarianMatcher, SetCriterion, box_cxcywh_to_xyxy, - box_xyxy_to_cxcywh, - generalized_box_iou) -from easycv.models.utils import (MLP, get_world_size, - is_dist_avail_and_initialized) + box_xyxy_to_cxcywh) +from easycv.models.utils import MLP @HEADS.register_module() @@ -50,7 +48,7 @@ class DETRHead(nn.Module): matcher=self.matcher, weight_dict=weight_dict, eos_coef=eos_coef, - losses=['labels', 'boxes', 'cardinality']) + losses=['labels', 'boxes']) self.postprocess = PostProcess() self.transformer = build_neck(transformer) @@ -120,21 +118,22 @@ class DETRHead(nn.Module): Returns: dict[str, Tensor]: A dictionary of loss components. """ - outputs = self.forward(x, img_metas) - + # prepare ground truth for i in range(len(img_metas)): img_h, img_w, _ = img_metas[i]['img_shape'] # DETR regress the relative position of boxes (cxcywh) in the image. # Thus the learning target should be normalized by the image size, also # the box format should be converted from defaultly x1y1x2y2 to cxcywh. - factor = outputs['pred_boxes'].new_tensor( - [img_w, img_h, img_w, img_h]).unsqueeze(0) + factor = gt_bboxes[i].new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0) gt_bboxes[i] = box_xyxy_to_cxcywh(gt_bboxes[i]) / factor targets = [] for gt_label, gt_bbox in zip(gt_labels, gt_bboxes): targets.append({'labels': gt_label, 'boxes': gt_bbox}) + outputs = self.forward(x, img_metas) + losses = self.criterion(outputs, targets) return losses @@ -188,171 +187,3 @@ class PostProcess(nn.Module): } return results - - -class SetCriterion(nn.Module): - """ This class computes the loss for DETR. - The process happens in two steps: - 1) we compute hungarian assignment between ground truth boxes and the outputs of the model - 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) - """ - - def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses): - """ Create the criterion. - Parameters: - num_classes: number of object categories, omitting the special no-object category - matcher: module able to compute a matching between targets and proposals - weight_dict: dict containing as key the names of the losses and as values their relative weight. - eos_coef: relative classification weight applied to the no-object category - losses: list of all the losses to be applied. See get_loss for list of available losses. - """ - super().__init__() - self.num_classes = num_classes - self.matcher = matcher - self.weight_dict = weight_dict - self.eos_coef = eos_coef - self.losses = losses - empty_weight = torch.ones(self.num_classes + 1) - empty_weight[-1] = self.eos_coef - self.register_buffer('empty_weight', empty_weight) - - def loss_labels(self, outputs, targets, indices, num_boxes, log=True): - """Classification loss (NLL) - targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] - """ - assert 'pred_logits' in outputs - src_logits = outputs['pred_logits'] - - idx = self._get_src_permutation_idx(indices) - target_classes_o = torch.cat( - [t['labels'][J] for t, (_, J) in zip(targets, indices)]) - target_classes = torch.full( - src_logits.shape[:2], - self.num_classes, - dtype=torch.int64, - device=src_logits.device) - target_classes[idx] = target_classes_o - - loss_ce = F.cross_entropy( - src_logits.transpose(1, 2), target_classes, - self.empty_weight) * self.weight_dict['loss_ce'] - losses = {'loss_ce': loss_ce} - - if log: - # TODO this should probably be a separate loss, not hacked in this one here - losses['class_error'] = 100 - accuracy(src_logits[idx], - target_classes_o)[0] - return losses - - @torch.no_grad() - def loss_cardinality(self, outputs, targets, indices, num_boxes): - """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes - This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients - """ - pred_logits = outputs['pred_logits'] - device = pred_logits.device - tgt_lengths = torch.as_tensor([len(v['labels']) for v in targets], - device=device) - # Count the number of predictions that are NOT "no-object" (which is the last class) - card_pred = (pred_logits.argmax(-1) != - pred_logits.shape[-1] - 1).sum(1) - card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) - losses = {'cardinality_error': card_err} - return losses - - def loss_boxes(self, outputs, targets, indices, num_boxes): - """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss - targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] - The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. - """ - assert 'pred_boxes' in outputs - idx = self._get_src_permutation_idx(indices) - src_boxes = outputs['pred_boxes'][idx] - target_boxes = torch.cat( - [t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) - - loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') - - losses = {} - losses['loss_bbox'] = loss_bbox.sum( - ) / num_boxes * self.weight_dict['loss_bbox'] - - loss_giou = 1 - torch.diag( - generalized_box_iou( - box_cxcywh_to_xyxy(src_boxes), - box_cxcywh_to_xyxy(target_boxes))) - losses['loss_giou'] = loss_giou.sum( - ) / num_boxes * self.weight_dict['loss_giou'] - return losses - - def _get_src_permutation_idx(self, indices): - # permute predictions following indices - batch_idx = torch.cat( - [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) - src_idx = torch.cat([src for (src, _) in indices]) - return batch_idx, src_idx - - def _get_tgt_permutation_idx(self, indices): - # permute targets following indices - batch_idx = torch.cat( - [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) - tgt_idx = torch.cat([tgt for (_, tgt) in indices]) - return batch_idx, tgt_idx - - def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): - loss_map = { - 'labels': self.loss_labels, - 'cardinality': self.loss_cardinality, - 'boxes': self.loss_boxes - } - assert loss in loss_map, f'do you really want to compute {loss} loss?' - return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) - - def forward(self, outputs, targets): - """ This performs the loss computation. - Parameters: - outputs: dict of tensors, see the output specification of the model for the format - targets: list of dicts, such that len(targets) == batch_size. - The expected keys in each dict depends on the losses applied, see each loss' doc - """ - outputs_without_aux = { - k: v - for k, v in outputs.items() if k != 'aux_outputs' - } - - # Retrieve the matching between the outputs of the last layer and the targets - indices = self.matcher(outputs_without_aux, targets) - - # Compute the average number of target boxes accross all nodes, for normalization purposes - num_boxes = sum(len(t['labels']) for t in targets) - num_boxes = torch.as_tensor([num_boxes], - dtype=torch.float, - device=next(iter(outputs.values())).device) - if is_dist_avail_and_initialized(): - torch.distributed.all_reduce(num_boxes) - num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() - - # Compute all the requested losses - losses = {} - for loss in self.losses: - losses.update( - self.get_loss(loss, outputs, targets, indices, num_boxes)) - - # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. - if 'aux_outputs' in outputs: - for i, aux_outputs in enumerate(outputs['aux_outputs']): - indices = self.matcher(aux_outputs, targets) - for loss in self.losses: - if loss == 'masks': - # Intermediate masks losses are too costly to compute, we ignore them. - continue - kwargs = {} - if loss == 'labels': - # Logging is enabled only for the last layer - kwargs = {'log': False} - l_dict = self.get_loss(loss, aux_outputs, targets, indices, - num_boxes, **kwargs) - l_dict = {k + f'_{i}': v for k, v in l_dict.items()} - losses.update(l_dict) - - return losses diff --git a/easycv/models/detection/utils/__init__.py b/easycv/models/detection/utils/__init__.py index cedd62df..c8032d63 100644 --- a/easycv/models/detection/utils/__init__.py +++ b/easycv/models/detection/utils/__init__.py @@ -7,3 +7,4 @@ from .generator import MlvlPointGenerator from .matcher import HungarianMatcher from .misc import (accuracy, filter_scores_and_topk, fp16_clamp, interpolate, inverse_sigmoid, output_postprocess, select_single_mlvl) +from .set_criterion import SetCriterion diff --git a/easycv/models/detection/utils/matcher.py b/easycv/models/detection/utils/matcher.py index a9df8378..8d52f047 100644 --- a/easycv/models/detection/utils/matcher.py +++ b/easycv/models/detection/utils/matcher.py @@ -13,7 +13,7 @@ class HungarianMatcher(nn.Module): while the others are un-matched (and thus treated as non-objects). """ - def __init__(self, cost_dict, cost_class_type=None): + def __init__(self, cost_dict, cost_class_type='ce_cost'): """Creates the matcher Params: cost_class: This is the relative weight of the classification error in the matching cost @@ -51,7 +51,7 @@ class HungarianMatcher(nn.Module): if self.cost_class_type == 'focal_loss_cost': out_prob = outputs['pred_logits'].flatten( 0, 1).sigmoid() # [batch_size * num_queries, num_classes] - else: + elif self.cost_class_type == 'ce_cost': out_prob = outputs['pred_logits'].flatten(0, 1).softmax( -1) # [batch_size * num_queries, num_classes] @@ -72,7 +72,7 @@ class HungarianMatcher(nn.Module): pos_cost_class = pos_cost_class * (-(out_prob + 1e-8).log()) cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] - else: + elif self.cost_class_type == 'ce_cost': # Compute the classification cost. Contrary to the loss, we don't use the NLL, # but approximate it in 1 - proba[target class]. # The 1 is a constant that doesn't change the matching, it can be ommitted. diff --git a/easycv/models/detection/utils/set_criterion.py b/easycv/models/detection/utils/set_criterion.py new file mode 100644 index 00000000..6661e667 --- /dev/null +++ b/easycv/models/detection/utils/set_criterion.py @@ -0,0 +1,388 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from easycv.models.detection.utils import (accuracy, box_cxcywh_to_xyxy, + generalized_box_iou) +from easycv.models.loss.focal_loss import py_sigmoid_focal_loss +from easycv.models.utils import get_world_size, is_dist_avail_and_initialized + + +class SetCriterion(nn.Module): + """ This class computes the loss for Conditional DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, + num_classes, + matcher, + weight_dict, + losses, + eos_coef=None, + loss_class_type='ce', + dn_components=None): + """ Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + losses: list of all the losses to be applied. See get_loss for list of available losses. + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.loss_class_type = loss_class_type + if self.loss_class_type == 'ce': + empty_weight = torch.ones(self.num_classes + 1) + empty_weight[-1] = eos_coef + self.register_buffer('empty_weight', empty_weight) + if dn_components is not None: + self.dn_criterion = DNCriterion(self.weight_dict) + + def loss_labels(self, outputs, targets, indices, num_boxes, log=True): + """Classification loss (Binary focal loss) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat( + [t['labels'][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full( + src_logits.shape[:2], + self.num_classes, + dtype=torch.int64, + device=src_logits.device) + target_classes[idx] = target_classes_o + + if self.loss_class_type == 'ce': + loss_ce = F.cross_entropy( + src_logits.transpose(1, 2), target_classes, + self.empty_weight) * self.weight_dict['loss_ce'] + elif self.loss_class_type == 'focal_loss': + target_classes_onehot = torch.zeros([ + src_logits.shape[0], src_logits.shape[1], + src_logits.shape[2] + 1 + ], + dtype=src_logits.dtype, + layout=src_logits.layout, + device=src_logits.device) + target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) + target_classes_onehot = target_classes_onehot[:, :, :-1] + + loss_ce = py_sigmoid_focal_loss( + src_logits, + target_classes_onehot.long(), + alpha=0.25, + gamma=2, + reduction='none').mean(1).sum() / num_boxes + loss_ce = loss_ce * src_logits.shape[1] * self.weight_dict[ + 'loss_ce'] + losses = {'loss_ce': loss_ce} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(src_logits[idx], + target_classes_o)[0] + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients + """ + pred_logits = outputs['pred_logits'] + device = pred_logits.device + tgt_lengths = torch.as_tensor([len(v['labels']) for v in targets], + device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (pred_logits.argmax(-1) != + pred_logits.shape[-1] - 1).sum(1) + card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) + losses = {'cardinality_error': card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. + """ + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat( + [t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') + + losses = {} + losses['loss_bbox'] = loss_bbox.sum( + ) / num_boxes * self.weight_dict['loss_bbox'] + + loss_giou = 1 - torch.diag( + generalized_box_iou( + box_cxcywh_to_xyxy(src_boxes), + box_cxcywh_to_xyxy(target_boxes))) + losses['loss_giou'] = loss_giou.sum( + ) / num_boxes * self.weight_dict['loss_giou'] + + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat( + [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat( + [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + 'labels': self.loss_labels, + 'cardinality': self.loss_cardinality, + 'boxes': self.loss_boxes, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + def forward(self, outputs, targets, mask_dict=None, return_indices=False): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + + return_indices: used for vis. if True, the layer0-5 indices will be returned as well. + """ + + outputs_without_aux = { + k: v + for k, v in outputs.items() if k != 'aux_outputs' + } + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + if return_indices: + indices0_copy = indices + indices_list = [] + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t['labels']) for t in targets) + num_boxes = torch.as_tensor([num_boxes], + dtype=torch.float, + device=next(iter(outputs.values())).device) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + losses.update( + self.get_loss(loss, outputs, targets, indices, num_boxes)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if 'aux_outputs' in outputs: + for i, aux_outputs in enumerate(outputs['aux_outputs']): + indices = self.matcher(aux_outputs, targets) + if return_indices: + indices_list.append(indices) + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs = {'log': False} + l_dict = self.get_loss(loss, aux_outputs, targets, indices, + num_boxes, **kwargs) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + if mask_dict is not None: + # dn loss computation + aux_num = 0 + if 'aux_outputs' in outputs: + aux_num = len(outputs['aux_outputs']) + dn_losses = self.dn_criterion(mask_dict, self.training, aux_num, + 0.25) + losses.update(dn_losses) + + if return_indices: + indices_list.append(indices0_copy) + return losses, indices_list + + return losses + + +class DNCriterion(nn.Module): + """ This class computes the loss for Conditional DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, weight_dict): + """ Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + losses: list of all the losses to be applied. See get_loss for list of available losses. + """ + super().__init__() + self.weight_dict = weight_dict + + def prepare_for_loss(self, mask_dict): + """ + prepare dn components to calculate loss + Args: + mask_dict: a dict that contains dn information + """ + output_known_class, output_known_coord = mask_dict[ + 'output_known_lbs_bboxes'] + known_labels, known_bboxs = mask_dict['known_lbs_bboxes'] + map_known_indice = mask_dict['map_known_indice'] + + known_indice = mask_dict['known_indice'] + + batch_idx = mask_dict['batch_idx'] + bid = batch_idx[known_indice] + if len(output_known_class) > 0: + output_known_class = output_known_class.permute( + 1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2) + output_known_coord = output_known_coord.permute( + 1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2) + num_tgt = known_indice.numel() + return known_labels, known_bboxs, output_known_class, output_known_coord, num_tgt + + def tgt_loss_boxes( + self, + src_boxes, + tgt_boxes, + num_tgt, + ): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. + """ + if len(tgt_boxes) == 0: + return { + 'tgt_loss_bbox': torch.as_tensor(0.).to('cuda'), + 'tgt_loss_giou': torch.as_tensor(0.).to('cuda'), + } + + loss_bbox = F.l1_loss(src_boxes, tgt_boxes, reduction='none') + + losses = {} + losses['tgt_loss_bbox'] = loss_bbox.sum( + ) / num_tgt * self.weight_dict['loss_bbox'] + + loss_giou = 1 - torch.diag( + generalized_box_iou( + box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(tgt_boxes))) + losses['tgt_loss_giou'] = loss_giou.sum( + ) / num_tgt * self.weight_dict['loss_giou'] + return losses + + def tgt_loss_labels(self, + src_logits_, + tgt_labels_, + num_tgt, + focal_alpha, + log=False): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + if len(tgt_labels_) == 0: + return { + 'tgt_loss_ce': torch.as_tensor(0.).to('cuda'), + 'tgt_class_error': torch.as_tensor(0.).to('cuda'), + } + + src_logits, tgt_labels = src_logits_.unsqueeze( + 0), tgt_labels_.unsqueeze(0) + + target_classes_onehot = torch.zeros([ + src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1 + ], + dtype=src_logits.dtype, + layout=src_logits.layout, + device=src_logits.device) + target_classes_onehot.scatter_(2, tgt_labels.unsqueeze(-1), 1) + + target_classes_onehot = target_classes_onehot[:, :, :-1] + loss_ce = py_sigmoid_focal_loss( + src_logits, + target_classes_onehot.long(), + alpha=focal_alpha, + gamma=2, + reduction='none').mean(1).sum( + ) / num_tgt * src_logits.shape[1] * self.weight_dict['loss_ce'] + + losses = {'tgt_loss_ce': loss_ce} + if log: + losses['tgt_class_error'] = 100 - accuracy(src_logits_, + tgt_labels_)[0] + return losses + + def forward(self, mask_dict, training, aux_num, focal_alpha): + """ + compute dn loss in criterion + Args: + mask_dict: a dict for dn information + training: training or inference flag + aux_num: aux loss number + focal_alpha: for focal loss + """ + losses = {} + if training and 'output_known_lbs_bboxes' in mask_dict: + known_labels, known_bboxs, output_known_class, output_known_coord, num_tgt = self.prepare_for_loss( + mask_dict) + losses.update( + self.tgt_loss_labels(output_known_class[-1], known_labels, + num_tgt, focal_alpha)) + losses.update( + self.tgt_loss_boxes(output_known_coord[-1], known_bboxs, + num_tgt)) + else: + losses['tgt_loss_bbox'] = torch.as_tensor(0.).to('cuda') + losses['tgt_loss_giou'] = torch.as_tensor(0.).to('cuda') + losses['tgt_loss_ce'] = torch.as_tensor(0.).to('cuda') + losses['tgt_class_error'] = torch.as_tensor(0.).to('cuda') + + if aux_num: + for i in range(aux_num): + # dn aux loss + if training and 'output_known_lbs_bboxes' in mask_dict: + l_dict = self.tgt_loss_labels(output_known_class[i], + known_labels, num_tgt, + focal_alpha) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + l_dict = self.tgt_loss_boxes(output_known_coord[i], + known_bboxs, num_tgt) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + else: + l_dict = dict() + l_dict['tgt_loss_bbox'] = torch.as_tensor(0.).to('cuda') + l_dict['tgt_class_error'] = torch.as_tensor(0.).to('cuda') + l_dict['tgt_loss_giou'] = torch.as_tensor(0.).to('cuda') + l_dict['tgt_loss_ce'] = torch.as_tensor(0.).to('cuda') + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + return losses diff --git a/tests/apis/test_export.py b/tests/apis/test_export.py index c22b94da..0ce5965d 100644 --- a/tests/apis/test_export.py +++ b/tests/apis/test_export.py @@ -73,6 +73,7 @@ class ModelExportTest(unittest.TestCase): def test_export_classification_jit(self): config_file = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py' cfg = mmcv_config_fromfile(config_file) + cfg.model.pretrained = False cfg.model.backbone = dict( type='ResNetJIT', depth=50, diff --git a/tests/models/classification/test_classification.py b/tests/models/classification/test_classification.py index ec341ca2..366f97b8 100644 --- a/tests/models/classification/test_classification.py +++ b/tests/models/classification/test_classification.py @@ -42,7 +42,8 @@ class ClassificationTest(unittest.TestCase): batch_size = 1 a = torch.rand(batch_size, 3, 224, 224).to('cuda') - model = Classification(backbone=backbone, head=head).to('cuda') + model = Classification( + backbone=backbone, head=head, pretrained=False).to('cuda') model.eval() model_jit = torch.jit.script(model) diff --git a/tests/models/detection/detr/test_detr.py b/tests/models/detection/detr/test_detr.py index fa3c0289..0e0552dc 100644 --- a/tests/models/detection/detr/test_detr.py +++ b/tests/models/detection/detr/test_detr.py @@ -12,7 +12,6 @@ from easycv.datasets.utils import replace_ImageToTensor from easycv.models import build_model from easycv.utils.checkpoint import load_checkpoint from easycv.utils.config_tools import mmcv_config_fromfile -from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab from easycv.utils.registry import build_from_cfg @@ -26,9 +25,6 @@ class DETRTest(unittest.TestCase): self.cfg = mmcv_config_fromfile(config_path) - # dynamic adapt mmdet models - dynamic_adapt_for_mmlab(self.cfg) - # modify model_config if self.cfg.model.head.get('num_select', None): self.cfg.model.head.num_select = 10 @@ -600,7 +596,7 @@ class DETRTest(unittest.TestCase): decimal=1) def test_dab_detr(self): - model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dab_detr/epoch_50.pth' + model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dab_detr/dab_detr_epoch_50.pth' config_path = 'configs/detection/dab_detr/dab_detr_r50_8x2_50e_coco.py' img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg' self.init_detr(model_path, config_path) @@ -673,6 +669,80 @@ class DETRTest(unittest.TestCase): ]]), decimal=1) + def test_dn_detr(self): + model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dn_detr/dn_detr_epoch_50.pth' + config_path = 'configs/detection/dab_detr/dn_detr_r50_8x2_50e_coco.py' + img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg' + self.init_detr(model_path, config_path) + output = self.predict(img) + + self.assertIn('detection_boxes', output) + self.assertIn('detection_scores', output) + self.assertIn('detection_classes', output) + self.assertIn('img_metas', output) + self.assertEqual(len(output['detection_boxes'][0]), 10) + self.assertEqual(len(output['detection_scores'][0]), 10) + self.assertEqual(len(output['detection_classes'][0]), 10) + + self.assertListEqual( + output['detection_classes'][0].tolist(), + np.array([2, 13, 2, 2, 2, 2, 2, 2, 2, 2], dtype=np.int32).tolist()) + + assert_array_almost_equal( + output['detection_scores'][0], + np.array([ + 0.8800525665283203, 0.866659939289093, 0.8665854930877686, + 0.8030595183372498, 0.7642921209335327, 0.7375038862228394, + 0.7270554304122925, 0.6710091233253479, 0.6316548585891724, + 0.6164721846580505 + ], + dtype=np.float32), + decimal=2) + + assert_array_almost_equal( + output['detection_boxes'][0], + np.array([[ + 294.9338073730469, 115.7542495727539, 377.5517578125, + 150.59274291992188 + ], + [ + 220.57424926757812, 175.97023010253906, + 456.9001770019531, 383.2597351074219 + ], + [ + 479.5928649902344, 109.94012451171875, + 523.7343139648438, 130.80604553222656 + ], + [ + 398.6956787109375, 111.45973205566406, + 434.0437316894531, 134.1909637451172 + ], + [ + 166.98208618164062, 109.44792938232422, + 210.35342407226562, 139.9746856689453 + ], + [ + 609.432373046875, 113.08062744140625, + 635.9082641601562, 136.74383544921875 + ], + [ + 268.0716552734375, 105.00788879394531, + 327.4037170410156, 128.01449584960938 + ], + [ + 190.77467346191406, 107.42850494384766, + 298.35760498046875, 156.2850341796875 + ], + [ + 591.0296020507812, 110.53913116455078, + 620.702880859375, 127.42123413085938 + ], + [ + 431.6607971191406, 105.04813385009766, + 484.4869689941406, 132.45864868164062 + ]]), + decimal=1) + if __name__ == '__main__': unittest.main()