mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
parent
7ee13097e0
commit
14f8a97bdb
@ -0,0 +1,311 @@
|
|||||||
|
_base_ = ['configs/base.py']
|
||||||
|
|
||||||
|
# If point cloud range is changed, the models should also change their point
|
||||||
|
# cloud range accordingly
|
||||||
|
point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
|
||||||
|
voxel_size = [0.2, 0.2, 8]
|
||||||
|
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
|
||||||
|
# For nuScenes we usually do 10-class detection
|
||||||
|
CLASSES = [
|
||||||
|
'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
|
||||||
|
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
|
||||||
|
]
|
||||||
|
|
||||||
|
input_modality = dict(
|
||||||
|
use_lidar=False,
|
||||||
|
use_camera=True,
|
||||||
|
use_radar=False,
|
||||||
|
use_map=False,
|
||||||
|
use_external=True)
|
||||||
|
|
||||||
|
embed_dim = 256
|
||||||
|
pos_dim = embed_dim // 2
|
||||||
|
ffn_dim = embed_dim * 2
|
||||||
|
num_levels = 4
|
||||||
|
bev_h = 200
|
||||||
|
bev_w = 200
|
||||||
|
queue_length = 4 # each sequence contains `queue_length` frames.
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
type='BEVFormer',
|
||||||
|
use_grid_mask=True,
|
||||||
|
video_test_mode=True,
|
||||||
|
img_backbone=dict(
|
||||||
|
type='ResNet',
|
||||||
|
depth=101,
|
||||||
|
num_stages=4,
|
||||||
|
out_indices=(2, 3, 4),
|
||||||
|
frozen_stages=-1,
|
||||||
|
norm_cfg=dict(type='BN', requires_grad=False),
|
||||||
|
norm_eval=True,
|
||||||
|
style='caffe',
|
||||||
|
dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),
|
||||||
|
stage_with_dcn=(False, False, True, True),
|
||||||
|
zero_init_residual=True),
|
||||||
|
img_neck=dict(
|
||||||
|
type='FPN',
|
||||||
|
in_channels=[512, 1024, 2048],
|
||||||
|
out_channels=embed_dim,
|
||||||
|
start_level=0,
|
||||||
|
add_extra_convs='on_output',
|
||||||
|
num_outs=num_levels,
|
||||||
|
relu_before_extra_convs=True),
|
||||||
|
pts_bbox_head=dict(
|
||||||
|
type='BEVFormerHead',
|
||||||
|
bev_h=bev_h,
|
||||||
|
bev_w=bev_w,
|
||||||
|
num_query=900,
|
||||||
|
num_query_one2many=1800,
|
||||||
|
one2many_gt_mul=[2, 3, 7, 7, 9, 6, 7, 6, 2, 5],
|
||||||
|
num_classes=10,
|
||||||
|
in_channels=embed_dim,
|
||||||
|
sync_cls_avg_factor=True,
|
||||||
|
with_box_refine=True,
|
||||||
|
as_two_stage=False,
|
||||||
|
transformer=dict(
|
||||||
|
type='PerceptionTransformer',
|
||||||
|
rotate_prev_bev=True,
|
||||||
|
use_shift=True,
|
||||||
|
use_can_bus=True,
|
||||||
|
embed_dims=embed_dim,
|
||||||
|
encoder=dict(
|
||||||
|
type='BEVFormerEncoder',
|
||||||
|
num_layers=6,
|
||||||
|
pc_range=point_cloud_range,
|
||||||
|
num_points_in_pillar=4,
|
||||||
|
return_intermediate=False,
|
||||||
|
transformerlayers=dict(
|
||||||
|
type='BEVFormerLayer',
|
||||||
|
attn_cfgs=[
|
||||||
|
dict(
|
||||||
|
type='TemporalSelfAttention',
|
||||||
|
embed_dims=embed_dim,
|
||||||
|
num_levels=1),
|
||||||
|
dict(
|
||||||
|
type='SpatialCrossAttention',
|
||||||
|
pc_range=point_cloud_range,
|
||||||
|
deformable_attention=dict(
|
||||||
|
type='MSDeformableAttention3D',
|
||||||
|
embed_dims=embed_dim,
|
||||||
|
num_points=8,
|
||||||
|
num_levels=num_levels),
|
||||||
|
embed_dims=embed_dim,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
ffn_cfgs=dict(
|
||||||
|
type='FFN',
|
||||||
|
embed_dims=256,
|
||||||
|
feedforward_channels=ffn_dim,
|
||||||
|
num_fcs=2,
|
||||||
|
ffn_drop=0.1,
|
||||||
|
act_cfg=dict(type='ReLU', inplace=True),
|
||||||
|
),
|
||||||
|
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
|
||||||
|
'ffn', 'norm'))),
|
||||||
|
decoder=dict(
|
||||||
|
type='Detr3DTransformerDecoder',
|
||||||
|
num_layers=6,
|
||||||
|
return_intermediate=True,
|
||||||
|
transformerlayers=dict(
|
||||||
|
type='DetrTransformerDecoderLayer',
|
||||||
|
attn_cfgs=[
|
||||||
|
dict(
|
||||||
|
type='MultiheadAttention',
|
||||||
|
embed_dims=embed_dim,
|
||||||
|
num_heads=8,
|
||||||
|
dropout=0.1),
|
||||||
|
dict(
|
||||||
|
type='CustomMSDeformableAttention',
|
||||||
|
embed_dims=embed_dim,
|
||||||
|
num_levels=1),
|
||||||
|
],
|
||||||
|
ffn_cfgs=dict(
|
||||||
|
type='FFN',
|
||||||
|
embed_dims=256,
|
||||||
|
feedforward_channels=ffn_dim,
|
||||||
|
num_fcs=2,
|
||||||
|
ffn_drop=0.1,
|
||||||
|
act_cfg=dict(type='ReLU', inplace=True),
|
||||||
|
),
|
||||||
|
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
|
||||||
|
'ffn', 'norm')))),
|
||||||
|
bbox_coder=dict(
|
||||||
|
type='NMSFreeBBoxCoder',
|
||||||
|
post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
|
||||||
|
pc_range=point_cloud_range,
|
||||||
|
max_num=300,
|
||||||
|
voxel_size=voxel_size,
|
||||||
|
num_classes=10),
|
||||||
|
positional_encoding=dict(
|
||||||
|
type='LearnedPositionalEncoding',
|
||||||
|
num_feats=pos_dim,
|
||||||
|
row_num_embed=bev_h,
|
||||||
|
col_num_embed=bev_w,
|
||||||
|
),
|
||||||
|
loss_cls=dict(
|
||||||
|
type='FocalLoss',
|
||||||
|
use_sigmoid=True,
|
||||||
|
gamma=2.0,
|
||||||
|
alpha=0.25,
|
||||||
|
loss_weight=2.0),
|
||||||
|
# loss_bbox=dict(type='L1Loss', loss_weight=0.25),
|
||||||
|
# loss_bbox=dict(type='SmoothL1Loss', loss_weight=0.25),
|
||||||
|
loss_bbox=dict(type='BalancedL1Loss', loss_weight=0.25, gamma=1),
|
||||||
|
loss_iou=dict(type='GIoULoss', loss_weight=0.0)),
|
||||||
|
# model training and testing settings
|
||||||
|
train_cfg=dict(
|
||||||
|
pts=dict(
|
||||||
|
grid_size=[512, 512, 1],
|
||||||
|
voxel_size=voxel_size,
|
||||||
|
point_cloud_range=point_cloud_range,
|
||||||
|
out_size_factor=4,
|
||||||
|
assigner=dict(
|
||||||
|
type='HungarianBBoxAssigner3D',
|
||||||
|
cls_cost=dict(type='FocalLossCost', weight=2.0),
|
||||||
|
reg_cost=dict(type='BBox3DL1Cost', weight=0.25),
|
||||||
|
iou_cost=dict(
|
||||||
|
type='IoUCost', weight=0.0
|
||||||
|
), # Fake cost. This is just to make it compatible with DETR head.
|
||||||
|
pc_range=point_cloud_range))))
|
||||||
|
|
||||||
|
dataset_type = 'NuScenesDataset'
|
||||||
|
data_root = 'data/nuscenes/train-val/'
|
||||||
|
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='PhotoMetricDistortionMultiViewImage'),
|
||||||
|
|
||||||
|
# dict(type='RandomScaleImageMultiViewImage', scales=[0.8,0.9,1.0,1.1,1.2]),
|
||||||
|
dict(type='RandomHorizontalFlipMultiViewImage'),
|
||||||
|
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
|
||||||
|
dict(type='ObjectNameFilter', classes=CLASSES),
|
||||||
|
dict(type='NormalizeMultiviewImage', **img_norm_cfg),
|
||||||
|
dict(type='PadMultiViewImage', size_divisor=32),
|
||||||
|
dict(type='DefaultFormatBundle3D', class_names=CLASSES),
|
||||||
|
dict(
|
||||||
|
type='Collect3D',
|
||||||
|
keys=['gt_bboxes_3d', 'gt_labels_3d', 'img'],
|
||||||
|
meta_keys=('filename', 'ori_shape', 'img_shape', 'lidar2img',
|
||||||
|
'depth2img', 'cam2img', 'pad_shape', 'scale_factor', 'flip',
|
||||||
|
'pcd_horizontal_flip', 'pcd_vertical_flip', 'box_mode_3d',
|
||||||
|
'box_type_3d', 'img_norm_cfg', 'pcd_trans', 'sample_idx',
|
||||||
|
'prev_idx', 'next_idx', 'pcd_scale_factor', 'pcd_rotation',
|
||||||
|
'pts_filename', 'transformation_3d_flow', 'scene_token',
|
||||||
|
'can_bus'))
|
||||||
|
]
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='NormalizeMultiviewImage', **img_norm_cfg),
|
||||||
|
dict(type='PadMultiViewImage', size_divisor=32),
|
||||||
|
dict(
|
||||||
|
type='MultiScaleFlipAug3D',
|
||||||
|
img_scale=(1600, 900),
|
||||||
|
pts_scale_ratio=1,
|
||||||
|
flip=False,
|
||||||
|
transforms=[
|
||||||
|
dict(
|
||||||
|
type='DefaultFormatBundle3D',
|
||||||
|
class_names=CLASSES,
|
||||||
|
with_label=False),
|
||||||
|
dict(
|
||||||
|
type='Collect3D',
|
||||||
|
keys=['img'],
|
||||||
|
meta_keys=('filename', 'ori_shape', 'img_shape', 'lidar2img',
|
||||||
|
'depth2img', 'cam2img', 'pad_shape', 'scale_factor',
|
||||||
|
'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip',
|
||||||
|
'box_mode_3d', 'box_type_3d', 'img_norm_cfg',
|
||||||
|
'pcd_trans', 'sample_idx', 'prev_idx', 'next_idx',
|
||||||
|
'pcd_scale_factor', 'pcd_rotation', 'pts_filename',
|
||||||
|
'transformation_3d_flow', 'scene_token', 'can_bus'))
|
||||||
|
])
|
||||||
|
]
|
||||||
|
|
||||||
|
data = dict(
|
||||||
|
imgs_per_gpu=1, # 8gpus, total batch size=8
|
||||||
|
workers_per_gpu=8,
|
||||||
|
pin_memory=True,
|
||||||
|
# shuffler_sampler=dict(type='DistributedGroupSampler'),
|
||||||
|
# nonshuffler_sampler=dict(type='DistributedSampler'),
|
||||||
|
train=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_source=dict(
|
||||||
|
type='Det3dSourceNuScenes',
|
||||||
|
data_root=data_root,
|
||||||
|
ann_file=data_root + 'nuscenes_infos_temporal_train.pkl',
|
||||||
|
pipeline=[
|
||||||
|
dict(
|
||||||
|
type='LoadMultiViewImageFromFiles',
|
||||||
|
to_float32=True,
|
||||||
|
backend='turbojpeg'),
|
||||||
|
dict(
|
||||||
|
type='LoadAnnotations3D',
|
||||||
|
with_bbox_3d=True,
|
||||||
|
with_label_3d=True,
|
||||||
|
with_attr_label=False)
|
||||||
|
],
|
||||||
|
classes=CLASSES,
|
||||||
|
modality=input_modality,
|
||||||
|
test_mode=False,
|
||||||
|
use_valid_flag=True,
|
||||||
|
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset
|
||||||
|
# and box_type_3d='Depth' in sunrgbd and scannet dataset.
|
||||||
|
box_type_3d='LiDAR'),
|
||||||
|
pipeline=train_pipeline,
|
||||||
|
queue_length=queue_length,
|
||||||
|
),
|
||||||
|
val=dict(
|
||||||
|
imgs_per_gpu=1,
|
||||||
|
type=dataset_type,
|
||||||
|
data_source=dict(
|
||||||
|
type='Det3dSourceNuScenes',
|
||||||
|
data_root=data_root,
|
||||||
|
ann_file=data_root + 'nuscenes_infos_temporal_val.pkl',
|
||||||
|
pipeline=[
|
||||||
|
dict(
|
||||||
|
type='LoadMultiViewImageFromFiles',
|
||||||
|
to_float32=True,
|
||||||
|
backend='turbojpeg')
|
||||||
|
],
|
||||||
|
classes=CLASSES,
|
||||||
|
modality=input_modality,
|
||||||
|
test_mode=True),
|
||||||
|
pipeline=test_pipeline))
|
||||||
|
|
||||||
|
paramwise_cfg = {'img_backbone': dict(lr_mult=0.1)}
|
||||||
|
optimizer = dict(
|
||||||
|
type='AdamW', lr=2e-4, paramwise_options=paramwise_cfg, weight_decay=0.01)
|
||||||
|
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
|
||||||
|
# learning policy
|
||||||
|
lr_config = dict(
|
||||||
|
policy='CosineAnnealing',
|
||||||
|
warmup='linear',
|
||||||
|
warmup_iters=500,
|
||||||
|
warmup_ratio=1.0 / 3,
|
||||||
|
min_lr_ratio=1e-3)
|
||||||
|
total_epochs = 24
|
||||||
|
|
||||||
|
eval_config = dict(initial=False, interval=1, gpu_collect=False)
|
||||||
|
eval_pipelines = [
|
||||||
|
dict(
|
||||||
|
mode='test',
|
||||||
|
data=data['val'],
|
||||||
|
dist_eval=True,
|
||||||
|
evaluators=[
|
||||||
|
dict(
|
||||||
|
type='NuScenesEvaluator',
|
||||||
|
classes=CLASSES,
|
||||||
|
result_names=['pts_bbox'])
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
load_from = 'https://github.com/zhiqi-li/storage/releases/download/v1.0/r101_dcn_fcos3d_pretrain.pth'
|
||||||
|
log_config = dict(
|
||||||
|
interval=50,
|
||||||
|
hooks=[dict(type='TextLoggerHook'),
|
||||||
|
dict(type='TensorboardLoggerHook')])
|
||||||
|
|
||||||
|
checkpoint_config = dict(interval=1)
|
||||||
|
cudnn_benchmark = True
|
||||||
|
find_unused_parameters = True
|
@ -58,7 +58,7 @@ model = dict(
|
|||||||
bev_w=bev_w,
|
bev_w=bev_w,
|
||||||
num_query=900,
|
num_query=900,
|
||||||
num_query_one2many=1800,
|
num_query_one2many=1800,
|
||||||
one2many_gt_mul=4,
|
one2many_gt_mul=[4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
|
||||||
num_classes=10,
|
num_classes=10,
|
||||||
in_channels=embed_dim,
|
in_channels=embed_dim,
|
||||||
sync_cls_avg_factor=True,
|
sync_cls_avg_factor=True,
|
||||||
|
@ -8,3 +8,4 @@ Pretrained on [nuScenes](https://www.nuscenes.org/) dataset.
|
|||||||
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||||
| BEVFormer-base | [bevformer_base_r101_dcn_nuscenes](https://github.com/alibaba/EasyCV/tree/master/configs/detection3d/bevformer/bevformer_base_r101_dcn_nuscenes.py) | 69M | 23.9 | 52.46 | 41.83 | [model](http://pai-vision-data-hz.oss-accelerate.aliyuncs.com/EasyCV/modelzoo/detection3d/bevformer/epoch_24.pth) |
|
| BEVFormer-base | [bevformer_base_r101_dcn_nuscenes](https://github.com/alibaba/EasyCV/tree/master/configs/detection3d/bevformer/bevformer_base_r101_dcn_nuscenes.py) | 69M | 23.9 | 52.46 | 41.83 | [model](http://pai-vision-data-hz.oss-accelerate.aliyuncs.com/EasyCV/modelzoo/detection3d/bevformer/epoch_24.pth) |
|
||||||
| BEVFormer-base-hybrid | [bevformer_base_r101_dcn_nuscenes_hybrid](https://github.com/alibaba/EasyCV/blob/master/configs/detection3d/bevformer/bevformer_base_r101_dcn_nuscenes_hybrid.py) | 69M | 46.1 | 53.02 | 42.48 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection3d/bevformer_base_hybrid2/epoch_23.pth) |
|
| BEVFormer-base-hybrid | [bevformer_base_r101_dcn_nuscenes_hybrid](https://github.com/alibaba/EasyCV/blob/master/configs/detection3d/bevformer/bevformer_base_r101_dcn_nuscenes_hybrid.py) | 69M | 46.1 | 53.02 | 42.48 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection3d/bevformer_base_hybrid2/epoch_23.pth) |
|
||||||
|
| BEVFormer-base-blancehybrid | [bevformer_base_r101_dcn_nuscenes_blancehybrid](https://github.com/alibaba/EasyCV/blob/master/configs/detection3d/bevformer/bevformer_base_r101_dcn_nuscenes_blancehybrid.py) | 69M | 46.1 | 53.28 | 42.63 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection3d/bevformer_base_blancehybrid/epoch_23.pth) |
|
||||||
|
@ -646,20 +646,22 @@ class BEVFormerHead(AnchorFreeHead):
|
|||||||
|
|
||||||
gt_bboxes_list_aux = []
|
gt_bboxes_list_aux = []
|
||||||
gt_labels_list_aux = []
|
gt_labels_list_aux = []
|
||||||
for gt_bboxes, gt_labels in zip(gt_bboxes_list, gt_labels_list):
|
# for gt_bboxes, gt_labels in zip(gt_bboxes_list, gt_labels_list):
|
||||||
gt_bboxes_list_aux.append(
|
# gt_bboxes_list_aux.append(
|
||||||
gt_bboxes.repeat(self.one2many_gt_mul, 1))
|
# gt_bboxes.repeat(self.one2many_gt_mul, 1))
|
||||||
gt_labels_list_aux.append(
|
# gt_labels_list_aux.append(
|
||||||
gt_labels.repeat(self.one2many_gt_mul))
|
# gt_labels.repeat(self.one2many_gt_mul))
|
||||||
# for classwise multiply
|
# for classwise multiply
|
||||||
# for gt_bboxes, gt_labels in zip(gt_bboxes_list,gt_labels_list):
|
for gt_bboxes, gt_labels in zip(gt_bboxes_list, gt_labels_list):
|
||||||
# gt_bboxes_aux = []
|
gt_bboxes_aux = []
|
||||||
# gt_labels_aux = []
|
gt_labels_aux = []
|
||||||
# for gt_bbox, gt_label in zip(gt_bboxes, gt_labels):
|
for gt_bbox, gt_label in zip(gt_bboxes, gt_labels):
|
||||||
# gt_bboxes_aux += [gt_bbox]*self.one2many_gt_mul[gt_label]
|
gt_bboxes_aux += [gt_bbox] * self.one2many_gt_mul[gt_label]
|
||||||
# gt_labels_aux += [gt_label]*self.one2many_gt_mul[gt_label]
|
gt_labels_aux += [gt_label
|
||||||
# gt_bboxes_list_aux.append(torch.stack(gt_bboxes_aux))
|
] * self.one2many_gt_mul[gt_label]
|
||||||
# gt_labels_list_aux.append(torch.stack(gt_labels_aux))
|
gt_bboxes_list_aux.append(torch.stack(gt_bboxes_aux))
|
||||||
|
gt_labels_list_aux.append(torch.stack(gt_labels_aux))
|
||||||
|
|
||||||
all_gt_bboxes_list_aux = [
|
all_gt_bboxes_list_aux = [
|
||||||
gt_bboxes_list_aux for _ in range(num_dec_layers)
|
gt_bboxes_list_aux for _ in range(num_dec_layers)
|
||||||
]
|
]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user