mirror of https://github.com/alibaba/EasyCV.git
parent
8c93caa2d9
commit
9517bb80ff
|
@ -3,7 +3,7 @@ test_cfg = {}
|
|||
optimizer_config = dict() # grad_clip, coalesce, bucket_size_mb
|
||||
# yapf:disable
|
||||
log_config = dict(
|
||||
interval=10,
|
||||
interval=50,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
# dict(type='TensorboardLoggerHook')
|
||||
|
|
|
@ -119,3 +119,14 @@ val_dataset = dict(
|
|||
|
||||
data = dict(
|
||||
imgs_per_gpu=2, workers_per_gpu=2, train=train_dataset, val=val_dataset)
|
||||
|
||||
# evaluation
|
||||
eval_config = dict(interval=1, gpu_collect=False)
|
||||
eval_pipelines = [
|
||||
dict(
|
||||
mode='test',
|
||||
evaluators=[
|
||||
dict(type='CocoDetectionEvaluator', classes=CLASSES),
|
||||
],
|
||||
)
|
||||
]
|
|
@ -16,7 +16,6 @@ model = dict(
|
|||
transformer=dict(
|
||||
type='DABDetrTransformer',
|
||||
in_channels=2048,
|
||||
num_queries=300,
|
||||
d_model=256,
|
||||
nhead=8,
|
||||
num_encoder_layers=6,
|
||||
|
@ -27,8 +26,6 @@ model = dict(
|
|||
normalize_before=False,
|
||||
return_intermediate_dec=True,
|
||||
query_dim=4,
|
||||
random_refpoints_xy=False,
|
||||
num_patterns=0,
|
||||
keep_query_pos=False,
|
||||
query_scale_type='cond_elewise',
|
||||
modulate_hw_attn=True,
|
||||
|
@ -40,16 +37,18 @@ model = dict(
|
|||
embed_dims=256,
|
||||
query_dim=4,
|
||||
iter_update=True,
|
||||
num_queries=300,
|
||||
num_select=300,
|
||||
random_refpoints_xy=False,
|
||||
num_patterns=0,
|
||||
bbox_embed_diff_each_layer=False,
|
||||
cost_dict={
|
||||
'cost_class': 2,
|
||||
'cost_bbox': 5,
|
||||
'cost_giou': 2,
|
||||
},
|
||||
weight_dict={
|
||||
'loss_ce': 1,
|
||||
'loss_bbox': 5,
|
||||
'loss_giou': 2
|
||||
},
|
||||
))
|
||||
cost_dict=dict(
|
||||
cost_class=2,
|
||||
cost_bbox=5,
|
||||
cost_giou=2,
|
||||
),
|
||||
weight_dict=dict(
|
||||
loss_ce=1,
|
||||
loss_bbox=5,
|
||||
loss_giou=2,
|
||||
)))
|
||||
|
|
|
@ -1,28 +1,8 @@
|
|||
_base_ = ['./dab_detr.py', './coco_detection.py', 'configs/base.py']
|
||||
|
||||
CLASSES = [
|
||||
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
||||
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
||||
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
||||
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
|
||||
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
|
||||
'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
||||
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
|
||||
'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
|
||||
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
|
||||
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
|
||||
'hair drier', 'toothbrush'
|
||||
_base_ = [
|
||||
'./dab_detr.py', '../_base_/dataset/autoaug_coco_detection.py',
|
||||
'configs/base.py'
|
||||
]
|
||||
|
||||
log_config = dict(
|
||||
interval=50,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
# dict(type='TensorboardLoggerHook')
|
||||
])
|
||||
|
||||
checkpoint_config = dict(interval=10)
|
||||
# optimizer
|
||||
paramwise_options = {'backbone': dict(lr_mult=0.1, weight_decay_mult=1.0)}
|
||||
|
@ -37,16 +17,4 @@ lr_config = dict(policy='step', step=[40])
|
|||
|
||||
total_epochs = 50
|
||||
|
||||
# evaluation
|
||||
# eval_config = dict(initial=True, interval=1, gpu_collect=False)
|
||||
eval_config = dict(interval=1, gpu_collect=False)
|
||||
eval_pipelines = [
|
||||
dict(
|
||||
mode='test',
|
||||
evaluators=[
|
||||
dict(type='CocoDetectionEvaluator', classes=CLASSES),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
find_unused_parameters = False
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
_base_ = './dab_detr_r50_8x2_50e_coco.py'
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
head=dict(
|
||||
dn_components=dict(
|
||||
scalar=5, label_noise_scale=0.2, box_noise_scale=0.4)))
|
|
@ -0,0 +1,4 @@
|
|||
_base_ = './dn_detr_r50_8x2_50e_coco.py'
|
||||
|
||||
# model settings
|
||||
model = dict(backbone=dict(strides=(1, 2, 2, 1), dilations=(1, 1, 1, 2)))
|
|
@ -1,121 +0,0 @@
|
|||
CLASSES = [
|
||||
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
||||
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
||||
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
||||
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
|
||||
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
|
||||
'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
||||
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
|
||||
'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
|
||||
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
|
||||
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
|
||||
'hair drier', 'toothbrush'
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
data_root = 'data/coco/'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='MMRandomFlip', flip_ratio=0.5),
|
||||
dict(
|
||||
type='MMAutoAugment',
|
||||
policies=[[
|
||||
dict(
|
||||
type='MMResize',
|
||||
img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
|
||||
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
|
||||
(736, 1333), (768, 1333), (800, 1333)],
|
||||
multiscale_mode='value',
|
||||
keep_ratio=True)
|
||||
],
|
||||
[
|
||||
dict(
|
||||
type='MMResize',
|
||||
img_scale=[(400, 1333), (500, 1333), (600, 1333)],
|
||||
multiscale_mode='value',
|
||||
keep_ratio=True),
|
||||
dict(
|
||||
type='MMRandomCrop',
|
||||
crop_type='absolute_range',
|
||||
crop_size=(384, 600),
|
||||
allow_negative_crop=True),
|
||||
dict(
|
||||
type='MMResize',
|
||||
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
||||
(576, 1333), (608, 1333), (640, 1333),
|
||||
(672, 1333), (704, 1333), (736, 1333),
|
||||
(768, 1333), (800, 1333)],
|
||||
multiscale_mode='value',
|
||||
override=True,
|
||||
keep_ratio=True)
|
||||
]]),
|
||||
dict(type='MMNormalize', **img_norm_cfg),
|
||||
dict(type='MMPad', size_divisor=1),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img', 'gt_bboxes', 'gt_labels'],
|
||||
meta_keys=('filename', 'ori_filename', 'ori_shape', 'ori_img_shape',
|
||||
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction', 'img_norm_cfg'))
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(
|
||||
type='MMMultiScaleFlipAug',
|
||||
img_scale=(1333, 800),
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='MMResize', keep_ratio=True),
|
||||
dict(type='MMRandomFlip'),
|
||||
dict(type='MMNormalize', **img_norm_cfg),
|
||||
dict(type='MMPad', size_divisor=1),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=('filename', 'ori_filename', 'ori_shape',
|
||||
'ori_img_shape', 'img_shape', 'pad_shape',
|
||||
'scale_factor', 'flip', 'flip_direction',
|
||||
'img_norm_cfg'))
|
||||
])
|
||||
]
|
||||
|
||||
train_dataset = dict(
|
||||
type='DetDataset',
|
||||
data_source=dict(
|
||||
type='DetSourceCoco',
|
||||
ann_file=data_root + 'annotations/instances_train2017.json',
|
||||
img_prefix=data_root + 'train2017/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations', with_bbox=True)
|
||||
],
|
||||
classes=CLASSES,
|
||||
test_mode=False,
|
||||
filter_empty_gt=True,
|
||||
iscrowd=False),
|
||||
pipeline=train_pipeline)
|
||||
|
||||
val_dataset = dict(
|
||||
type='DetDataset',
|
||||
imgs_per_gpu=1,
|
||||
data_source=dict(
|
||||
type='DetSourceCoco',
|
||||
ann_file=data_root + 'annotations/instances_val2017.json',
|
||||
img_prefix=data_root + 'val2017/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations', with_bbox=True)
|
||||
],
|
||||
classes=CLASSES,
|
||||
test_mode=True,
|
||||
filter_empty_gt=False,
|
||||
iscrowd=True),
|
||||
pipeline=test_pipeline)
|
||||
|
||||
data = dict(
|
||||
imgs_per_gpu=2, workers_per_gpu=2, train=train_dataset, val=val_dataset)
|
|
@ -1,8 +1,7 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='Detection',
|
||||
pretrained=
|
||||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/pretrained_models/easycv/resnet/torchvision/resnet50.pth',
|
||||
pretrained=True,
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
|
@ -32,14 +31,13 @@ model = dict(
|
|||
in_channels=2048,
|
||||
embed_dims=256,
|
||||
eos_coef=0.1,
|
||||
cost_dict={
|
||||
'cost_class': 1,
|
||||
'cost_bbox': 5,
|
||||
'cost_giou': 2,
|
||||
},
|
||||
weight_dict={
|
||||
'loss_ce': 1,
|
||||
'loss_bbox': 5,
|
||||
'loss_giou': 2
|
||||
},
|
||||
))
|
||||
cost_dict=dict(
|
||||
cost_class=1,
|
||||
cost_bbox=5,
|
||||
cost_giou=2,
|
||||
),
|
||||
weight_dict=dict(
|
||||
loss_ce=1,
|
||||
loss_bbox=5,
|
||||
loss_giou=2,
|
||||
)))
|
||||
|
|
|
@ -1,28 +1,8 @@
|
|||
_base_ = ['./detr.py', './coco_detection.py', 'configs/base.py']
|
||||
|
||||
CLASSES = [
|
||||
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
||||
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
||||
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
||||
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
|
||||
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
|
||||
'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
||||
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
|
||||
'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
|
||||
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
|
||||
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
|
||||
'hair drier', 'toothbrush'
|
||||
_base_ = [
|
||||
'./detr.py', '../_base_/dataset/autoaug_coco_detection.py',
|
||||
'configs/base.py'
|
||||
]
|
||||
|
||||
log_config = dict(
|
||||
interval=50,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
# dict(type='TensorboardLoggerHook')
|
||||
])
|
||||
|
||||
checkpoint_config = dict(interval=10)
|
||||
# optimizer
|
||||
paramwise_options = {'backbone': dict(lr_mult=0.1, weight_decay_mult=1.0)}
|
||||
|
@ -37,16 +17,4 @@ lr_config = dict(policy='step', step=[100])
|
|||
|
||||
total_epochs = 150
|
||||
|
||||
# evaluation
|
||||
# eval_config = dict(initial=True, interval=1, gpu_collect=False)
|
||||
eval_config = dict(interval=1, gpu_collect=False)
|
||||
eval_pipelines = [
|
||||
dict(
|
||||
mode='test',
|
||||
evaluators=[
|
||||
dict(type='CocoDetectionEvaluator', classes=CLASSES),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
find_unused_parameters = False
|
||||
|
|
|
@ -43,8 +43,7 @@ lr_config = dict(
|
|||
total_epochs = 12
|
||||
|
||||
# evaluation
|
||||
eval_config = dict(initial=True, interval=1, gpu_collect=False)
|
||||
# eval_config = dict(interval=1, gpu_collect=False)
|
||||
eval_config = dict(interval=1, gpu_collect=False)
|
||||
eval_pipelines = [
|
||||
dict(
|
||||
mode='test',
|
||||
|
|
|
@ -24,9 +24,11 @@ Pretrained on COCO2017 dataset.
|
|||
| Algorithm | Config | Params<br/>(backbone/total) | inference time(V100)<br/>(ms/img) | mAP<sup>val<br/><sub>0.5:0.95</sub> | AP<sup>val<br/><sub>50</sub> | Download |
|
||||
| ---------- | ------------------------------------------------------------ | ------------------------ | --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||
| FCOS-r50 | [fcos-r50](https://github.com/alibaba/EasyCV/tree/master/configs/detection/fcos/fcos_center-normbbox-centeronreg-giou_r50_caffe_fpn_gn-head_1x_coco.py) | 23M/32M | 85.8ms | 38.58 | 57.18 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/fcos/epoch_12.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/fcos/20220621_121315.log.json) |
|
||||
|
||||
## DETR
|
||||
|
||||
| Algorithm | Config | Params<br/>(backbone/total) | inference time(V100)<br/>(ms/img) | bbox_mAP<sup>val<br/><sub>0.5:0.95</sub> | AP<sup>val<br/><sub>50</sub> | Download |
|
||||
| ---------- | ------------------------------------------------------------ | ------------------------ | --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||
| DETR-r50 | [detr-r50](https://github.com/alibaba/EasyCV/tree/master/configs/detection/detr/detr_r50_8x2_150e_coco.py) | 23M/41M | 48.5ms | 39.92 | 60.52 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/detr/epoch_150.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/detr/20220609_101243.log.json) |
|
||||
| DAB-DETR-r50 | [dab-detr-r50](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dab_detr/dab_detr_r50_8x2_50e_coco.py) | 23M/43M | 58.5ms | 42.52 | 63.03 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dab_detr/epoch_50.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dab_detr/20220610_122811.log.json) |
|
||||
| DAB-DETR-r50 | [dab-detr-r50](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dab_detr/dab_detr_r50_8x2_50e_coco.py) | 23M/43M | 58.5ms | 42.52 | 63.03 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dab_detr/dab_detr_epoch_50.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dab_detr/20220610_122811.log.json) |
|
||||
| DN-DETR-r50 | [dab-detr-r50](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dab_detr/dn_detr_r50_8x2_50e_coco.py) | 23M/43M | 58.5ms | 44.39 | 64.66 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dn_detr/dn_detr_epoch_50.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dn_detr/20220713_105127.log.json) |
|
||||
|
|
|
@ -128,8 +128,9 @@ class Classification(BaseModel):
|
|||
strict=False,
|
||||
logger=logger)
|
||||
else:
|
||||
print_log('load model from init weights')
|
||||
self.backbone.init_weights()
|
||||
raise ValueError(
|
||||
'default_pretrained_model_path for {} not found'.format(
|
||||
self.backbone.__class__.__name__))
|
||||
else:
|
||||
print_log('load model from init weights')
|
||||
self.backbone.init_weights()
|
||||
|
|
|
@ -5,17 +5,13 @@ import math
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from easycv.models.builder import HEADS, build_neck
|
||||
from easycv.models.detection.utils import (HungarianMatcher, accuracy,
|
||||
from easycv.models.detection.utils import (HungarianMatcher, SetCriterion,
|
||||
box_cxcywh_to_xyxy,
|
||||
box_xyxy_to_cxcywh,
|
||||
generalized_box_iou,
|
||||
inverse_sigmoid)
|
||||
from easycv.models.loss.focal_loss import py_sigmoid_focal_loss
|
||||
from easycv.models.utils import (MLP, get_world_size,
|
||||
is_dist_avail_and_initialized)
|
||||
box_xyxy_to_cxcywh, inverse_sigmoid)
|
||||
from easycv.models.utils import MLP
|
||||
from .dn_components import dn_post_process, prepare_for_dn
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
|
@ -27,15 +23,17 @@ class DABDETRHead(nn.Module):
|
|||
num_classes (int): Number of categories excluding the background.
|
||||
"""
|
||||
|
||||
_version = 2
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
embed_dims,
|
||||
query_dim=4,
|
||||
iter_update=True,
|
||||
num_queries=300,
|
||||
num_select=300,
|
||||
random_refpoints_xy=False,
|
||||
num_patterns=0,
|
||||
bbox_embed_diff_each_layer=False,
|
||||
dn_components=None,
|
||||
transformer=None,
|
||||
cost_dict={
|
||||
'cost_class': 1,
|
||||
|
@ -57,7 +55,9 @@ class DABDETRHead(nn.Module):
|
|||
num_classes,
|
||||
matcher=self.matcher,
|
||||
weight_dict=weight_dict,
|
||||
losses=['labels', 'boxes', 'cardinality'])
|
||||
losses=['labels', 'boxes'],
|
||||
loss_class_type='focal_loss',
|
||||
dn_components=dn_components)
|
||||
self.postprocess = PostProcess(num_select=num_select)
|
||||
self.transformer = build_neck(transformer)
|
||||
|
||||
|
@ -71,8 +71,31 @@ class DABDETRHead(nn.Module):
|
|||
self.transformer.decoder.bbox_embed = self.bbox_embed
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.num_queries = num_queries
|
||||
self.embed_dims = embed_dims
|
||||
self.query_dim = query_dim
|
||||
self.bbox_embed_diff_each_layer = bbox_embed_diff_each_layer
|
||||
self.dn_components = dn_components
|
||||
|
||||
self.query_embed = nn.Embedding(num_queries, query_dim)
|
||||
self.random_refpoints_xy = random_refpoints_xy
|
||||
if random_refpoints_xy:
|
||||
self.query_embed.weight.data[:, :2].uniform_(0, 1)
|
||||
self.query_embed.weight.data[:, :2] = inverse_sigmoid(
|
||||
self.query_embed.weight.data[:, :2])
|
||||
self.query_embed.weight.data[:, :2].requires_grad = False
|
||||
|
||||
self.num_patterns = num_patterns
|
||||
if not isinstance(num_patterns, int):
|
||||
Warning('num_patterns should be int but {}'.format(
|
||||
type(num_patterns)))
|
||||
self.num_patterns = 0
|
||||
if self.num_patterns > 0:
|
||||
self.patterns = nn.Embedding(self.num_patterns, embed_dims)
|
||||
|
||||
if self.dn_components:
|
||||
# leave one dim for indicator
|
||||
self.label_enc = nn.Embedding(num_classes + 1, embed_dims - 1)
|
||||
|
||||
def init_weights(self):
|
||||
self.transformer.init_weights()
|
||||
|
@ -92,7 +115,45 @@ class DABDETRHead(nn.Module):
|
|||
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
|
||||
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
|
||||
|
||||
def forward(self, feats, img_metas):
|
||||
def prepare(self, feats, targets=None, mode='train'):
|
||||
bs = feats[0].shape[0]
|
||||
query_embed = self.query_embed.weight
|
||||
if self.dn_components:
|
||||
# default pipeline
|
||||
self.dn_components['num_patterns'] = self.num_patterns
|
||||
self.dn_components['targets'] = targets
|
||||
# prepare for dn
|
||||
tgt, query_embed, attn_mask, mask_dict = prepare_for_dn(
|
||||
mode, self.dn_components, query_embed, bs, self.num_queries,
|
||||
self.num_classes, self.embed_dims, self.label_enc)
|
||||
if self.num_patterns > 0:
|
||||
l = tgt.shape[0]
|
||||
tgt[l - self.num_queries * self.num_patterns:] += \
|
||||
self.patterns.weight[:, None, None, :].repeat(1, self.num_queries, bs, 1).flatten(0, 1)
|
||||
return query_embed, tgt, attn_mask, mask_dict
|
||||
else:
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
if self.num_patterns == 0:
|
||||
tgt = torch.zeros(
|
||||
self.num_queries,
|
||||
bs,
|
||||
self.embed_dims,
|
||||
device=query_embed.device)
|
||||
else:
|
||||
tgt = self.patterns.weight[:, None, None, :].repeat(
|
||||
1, self.num_queries, bs,
|
||||
1).flatten(0, 1) # n_q*n_pat, bs, d_model
|
||||
query_embed = query_embed.repeat(self.num_patterns, 1,
|
||||
1) # n_q*n_pat, bs, d_model
|
||||
return query_embed, tgt, None, None
|
||||
|
||||
def forward(self,
|
||||
feats,
|
||||
img_metas,
|
||||
query_embed=None,
|
||||
tgt=None,
|
||||
attn_mask=None,
|
||||
mask_dict=None):
|
||||
"""Forward function.
|
||||
Args:
|
||||
feats (tuple[Tensor]): Features from the upstream network, each is
|
||||
|
@ -109,7 +170,9 @@ class DABDETRHead(nn.Module):
|
|||
normalized coordinate format (cx, cy, w, h) and shape \
|
||||
[nb_dec, bs, num_query, 4].
|
||||
"""
|
||||
feats = self.transformer(feats, img_metas)
|
||||
|
||||
feats = self.transformer(
|
||||
feats, img_metas, query_embed, tgt, attn_mask=attn_mask)
|
||||
|
||||
hs, reference = feats
|
||||
outputs_class = self.class_embed(hs)
|
||||
|
@ -128,6 +191,10 @@ class DABDETRHead(nn.Module):
|
|||
outputs_coords.append(outputs_coord)
|
||||
outputs_coord = torch.stack(outputs_coords)
|
||||
|
||||
if mask_dict is not None:
|
||||
# dn post process
|
||||
outputs_class, outputs_coord = dn_post_process(
|
||||
outputs_class, outputs_coord, mask_dict)
|
||||
out = {
|
||||
'pred_logits': outputs_class[-1],
|
||||
'pred_boxes': outputs_coord[-1]
|
||||
|
@ -164,27 +231,45 @@ class DABDETRHead(nn.Module):
|
|||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
outputs = self.forward(x, img_metas)
|
||||
|
||||
# prepare ground truth
|
||||
for i in range(len(img_metas)):
|
||||
img_h, img_w, _ = img_metas[i]['img_shape']
|
||||
# DETR regress the relative position of boxes (cxcywh) in the image.
|
||||
# Thus the learning target should be normalized by the image size, also
|
||||
# the box format should be converted from defaultly x1y1x2y2 to cxcywh.
|
||||
factor = outputs['pred_boxes'].new_tensor(
|
||||
[img_w, img_h, img_w, img_h]).unsqueeze(0)
|
||||
factor = gt_bboxes[i].new_tensor([img_w, img_h, img_w,
|
||||
img_h]).unsqueeze(0)
|
||||
gt_bboxes[i] = box_xyxy_to_cxcywh(gt_bboxes[i]) / factor
|
||||
|
||||
targets = []
|
||||
for gt_label, gt_bbox in zip(gt_labels, gt_bboxes):
|
||||
targets.append({'labels': gt_label, 'boxes': gt_bbox})
|
||||
|
||||
losses = self.criterion(outputs, targets)
|
||||
query_embed, tgt, attn_mask, mask_dict = self.prepare(
|
||||
x, targets=targets, mode='train')
|
||||
|
||||
outputs = self.forward(
|
||||
x,
|
||||
img_metas,
|
||||
query_embed=query_embed,
|
||||
tgt=tgt,
|
||||
attn_mask=attn_mask,
|
||||
mask_dict=mask_dict)
|
||||
|
||||
losses = self.criterion(outputs, targets, mask_dict)
|
||||
|
||||
return losses
|
||||
|
||||
def forward_test(self, x, img_metas):
|
||||
outputs = self.forward(x, img_metas)
|
||||
query_embed, tgt, attn_mask, mask_dict = self.prepare(x, mode='test')
|
||||
|
||||
outputs = self.forward(
|
||||
x,
|
||||
img_metas,
|
||||
query_embed=query_embed,
|
||||
tgt=tgt,
|
||||
attn_mask=attn_mask,
|
||||
mask_dict=mask_dict)
|
||||
|
||||
ori_shape_list = []
|
||||
for i in range(len(img_metas)):
|
||||
|
@ -242,192 +327,3 @@ class PostProcess(nn.Module):
|
|||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class SetCriterion(nn.Module):
|
||||
""" This class computes the loss for Conditional DETR.
|
||||
The process happens in two steps:
|
||||
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
||||
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes, matcher, weight_dict, losses):
|
||||
""" Create the criterion.
|
||||
Parameters:
|
||||
num_classes: number of object categories, omitting the special no-object category
|
||||
matcher: module able to compute a matching between targets and proposals
|
||||
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
||||
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.matcher = matcher
|
||||
self.weight_dict = weight_dict
|
||||
self.losses = losses
|
||||
|
||||
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
|
||||
"""Classification loss (Binary focal loss)
|
||||
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
||||
"""
|
||||
assert 'pred_logits' in outputs
|
||||
src_logits = outputs['pred_logits']
|
||||
|
||||
idx = self._get_src_permutation_idx(indices)
|
||||
target_classes_o = torch.cat(
|
||||
[t['labels'][J] for t, (_, J) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
src_logits.shape[:2],
|
||||
self.num_classes,
|
||||
dtype=torch.int64,
|
||||
device=src_logits.device)
|
||||
target_classes[idx] = target_classes_o
|
||||
|
||||
target_classes_onehot = torch.zeros([
|
||||
src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1
|
||||
],
|
||||
dtype=src_logits.dtype,
|
||||
layout=src_logits.layout,
|
||||
device=src_logits.device)
|
||||
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
|
||||
|
||||
target_classes_onehot = target_classes_onehot[:, :, :-1]
|
||||
loss_ce = py_sigmoid_focal_loss(
|
||||
src_logits,
|
||||
target_classes_onehot.long(),
|
||||
alpha=0.25,
|
||||
gamma=2,
|
||||
reduction='none').mean(1).sum() / num_boxes
|
||||
loss_ce = loss_ce * src_logits.shape[1] * self.weight_dict['loss_ce']
|
||||
losses = {'loss_ce': loss_ce}
|
||||
|
||||
if log:
|
||||
# TODO this should probably be a separate loss, not hacked in this one here
|
||||
losses['class_error'] = 100 - accuracy(src_logits[idx],
|
||||
target_classes_o)[0]
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
||||
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
|
||||
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
|
||||
"""
|
||||
pred_logits = outputs['pred_logits']
|
||||
device = pred_logits.device
|
||||
tgt_lengths = torch.as_tensor([len(v['labels']) for v in targets],
|
||||
device=device)
|
||||
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
||||
card_pred = (pred_logits.argmax(-1) !=
|
||||
pred_logits.shape[-1] - 1).sum(1)
|
||||
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
|
||||
losses = {'cardinality_error': card_err}
|
||||
return losses
|
||||
|
||||
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
||||
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
|
||||
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
|
||||
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
|
||||
"""
|
||||
assert 'pred_boxes' in outputs
|
||||
idx = self._get_src_permutation_idx(indices)
|
||||
src_boxes = outputs['pred_boxes'][idx]
|
||||
target_boxes = torch.cat(
|
||||
[t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||||
|
||||
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
|
||||
|
||||
losses = {}
|
||||
losses['loss_bbox'] = loss_bbox.sum(
|
||||
) / num_boxes * self.weight_dict['loss_bbox']
|
||||
|
||||
loss_giou = 1 - torch.diag(
|
||||
generalized_box_iou(
|
||||
box_cxcywh_to_xyxy(src_boxes),
|
||||
box_cxcywh_to_xyxy(target_boxes)))
|
||||
losses['loss_giou'] = loss_giou.sum(
|
||||
) / num_boxes * self.weight_dict['loss_giou']
|
||||
|
||||
return losses
|
||||
|
||||
def _get_src_permutation_idx(self, indices):
|
||||
# permute predictions following indices
|
||||
batch_idx = torch.cat(
|
||||
[torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
||||
src_idx = torch.cat([src for (src, _) in indices])
|
||||
return batch_idx, src_idx
|
||||
|
||||
def _get_tgt_permutation_idx(self, indices):
|
||||
# permute targets following indices
|
||||
batch_idx = torch.cat(
|
||||
[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
||||
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
||||
return batch_idx, tgt_idx
|
||||
|
||||
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
|
||||
loss_map = {
|
||||
'labels': self.loss_labels,
|
||||
'cardinality': self.loss_cardinality,
|
||||
'boxes': self.loss_boxes,
|
||||
}
|
||||
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
||||
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
|
||||
|
||||
def forward(self, outputs, targets, return_indices=False):
|
||||
""" This performs the loss computation.
|
||||
Parameters:
|
||||
outputs: dict of tensors, see the output specification of the model for the format
|
||||
targets: list of dicts, such that len(targets) == batch_size.
|
||||
The expected keys in each dict depends on the losses applied, see each loss' doc
|
||||
|
||||
return_indices: used for vis. if True, the layer0-5 indices will be returned as well.
|
||||
"""
|
||||
|
||||
outputs_without_aux = {
|
||||
k: v
|
||||
for k, v in outputs.items() if k != 'aux_outputs'
|
||||
}
|
||||
|
||||
# Retrieve the matching between the outputs of the last layer and the targets
|
||||
indices = self.matcher(outputs_without_aux, targets)
|
||||
if return_indices:
|
||||
indices0_copy = indices
|
||||
indices_list = []
|
||||
|
||||
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
||||
num_boxes = sum(len(t['labels']) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes],
|
||||
dtype=torch.float,
|
||||
device=next(iter(outputs.values())).device)
|
||||
if is_dist_avail_and_initialized():
|
||||
torch.distributed.all_reduce(num_boxes)
|
||||
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
losses = {}
|
||||
for loss in self.losses:
|
||||
losses.update(
|
||||
self.get_loss(loss, outputs, targets, indices, num_boxes))
|
||||
|
||||
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
||||
if 'aux_outputs' in outputs:
|
||||
for i, aux_outputs in enumerate(outputs['aux_outputs']):
|
||||
indices = self.matcher(aux_outputs, targets)
|
||||
if return_indices:
|
||||
indices_list.append(indices)
|
||||
for loss in self.losses:
|
||||
if loss == 'masks':
|
||||
# Intermediate masks losses are too costly to compute, we ignore them.
|
||||
continue
|
||||
kwargs = {}
|
||||
if loss == 'labels':
|
||||
# Logging is enabled only for the last layer
|
||||
kwargs = {'log': False}
|
||||
l_dict = self.get_loss(loss, aux_outputs, targets, indices,
|
||||
num_boxes, **kwargs)
|
||||
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
|
||||
if return_indices:
|
||||
indices_list.append(indices0_copy)
|
||||
return losses, indices_list
|
||||
|
||||
return losses
|
||||
|
|
|
@ -33,10 +33,7 @@ class DABDetrTransformer(nn.Module):
|
|||
|
||||
def __init__(self,
|
||||
in_channels=1024,
|
||||
num_queries=300,
|
||||
query_dim=4,
|
||||
random_refpoints_xy=False,
|
||||
num_patterns=0,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
num_encoder_layers=6,
|
||||
|
@ -58,27 +55,11 @@ class DABDetrTransformer(nn.Module):
|
|||
]
|
||||
|
||||
self.input_proj = nn.Conv2d(in_channels, d_model, kernel_size=1)
|
||||
self.query_embed = nn.Embedding(num_queries, query_dim)
|
||||
self.positional_encoding = PositionEmbeddingSineHW(
|
||||
d_model // 2,
|
||||
temperatureH=temperatureH,
|
||||
temperatureW=temperatureW,
|
||||
normalize=True)
|
||||
self.random_refpoints_xy = random_refpoints_xy
|
||||
if random_refpoints_xy:
|
||||
self.query_embed.weight.data[:, :2].uniform_(0, 1)
|
||||
self.query_embed.weight.data[:, :2] = inverse_sigmoid(
|
||||
self.query_embed.weight.data[:, :2])
|
||||
self.query_embed.weight.data[:, :2].requires_grad = False
|
||||
|
||||
self.num_queries = num_queries
|
||||
self.num_patterns = num_patterns
|
||||
if not isinstance(num_patterns, int):
|
||||
Warning('num_patterns should be int but {}'.format(
|
||||
type(num_patterns)))
|
||||
self.num_patterns = 0
|
||||
if self.num_patterns > 0:
|
||||
self.patterns = nn.Embedding(self.num_patterns, d_model)
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(d_model, nhead,
|
||||
dim_feedforward, dropout,
|
||||
|
@ -116,14 +97,13 @@ class DABDetrTransformer(nn.Module):
|
|||
|
||||
def init_weights(self):
|
||||
for p in self.named_parameters():
|
||||
if 'input_proj' in p[0] or 'query_embed' in p[
|
||||
0] or 'positional_encoding' in p[0] or 'patterns' in p[
|
||||
0] or 'bbox_embed' in p[0]:
|
||||
if 'input_proj' in p[0] or 'positional_encoding' in p[
|
||||
0] or 'bbox_embed' in p[0]:
|
||||
continue
|
||||
if p[1].dim() > 1:
|
||||
nn.init.xavier_uniform_(p[1])
|
||||
|
||||
def forward(self, src, img_metas):
|
||||
def forward(self, src, img_metas, query_embed, tgt, attn_mask=None):
|
||||
src = src[0]
|
||||
|
||||
# construct binary masks which used for the transformer.
|
||||
|
@ -143,30 +123,18 @@ class DABDetrTransformer(nn.Module):
|
|||
# position encoding
|
||||
pos_embed = self.positional_encoding(mask) # [bs, embed_dim, h, w]
|
||||
# outs_dec: [nb_dec, bs, num_query, embed_dim]
|
||||
query_embed = self.query_embed.weight
|
||||
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
src = src.flatten(2).permute(2, 0, 1)
|
||||
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
mask = mask.flatten(1)
|
||||
|
||||
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
||||
|
||||
num_queries = query_embed.shape[0]
|
||||
if self.num_patterns == 0:
|
||||
tgt = torch.zeros(
|
||||
num_queries, bs, self.d_model, device=query_embed.device)
|
||||
else:
|
||||
tgt = self.patterns.weight[:, None, None, :].repeat(
|
||||
1, self.num_queries, bs,
|
||||
1).flatten(0, 1) # n_q*n_pat, bs, d_model
|
||||
query_embed = query_embed.repeat(self.num_patterns, 1,
|
||||
1) # n_q*n_pat, bs, d_model
|
||||
|
||||
hs, references = self.decoder(
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask=attn_mask,
|
||||
memory_key_padding_mask=mask,
|
||||
pos=pos_embed,
|
||||
refpoints_unsigmoid=query_embed)
|
||||
|
|
|
@ -0,0 +1,167 @@
|
|||
# ------------------------------------------------------------------------
|
||||
# DN-DETR
|
||||
# Copyright (c) 2022 IDEA. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
import torch
|
||||
|
||||
from easycv.models.detection.utils import inverse_sigmoid
|
||||
|
||||
|
||||
def prepare_for_dn(mode, dn_args, embedweight, batch_size, num_queries,
|
||||
num_classes, hidden_dim, label_enc):
|
||||
"""
|
||||
prepare for dn components in forward function
|
||||
Args:
|
||||
dn_args: (targets, args.scalar, args.label_noise_scale, args.box_noise_scale, args.num_patterns) from engine input
|
||||
embedweight: positional queries as anchor
|
||||
training: whether it is training or inference
|
||||
num_queries: number of queries
|
||||
num_classes: number of classes
|
||||
hidden_dim: transformer hidden dimenstion
|
||||
label_enc: label encoding embedding
|
||||
|
||||
Returns: input_query_label, input_query_bbox, attn_mask, mask_dict
|
||||
"""
|
||||
if mode == 'train':
|
||||
targets, scalar, label_noise_scale, box_noise_scale, num_patterns = dn_args[
|
||||
'targets'], dn_args['scalar'], dn_args[
|
||||
'label_noise_scale'], dn_args['box_noise_scale'], dn_args[
|
||||
'num_patterns']
|
||||
else:
|
||||
num_patterns = dn_args['num_patterns']
|
||||
|
||||
if num_patterns == 0:
|
||||
num_patterns = 1
|
||||
indicator0 = torch.zeros([num_queries * num_patterns, 1]).cuda()
|
||||
tgt = label_enc(torch.tensor(num_classes).cuda()).repeat(
|
||||
num_queries * num_patterns, 1)
|
||||
tgt = torch.cat([tgt, indicator0], dim=1)
|
||||
refpoint_emb = embedweight.repeat(num_patterns, 1)
|
||||
if mode == 'train':
|
||||
known = [(torch.ones_like(t['labels'])).cuda() for t in targets]
|
||||
know_idx = [torch.nonzero(t) for t in known]
|
||||
known_num = [sum(k) for k in known]
|
||||
# you can uncomment this to use fix number of dn queries
|
||||
# if int(max(known_num))>0:
|
||||
# scalar=scalar//int(max(known_num))
|
||||
|
||||
# can be modified to selectively denosie some label or boxes; also known label prediction
|
||||
unmask_bbox = unmask_label = torch.cat(known)
|
||||
labels = torch.cat([t['labels'] for t in targets])
|
||||
boxes = torch.cat([t['boxes'] for t in targets])
|
||||
batch_idx = torch.cat([
|
||||
torch.full_like(t['labels'].long(), i)
|
||||
for i, t in enumerate(targets)
|
||||
])
|
||||
|
||||
known_indice = torch.nonzero(unmask_label + unmask_bbox)
|
||||
known_indice = known_indice.view(-1)
|
||||
|
||||
# add noise
|
||||
known_indice = known_indice.repeat(scalar, 1).view(-1)
|
||||
known_labels = labels.repeat(scalar, 1).view(-1)
|
||||
known_bid = batch_idx.repeat(scalar, 1).view(-1)
|
||||
known_bboxs = boxes.repeat(scalar, 1)
|
||||
known_labels_expaned = known_labels.clone()
|
||||
known_bbox_expand = known_bboxs.clone()
|
||||
|
||||
# noise on the label
|
||||
if label_noise_scale > 0:
|
||||
p = torch.rand_like(known_labels_expaned.float())
|
||||
chosen_indice = torch.nonzero(p < (label_noise_scale)).view(
|
||||
-1) # usually half of bbox noise
|
||||
new_label = torch.randint_like(
|
||||
chosen_indice, 0, num_classes) # randomly put a new one here
|
||||
known_labels_expaned.scatter_(0, chosen_indice, new_label)
|
||||
# noise on the box
|
||||
if box_noise_scale > 0:
|
||||
diff = torch.zeros_like(known_bbox_expand)
|
||||
diff[:, :2] = known_bbox_expand[:, 2:] / 2
|
||||
diff[:, 2:] = known_bbox_expand[:, 2:]
|
||||
known_bbox_expand += torch.mul(
|
||||
(torch.rand_like(known_bbox_expand) * 2 - 1.0),
|
||||
diff).cuda() * box_noise_scale
|
||||
known_bbox_expand = known_bbox_expand.clamp(min=0.0, max=1.0)
|
||||
|
||||
m = known_labels_expaned.long().to('cuda')
|
||||
input_label_embed = label_enc(m)
|
||||
# add dn part indicator
|
||||
indicator1 = torch.ones([input_label_embed.shape[0], 1]).cuda()
|
||||
input_label_embed = torch.cat([input_label_embed, indicator1], dim=1)
|
||||
input_bbox_embed = inverse_sigmoid(known_bbox_expand)
|
||||
single_pad = int(max(known_num))
|
||||
pad_size = int(single_pad * scalar)
|
||||
padding_label = torch.zeros(pad_size, hidden_dim).cuda()
|
||||
padding_bbox = torch.zeros(pad_size, 4).cuda()
|
||||
input_query_label = torch.cat([padding_label, tgt],
|
||||
dim=0).repeat(batch_size, 1, 1)
|
||||
input_query_bbox = torch.cat([padding_bbox, refpoint_emb],
|
||||
dim=0).repeat(batch_size, 1, 1)
|
||||
|
||||
# map in order
|
||||
map_known_indice = torch.tensor([]).to('cuda')
|
||||
if len(known_num):
|
||||
map_known_indice = torch.cat([
|
||||
torch.tensor(range(num)) for num in known_num
|
||||
]) # [1,2, 1,2,3]
|
||||
map_known_indice = torch.cat([
|
||||
map_known_indice + single_pad * i for i in range(scalar)
|
||||
]).long()
|
||||
if len(known_bid):
|
||||
input_query_label[(known_bid.long(),
|
||||
map_known_indice)] = input_label_embed
|
||||
input_query_bbox[(known_bid.long(),
|
||||
map_known_indice)] = input_bbox_embed
|
||||
|
||||
tgt_size = pad_size + num_queries * num_patterns
|
||||
attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0
|
||||
# match query cannot see the reconstruct
|
||||
attn_mask[pad_size:, :pad_size] = True
|
||||
# reconstruct cannot see each other
|
||||
for i in range(scalar):
|
||||
if i == 0:
|
||||
attn_mask[single_pad * i:single_pad * (i + 1),
|
||||
single_pad * (i + 1):pad_size] = True
|
||||
if i == scalar - 1:
|
||||
attn_mask[single_pad * i:single_pad * (i + 1), :single_pad *
|
||||
i] = True
|
||||
else:
|
||||
attn_mask[single_pad * i:single_pad * (i + 1),
|
||||
single_pad * (i + 1):pad_size] = True
|
||||
attn_mask[single_pad * i:single_pad * (i + 1), :single_pad *
|
||||
i] = True
|
||||
mask_dict = {
|
||||
'known_indice': torch.as_tensor(known_indice).long(),
|
||||
'batch_idx': torch.as_tensor(batch_idx).long(),
|
||||
'map_known_indice': torch.as_tensor(map_known_indice).long(),
|
||||
'known_lbs_bboxes': (known_labels, known_bboxs),
|
||||
'know_idx': know_idx,
|
||||
'pad_size': pad_size
|
||||
}
|
||||
else: # no dn for inference
|
||||
input_query_label = tgt.repeat(batch_size, 1, 1)
|
||||
input_query_bbox = refpoint_emb.repeat(batch_size, 1, 1)
|
||||
attn_mask = None
|
||||
mask_dict = None
|
||||
|
||||
input_query_label = input_query_label.transpose(0, 1)
|
||||
input_query_bbox = input_query_bbox.transpose(0, 1)
|
||||
|
||||
return input_query_label, input_query_bbox, attn_mask, mask_dict
|
||||
|
||||
|
||||
def dn_post_process(outputs_class, outputs_coord, mask_dict):
|
||||
"""
|
||||
post process of dn after output from the transformer
|
||||
put the dn part in the mask_dict
|
||||
"""
|
||||
if mask_dict and mask_dict['pad_size'] > 0:
|
||||
output_known_class = outputs_class[:, :, :mask_dict['pad_size'], :]
|
||||
output_known_coord = outputs_coord[:, :, :mask_dict['pad_size'], :]
|
||||
outputs_class = outputs_class[:, :, mask_dict['pad_size']:, :]
|
||||
outputs_coord = outputs_coord[:, :, mask_dict['pad_size']:, :]
|
||||
mask_dict['output_known_lbs_bboxes'] = (output_known_class,
|
||||
output_known_coord)
|
||||
return outputs_class, outputs_coord
|
|
@ -44,8 +44,9 @@ class Detection(BaseModel):
|
|||
strict=False,
|
||||
logger=logger)
|
||||
else:
|
||||
print_log('load model from init weights')
|
||||
self.backbone.init_weights()
|
||||
raise ValueError(
|
||||
'default_pretrained_model_path for {} not found'.format(
|
||||
self.backbone.__class__.__name__))
|
||||
else:
|
||||
print_log('load model from init weights')
|
||||
self.backbone.init_weights()
|
||||
|
|
|
@ -6,12 +6,10 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from easycv.models.builder import HEADS, build_neck
|
||||
from easycv.models.detection.utils import (HungarianMatcher, accuracy,
|
||||
from easycv.models.detection.utils import (HungarianMatcher, SetCriterion,
|
||||
box_cxcywh_to_xyxy,
|
||||
box_xyxy_to_cxcywh,
|
||||
generalized_box_iou)
|
||||
from easycv.models.utils import (MLP, get_world_size,
|
||||
is_dist_avail_and_initialized)
|
||||
box_xyxy_to_cxcywh)
|
||||
from easycv.models.utils import MLP
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
|
@ -50,7 +48,7 @@ class DETRHead(nn.Module):
|
|||
matcher=self.matcher,
|
||||
weight_dict=weight_dict,
|
||||
eos_coef=eos_coef,
|
||||
losses=['labels', 'boxes', 'cardinality'])
|
||||
losses=['labels', 'boxes'])
|
||||
self.postprocess = PostProcess()
|
||||
self.transformer = build_neck(transformer)
|
||||
|
||||
|
@ -120,21 +118,22 @@ class DETRHead(nn.Module):
|
|||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
outputs = self.forward(x, img_metas)
|
||||
|
||||
# prepare ground truth
|
||||
for i in range(len(img_metas)):
|
||||
img_h, img_w, _ = img_metas[i]['img_shape']
|
||||
# DETR regress the relative position of boxes (cxcywh) in the image.
|
||||
# Thus the learning target should be normalized by the image size, also
|
||||
# the box format should be converted from defaultly x1y1x2y2 to cxcywh.
|
||||
factor = outputs['pred_boxes'].new_tensor(
|
||||
[img_w, img_h, img_w, img_h]).unsqueeze(0)
|
||||
factor = gt_bboxes[i].new_tensor([img_w, img_h, img_w,
|
||||
img_h]).unsqueeze(0)
|
||||
gt_bboxes[i] = box_xyxy_to_cxcywh(gt_bboxes[i]) / factor
|
||||
|
||||
targets = []
|
||||
for gt_label, gt_bbox in zip(gt_labels, gt_bboxes):
|
||||
targets.append({'labels': gt_label, 'boxes': gt_bbox})
|
||||
|
||||
outputs = self.forward(x, img_metas)
|
||||
|
||||
losses = self.criterion(outputs, targets)
|
||||
|
||||
return losses
|
||||
|
@ -188,171 +187,3 @@ class PostProcess(nn.Module):
|
|||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class SetCriterion(nn.Module):
|
||||
""" This class computes the loss for DETR.
|
||||
The process happens in two steps:
|
||||
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
||||
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
|
||||
""" Create the criterion.
|
||||
Parameters:
|
||||
num_classes: number of object categories, omitting the special no-object category
|
||||
matcher: module able to compute a matching between targets and proposals
|
||||
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
||||
eos_coef: relative classification weight applied to the no-object category
|
||||
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.matcher = matcher
|
||||
self.weight_dict = weight_dict
|
||||
self.eos_coef = eos_coef
|
||||
self.losses = losses
|
||||
empty_weight = torch.ones(self.num_classes + 1)
|
||||
empty_weight[-1] = self.eos_coef
|
||||
self.register_buffer('empty_weight', empty_weight)
|
||||
|
||||
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
|
||||
"""Classification loss (NLL)
|
||||
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
||||
"""
|
||||
assert 'pred_logits' in outputs
|
||||
src_logits = outputs['pred_logits']
|
||||
|
||||
idx = self._get_src_permutation_idx(indices)
|
||||
target_classes_o = torch.cat(
|
||||
[t['labels'][J] for t, (_, J) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
src_logits.shape[:2],
|
||||
self.num_classes,
|
||||
dtype=torch.int64,
|
||||
device=src_logits.device)
|
||||
target_classes[idx] = target_classes_o
|
||||
|
||||
loss_ce = F.cross_entropy(
|
||||
src_logits.transpose(1, 2), target_classes,
|
||||
self.empty_weight) * self.weight_dict['loss_ce']
|
||||
losses = {'loss_ce': loss_ce}
|
||||
|
||||
if log:
|
||||
# TODO this should probably be a separate loss, not hacked in this one here
|
||||
losses['class_error'] = 100 - accuracy(src_logits[idx],
|
||||
target_classes_o)[0]
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
||||
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
|
||||
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
|
||||
"""
|
||||
pred_logits = outputs['pred_logits']
|
||||
device = pred_logits.device
|
||||
tgt_lengths = torch.as_tensor([len(v['labels']) for v in targets],
|
||||
device=device)
|
||||
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
||||
card_pred = (pred_logits.argmax(-1) !=
|
||||
pred_logits.shape[-1] - 1).sum(1)
|
||||
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
|
||||
losses = {'cardinality_error': card_err}
|
||||
return losses
|
||||
|
||||
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
||||
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
|
||||
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
|
||||
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
|
||||
"""
|
||||
assert 'pred_boxes' in outputs
|
||||
idx = self._get_src_permutation_idx(indices)
|
||||
src_boxes = outputs['pred_boxes'][idx]
|
||||
target_boxes = torch.cat(
|
||||
[t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||||
|
||||
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
|
||||
|
||||
losses = {}
|
||||
losses['loss_bbox'] = loss_bbox.sum(
|
||||
) / num_boxes * self.weight_dict['loss_bbox']
|
||||
|
||||
loss_giou = 1 - torch.diag(
|
||||
generalized_box_iou(
|
||||
box_cxcywh_to_xyxy(src_boxes),
|
||||
box_cxcywh_to_xyxy(target_boxes)))
|
||||
losses['loss_giou'] = loss_giou.sum(
|
||||
) / num_boxes * self.weight_dict['loss_giou']
|
||||
return losses
|
||||
|
||||
def _get_src_permutation_idx(self, indices):
|
||||
# permute predictions following indices
|
||||
batch_idx = torch.cat(
|
||||
[torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
||||
src_idx = torch.cat([src for (src, _) in indices])
|
||||
return batch_idx, src_idx
|
||||
|
||||
def _get_tgt_permutation_idx(self, indices):
|
||||
# permute targets following indices
|
||||
batch_idx = torch.cat(
|
||||
[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
||||
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
||||
return batch_idx, tgt_idx
|
||||
|
||||
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
|
||||
loss_map = {
|
||||
'labels': self.loss_labels,
|
||||
'cardinality': self.loss_cardinality,
|
||||
'boxes': self.loss_boxes
|
||||
}
|
||||
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
||||
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
""" This performs the loss computation.
|
||||
Parameters:
|
||||
outputs: dict of tensors, see the output specification of the model for the format
|
||||
targets: list of dicts, such that len(targets) == batch_size.
|
||||
The expected keys in each dict depends on the losses applied, see each loss' doc
|
||||
"""
|
||||
outputs_without_aux = {
|
||||
k: v
|
||||
for k, v in outputs.items() if k != 'aux_outputs'
|
||||
}
|
||||
|
||||
# Retrieve the matching between the outputs of the last layer and the targets
|
||||
indices = self.matcher(outputs_without_aux, targets)
|
||||
|
||||
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
||||
num_boxes = sum(len(t['labels']) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes],
|
||||
dtype=torch.float,
|
||||
device=next(iter(outputs.values())).device)
|
||||
if is_dist_avail_and_initialized():
|
||||
torch.distributed.all_reduce(num_boxes)
|
||||
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
losses = {}
|
||||
for loss in self.losses:
|
||||
losses.update(
|
||||
self.get_loss(loss, outputs, targets, indices, num_boxes))
|
||||
|
||||
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
||||
if 'aux_outputs' in outputs:
|
||||
for i, aux_outputs in enumerate(outputs['aux_outputs']):
|
||||
indices = self.matcher(aux_outputs, targets)
|
||||
for loss in self.losses:
|
||||
if loss == 'masks':
|
||||
# Intermediate masks losses are too costly to compute, we ignore them.
|
||||
continue
|
||||
kwargs = {}
|
||||
if loss == 'labels':
|
||||
# Logging is enabled only for the last layer
|
||||
kwargs = {'log': False}
|
||||
l_dict = self.get_loss(loss, aux_outputs, targets, indices,
|
||||
num_boxes, **kwargs)
|
||||
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
|
||||
return losses
|
||||
|
|
|
@ -7,3 +7,4 @@ from .generator import MlvlPointGenerator
|
|||
from .matcher import HungarianMatcher
|
||||
from .misc import (accuracy, filter_scores_and_topk, fp16_clamp, interpolate,
|
||||
inverse_sigmoid, output_postprocess, select_single_mlvl)
|
||||
from .set_criterion import SetCriterion
|
||||
|
|
|
@ -13,7 +13,7 @@ class HungarianMatcher(nn.Module):
|
|||
while the others are un-matched (and thus treated as non-objects).
|
||||
"""
|
||||
|
||||
def __init__(self, cost_dict, cost_class_type=None):
|
||||
def __init__(self, cost_dict, cost_class_type='ce_cost'):
|
||||
"""Creates the matcher
|
||||
Params:
|
||||
cost_class: This is the relative weight of the classification error in the matching cost
|
||||
|
@ -51,7 +51,7 @@ class HungarianMatcher(nn.Module):
|
|||
if self.cost_class_type == 'focal_loss_cost':
|
||||
out_prob = outputs['pred_logits'].flatten(
|
||||
0, 1).sigmoid() # [batch_size * num_queries, num_classes]
|
||||
else:
|
||||
elif self.cost_class_type == 'ce_cost':
|
||||
out_prob = outputs['pred_logits'].flatten(0, 1).softmax(
|
||||
-1) # [batch_size * num_queries, num_classes]
|
||||
|
||||
|
@ -72,7 +72,7 @@ class HungarianMatcher(nn.Module):
|
|||
pos_cost_class = pos_cost_class * (-(out_prob + 1e-8).log())
|
||||
cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:,
|
||||
tgt_ids]
|
||||
else:
|
||||
elif self.cost_class_type == 'ce_cost':
|
||||
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
||||
# but approximate it in 1 - proba[target class].
|
||||
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
||||
|
|
|
@ -0,0 +1,388 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from easycv.models.detection.utils import (accuracy, box_cxcywh_to_xyxy,
|
||||
generalized_box_iou)
|
||||
from easycv.models.loss.focal_loss import py_sigmoid_focal_loss
|
||||
from easycv.models.utils import get_world_size, is_dist_avail_and_initialized
|
||||
|
||||
|
||||
class SetCriterion(nn.Module):
|
||||
""" This class computes the loss for Conditional DETR.
|
||||
The process happens in two steps:
|
||||
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
||||
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes,
|
||||
matcher,
|
||||
weight_dict,
|
||||
losses,
|
||||
eos_coef=None,
|
||||
loss_class_type='ce',
|
||||
dn_components=None):
|
||||
""" Create the criterion.
|
||||
Parameters:
|
||||
num_classes: number of object categories, omitting the special no-object category
|
||||
matcher: module able to compute a matching between targets and proposals
|
||||
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
||||
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.matcher = matcher
|
||||
self.weight_dict = weight_dict
|
||||
self.losses = losses
|
||||
self.loss_class_type = loss_class_type
|
||||
if self.loss_class_type == 'ce':
|
||||
empty_weight = torch.ones(self.num_classes + 1)
|
||||
empty_weight[-1] = eos_coef
|
||||
self.register_buffer('empty_weight', empty_weight)
|
||||
if dn_components is not None:
|
||||
self.dn_criterion = DNCriterion(self.weight_dict)
|
||||
|
||||
def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
|
||||
"""Classification loss (Binary focal loss)
|
||||
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
||||
"""
|
||||
assert 'pred_logits' in outputs
|
||||
src_logits = outputs['pred_logits']
|
||||
|
||||
idx = self._get_src_permutation_idx(indices)
|
||||
target_classes_o = torch.cat(
|
||||
[t['labels'][J] for t, (_, J) in zip(targets, indices)])
|
||||
target_classes = torch.full(
|
||||
src_logits.shape[:2],
|
||||
self.num_classes,
|
||||
dtype=torch.int64,
|
||||
device=src_logits.device)
|
||||
target_classes[idx] = target_classes_o
|
||||
|
||||
if self.loss_class_type == 'ce':
|
||||
loss_ce = F.cross_entropy(
|
||||
src_logits.transpose(1, 2), target_classes,
|
||||
self.empty_weight) * self.weight_dict['loss_ce']
|
||||
elif self.loss_class_type == 'focal_loss':
|
||||
target_classes_onehot = torch.zeros([
|
||||
src_logits.shape[0], src_logits.shape[1],
|
||||
src_logits.shape[2] + 1
|
||||
],
|
||||
dtype=src_logits.dtype,
|
||||
layout=src_logits.layout,
|
||||
device=src_logits.device)
|
||||
target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
|
||||
target_classes_onehot = target_classes_onehot[:, :, :-1]
|
||||
|
||||
loss_ce = py_sigmoid_focal_loss(
|
||||
src_logits,
|
||||
target_classes_onehot.long(),
|
||||
alpha=0.25,
|
||||
gamma=2,
|
||||
reduction='none').mean(1).sum() / num_boxes
|
||||
loss_ce = loss_ce * src_logits.shape[1] * self.weight_dict[
|
||||
'loss_ce']
|
||||
losses = {'loss_ce': loss_ce}
|
||||
|
||||
if log:
|
||||
# TODO this should probably be a separate loss, not hacked in this one here
|
||||
losses['class_error'] = 100 - accuracy(src_logits[idx],
|
||||
target_classes_o)[0]
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
||||
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
|
||||
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
|
||||
"""
|
||||
pred_logits = outputs['pred_logits']
|
||||
device = pred_logits.device
|
||||
tgt_lengths = torch.as_tensor([len(v['labels']) for v in targets],
|
||||
device=device)
|
||||
# Count the number of predictions that are NOT "no-object" (which is the last class)
|
||||
card_pred = (pred_logits.argmax(-1) !=
|
||||
pred_logits.shape[-1] - 1).sum(1)
|
||||
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
|
||||
losses = {'cardinality_error': card_err}
|
||||
return losses
|
||||
|
||||
def loss_boxes(self, outputs, targets, indices, num_boxes):
|
||||
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
|
||||
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
|
||||
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
|
||||
"""
|
||||
assert 'pred_boxes' in outputs
|
||||
idx = self._get_src_permutation_idx(indices)
|
||||
src_boxes = outputs['pred_boxes'][idx]
|
||||
target_boxes = torch.cat(
|
||||
[t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||||
|
||||
loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
|
||||
|
||||
losses = {}
|
||||
losses['loss_bbox'] = loss_bbox.sum(
|
||||
) / num_boxes * self.weight_dict['loss_bbox']
|
||||
|
||||
loss_giou = 1 - torch.diag(
|
||||
generalized_box_iou(
|
||||
box_cxcywh_to_xyxy(src_boxes),
|
||||
box_cxcywh_to_xyxy(target_boxes)))
|
||||
losses['loss_giou'] = loss_giou.sum(
|
||||
) / num_boxes * self.weight_dict['loss_giou']
|
||||
|
||||
return losses
|
||||
|
||||
def _get_src_permutation_idx(self, indices):
|
||||
# permute predictions following indices
|
||||
batch_idx = torch.cat(
|
||||
[torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
||||
src_idx = torch.cat([src for (src, _) in indices])
|
||||
return batch_idx, src_idx
|
||||
|
||||
def _get_tgt_permutation_idx(self, indices):
|
||||
# permute targets following indices
|
||||
batch_idx = torch.cat(
|
||||
[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
||||
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
||||
return batch_idx, tgt_idx
|
||||
|
||||
def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
|
||||
loss_map = {
|
||||
'labels': self.loss_labels,
|
||||
'cardinality': self.loss_cardinality,
|
||||
'boxes': self.loss_boxes,
|
||||
}
|
||||
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
||||
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
|
||||
|
||||
def forward(self, outputs, targets, mask_dict=None, return_indices=False):
|
||||
""" This performs the loss computation.
|
||||
Parameters:
|
||||
outputs: dict of tensors, see the output specification of the model for the format
|
||||
targets: list of dicts, such that len(targets) == batch_size.
|
||||
The expected keys in each dict depends on the losses applied, see each loss' doc
|
||||
|
||||
return_indices: used for vis. if True, the layer0-5 indices will be returned as well.
|
||||
"""
|
||||
|
||||
outputs_without_aux = {
|
||||
k: v
|
||||
for k, v in outputs.items() if k != 'aux_outputs'
|
||||
}
|
||||
|
||||
# Retrieve the matching between the outputs of the last layer and the targets
|
||||
indices = self.matcher(outputs_without_aux, targets)
|
||||
if return_indices:
|
||||
indices0_copy = indices
|
||||
indices_list = []
|
||||
|
||||
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
||||
num_boxes = sum(len(t['labels']) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes],
|
||||
dtype=torch.float,
|
||||
device=next(iter(outputs.values())).device)
|
||||
if is_dist_avail_and_initialized():
|
||||
torch.distributed.all_reduce(num_boxes)
|
||||
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
losses = {}
|
||||
for loss in self.losses:
|
||||
losses.update(
|
||||
self.get_loss(loss, outputs, targets, indices, num_boxes))
|
||||
|
||||
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
||||
if 'aux_outputs' in outputs:
|
||||
for i, aux_outputs in enumerate(outputs['aux_outputs']):
|
||||
indices = self.matcher(aux_outputs, targets)
|
||||
if return_indices:
|
||||
indices_list.append(indices)
|
||||
for loss in self.losses:
|
||||
if loss == 'masks':
|
||||
# Intermediate masks losses are too costly to compute, we ignore them.
|
||||
continue
|
||||
kwargs = {}
|
||||
if loss == 'labels':
|
||||
# Logging is enabled only for the last layer
|
||||
kwargs = {'log': False}
|
||||
l_dict = self.get_loss(loss, aux_outputs, targets, indices,
|
||||
num_boxes, **kwargs)
|
||||
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
|
||||
if mask_dict is not None:
|
||||
# dn loss computation
|
||||
aux_num = 0
|
||||
if 'aux_outputs' in outputs:
|
||||
aux_num = len(outputs['aux_outputs'])
|
||||
dn_losses = self.dn_criterion(mask_dict, self.training, aux_num,
|
||||
0.25)
|
||||
losses.update(dn_losses)
|
||||
|
||||
if return_indices:
|
||||
indices_list.append(indices0_copy)
|
||||
return losses, indices_list
|
||||
|
||||
return losses
|
||||
|
||||
|
||||
class DNCriterion(nn.Module):
|
||||
""" This class computes the loss for Conditional DETR.
|
||||
The process happens in two steps:
|
||||
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
||||
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
||||
"""
|
||||
|
||||
def __init__(self, weight_dict):
|
||||
""" Create the criterion.
|
||||
Parameters:
|
||||
num_classes: number of object categories, omitting the special no-object category
|
||||
matcher: module able to compute a matching between targets and proposals
|
||||
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
||||
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight_dict = weight_dict
|
||||
|
||||
def prepare_for_loss(self, mask_dict):
|
||||
"""
|
||||
prepare dn components to calculate loss
|
||||
Args:
|
||||
mask_dict: a dict that contains dn information
|
||||
"""
|
||||
output_known_class, output_known_coord = mask_dict[
|
||||
'output_known_lbs_bboxes']
|
||||
known_labels, known_bboxs = mask_dict['known_lbs_bboxes']
|
||||
map_known_indice = mask_dict['map_known_indice']
|
||||
|
||||
known_indice = mask_dict['known_indice']
|
||||
|
||||
batch_idx = mask_dict['batch_idx']
|
||||
bid = batch_idx[known_indice]
|
||||
if len(output_known_class) > 0:
|
||||
output_known_class = output_known_class.permute(
|
||||
1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2)
|
||||
output_known_coord = output_known_coord.permute(
|
||||
1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2)
|
||||
num_tgt = known_indice.numel()
|
||||
return known_labels, known_bboxs, output_known_class, output_known_coord, num_tgt
|
||||
|
||||
def tgt_loss_boxes(
|
||||
self,
|
||||
src_boxes,
|
||||
tgt_boxes,
|
||||
num_tgt,
|
||||
):
|
||||
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
|
||||
targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
|
||||
The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
|
||||
"""
|
||||
if len(tgt_boxes) == 0:
|
||||
return {
|
||||
'tgt_loss_bbox': torch.as_tensor(0.).to('cuda'),
|
||||
'tgt_loss_giou': torch.as_tensor(0.).to('cuda'),
|
||||
}
|
||||
|
||||
loss_bbox = F.l1_loss(src_boxes, tgt_boxes, reduction='none')
|
||||
|
||||
losses = {}
|
||||
losses['tgt_loss_bbox'] = loss_bbox.sum(
|
||||
) / num_tgt * self.weight_dict['loss_bbox']
|
||||
|
||||
loss_giou = 1 - torch.diag(
|
||||
generalized_box_iou(
|
||||
box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(tgt_boxes)))
|
||||
losses['tgt_loss_giou'] = loss_giou.sum(
|
||||
) / num_tgt * self.weight_dict['loss_giou']
|
||||
return losses
|
||||
|
||||
def tgt_loss_labels(self,
|
||||
src_logits_,
|
||||
tgt_labels_,
|
||||
num_tgt,
|
||||
focal_alpha,
|
||||
log=False):
|
||||
"""Classification loss (NLL)
|
||||
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
||||
"""
|
||||
if len(tgt_labels_) == 0:
|
||||
return {
|
||||
'tgt_loss_ce': torch.as_tensor(0.).to('cuda'),
|
||||
'tgt_class_error': torch.as_tensor(0.).to('cuda'),
|
||||
}
|
||||
|
||||
src_logits, tgt_labels = src_logits_.unsqueeze(
|
||||
0), tgt_labels_.unsqueeze(0)
|
||||
|
||||
target_classes_onehot = torch.zeros([
|
||||
src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1
|
||||
],
|
||||
dtype=src_logits.dtype,
|
||||
layout=src_logits.layout,
|
||||
device=src_logits.device)
|
||||
target_classes_onehot.scatter_(2, tgt_labels.unsqueeze(-1), 1)
|
||||
|
||||
target_classes_onehot = target_classes_onehot[:, :, :-1]
|
||||
loss_ce = py_sigmoid_focal_loss(
|
||||
src_logits,
|
||||
target_classes_onehot.long(),
|
||||
alpha=focal_alpha,
|
||||
gamma=2,
|
||||
reduction='none').mean(1).sum(
|
||||
) / num_tgt * src_logits.shape[1] * self.weight_dict['loss_ce']
|
||||
|
||||
losses = {'tgt_loss_ce': loss_ce}
|
||||
if log:
|
||||
losses['tgt_class_error'] = 100 - accuracy(src_logits_,
|
||||
tgt_labels_)[0]
|
||||
return losses
|
||||
|
||||
def forward(self, mask_dict, training, aux_num, focal_alpha):
|
||||
"""
|
||||
compute dn loss in criterion
|
||||
Args:
|
||||
mask_dict: a dict for dn information
|
||||
training: training or inference flag
|
||||
aux_num: aux loss number
|
||||
focal_alpha: for focal loss
|
||||
"""
|
||||
losses = {}
|
||||
if training and 'output_known_lbs_bboxes' in mask_dict:
|
||||
known_labels, known_bboxs, output_known_class, output_known_coord, num_tgt = self.prepare_for_loss(
|
||||
mask_dict)
|
||||
losses.update(
|
||||
self.tgt_loss_labels(output_known_class[-1], known_labels,
|
||||
num_tgt, focal_alpha))
|
||||
losses.update(
|
||||
self.tgt_loss_boxes(output_known_coord[-1], known_bboxs,
|
||||
num_tgt))
|
||||
else:
|
||||
losses['tgt_loss_bbox'] = torch.as_tensor(0.).to('cuda')
|
||||
losses['tgt_loss_giou'] = torch.as_tensor(0.).to('cuda')
|
||||
losses['tgt_loss_ce'] = torch.as_tensor(0.).to('cuda')
|
||||
losses['tgt_class_error'] = torch.as_tensor(0.).to('cuda')
|
||||
|
||||
if aux_num:
|
||||
for i in range(aux_num):
|
||||
# dn aux loss
|
||||
if training and 'output_known_lbs_bboxes' in mask_dict:
|
||||
l_dict = self.tgt_loss_labels(output_known_class[i],
|
||||
known_labels, num_tgt,
|
||||
focal_alpha)
|
||||
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
l_dict = self.tgt_loss_boxes(output_known_coord[i],
|
||||
known_bboxs, num_tgt)
|
||||
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
else:
|
||||
l_dict = dict()
|
||||
l_dict['tgt_loss_bbox'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict['tgt_class_error'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict['tgt_loss_giou'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict['tgt_loss_ce'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
||||
losses.update(l_dict)
|
||||
return losses
|
|
@ -73,6 +73,7 @@ class ModelExportTest(unittest.TestCase):
|
|||
def test_export_classification_jit(self):
|
||||
config_file = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py'
|
||||
cfg = mmcv_config_fromfile(config_file)
|
||||
cfg.model.pretrained = False
|
||||
cfg.model.backbone = dict(
|
||||
type='ResNetJIT',
|
||||
depth=50,
|
||||
|
|
|
@ -42,7 +42,8 @@ class ClassificationTest(unittest.TestCase):
|
|||
batch_size = 1
|
||||
a = torch.rand(batch_size, 3, 224, 224).to('cuda')
|
||||
|
||||
model = Classification(backbone=backbone, head=head).to('cuda')
|
||||
model = Classification(
|
||||
backbone=backbone, head=head, pretrained=False).to('cuda')
|
||||
model.eval()
|
||||
model_jit = torch.jit.script(model)
|
||||
|
||||
|
|
|
@ -12,7 +12,6 @@ from easycv.datasets.utils import replace_ImageToTensor
|
|||
from easycv.models import build_model
|
||||
from easycv.utils.checkpoint import load_checkpoint
|
||||
from easycv.utils.config_tools import mmcv_config_fromfile
|
||||
from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab
|
||||
from easycv.utils.registry import build_from_cfg
|
||||
|
||||
|
||||
|
@ -26,9 +25,6 @@ class DETRTest(unittest.TestCase):
|
|||
|
||||
self.cfg = mmcv_config_fromfile(config_path)
|
||||
|
||||
# dynamic adapt mmdet models
|
||||
dynamic_adapt_for_mmlab(self.cfg)
|
||||
|
||||
# modify model_config
|
||||
if self.cfg.model.head.get('num_select', None):
|
||||
self.cfg.model.head.num_select = 10
|
||||
|
@ -600,7 +596,7 @@ class DETRTest(unittest.TestCase):
|
|||
decimal=1)
|
||||
|
||||
def test_dab_detr(self):
|
||||
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dab_detr/epoch_50.pth'
|
||||
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dab_detr/dab_detr_epoch_50.pth'
|
||||
config_path = 'configs/detection/dab_detr/dab_detr_r50_8x2_50e_coco.py'
|
||||
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
|
||||
self.init_detr(model_path, config_path)
|
||||
|
@ -673,6 +669,80 @@ class DETRTest(unittest.TestCase):
|
|||
]]),
|
||||
decimal=1)
|
||||
|
||||
def test_dn_detr(self):
|
||||
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dn_detr/dn_detr_epoch_50.pth'
|
||||
config_path = 'configs/detection/dab_detr/dn_detr_r50_8x2_50e_coco.py'
|
||||
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
|
||||
self.init_detr(model_path, config_path)
|
||||
output = self.predict(img)
|
||||
|
||||
self.assertIn('detection_boxes', output)
|
||||
self.assertIn('detection_scores', output)
|
||||
self.assertIn('detection_classes', output)
|
||||
self.assertIn('img_metas', output)
|
||||
self.assertEqual(len(output['detection_boxes'][0]), 10)
|
||||
self.assertEqual(len(output['detection_scores'][0]), 10)
|
||||
self.assertEqual(len(output['detection_classes'][0]), 10)
|
||||
|
||||
self.assertListEqual(
|
||||
output['detection_classes'][0].tolist(),
|
||||
np.array([2, 13, 2, 2, 2, 2, 2, 2, 2, 2], dtype=np.int32).tolist())
|
||||
|
||||
assert_array_almost_equal(
|
||||
output['detection_scores'][0],
|
||||
np.array([
|
||||
0.8800525665283203, 0.866659939289093, 0.8665854930877686,
|
||||
0.8030595183372498, 0.7642921209335327, 0.7375038862228394,
|
||||
0.7270554304122925, 0.6710091233253479, 0.6316548585891724,
|
||||
0.6164721846580505
|
||||
],
|
||||
dtype=np.float32),
|
||||
decimal=2)
|
||||
|
||||
assert_array_almost_equal(
|
||||
output['detection_boxes'][0],
|
||||
np.array([[
|
||||
294.9338073730469, 115.7542495727539, 377.5517578125,
|
||||
150.59274291992188
|
||||
],
|
||||
[
|
||||
220.57424926757812, 175.97023010253906,
|
||||
456.9001770019531, 383.2597351074219
|
||||
],
|
||||
[
|
||||
479.5928649902344, 109.94012451171875,
|
||||
523.7343139648438, 130.80604553222656
|
||||
],
|
||||
[
|
||||
398.6956787109375, 111.45973205566406,
|
||||
434.0437316894531, 134.1909637451172
|
||||
],
|
||||
[
|
||||
166.98208618164062, 109.44792938232422,
|
||||
210.35342407226562, 139.9746856689453
|
||||
],
|
||||
[
|
||||
609.432373046875, 113.08062744140625,
|
||||
635.9082641601562, 136.74383544921875
|
||||
],
|
||||
[
|
||||
268.0716552734375, 105.00788879394531,
|
||||
327.4037170410156, 128.01449584960938
|
||||
],
|
||||
[
|
||||
190.77467346191406, 107.42850494384766,
|
||||
298.35760498046875, 156.2850341796875
|
||||
],
|
||||
[
|
||||
591.0296020507812, 110.53913116455078,
|
||||
620.702880859375, 127.42123413085938
|
||||
],
|
||||
[
|
||||
431.6607971191406, 105.04813385009766,
|
||||
484.4869689941406, 132.45864868164062
|
||||
]]),
|
||||
decimal=1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue