mirror of https://github.com/alibaba/EasyCV.git
support obj365 (#242)
Support Objects365 pretrain and Adding the DINO++ model can achieve an accuracy of 63.4mAP at a model scale of 200M(Under the same scale, the accuracy is the best)pull/247/head
parent
36a3c45efa
commit
654554cf65
|
@ -0,0 +1,142 @@
|
|||
CLASSES = [
|
||||
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
|
||||
'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',
|
||||
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
||||
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag',
|
||||
'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
|
||||
'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
||||
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon',
|
||||
'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
|
||||
'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant',
|
||||
'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
|
||||
'hair drier', 'toothbrush'
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
data_root = 'data/coco/'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='MMRandomFlip', flip_ratio=0.5),
|
||||
dict(
|
||||
type='MMAutoAugment',
|
||||
policies=[
|
||||
[
|
||||
dict(
|
||||
type='MMResize',
|
||||
img_scale=[(720, 2000), (768, 2000), (816, 2000),
|
||||
(864, 2000), (912, 2000), (960, 2000),
|
||||
(1008, 2000), (1056, 2000), (1104, 2000),
|
||||
(1152, 2000), (1200, 2000)],
|
||||
multiscale_mode='value',
|
||||
keep_ratio=True)
|
||||
],
|
||||
[
|
||||
dict(
|
||||
type='MMResize',
|
||||
# The radio of all image in train dataset < 7
|
||||
# follow the original impl
|
||||
img_scale=[(600, 6300), (750, 6300), (900, 6300)],
|
||||
multiscale_mode='value',
|
||||
keep_ratio=True),
|
||||
dict(
|
||||
type='MMRandomCrop',
|
||||
crop_type='absolute_range',
|
||||
crop_size=(576, 900),
|
||||
allow_negative_crop=True),
|
||||
dict(
|
||||
type='MMResize',
|
||||
img_scale=[(720, 2000), (768, 2000), (816, 2000),
|
||||
(864, 2000), (912, 2000), (960, 2000),
|
||||
(1008, 2000), (1056, 2000), (1104, 2000),
|
||||
(1152, 2000), (1200, 2000)],
|
||||
multiscale_mode='value',
|
||||
override=True,
|
||||
keep_ratio=True)
|
||||
]
|
||||
]),
|
||||
dict(type='MMNormalize', **img_norm_cfg),
|
||||
dict(type='MMPad', size_divisor=1),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img', 'gt_bboxes', 'gt_labels'],
|
||||
meta_keys=('filename', 'ori_filename', 'ori_shape', 'ori_img_shape',
|
||||
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction', 'img_norm_cfg'))
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(
|
||||
type='MMMultiScaleFlipAug',
|
||||
img_scale=(2000, 1200),
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='MMResize', keep_ratio=True),
|
||||
dict(type='MMRandomFlip'),
|
||||
dict(type='MMNormalize', **img_norm_cfg),
|
||||
dict(type='MMPad', size_divisor=1),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=('filename', 'ori_filename', 'ori_shape',
|
||||
'ori_img_shape', 'img_shape', 'pad_shape',
|
||||
'scale_factor', 'flip', 'flip_direction',
|
||||
'img_norm_cfg'))
|
||||
])
|
||||
]
|
||||
|
||||
train_dataset = dict(
|
||||
type='DetDataset',
|
||||
data_source=dict(
|
||||
type='DetSourceCoco',
|
||||
ann_file=data_root + 'annotations/instances_train2017.json',
|
||||
img_prefix=data_root + 'train2017/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations', with_bbox=True)
|
||||
],
|
||||
classes=CLASSES,
|
||||
test_mode=False,
|
||||
filter_empty_gt=False,
|
||||
iscrowd=False),
|
||||
pipeline=train_pipeline)
|
||||
|
||||
val_dataset = dict(
|
||||
type='DetDataset',
|
||||
imgs_per_gpu=1,
|
||||
data_source=dict(
|
||||
type='DetSourceCoco',
|
||||
ann_file=data_root + 'annotations/instances_val2017.json',
|
||||
img_prefix=data_root + 'val2017/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations', with_bbox=True)
|
||||
],
|
||||
classes=CLASSES,
|
||||
test_mode=True,
|
||||
filter_empty_gt=False,
|
||||
iscrowd=True),
|
||||
pipeline=test_pipeline)
|
||||
|
||||
data = dict(
|
||||
imgs_per_gpu=2,
|
||||
workers_per_gpu=2,
|
||||
train=train_dataset,
|
||||
val=val_dataset,
|
||||
drop_last=True)
|
||||
|
||||
# evaluation
|
||||
eval_config = dict(initial=False, interval=1, gpu_collect=False)
|
||||
eval_pipelines = [
|
||||
dict(
|
||||
mode='test',
|
||||
dist_eval=False,
|
||||
evaluators=[
|
||||
dict(type='CocoDetectionEvaluator', classes=CLASSES),
|
||||
],
|
||||
)
|
||||
]
|
|
@ -0,0 +1,192 @@
|
|||
CLASSES = [
|
||||
'Person', 'Sneakers', 'Chair', 'Other Shoes', 'Hat', 'Car', 'Lamp',
|
||||
'Glasses', 'Bottle', 'Desk', 'Cup', 'Street Lights', 'Cabinet/shelf',
|
||||
'Handbag/Satchel', 'Bracelet', 'Plate', 'Picture/Frame', 'Helmet', 'Book',
|
||||
'Gloves', 'Storage box', 'Boat', 'Leather Shoes', 'Flower', 'Bench',
|
||||
'Potted Plant', 'Bowl/Basin', 'Flag', 'Pillow', 'Boots', 'Vase',
|
||||
'Microphone', 'Necklace', 'Ring', 'SUV', 'Wine Glass', 'Belt',
|
||||
'Moniter/TV', 'Backpack', 'Umbrella', 'Traffic Light', 'Speaker', 'Watch',
|
||||
'Tie', 'Trash bin Can', 'Slippers', 'Bicycle', 'Stool', 'Barrel/bucket',
|
||||
'Van', 'Couch', 'Sandals', 'Bakset', 'Drum', 'Pen/Pencil', 'Bus',
|
||||
'Wild Bird', 'High Heels', 'Motorcycle', 'Guitar', 'Carpet', 'Cell Phone',
|
||||
'Bread', 'Camera', 'Canned', 'Truck', 'Traffic cone', 'Cymbal',
|
||||
'Lifesaver', 'Towel', 'Stuffed Toy', 'Candle', 'Sailboat', 'Laptop',
|
||||
'Awning', 'Bed', 'Faucet', 'Tent', 'Horse', 'Mirror', 'Power outlet',
|
||||
'Sink', 'Apple', 'Air Conditioner', 'Knife', 'Hockey Stick', 'Paddle',
|
||||
'Pickup Truck', 'Fork', 'Traffic Sign', 'Ballon', 'Tripod', 'Dog', 'Spoon',
|
||||
'Clock', 'Pot', 'Cow', 'Cake', 'Dinning Table', 'Sheep', 'Hanger',
|
||||
'Blackboard/Whiteboard', 'Napkin', 'Other Fish', 'Orange/Tangerine',
|
||||
'Toiletry', 'Keyboard', 'Tomato', 'Lantern', 'Machinery Vehicle', 'Fan',
|
||||
'Green Vegetables', 'Banana', 'Baseball Glove', 'Airplane', 'Mouse',
|
||||
'Train', 'Pumpkin', 'Soccer', 'Skiboard', 'Luggage', 'Nightstand',
|
||||
'Tea pot', 'Telephone', 'Trolley', 'Head Phone', 'Sports Car', 'Stop Sign',
|
||||
'Dessert', 'Scooter', 'Stroller', 'Crane', 'Remote', 'Refrigerator',
|
||||
'Oven', 'Lemon', 'Duck', 'Baseball Bat', 'Surveillance Camera', 'Cat',
|
||||
'Jug', 'Broccoli', 'Piano', 'Pizza', 'Elephant', 'Skateboard', 'Surfboard',
|
||||
'Gun', 'Skating and Skiing shoes', 'Gas stove', 'Donut', 'Bow Tie',
|
||||
'Carrot', 'Toilet', 'Kite', 'Strawberry', 'Other Balls', 'Shovel',
|
||||
'Pepper', 'Computer Box', 'Toilet Paper', 'Cleaning Products',
|
||||
'Chopsticks', 'Microwave', 'Pigeon', 'Baseball', 'Cutting/chopping Board',
|
||||
'Coffee Table', 'Side Table', 'Scissors', 'Marker', 'Pie', 'Ladder',
|
||||
'Snowboard', 'Cookies', 'Radiator', 'Fire Hydrant', 'Basketball', 'Zebra',
|
||||
'Grape', 'Giraffe', 'Potato', 'Sausage', 'Tricycle', 'Violin', 'Egg',
|
||||
'Fire Extinguisher', 'Candy', 'Fire Truck', 'Billards', 'Converter',
|
||||
'Bathtub', 'Wheelchair', 'Golf Club', 'Briefcase', 'Cucumber',
|
||||
'Cigar/Cigarette ', 'Paint Brush', 'Pear', 'Heavy Truck', 'Hamburger',
|
||||
'Extractor', 'Extention Cord', 'Tong', 'Tennis Racket', 'Folder',
|
||||
'American Football', 'earphone', 'Mask', 'Kettle', 'Tennis', 'Ship',
|
||||
'Swing', 'Coffee Machine', 'Slide', 'Carriage', 'Onion', 'Green beans',
|
||||
'Projector', 'Frisbee', 'Washing Machine/Drying Machine', 'Chicken',
|
||||
'Printer', 'Watermelon', 'Saxophone', 'Tissue', 'Toothbrush', 'Ice cream',
|
||||
'Hotair ballon', 'Cello', 'French Fries', 'Scale', 'Trophy', 'Cabbage',
|
||||
'Hot dog', 'Blender', 'Peach', 'Rice', 'Wallet/Purse', 'Volleyball',
|
||||
'Deer', 'Goose', 'Tape', 'Tablet', 'Cosmetics', 'Trumpet', 'Pineapple',
|
||||
'Golf Ball', 'Ambulance', 'Parking meter', 'Mango', 'Key', 'Hurdle',
|
||||
'Fishing Rod', 'Medal', 'Flute', 'Brush', 'Penguin', 'Megaphone', 'Corn',
|
||||
'Lettuce', 'Garlic', 'Swan', 'Helicopter', 'Green Onion', 'Sandwich',
|
||||
'Nuts', 'Speed Limit Sign', 'Induction Cooker', 'Broom', 'Trombone',
|
||||
'Plum', 'Rickshaw', 'Goldfish', 'Kiwi fruit', 'Router/modem', 'Poker Card',
|
||||
'Toaster', 'Shrimp', 'Sushi', 'Cheese', 'Notepaper', 'Cherry', 'Pliers',
|
||||
'CD', 'Pasta', 'Hammer', 'Cue', 'Avocado', 'Hamimelon', 'Flask',
|
||||
'Mushroon', 'Screwdriver', 'Soap', 'Recorder', 'Bear', 'Eggplant',
|
||||
'Board Eraser', 'Coconut', 'Tape Measur/ Ruler', 'Pig', 'Showerhead',
|
||||
'Globe', 'Chips', 'Steak', 'Crosswalk Sign', 'Stapler', 'Campel',
|
||||
'Formula 1 ', 'Pomegranate', 'Dishwasher', 'Crab', 'Hoverboard',
|
||||
'Meat ball', 'Rice Cooker', 'Tuba', 'Calculator', 'Papaya', 'Antelope',
|
||||
'Parrot', 'Seal', 'Buttefly', 'Dumbbell', 'Donkey', 'Lion', 'Urinal',
|
||||
'Dolphin', 'Electric Drill', 'Hair Dryer', 'Egg tart', 'Jellyfish',
|
||||
'Treadmill', 'Lighter', 'Grapefruit', 'Game board', 'Mop', 'Radish',
|
||||
'Baozi', 'Target', 'French', 'Spring Rolls', 'Monkey', 'Rabbit',
|
||||
'Pencil Case', 'Yak', 'Red Cabbage', 'Binoculars', 'Asparagus', 'Barbell',
|
||||
'Scallop', 'Noddles', 'Comb', 'Dumpling', 'Oyster', 'Table Teniis paddle',
|
||||
'Cosmetics Brush/Eyeliner Pencil', 'Chainsaw', 'Eraser', 'Lobster',
|
||||
'Durian', 'Okra', 'Lipstick', 'Cosmetics Mirror', 'Curling',
|
||||
'Table Tennis '
|
||||
]
|
||||
|
||||
# dataset settings
|
||||
data_root = 'data/objects365/'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='MMRandomFlip', flip_ratio=0.5),
|
||||
dict(
|
||||
type='MMAutoAugment',
|
||||
policies=[
|
||||
[
|
||||
dict(
|
||||
type='MMResize',
|
||||
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
||||
(576, 1333), (608, 1333), (640, 1333),
|
||||
(672, 1333), (704, 1333), (736, 1333),
|
||||
(768, 1333), (800, 1333)],
|
||||
multiscale_mode='value',
|
||||
keep_ratio=True)
|
||||
],
|
||||
[
|
||||
dict(
|
||||
type='MMResize',
|
||||
# The radio of all image in train dataset < 7
|
||||
# follow the original impl
|
||||
img_scale=[(400, 4200), (500, 4200), (600, 4200)],
|
||||
multiscale_mode='value',
|
||||
keep_ratio=True),
|
||||
dict(
|
||||
type='MMRandomCrop',
|
||||
crop_type='absolute_range',
|
||||
crop_size=(384, 600),
|
||||
allow_negative_crop=True),
|
||||
dict(
|
||||
type='MMResize',
|
||||
img_scale=[(480, 1333), (512, 1333), (544, 1333),
|
||||
(576, 1333), (608, 1333), (640, 1333),
|
||||
(672, 1333), (704, 1333), (736, 1333),
|
||||
(768, 1333), (800, 1333)],
|
||||
multiscale_mode='value',
|
||||
override=True,
|
||||
keep_ratio=True)
|
||||
]
|
||||
]),
|
||||
dict(type='MMNormalize', **img_norm_cfg),
|
||||
dict(type='MMPad', size_divisor=1),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img', 'gt_bboxes', 'gt_labels'],
|
||||
meta_keys=('filename', 'ori_filename', 'ori_shape', 'ori_img_shape',
|
||||
'img_shape', 'pad_shape', 'scale_factor', 'flip',
|
||||
'flip_direction', 'img_norm_cfg'))
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(
|
||||
type='MMMultiScaleFlipAug',
|
||||
img_scale=(1333, 800),
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='MMResize', keep_ratio=True),
|
||||
dict(type='MMRandomFlip'),
|
||||
dict(type='MMNormalize', **img_norm_cfg),
|
||||
dict(type='MMPad', size_divisor=1),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(
|
||||
type='Collect',
|
||||
keys=['img'],
|
||||
meta_keys=('filename', 'ori_filename', 'ori_shape',
|
||||
'ori_img_shape', 'img_shape', 'pad_shape',
|
||||
'scale_factor', 'flip', 'flip_direction',
|
||||
'img_norm_cfg'))
|
||||
])
|
||||
]
|
||||
|
||||
train_dataset = dict(
|
||||
type='DetDataset',
|
||||
data_source=dict(
|
||||
type='DetSourceObjects365',
|
||||
ann_file=data_root + 'annotations/zhiyuan_objv2_fullno5k.json',
|
||||
img_prefix=data_root + 'train/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations', with_bbox=True)
|
||||
],
|
||||
classes=CLASSES,
|
||||
test_mode=False,
|
||||
filter_empty_gt=False,
|
||||
iscrowd=False),
|
||||
pipeline=train_pipeline)
|
||||
|
||||
val_dataset = dict(
|
||||
type='DetDataset',
|
||||
imgs_per_gpu=1,
|
||||
data_source=dict(
|
||||
type='DetSourceObjects365',
|
||||
ann_file=data_root + 'annotations/zhiyuan_objv2_val5k.json',
|
||||
img_prefix=data_root + 'val/',
|
||||
pipeline=[
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations', with_bbox=True)
|
||||
],
|
||||
classes=CLASSES,
|
||||
test_mode=True,
|
||||
filter_empty_gt=False,
|
||||
iscrowd=True),
|
||||
pipeline=test_pipeline)
|
||||
|
||||
data = dict(
|
||||
imgs_per_gpu=2,
|
||||
workers_per_gpu=2,
|
||||
train=train_dataset,
|
||||
val=val_dataset,
|
||||
drop_last=True)
|
||||
|
||||
# evaluation
|
||||
eval_config = dict(initial=False, interval=1, gpu_collect=False)
|
||||
eval_pipelines = [
|
||||
dict(
|
||||
mode='test',
|
||||
# dist_eval=True,
|
||||
evaluators=[
|
||||
dict(type='CocoDetectionEvaluator', classes=CLASSES),
|
||||
],
|
||||
)
|
||||
]
|
|
@ -21,7 +21,8 @@ We present DINO(DETR with Improved deNoising anchOr boxes), a state-of-the-art e
|
|||
| DINO_4sc_swinl_12e | [DINO_4sc_swinl_12e](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dino/dino_4sc_swinl_12e_coco.py) | 195M/217M | 155ms | 56.86 | 75.61 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_swinl_12e/epoch_12.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_swinl_12e/20220815_211633.log.json) |
|
||||
| DINO_4sc_swinl_36e | [DINO_4sc_swinl_36e](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dino/dino_4sc_swinl_36e_coco.py) | 195M/217M | 155ms | 58.04 | 76.76 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_swinl_36e/epoch_34.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_swinl_36e/20220817_101416.log.json) |
|
||||
| DINO_5sc_swinl_36e | [DINO_5sc_swinl_36e](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dino/dino_5sc_swinl_36e_coco.py) | 195M/217M | 235ms | 58.47 | 77.10 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_5sc_swinl_36e/epoch_35.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_5sc_swinl_36e/20220820_215711.log.json) |
|
||||
|
||||
| DINO++_5sc_swinl_18e | [DINO++_5sc_swinl_18e](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dino/dino_5sc_swinl_center_iou_18e_obj2coco.py) | 195M/218M | 325ms | 63.39 | 80.25 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_5sc_swinl_obj2coco/epoch_13.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_5sc_swinl_obj2coco/20221107_095659.log.json) |
|
||||
(objects365 dataset processing tools: https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/obj365_download_tools.tar.gz)
|
||||
## Citation
|
||||
|
||||
```latex
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
_base_ = [
|
||||
'./dino_5sc_swinl.py', '../common/dataset/autoaug_obj2coco_detection.py',
|
||||
'./dino_schedule_1x.py'
|
||||
]
|
||||
|
||||
data = dict(imgs_per_gpu=1) # total 16 = 8(gpu_num) * 2(batch_size)
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
head=dict(
|
||||
dn_components=dict(dn_number=1000),
|
||||
use_centerness=True,
|
||||
use_iouaware=True,
|
||||
losses_list=['labels', 'boxes', 'centerness', 'iouaware'],
|
||||
transformer=dict(multi_encoder_memory=True),
|
||||
weight_dict=dict(loss_ce=2, loss_center=2, loss_iouaware=2)))
|
||||
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[8])
|
||||
|
||||
total_epochs = 18
|
||||
|
||||
checkpoint_config = dict(interval=1)
|
|
@ -0,0 +1,30 @@
|
|||
_base_ = [
|
||||
'./dino_5sc_swinl.py',
|
||||
'../common/dataset/autoaug_obj365_val5k_detection.py',
|
||||
'./dino_schedule_1x.py'
|
||||
]
|
||||
|
||||
data = dict(
|
||||
imgs_per_gpu=2
|
||||
) # total 64 = 2(update_interval) * 2(node_num) * 8(gpu_num) * 2(batch_size)
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
head=dict(
|
||||
num_classes=365,
|
||||
dn_components=dict(dn_number=1000, dn_labelbook_size=365),
|
||||
use_centerness=True,
|
||||
use_iouaware=True,
|
||||
losses_list=['labels', 'boxes', 'centerness', 'iouaware'],
|
||||
transformer=dict(multi_encoder_memory=True),
|
||||
weight_dict=dict(loss_ce=2, loss_center=2, loss_iouaware=2)))
|
||||
|
||||
# optimizer
|
||||
optimizer_config = dict(update_interval=1)
|
||||
|
||||
# learning policy
|
||||
lr_config = dict(policy='step', step=[24])
|
||||
|
||||
total_epochs = 26
|
||||
|
||||
checkpoint_config = dict(interval=1)
|
|
@ -47,3 +47,4 @@ Pretrained on COCO2017 dataset. (The result has been optimized with PAI-Blade, a
|
|||
| DINO_4sc_swinl_12e | [DINO_4sc_swinl_12e](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dino/dino_4sc_swinl_12e_coco.py) | 195M/217M | 155ms | 56.86 | 75.61 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_swinl_12e/epoch_12.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_swinl_12e/20220815_211633.log.json) |Inference use V100 32G|
|
||||
| DINO_4sc_swinl_36e | [DINO_4sc_swinl_36e](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dino/dino_4sc_swinl_36e_coco.py) | 195M/217M | 155ms | 58.04 | 76.76 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_swinl_36e/epoch_34.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_4sc_swinl_36e/20220817_101416.log.json) |Inference use V100 32G|
|
||||
| DINO_5sc_swinl_36e | [DINO_5sc_swinl_36e](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dino/dino_5sc_swinl_36e_coco.py) | 195M/217M | 235ms | 58.47 | 77.10 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_5sc_swinl_36e/epoch_35.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_5sc_swinl_36e/20220820_215711.log.json) |Inference use V100 32G|
|
||||
| DINO++_5sc_swinl_18e | [DINO++_5sc_swinl_18e](https://github.com/alibaba/EasyCV/tree/master/configs/detection/dino/dino_5sc_swinl_center_iou_18e_obj2coco.py) | 195M/218M | 325ms | 63.39 | 80.25 | [model](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_5sc_swinl_obj2coco/epoch_13.pth) - [log](https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/dino/dino_5sc_swinl_obj2coco/20221107_095659.log.json) |Inference use A100 80G|
|
||||
|
|
|
@ -6,7 +6,7 @@ from .coco_livs import DetSourceLvis
|
|||
from .coco_panoptic import DetSourceCocoPanoptic
|
||||
from .crowd_human import DetSourceCrowdHuman
|
||||
from .fruit import DetSourceFruit
|
||||
from .object365 import DetSourceObject365
|
||||
from .objects365 import DetSourceObjects365
|
||||
from .pai_format import DetSourcePAI
|
||||
from .pet import DetSourcePet
|
||||
from .raw import DetSourceRaw
|
||||
|
@ -15,9 +15,12 @@ from .wider_face import DetSourceWiderFace
|
|||
from .wider_person import DetSourceWiderPerson
|
||||
|
||||
__all__ = [
|
||||
'DetSourceCoco', 'DetSourceCocoPanoptic', 'DetSourceObjects365',
|
||||
'DetSourcePAI', 'DetSourceRaw', 'DetSourceVOC', 'DetSourceVOC2007',
|
||||
'DetSourceVOC2012', 'DetSourceCoco2017'
|
||||
'DetSourceCoco', 'DetSourceCocoPanoptic', 'DetSourcePAI', 'DetSourceRaw',
|
||||
'DetSourceVOC', 'DetSourceVOC2007', 'DetSourceVOC2012',
|
||||
'DetSourceCoco2017', 'DetSourceLvis', 'DetSourceWiderPerson',
|
||||
'DetSourceAfricanWildlife', 'DetSourcePet', 'DetSourceWiderFace',
|
||||
'DetSourceCrowdHuman', 'DetSourceObject365'
|
||||
'DetSourceCrowdHuman'
|
||||
]
|
||||
|
|
|
@ -1,17 +1,43 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
from tqdm import tqdm
|
||||
from xtcocotools.coco import COCO
|
||||
|
||||
from easycv.datasets.detection.data_sources.coco import DetSourceCoco
|
||||
from easycv.datasets.registry import DATASOURCES
|
||||
from .coco import DetSourceCoco
|
||||
|
||||
objv2_ignore_list = [
|
||||
# images exist in annotations but not in image folder.
|
||||
'patch16/objects365_v2_00908726.jpg',
|
||||
'patch6/objects365_v1_00320532.jpg',
|
||||
'patch6/objects365_v1_00320534.jpg',
|
||||
]
|
||||
|
||||
|
||||
@DATASOURCES.register_module
|
||||
class DetSourceObject365(DetSourceCoco):
|
||||
class DetSourceObjects365(DetSourceCoco):
|
||||
"""
|
||||
Object 365 data source
|
||||
objects365 data source.
|
||||
The form of the objects365 dataset folder build:
|
||||
|- objects365
|
||||
|- annotation
|
||||
|- zhiyuan_objv2_train.json
|
||||
|- zhiyuan_objv2_val.json
|
||||
|- train
|
||||
|- patch0
|
||||
|- *****(imageID)
|
||||
|- patch1
|
||||
|- *****(imageID)
|
||||
...
|
||||
|- patch50
|
||||
|- *****(imageID)
|
||||
|- val
|
||||
|- patch0
|
||||
|- *****(imageID)
|
||||
|- patch1
|
||||
|- *****(imageID)
|
||||
...
|
||||
|- patch43
|
||||
|- *****(imageID)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -34,7 +60,7 @@ class DetSourceObject365(DetSourceCoco):
|
|||
iscrowd: when traing setted as False, when val setted as True
|
||||
"""
|
||||
|
||||
super(DetSourceObject365, self).__init__(
|
||||
super(DetSourceObjects365, self).__init__(
|
||||
ann_file=ann_file,
|
||||
img_prefix=img_prefix,
|
||||
pipeline=pipeline,
|
||||
|
@ -50,6 +76,7 @@ class DetSourceObject365(DetSourceCoco):
|
|||
Returns:
|
||||
list[dict]: Annotation info from COCO api.
|
||||
"""
|
||||
|
||||
self.coco = COCO(ann_file)
|
||||
# The order of returned `cat_ids` will not
|
||||
# change with the order of the CLASSES
|
||||
|
@ -57,19 +84,23 @@ class DetSourceObject365(DetSourceCoco):
|
|||
|
||||
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
|
||||
self.img_ids = self.coco.getImgIds()
|
||||
img_path = os.listdir(self.img_prefix)
|
||||
data_infos = []
|
||||
total_ann_ids = []
|
||||
for i in tqdm(self.img_ids, desc='Scaning Images'):
|
||||
for i in self.img_ids:
|
||||
info = self.coco.loadImgs([i])[0]
|
||||
filename = os.path.basename(info['file_name'])
|
||||
# Filter the information corresponding to the image
|
||||
if filename in img_path:
|
||||
info['filename'] = filename
|
||||
data_infos.append(info)
|
||||
ann_ids = self.coco.getAnnIds(imgIds=[i])
|
||||
total_ann_ids.extend(ann_ids)
|
||||
|
||||
# rename filename and filter wrong data
|
||||
info['patch_name'] = osp.join(
|
||||
osp.split(osp.split(info['file_name'])[0])[-1],
|
||||
osp.split(info['file_name'])[-1])
|
||||
if info['patch_name'] in objv2_ignore_list:
|
||||
continue
|
||||
|
||||
info['filename'] = info['patch_name']
|
||||
|
||||
data_infos.append(info)
|
||||
ann_ids = self.coco.getAnnIds(imgIds=[i])
|
||||
total_ann_ids.extend(ann_ids)
|
||||
assert len(set(total_ann_ids)) == len(
|
||||
total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
|
||||
del total_ann_ids
|
||||
return data_infos
|
|
@ -9,147 +9,144 @@ import torch
|
|||
from easycv.models.detection.utils import inverse_sigmoid
|
||||
|
||||
|
||||
def prepare_for_cdn(dn_args, training, num_queries, num_classes, hidden_dim,
|
||||
label_enc):
|
||||
def prepare_for_cdn(dn_args, num_queries, num_classes, hidden_dim, label_enc):
|
||||
"""
|
||||
A major difference of DINO from DN-DETR is that the author process pattern embedding pattern embedding in its detector
|
||||
forward function and use learnable tgt embedding, so we change this function a little bit.
|
||||
:param dn_args: targets, dn_number, label_noise_ratio, box_noise_scale
|
||||
:param training: if it is training or inference
|
||||
:param num_queries: number of queires
|
||||
:param num_classes: number of classes
|
||||
:param hidden_dim: transformer hidden dim
|
||||
:param label_enc: encode labels in dn
|
||||
:return:
|
||||
"""
|
||||
if training:
|
||||
targets, dn_number, label_noise_ratio, box_noise_scale = dn_args
|
||||
# positive and negative dn queries
|
||||
dn_number = dn_number * 2
|
||||
known = [(torch.ones_like(t['labels'])).cuda() for t in targets]
|
||||
batch_size = len(known)
|
||||
known_num = [sum(k) for k in known]
|
||||
if int(max(known_num)) == 0:
|
||||
dn_number = 1
|
||||
else:
|
||||
if dn_number >= 100:
|
||||
dn_number = dn_number // (int(max(known_num) * 2))
|
||||
elif dn_number < 1:
|
||||
dn_number = 1
|
||||
if dn_number == 0:
|
||||
dn_number = 1
|
||||
unmask_bbox = unmask_label = torch.cat(known)
|
||||
labels = torch.cat([t['labels'] for t in targets])
|
||||
boxes = torch.cat([t['boxes'] for t in targets])
|
||||
batch_idx = torch.cat([
|
||||
torch.full_like(t['labels'].long(), i)
|
||||
for i, t in enumerate(targets)
|
||||
])
|
||||
|
||||
known_indice = torch.nonzero(unmask_label + unmask_bbox)
|
||||
known_indice = known_indice.view(-1)
|
||||
|
||||
known_indice = known_indice.repeat(2 * dn_number, 1).view(-1)
|
||||
known_labels = labels.repeat(2 * dn_number, 1).view(-1)
|
||||
known_bid = batch_idx.repeat(2 * dn_number, 1).view(-1)
|
||||
known_bboxs = boxes.repeat(2 * dn_number, 1)
|
||||
known_labels_expaned = known_labels.clone()
|
||||
known_bbox_expand = known_bboxs.clone()
|
||||
|
||||
if label_noise_ratio > 0:
|
||||
p = torch.rand_like(known_labels_expaned.float())
|
||||
chosen_indice = torch.nonzero(p < (label_noise_ratio)).view(
|
||||
-1) # half of bbox prob
|
||||
new_label = torch.randint_like(
|
||||
chosen_indice, 0, num_classes) # randomly put a new one here
|
||||
known_labels_expaned.scatter_(0, chosen_indice, new_label)
|
||||
single_pad = int(max(known_num))
|
||||
|
||||
pad_size = int(single_pad * 2 * dn_number)
|
||||
positive_idx = torch.tensor(range(
|
||||
len(boxes))).long().cuda().unsqueeze(0).repeat(dn_number, 1)
|
||||
positive_idx += (torch.tensor(range(dn_number)) * len(boxes) *
|
||||
2).long().cuda().unsqueeze(1)
|
||||
positive_idx = positive_idx.flatten()
|
||||
negative_idx = positive_idx + len(boxes)
|
||||
if box_noise_scale > 0:
|
||||
known_bbox_ = torch.zeros_like(known_bboxs)
|
||||
known_bbox_[:, :2] = known_bboxs[:, :2] - known_bboxs[:, 2:] / 2
|
||||
known_bbox_[:, 2:] = known_bboxs[:, :2] + known_bboxs[:, 2:] / 2
|
||||
|
||||
diff = torch.zeros_like(known_bboxs)
|
||||
diff[:, :2] = known_bboxs[:, 2:] / 2
|
||||
diff[:, 2:] = known_bboxs[:, 2:] / 2
|
||||
|
||||
rand_sign = torch.randint_like(
|
||||
known_bboxs, low=0, high=2, dtype=torch.float32) * 2.0 - 1.0
|
||||
rand_part = torch.rand_like(known_bboxs)
|
||||
rand_part[negative_idx] += 1.0
|
||||
rand_part *= rand_sign
|
||||
known_bbox_ = known_bbox_ + torch.mul(
|
||||
rand_part, diff).cuda() * box_noise_scale
|
||||
known_bbox_ = known_bbox_.clamp(min=0.0, max=1.0)
|
||||
known_bbox_expand[:, :2] = (known_bbox_[:, :2] +
|
||||
known_bbox_[:, 2:]) / 2
|
||||
known_bbox_expand[:, 2:] = known_bbox_[:, 2:] - known_bbox_[:, :2]
|
||||
|
||||
m = known_labels_expaned.long().to('cuda')
|
||||
input_label_embed = label_enc(m)
|
||||
input_bbox_embed = inverse_sigmoid(known_bbox_expand)
|
||||
|
||||
padding_label = torch.zeros(pad_size, hidden_dim).cuda()
|
||||
padding_bbox = torch.zeros(pad_size, 4).cuda()
|
||||
|
||||
input_query_label = padding_label.repeat(batch_size, 1, 1)
|
||||
input_query_bbox = padding_bbox.repeat(batch_size, 1, 1)
|
||||
|
||||
map_known_indice = torch.tensor([]).to('cuda')
|
||||
if len(known_num):
|
||||
map_known_indice = torch.cat([
|
||||
torch.tensor(range(num)) for num in known_num
|
||||
]) # [1,2, 1,2,3]
|
||||
map_known_indice = torch.cat([
|
||||
map_known_indice + single_pad * i for i in range(2 * dn_number)
|
||||
]).long()
|
||||
if len(known_bid):
|
||||
input_query_label[(known_bid.long(),
|
||||
map_known_indice)] = input_label_embed
|
||||
input_query_bbox[(known_bid.long(),
|
||||
map_known_indice)] = input_bbox_embed
|
||||
|
||||
tgt_size = pad_size + num_queries
|
||||
attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0
|
||||
# match query cannot see the reconstruct
|
||||
attn_mask[pad_size:, :pad_size] = True
|
||||
# reconstruct cannot see each other
|
||||
for i in range(dn_number):
|
||||
if i == 0:
|
||||
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1),
|
||||
single_pad * 2 * (i + 1):pad_size] = True
|
||||
if i == dn_number - 1:
|
||||
attn_mask[single_pad * 2 * i:single_pad * 2 *
|
||||
(i + 1), :single_pad * i * 2] = True
|
||||
else:
|
||||
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1),
|
||||
single_pad * 2 * (i + 1):pad_size] = True
|
||||
attn_mask[single_pad * 2 * i:single_pad * 2 *
|
||||
(i + 1), :single_pad * 2 * i] = True
|
||||
|
||||
dn_meta = {
|
||||
'pad_size': pad_size,
|
||||
'num_dn_group': dn_number,
|
||||
}
|
||||
targets, dn_number, label_noise_ratio, box_noise_scale = dn_args
|
||||
# positive and negative dn queries
|
||||
dn_number = dn_number * 2
|
||||
known = [(torch.ones_like(t['labels'])).cuda() for t in targets]
|
||||
batch_size = len(known)
|
||||
known_num = [sum(k) for k in known]
|
||||
if int(max(known_num)) == 0:
|
||||
dn_number = 1
|
||||
else:
|
||||
if dn_number >= 100:
|
||||
dn_number = dn_number // (int(max(known_num) * 2))
|
||||
elif dn_number < 1:
|
||||
dn_number = 1
|
||||
if dn_number == 0:
|
||||
dn_number = 1
|
||||
unmask_bbox = unmask_label = torch.cat(known)
|
||||
labels = torch.cat([t['labels'] for t in targets])
|
||||
boxes = torch.cat([t['boxes'] for t in targets])
|
||||
batch_idx = torch.cat([
|
||||
torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)
|
||||
])
|
||||
|
||||
input_query_label = None
|
||||
input_query_bbox = None
|
||||
attn_mask = None
|
||||
dn_meta = None
|
||||
known_indice = torch.nonzero(unmask_label + unmask_bbox)
|
||||
known_indice = known_indice.view(-1)
|
||||
|
||||
known_indice = known_indice.repeat(2 * dn_number, 1).view(-1)
|
||||
known_labels = labels.repeat(2 * dn_number, 1).view(-1)
|
||||
known_bid = batch_idx.repeat(2 * dn_number, 1).view(-1)
|
||||
known_bboxs = boxes.repeat(2 * dn_number, 1)
|
||||
known_labels_expaned = known_labels.clone()
|
||||
known_bbox_expand = known_bboxs.clone()
|
||||
|
||||
if label_noise_ratio > 0:
|
||||
p = torch.rand_like(known_labels_expaned.float())
|
||||
chosen_indice = torch.nonzero(p < (label_noise_ratio)).view(
|
||||
-1) # half of bbox prob
|
||||
new_label = torch.randint_like(
|
||||
chosen_indice, 0, num_classes) # randomly put a new one here
|
||||
known_labels_expaned.scatter_(0, chosen_indice, new_label)
|
||||
single_pad = int(max(known_num))
|
||||
|
||||
pad_size = int(single_pad * 2 * dn_number)
|
||||
positive_idx = torch.tensor(range(
|
||||
len(boxes))).long().cuda().unsqueeze(0).repeat(dn_number, 1)
|
||||
positive_idx += (torch.tensor(range(dn_number)) * len(boxes) *
|
||||
2).long().cuda().unsqueeze(1)
|
||||
positive_idx = positive_idx.flatten()
|
||||
negative_idx = positive_idx + len(boxes)
|
||||
if box_noise_scale > 0:
|
||||
known_bbox_ = torch.zeros_like(known_bboxs)
|
||||
known_bbox_[:, :2] = known_bboxs[:, :2] - known_bboxs[:, 2:] / 2
|
||||
known_bbox_[:, 2:] = known_bboxs[:, :2] + known_bboxs[:, 2:] / 2
|
||||
|
||||
diff = torch.zeros_like(known_bboxs)
|
||||
diff[:, :2] = known_bboxs[:, 2:] / 2
|
||||
diff[:, 2:] = known_bboxs[:, 2:] / 2
|
||||
|
||||
rand_sign = torch.randint_like(
|
||||
known_bboxs, low=0, high=2, dtype=torch.float32) * 2.0 - 1.0
|
||||
rand_part = torch.rand_like(known_bboxs)
|
||||
rand_part[negative_idx] += 1.0
|
||||
rand_part *= rand_sign
|
||||
known_bbox_ = known_bbox_ + torch.mul(rand_part,
|
||||
diff).cuda() * box_noise_scale
|
||||
known_bbox_ = known_bbox_.clamp(min=0.0, max=1.0)
|
||||
known_bbox_expand[:, :2] = (known_bbox_[:, :2] +
|
||||
known_bbox_[:, 2:]) / 2
|
||||
known_bbox_expand[:, 2:] = known_bbox_[:, 2:] - known_bbox_[:, :2]
|
||||
|
||||
m = known_labels_expaned.long().to('cuda')
|
||||
input_label_embed = label_enc(m)
|
||||
input_bbox_embed = inverse_sigmoid(known_bbox_expand)
|
||||
|
||||
padding_label = torch.zeros(pad_size, hidden_dim).cuda()
|
||||
padding_bbox = torch.zeros(pad_size, 4).cuda()
|
||||
|
||||
input_query_label = padding_label.repeat(batch_size, 1, 1)
|
||||
input_query_bbox = padding_bbox.repeat(batch_size, 1, 1)
|
||||
|
||||
map_known_indice = torch.tensor([]).to('cuda')
|
||||
if len(known_num):
|
||||
map_known_indice = torch.cat(
|
||||
[torch.tensor(range(num)) for num in known_num]) # [1,2, 1,2,3]
|
||||
map_known_indice = torch.cat([
|
||||
map_known_indice + single_pad * i for i in range(2 * dn_number)
|
||||
]).long()
|
||||
if len(known_bid):
|
||||
input_query_label[(known_bid.long(),
|
||||
map_known_indice)] = input_label_embed
|
||||
input_query_bbox[(known_bid.long(),
|
||||
map_known_indice)] = input_bbox_embed
|
||||
|
||||
tgt_size = pad_size + num_queries
|
||||
attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0
|
||||
# match query cannot see the reconstruct
|
||||
attn_mask[pad_size:, :pad_size] = True
|
||||
# reconstruct query cannot see the match
|
||||
attn_mask[:pad_size, pad_size:] = True
|
||||
# reconstruct cannot see each other
|
||||
for i in range(dn_number):
|
||||
if i == 0:
|
||||
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1),
|
||||
single_pad * 2 * (i + 1):pad_size] = True
|
||||
if i == dn_number - 1:
|
||||
attn_mask[single_pad * 2 * i:single_pad * 2 *
|
||||
(i + 1), :single_pad * i * 2] = True
|
||||
else:
|
||||
attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1),
|
||||
single_pad * 2 * (i + 1):pad_size] = True
|
||||
attn_mask[single_pad * 2 * i:single_pad * 2 *
|
||||
(i + 1), :single_pad * 2 * i] = True
|
||||
|
||||
dn_meta = {
|
||||
'pad_size': pad_size,
|
||||
'num_dn_group': dn_number,
|
||||
}
|
||||
|
||||
return input_query_label, input_query_bbox, attn_mask, dn_meta
|
||||
|
||||
|
||||
def cdn_post_process(outputs_class, outputs_coord, dn_meta, _set_aux_loss):
|
||||
def cdn_post_process(outputs_class,
|
||||
outputs_coord,
|
||||
dn_meta,
|
||||
_set_aux_loss,
|
||||
outputs_center=None,
|
||||
outputs_iou=None,
|
||||
reference=None):
|
||||
"""
|
||||
post process of dn after output from the transformer
|
||||
put the dn part in the dn_meta
|
||||
|
@ -159,11 +156,32 @@ def cdn_post_process(outputs_class, outputs_coord, dn_meta, _set_aux_loss):
|
|||
output_known_coord = outputs_coord[:, :, :dn_meta['pad_size'], :]
|
||||
outputs_class = outputs_class[:, :, dn_meta['pad_size']:, :]
|
||||
outputs_coord = outputs_coord[:, :, dn_meta['pad_size']:, :]
|
||||
output_known_center = None
|
||||
output_known_iou = None
|
||||
if outputs_center is not None:
|
||||
output_known_center = outputs_center[:, :, :dn_meta['pad_size'], :]
|
||||
outputs_center = outputs_center[:, :, dn_meta['pad_size']:, :]
|
||||
if outputs_iou is not None:
|
||||
output_known_iou = outputs_iou[:, :, :dn_meta['pad_size'], :]
|
||||
outputs_iou = outputs_iou[:, :, dn_meta['pad_size']:, :]
|
||||
known_reference = reference[:, :, :dn_meta['pad_size'], :]
|
||||
reference = reference[:, :, dn_meta['pad_size']:, :]
|
||||
out = {
|
||||
'pred_logits': output_known_class[-1],
|
||||
'pred_boxes': output_known_coord[-1]
|
||||
'pred_logits':
|
||||
output_known_class[-1],
|
||||
'pred_boxes':
|
||||
output_known_coord[-1],
|
||||
'pred_centers':
|
||||
output_known_center[-1]
|
||||
if output_known_center is not None else None,
|
||||
'pred_ious':
|
||||
output_known_iou[-1] if output_known_iou is not None else None,
|
||||
'refpts':
|
||||
known_reference[-1],
|
||||
}
|
||||
out['aux_outputs'] = _set_aux_loss(output_known_class,
|
||||
output_known_coord)
|
||||
output_known_coord,
|
||||
output_known_center,
|
||||
output_known_iou, known_reference)
|
||||
dn_meta['output_known_lbs_bboxes'] = out
|
||||
return outputs_class, outputs_coord
|
||||
return outputs_class, outputs_coord, outputs_center, outputs_iou, reference
|
||||
|
|
|
@ -43,6 +43,7 @@ class DeformableTransformer(nn.Module):
|
|||
num_patterns=0,
|
||||
modulate_hw_attn=False,
|
||||
# for deformable encoder
|
||||
multi_encoder_memory=False,
|
||||
deformable_encoder=True,
|
||||
deformable_decoder=True,
|
||||
num_feature_levels=1,
|
||||
|
@ -127,6 +128,10 @@ class DeformableTransformer(nn.Module):
|
|||
enc_layer_share=enc_layer_share,
|
||||
two_stage_type=two_stage_type)
|
||||
|
||||
self.multi_encoder_memory = multi_encoder_memory
|
||||
if self.multi_encoder_memory:
|
||||
self.memory_reduce = nn.Linear(d_model * 2, d_model)
|
||||
|
||||
# choose decoder layer type
|
||||
if deformable_decoder:
|
||||
decoder_layer = DeformableTransformerDecoderLayer(
|
||||
|
@ -218,6 +223,8 @@ class DeformableTransformer(nn.Module):
|
|||
|
||||
self.enc_out_class_embed = None
|
||||
self.enc_out_bbox_embed = None
|
||||
self.enc_out_center_embed = None
|
||||
self.enc_out_iou_embed = None
|
||||
|
||||
# evolution of anchors
|
||||
self.dec_layer_number = dec_layer_number
|
||||
|
@ -343,6 +350,8 @@ class DeformableTransformer(nn.Module):
|
|||
ref_token_index=enc_topk_proposals, # bs, nq
|
||||
ref_token_coord=enc_refpoint_embed, # bs, nq, 4
|
||||
)
|
||||
if self.multi_encoder_memory:
|
||||
memory = self.memory_reduce(torch.cat([src_flatten, memory], -1))
|
||||
#########################################################
|
||||
# End Encoder
|
||||
# - memory: bs, \sum{hw}, c
|
||||
|
@ -729,6 +738,8 @@ class TransformerDecoder(nn.Module):
|
|||
self.query_scale = MLP(d_model, d_model, d_model, 2)
|
||||
self.bbox_embed = None
|
||||
self.class_embed = None
|
||||
self.center_embed = None
|
||||
self.iou_embed = None
|
||||
|
||||
self.d_model = d_model
|
||||
self.modulate_hw_attn = modulate_hw_attn
|
||||
|
|
|
@ -57,6 +57,9 @@ class DINOHead(nn.Module):
|
|||
dec_pred_bbox_embed_share=True,
|
||||
two_stage_class_embed_share=True,
|
||||
two_stage_bbox_embed_share=True,
|
||||
use_centerness=False,
|
||||
use_iouaware=False,
|
||||
losses_list=['labels', 'boxes'],
|
||||
decoder_sa_type='sa',
|
||||
temperatureH=20,
|
||||
temperatureW=20,
|
||||
|
@ -80,16 +83,19 @@ class DINOHead(nn.Module):
|
|||
num_classes,
|
||||
matcher=self.matcher,
|
||||
weight_dict=weight_dict,
|
||||
losses=['labels', 'boxes'],
|
||||
losses=losses_list,
|
||||
loss_class_type='focal_loss')
|
||||
if dn_components is not None:
|
||||
self.dn_criterion = CDNCriterion(
|
||||
num_classes,
|
||||
matcher=self.matcher,
|
||||
weight_dict=weight_dict,
|
||||
losses=['labels', 'boxes'],
|
||||
losses=losses_list,
|
||||
loss_class_type='focal_loss')
|
||||
self.postprocess = DetrPostProcess(num_select=num_select)
|
||||
self.postprocess = DetrPostProcess(
|
||||
num_select=num_select,
|
||||
use_centerness=use_centerness,
|
||||
use_iouaware=use_iouaware)
|
||||
self.transformer = build_neck(transformer)
|
||||
|
||||
self.positional_encoding = PositionEmbeddingSineHW(
|
||||
|
@ -161,15 +167,43 @@ class DINOHead(nn.Module):
|
|||
nn.init.constant_(_bbox_embed.layers[-1].weight.data, 0)
|
||||
nn.init.constant_(_bbox_embed.layers[-1].bias.data, 0)
|
||||
|
||||
# fcos centerness & iou-aware & tokenlabel
|
||||
self.use_centerness = use_centerness
|
||||
self.use_iouaware = use_iouaware
|
||||
if self.use_centerness:
|
||||
_center_embed = MLP(embed_dims, embed_dims, 1, 3)
|
||||
if self.use_iouaware:
|
||||
_iou_embed = MLP(embed_dims, embed_dims, 1, 3)
|
||||
|
||||
if dec_pred_bbox_embed_share:
|
||||
box_embed_layerlist = [
|
||||
_bbox_embed for i in range(transformer.num_decoder_layers)
|
||||
]
|
||||
if self.use_centerness:
|
||||
center_embed_layerlist = [
|
||||
_center_embed
|
||||
for i in range(transformer.num_decoder_layers)
|
||||
]
|
||||
if self.use_iouaware:
|
||||
iou_embed_layerlist = [
|
||||
_iou_embed for i in range(transformer.num_decoder_layers)
|
||||
]
|
||||
else:
|
||||
box_embed_layerlist = [
|
||||
copy.deepcopy(_bbox_embed)
|
||||
for i in range(transformer.num_decoder_layers)
|
||||
]
|
||||
if self.use_centerness:
|
||||
center_embed_layerlist = [
|
||||
copy.deepcopy(_center_embed)
|
||||
for i in range(transformer.num_decoder_layers)
|
||||
]
|
||||
if self.use_iouaware:
|
||||
iou_embed_layerlist = [
|
||||
copy.deepcopy(_iou_embed)
|
||||
for i in range(transformer.num_decoder_layers)
|
||||
]
|
||||
|
||||
if dec_pred_class_embed_share:
|
||||
class_embed_layerlist = [
|
||||
_class_embed for i in range(transformer.num_decoder_layers)
|
||||
|
@ -184,6 +218,13 @@ class DINOHead(nn.Module):
|
|||
self.transformer.decoder.bbox_embed = self.bbox_embed
|
||||
self.transformer.decoder.class_embed = self.class_embed
|
||||
|
||||
if self.use_centerness:
|
||||
self.center_embed = nn.ModuleList(center_embed_layerlist)
|
||||
self.transformer.decoder.center_embed = self.center_embed
|
||||
if self.use_iouaware:
|
||||
self.iou_embed = nn.ModuleList(iou_embed_layerlist)
|
||||
self.transformer.decoder.iou_embed = self.iou_embed
|
||||
|
||||
# two stage
|
||||
self.two_stage_type = two_stage_type
|
||||
self.two_stage_add_query_num = two_stage_add_query_num
|
||||
|
@ -194,9 +235,19 @@ class DINOHead(nn.Module):
|
|||
if two_stage_bbox_embed_share:
|
||||
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
|
||||
self.transformer.enc_out_bbox_embed = _bbox_embed
|
||||
if self.use_centerness:
|
||||
self.transformer.enc_out_center_embed = _center_embed
|
||||
if self.use_iouaware:
|
||||
self.transformer.enc_out_iou_embed = _iou_embed
|
||||
else:
|
||||
self.transformer.enc_out_bbox_embed = copy.deepcopy(
|
||||
_bbox_embed)
|
||||
if self.use_centerness:
|
||||
self.transformer.enc_out_center_embed = copy.deepcopy(
|
||||
_center_embed)
|
||||
if self.use_iouaware:
|
||||
self.transformer.enc_out_iou_embed = copy.deepcopy(
|
||||
_iou_embed)
|
||||
|
||||
if two_stage_class_embed_share:
|
||||
assert dec_pred_class_embed_share and dec_pred_bbox_embed_share
|
||||
|
@ -261,10 +312,9 @@ class DINOHead(nn.Module):
|
|||
|
||||
def prepare(self, features, targets=None, mode='train'):
|
||||
|
||||
if self.dn_number > 0 or targets is not None:
|
||||
if self.dn_number > 0 and targets is not None:
|
||||
input_query_label, input_query_bbox, attn_mask, dn_meta =\
|
||||
prepare_for_cdn(dn_args=(targets, self.dn_number, self.dn_label_noise_ratio, self.dn_box_noise_scale),
|
||||
training=self.training, num_queries=self.num_queries, num_classes=self.num_classes,
|
||||
prepare_for_cdn(dn_args=(targets, self.dn_number, self.dn_label_noise_ratio, self.dn_box_noise_scale), num_queries=self.num_queries, num_classes=self.num_classes,
|
||||
hidden_dim=self.embed_dims, label_enc=self.label_enc)
|
||||
else:
|
||||
assert targets is None
|
||||
|
@ -355,29 +405,61 @@ class DINOHead(nn.Module):
|
|||
layer_cls_embed(layer_hs)
|
||||
for layer_cls_embed, layer_hs in zip(self.class_embed, hs)
|
||||
])
|
||||
|
||||
outputs_center_list = None
|
||||
if self.use_centerness:
|
||||
outputs_center_list = torch.stack([
|
||||
layer_center_embed(layer_hs)
|
||||
for layer_center_embed, layer_hs in zip(self.center_embed, hs)
|
||||
])
|
||||
|
||||
outputs_iou_list = None
|
||||
if self.use_iouaware:
|
||||
outputs_iou_list = torch.stack([
|
||||
layer_iou_embed(layer_hs)
|
||||
for layer_iou_embed, layer_hs in zip(self.iou_embed, hs)
|
||||
])
|
||||
|
||||
reference = torch.stack(reference)[:-1][..., :2]
|
||||
if self.dn_number > 0 and dn_meta is not None:
|
||||
outputs_class, outputs_coord_list = cdn_post_process(
|
||||
outputs_class, outputs_coord_list, dn_meta, self._set_aux_loss)
|
||||
outputs_class, outputs_coord_list, outputs_center_list, outputs_iou_list, reference = cdn_post_process(
|
||||
outputs_class, outputs_coord_list, dn_meta, self._set_aux_loss,
|
||||
outputs_center_list, outputs_iou_list, reference)
|
||||
out = {
|
||||
'pred_logits': outputs_class[-1],
|
||||
'pred_boxes': outputs_coord_list[-1]
|
||||
'pred_logits':
|
||||
outputs_class[-1],
|
||||
'pred_boxes':
|
||||
outputs_coord_list[-1],
|
||||
'pred_centers':
|
||||
outputs_center_list[-1]
|
||||
if outputs_center_list is not None else None,
|
||||
'pred_ious':
|
||||
outputs_iou_list[-1] if outputs_iou_list is not None else None,
|
||||
'refpts':
|
||||
reference[-1],
|
||||
}
|
||||
|
||||
out['aux_outputs'] = self._set_aux_loss(outputs_class,
|
||||
outputs_coord_list)
|
||||
outputs_coord_list,
|
||||
outputs_center_list,
|
||||
outputs_iou_list, reference)
|
||||
|
||||
# for encoder output
|
||||
if hs_enc is not None:
|
||||
# prepare intermediate outputs
|
||||
interm_coord = ref_enc[-1]
|
||||
interm_class = self.transformer.enc_out_class_embed(hs_enc[-1])
|
||||
if self.use_centerness:
|
||||
interm_center = self.transformer.enc_out_center_embed(
|
||||
hs_enc[-1])
|
||||
if self.use_iouaware:
|
||||
interm_iou = self.transformer.enc_out_iou_embed(hs_enc[-1])
|
||||
out['interm_outputs'] = {
|
||||
'pred_logits': interm_class,
|
||||
'pred_boxes': interm_coord
|
||||
}
|
||||
out['interm_outputs_for_matching_pre'] = {
|
||||
'pred_logits': interm_class,
|
||||
'pred_boxes': init_box_proposal
|
||||
'pred_boxes': interm_coord,
|
||||
'pred_centers': interm_center if self.use_centerness else None,
|
||||
'pred_ious': interm_iou if self.use_iouaware else None,
|
||||
'refpts': init_box_proposal[..., :2],
|
||||
}
|
||||
|
||||
out['dn_meta'] = dn_meta
|
||||
|
@ -385,14 +467,28 @@ class DINOHead(nn.Module):
|
|||
return out
|
||||
|
||||
@torch.jit.unused
|
||||
def _set_aux_loss(self, outputs_class, outputs_coord):
|
||||
def _set_aux_loss(self,
|
||||
outputs_class,
|
||||
outputs_coord,
|
||||
outputs_center=None,
|
||||
outputs_iou=None,
|
||||
reference=None):
|
||||
# this is a workaround to make torchscript happy, as torchscript
|
||||
# doesn't support dictionary with non-homogeneous values, such
|
||||
# as a dict having both a Tensor and a list.
|
||||
return [{
|
||||
'pred_logits': a,
|
||||
'pred_boxes': b
|
||||
} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
|
||||
'pred_logits':
|
||||
a,
|
||||
'pred_boxes':
|
||||
b,
|
||||
'pred_centers':
|
||||
outputs_center[i] if outputs_center is not None else None,
|
||||
'pred_ious':
|
||||
outputs_iou[i] if outputs_iou is not None else None,
|
||||
'refpts':
|
||||
reference[i],
|
||||
} for i, (a,
|
||||
b) in enumerate(zip(outputs_class[:-1], outputs_coord[:-1]))]
|
||||
|
||||
# over-write because img_metas are needed as inputs for bbox_head.
|
||||
def forward_train(self, x, img_metas, gt_bboxes, gt_labels):
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
|
||||
from .boxes import (batched_nms, bbox2result, bbox_overlaps, bboxes_iou,
|
||||
box_cxcywh_to_xyxy, box_xyxy_to_cxcywh, distance2bbox,
|
||||
fp16_clamp, generalized_box_iou)
|
||||
box_cxcywh_to_xyxy, box_iou, box_xyxy_to_cxcywh,
|
||||
distance2bbox, fp16_clamp, generalized_box_iou)
|
||||
from .generator import MlvlPointGenerator
|
||||
from .misc import (accuracy, gen_encoder_output_proposals,
|
||||
gen_sineembed_for_position, interpolate, inverse_sigmoid,
|
||||
|
|
|
@ -12,9 +12,14 @@ from easycv.models.detection.utils import box_cxcywh_to_xyxy
|
|||
class DetrPostProcess(nn.Module):
|
||||
""" This module converts the model's output into the format expected by the coco api"""
|
||||
|
||||
def __init__(self, num_select=None) -> None:
|
||||
def __init__(self,
|
||||
num_select=None,
|
||||
use_centerness=False,
|
||||
use_iouaware=False) -> None:
|
||||
super().__init__()
|
||||
self.num_select = num_select
|
||||
self.use_centerness = use_centerness
|
||||
self.use_iouaware = use_iouaware
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, outputs, target_sizes, img_metas):
|
||||
|
@ -34,8 +39,24 @@ class DetrPostProcess(nn.Module):
|
|||
prob = F.softmax(out_logits, -1)
|
||||
scores, labels = prob[..., :-1].max(-1)
|
||||
boxes = box_cxcywh_to_xyxy(out_bbox)
|
||||
|
||||
# and from relative [0, 1] to absolute [0, height] coordinates
|
||||
img_h, img_w = target_sizes.unbind(1)
|
||||
scale_fct = torch.stack([img_w, img_h, img_w, img_h],
|
||||
dim=1).to(boxes.device)
|
||||
boxes = boxes * scale_fct[:, None, :]
|
||||
else:
|
||||
prob = out_logits.sigmoid()
|
||||
if self.use_centerness and self.use_iouaware:
|
||||
prob = out_logits.sigmoid(
|
||||
)**0.45 * outputs['pred_centers'].sigmoid(
|
||||
)**0.05 * outputs['pred_ious'].sigmoid()**0.5
|
||||
elif self.use_centerness:
|
||||
prob = out_logits.sigmoid() * outputs['pred_centers'].sigmoid()
|
||||
elif self.use_iouaware:
|
||||
prob = out_logits.sigmoid() * outputs['pred_ious'].sigmoid()
|
||||
else:
|
||||
prob = out_logits.sigmoid()
|
||||
|
||||
topk_values, topk_indexes = torch.topk(
|
||||
prob.view(out_logits.shape[0], -1), self.num_select, dim=1)
|
||||
scores = topk_values
|
||||
|
@ -45,11 +66,11 @@ class DetrPostProcess(nn.Module):
|
|||
boxes = torch.gather(boxes, 1,
|
||||
topk_boxes.unsqueeze(-1).repeat(1, 1, 4))
|
||||
|
||||
# and from relative [0, 1] to absolute [0, height] coordinates
|
||||
img_h, img_w = target_sizes.unbind(1)
|
||||
scale_fct = torch.stack([img_w, img_h, img_w, img_h],
|
||||
dim=1).to(boxes.device)
|
||||
boxes = boxes * scale_fct[:, None, :]
|
||||
# and from relative [0, 1] to absolute [0, height] coordinates
|
||||
img_h, img_w = target_sizes.unbind(1)
|
||||
scale_fct = torch.stack([img_w, img_h, img_w, img_h],
|
||||
dim=1).to(boxes.device)
|
||||
boxes = boxes * scale_fct[:, None, :]
|
||||
|
||||
results = {
|
||||
'detection_boxes': [boxes[0].cpu().numpy()],
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from easycv.models.detection.utils import (accuracy, box_cxcywh_to_xyxy,
|
||||
generalized_box_iou)
|
||||
box_iou, generalized_box_iou)
|
||||
from easycv.models.loss.focal_loss import py_sigmoid_focal_loss
|
||||
from easycv.utils.dist_utils import get_dist_info, is_dist_available
|
||||
|
||||
|
@ -127,6 +127,73 @@ class SetCriterion(nn.Module):
|
|||
|
||||
return losses
|
||||
|
||||
def loss_centerness(self, outputs, targets, indices, num_boxes):
|
||||
|
||||
def ref2ltrb(ref, xyxy):
|
||||
lt = ref - xyxy[..., :2]
|
||||
rb = xyxy[..., 2:] - ref
|
||||
ltrb = torch.cat([lt, rb], dim=-1)
|
||||
return ltrb
|
||||
|
||||
def compute_centerness_targets(box_targets):
|
||||
left_right = box_targets[:, [0, 2]]
|
||||
top_bottom = box_targets[:, [1, 3]]
|
||||
centerness = (left_right.min(-1)[0] / left_right.max(-1)[0]) * (
|
||||
top_bottom.min(-1)[0] / top_bottom.max(-1)[0])
|
||||
return torch.sqrt(centerness)
|
||||
|
||||
assert 'pred_centers' in outputs
|
||||
idx = self._get_src_permutation_idx(indices)
|
||||
src_centers = outputs['pred_centers'][idx] # logits
|
||||
src_centers = src_centers.squeeze(1)
|
||||
target_boxes = torch.cat(
|
||||
[t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||||
|
||||
assert 'refpts' in outputs
|
||||
src_refpts = outputs['refpts'][idx] # sigmoided
|
||||
assert src_refpts.shape[-1] == 2
|
||||
|
||||
target_boxes_xyxy = box_cxcywh_to_xyxy(target_boxes)
|
||||
target_boxes_ltrb = ref2ltrb(src_refpts, target_boxes_xyxy)
|
||||
is_in_box = torch.sum(target_boxes_ltrb >= 0, dim=-1) == 4
|
||||
|
||||
src_centers = src_centers[is_in_box]
|
||||
target_boxes_ltrb = target_boxes_ltrb[is_in_box]
|
||||
|
||||
target_boxes_ltrb = target_boxes_ltrb.detach()
|
||||
|
||||
losses = {}
|
||||
if len(target_boxes_ltrb) == 0:
|
||||
losses['loss_center'] = src_centers.sum(
|
||||
) * 0 # prevent unused parameters
|
||||
else:
|
||||
target_centers = compute_centerness_targets(target_boxes_ltrb)
|
||||
loss_center = F.binary_cross_entropy_with_logits(
|
||||
src_centers, target_centers, reduction='none')
|
||||
losses['loss_center'] = loss_center.sum() / num_boxes
|
||||
|
||||
return losses
|
||||
|
||||
def loss_iouaware(self, outputs, targets, indices, num_boxes):
|
||||
assert 'pred_ious' in outputs
|
||||
idx = self._get_src_permutation_idx(indices)
|
||||
src_ious = outputs['pred_ious'][idx] # logits
|
||||
src_ious = src_ious.squeeze(1)
|
||||
src_boxes = outputs['pred_boxes'][idx]
|
||||
target_boxes = torch.cat(
|
||||
[t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
||||
|
||||
iou = torch.diag(
|
||||
box_iou(
|
||||
box_cxcywh_to_xyxy(src_boxes),
|
||||
box_cxcywh_to_xyxy(target_boxes))[0])
|
||||
|
||||
losses = {}
|
||||
loss_iouaware = F.binary_cross_entropy_with_logits(
|
||||
src_ious, iou, reduction='none')
|
||||
losses['loss_iouaware'] = loss_iouaware.sum() / num_boxes
|
||||
return losses
|
||||
|
||||
def _get_src_permutation_idx(self, indices):
|
||||
# permute predictions following indices
|
||||
batch_idx = torch.cat(
|
||||
|
@ -146,6 +213,8 @@ class SetCriterion(nn.Module):
|
|||
'labels': self.loss_labels,
|
||||
'cardinality': self.loss_cardinality,
|
||||
'boxes': self.loss_boxes,
|
||||
'centerness': self.loss_centerness,
|
||||
'iouaware': self.loss_iouaware,
|
||||
}
|
||||
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
||||
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
|
||||
|
@ -156,7 +225,6 @@ class SetCriterion(nn.Module):
|
|||
outputs: dict of tensors, see the output specification of the model for the format
|
||||
targets: list of dicts, such that len(targets) == batch_size.
|
||||
The expected keys in each dict depends on the losses applied, see each loss' doc
|
||||
|
||||
return_indices: used for vis. if True, the layer0-5 indices will be returned as well.
|
||||
"""
|
||||
|
||||
|
@ -200,9 +268,6 @@ class SetCriterion(nn.Module):
|
|||
if return_indices:
|
||||
indices_list.append(indices)
|
||||
for loss in self.losses:
|
||||
if loss == 'masks':
|
||||
# Intermediate masks losses are too costly to compute, we ignore them.
|
||||
continue
|
||||
kwargs = {}
|
||||
if loss == 'labels':
|
||||
# Logging is enabled only for the last layer
|
||||
|
@ -223,9 +288,6 @@ class SetCriterion(nn.Module):
|
|||
if return_indices:
|
||||
indices_list.append(indices)
|
||||
for loss in self.losses:
|
||||
if loss == 'masks':
|
||||
# Intermediate masks losses are too costly to compute, we ignore them.
|
||||
continue
|
||||
kwargs = {}
|
||||
if loss == 'labels':
|
||||
# Logging is enabled only for the last layer
|
||||
|
@ -321,9 +383,15 @@ class CDNCriterion(SetCriterion):
|
|||
losses.update(l_dict)
|
||||
else:
|
||||
l_dict = dict()
|
||||
l_dict['loss_bbox_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict['loss_giou_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict['loss_ce_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
if 'labels' in self.losses:
|
||||
l_dict['loss_ce_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
if 'boxes' in self.losses:
|
||||
l_dict['loss_bbox_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict['loss_giou_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
if 'centerness' in self.losses:
|
||||
l_dict['loss_center_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
if 'iouaware' in self.losses:
|
||||
l_dict['loss_iouaware_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
losses.update(l_dict)
|
||||
|
||||
for i in range(aux_num):
|
||||
|
@ -348,9 +416,15 @@ class CDNCriterion(SetCriterion):
|
|||
losses.update(l_dict)
|
||||
else:
|
||||
l_dict = dict()
|
||||
l_dict['loss_bbox_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict['loss_giou_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict['loss_ce_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
if 'labels' in self.losses:
|
||||
l_dict['loss_ce_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
if 'boxes' in self.losses:
|
||||
l_dict['loss_bbox_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict['loss_giou_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
if 'centerness' in self.losses:
|
||||
l_dict['loss_center_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
if 'iouaware' in self.losses:
|
||||
l_dict['loss_iouaware_dn'] = torch.as_tensor(0.).to('cuda')
|
||||
l_dict = {
|
||||
k + f'_{i}':
|
||||
v * (self.weight_dict[k] if k in self.weight_dict else 1.0)
|
||||
|
|
|
@ -62,7 +62,7 @@ class DetSourceObject365(unittest.TestCase):
|
|||
|
||||
data_source = build_datasource(
|
||||
dict(
|
||||
type='DetSourceObject365',
|
||||
type='DetSourceObjects365',
|
||||
ann_file=DET_DATASET_OBJECT365 + '/val.json',
|
||||
img_prefix=DET_DATASET_OBJECT365 + '/images',
|
||||
pipeline=[
|
||||
|
|
Loading…
Reference in New Issue