mirror of https://github.com/alibaba/EasyCV.git
Improve the performance of bevformer (#224)
Improve the performance of bevformer * add hybrid brach (#232) Co-authored-by: yhq <yanhaiqiang.yhq@alibaba-inc.com>pull/240/head
parent
a36e0e32a4
commit
f8c9a9a1c9
|
@ -20,14 +20,16 @@ input_modality = dict(
|
|||
use_map=False,
|
||||
use_external=True)
|
||||
|
||||
_dim_ = 256
|
||||
_pos_dim_ = _dim_ // 2
|
||||
_ffn_dim_ = _dim_ * 2
|
||||
_num_levels_ = 4
|
||||
bev_h_ = 200
|
||||
bev_w_ = 200
|
||||
embed_dim = 256
|
||||
pos_dim = embed_dim // 2
|
||||
ffn_dim = embed_dim * 2
|
||||
num_levels = 4
|
||||
bev_h = 200
|
||||
bev_w = 200
|
||||
queue_length = 4 # each sequence contains `queue_length` frames.
|
||||
|
||||
adapt_jit = False # set True when export jit trace model or blade model
|
||||
|
||||
model = dict(
|
||||
type='BEVFormer',
|
||||
use_grid_mask=True,
|
||||
|
@ -47,18 +49,18 @@ model = dict(
|
|||
img_neck=dict(
|
||||
type='FPN',
|
||||
in_channels=[512, 1024, 2048],
|
||||
out_channels=_dim_,
|
||||
out_channels=embed_dim,
|
||||
start_level=0,
|
||||
add_extra_convs='on_output',
|
||||
num_outs=_num_levels_,
|
||||
num_outs=num_levels,
|
||||
relu_before_extra_convs=True),
|
||||
pts_bbox_head=dict(
|
||||
type='BEVFormerHead',
|
||||
bev_h=bev_h_,
|
||||
bev_w=bev_w_,
|
||||
bev_h=bev_h,
|
||||
bev_w=bev_w,
|
||||
num_query=900,
|
||||
num_classes=10,
|
||||
in_channels=_dim_,
|
||||
in_channels=embed_dim,
|
||||
sync_cls_avg_factor=True,
|
||||
with_box_refine=True,
|
||||
as_two_stage=False,
|
||||
|
@ -67,7 +69,7 @@ model = dict(
|
|||
rotate_prev_bev=True,
|
||||
use_shift=True,
|
||||
use_can_bus=True,
|
||||
embed_dims=_dim_,
|
||||
embed_dims=embed_dim,
|
||||
encoder=dict(
|
||||
type='BEVFormerEncoder',
|
||||
num_layers=6,
|
||||
|
@ -76,26 +78,28 @@ model = dict(
|
|||
return_intermediate=False,
|
||||
transformerlayers=dict(
|
||||
type='BEVFormerLayer',
|
||||
adapt_jit=adapt_jit,
|
||||
attn_cfgs=[
|
||||
dict(
|
||||
type='TemporalSelfAttention',
|
||||
embed_dims=_dim_,
|
||||
embed_dims=embed_dim,
|
||||
num_levels=1),
|
||||
dict(
|
||||
type='SpatialCrossAttention',
|
||||
pc_range=point_cloud_range,
|
||||
deformable_attention=dict(
|
||||
type='MSDeformableAttention3D',
|
||||
embed_dims=_dim_,
|
||||
embed_dims=embed_dim,
|
||||
num_points=8,
|
||||
num_levels=_num_levels_),
|
||||
embed_dims=_dim_,
|
||||
num_levels=num_levels,
|
||||
adapt_jit=adapt_jit),
|
||||
embed_dims=embed_dim,
|
||||
)
|
||||
],
|
||||
ffn_cfgs=dict(
|
||||
type='FFN',
|
||||
embed_dims=256,
|
||||
feedforward_channels=_ffn_dim_,
|
||||
feedforward_channels=ffn_dim,
|
||||
num_fcs=2,
|
||||
ffn_drop=0.1,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
|
@ -111,18 +115,19 @@ model = dict(
|
|||
attn_cfgs=[
|
||||
dict(
|
||||
type='MultiheadAttention',
|
||||
embed_dims=_dim_,
|
||||
embed_dims=embed_dim,
|
||||
num_heads=8,
|
||||
dropout=0.1),
|
||||
dict(
|
||||
type='CustomMSDeformableAttention',
|
||||
embed_dims=_dim_,
|
||||
num_levels=1),
|
||||
embed_dims=embed_dim,
|
||||
num_levels=1,
|
||||
adapt_jit=adapt_jit),
|
||||
],
|
||||
ffn_cfgs=dict(
|
||||
type='FFN',
|
||||
embed_dims=256,
|
||||
feedforward_channels=_ffn_dim_,
|
||||
feedforward_channels=ffn_dim,
|
||||
num_fcs=2,
|
||||
ffn_drop=0.1,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
|
@ -138,9 +143,9 @@ model = dict(
|
|||
num_classes=10),
|
||||
positional_encoding=dict(
|
||||
type='LearnedPositionalEncoding',
|
||||
num_feats=_pos_dim_,
|
||||
row_num_embed=bev_h_,
|
||||
col_num_embed=bev_w_,
|
||||
num_feats=pos_dim,
|
||||
row_num_embed=bev_h,
|
||||
col_num_embed=bev_w,
|
||||
),
|
||||
loss_cls=dict(
|
||||
type='FocalLoss',
|
||||
|
@ -217,6 +222,7 @@ test_pipeline = [
|
|||
data = dict(
|
||||
imgs_per_gpu=1, # 8gpus, total batch size=8
|
||||
workers_per_gpu=4,
|
||||
pin_memory=True,
|
||||
# shuffler_sampler=dict(type='DistributedGroupSampler'),
|
||||
# nonshuffler_sampler=dict(type='DistributedSampler'),
|
||||
train=dict(
|
||||
|
@ -226,7 +232,10 @@ data = dict(
|
|||
data_root=data_root,
|
||||
ann_file=data_root + 'nuscenes_infos_temporal_train.pkl',
|
||||
pipeline=[
|
||||
dict(type='LoadMultiViewImageFromFiles', to_float32=True),
|
||||
dict(
|
||||
type='LoadMultiViewImageFromFiles',
|
||||
to_float32=True,
|
||||
backend='turbojpeg'),
|
||||
dict(
|
||||
type='LoadAnnotations3D',
|
||||
with_bbox_3d=True,
|
||||
|
@ -251,7 +260,10 @@ data = dict(
|
|||
data_root=data_root,
|
||||
ann_file=data_root + 'nuscenes_infos_temporal_val.pkl',
|
||||
pipeline=[
|
||||
dict(type='LoadMultiViewImageFromFiles', to_float32=True)
|
||||
dict(
|
||||
type='LoadMultiViewImageFromFiles',
|
||||
to_float32=True,
|
||||
backend='turbojpeg')
|
||||
],
|
||||
classes=CLASSES,
|
||||
modality=input_modality,
|
||||
|
@ -295,3 +307,12 @@ log_config = dict(
|
|||
|
||||
checkpoint_config = dict(interval=1)
|
||||
cudnn_benchmark = True
|
||||
export = dict(
|
||||
type='blade',
|
||||
blade_config=dict(
|
||||
enable_fp16=True,
|
||||
fp16_fallback_op_ratio=0.0,
|
||||
customize_op_black_list=[
|
||||
'aten::select', 'aten::index', 'aten::slice', 'aten::view',
|
||||
'aten::upsample', 'aten::clamp'
|
||||
]))
|
||||
|
|
|
@ -0,0 +1,311 @@
|
|||
_base_ = ['configs/base.py']
|
||||
|
||||
# If point cloud range is changed, the models should also change their point
|
||||
# cloud range accordingly
|
||||
point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
|
||||
voxel_size = [0.2, 0.2, 8]
|
||||
|
||||
img_norm_cfg = dict(
|
||||
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
|
||||
# For nuScenes we usually do 10-class detection
|
||||
CLASSES = [
|
||||
'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
|
||||
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
|
||||
]
|
||||
|
||||
input_modality = dict(
|
||||
use_lidar=False,
|
||||
use_camera=True,
|
||||
use_radar=False,
|
||||
use_map=False,
|
||||
use_external=True)
|
||||
|
||||
embed_dim = 256
|
||||
pos_dim = embed_dim // 2
|
||||
ffn_dim = embed_dim * 2
|
||||
num_levels = 4
|
||||
bev_h = 200
|
||||
bev_w = 200
|
||||
queue_length = 4 # each sequence contains `queue_length` frames.
|
||||
|
||||
model = dict(
|
||||
type='BEVFormer',
|
||||
use_grid_mask=True,
|
||||
video_test_mode=True,
|
||||
img_backbone=dict(
|
||||
type='ResNet',
|
||||
depth=101,
|
||||
num_stages=4,
|
||||
out_indices=(2, 3, 4),
|
||||
frozen_stages=-1,
|
||||
norm_cfg=dict(type='BN', requires_grad=False),
|
||||
norm_eval=True,
|
||||
style='caffe',
|
||||
dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),
|
||||
stage_with_dcn=(False, False, True, True),
|
||||
zero_init_residual=True),
|
||||
img_neck=dict(
|
||||
type='FPN',
|
||||
in_channels=[512, 1024, 2048],
|
||||
out_channels=embed_dim,
|
||||
start_level=0,
|
||||
add_extra_convs='on_output',
|
||||
num_outs=num_levels,
|
||||
relu_before_extra_convs=True),
|
||||
pts_bbox_head=dict(
|
||||
type='BEVFormerHead',
|
||||
bev_h=bev_h,
|
||||
bev_w=bev_w,
|
||||
num_query=900,
|
||||
num_query_one2many=1800,
|
||||
one2many_gt_mul=4,
|
||||
num_classes=10,
|
||||
in_channels=embed_dim,
|
||||
sync_cls_avg_factor=True,
|
||||
with_box_refine=True,
|
||||
as_two_stage=False,
|
||||
transformer=dict(
|
||||
type='PerceptionTransformer',
|
||||
rotate_prev_bev=True,
|
||||
use_shift=True,
|
||||
use_can_bus=True,
|
||||
embed_dims=embed_dim,
|
||||
encoder=dict(
|
||||
type='BEVFormerEncoder',
|
||||
num_layers=6,
|
||||
pc_range=point_cloud_range,
|
||||
num_points_in_pillar=4,
|
||||
return_intermediate=False,
|
||||
transformerlayers=dict(
|
||||
type='BEVFormerLayer',
|
||||
attn_cfgs=[
|
||||
dict(
|
||||
type='TemporalSelfAttention',
|
||||
embed_dims=embed_dim,
|
||||
num_levels=1),
|
||||
dict(
|
||||
type='SpatialCrossAttention',
|
||||
pc_range=point_cloud_range,
|
||||
deformable_attention=dict(
|
||||
type='MSDeformableAttention3D',
|
||||
embed_dims=embed_dim,
|
||||
num_points=8,
|
||||
num_levels=num_levels),
|
||||
embed_dims=embed_dim,
|
||||
)
|
||||
],
|
||||
ffn_cfgs=dict(
|
||||
type='FFN',
|
||||
embed_dims=256,
|
||||
feedforward_channels=ffn_dim,
|
||||
num_fcs=2,
|
||||
ffn_drop=0.1,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
),
|
||||
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
|
||||
'ffn', 'norm'))),
|
||||
decoder=dict(
|
||||
type='Detr3DTransformerDecoder',
|
||||
num_layers=6,
|
||||
return_intermediate=True,
|
||||
transformerlayers=dict(
|
||||
type='DetrTransformerDecoderLayer',
|
||||
attn_cfgs=[
|
||||
dict(
|
||||
type='MultiheadAttention',
|
||||
embed_dims=embed_dim,
|
||||
num_heads=8,
|
||||
dropout=0.1),
|
||||
dict(
|
||||
type='CustomMSDeformableAttention',
|
||||
embed_dims=embed_dim,
|
||||
num_levels=1),
|
||||
],
|
||||
ffn_cfgs=dict(
|
||||
type='FFN',
|
||||
embed_dims=256,
|
||||
feedforward_channels=ffn_dim,
|
||||
num_fcs=2,
|
||||
ffn_drop=0.1,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
),
|
||||
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
|
||||
'ffn', 'norm')))),
|
||||
bbox_coder=dict(
|
||||
type='NMSFreeBBoxCoder',
|
||||
post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
|
||||
pc_range=point_cloud_range,
|
||||
max_num=300,
|
||||
voxel_size=voxel_size,
|
||||
num_classes=10),
|
||||
positional_encoding=dict(
|
||||
type='LearnedPositionalEncoding',
|
||||
num_feats=pos_dim,
|
||||
row_num_embed=bev_h,
|
||||
col_num_embed=bev_w,
|
||||
),
|
||||
loss_cls=dict(
|
||||
type='FocalLoss',
|
||||
use_sigmoid=True,
|
||||
gamma=2.0,
|
||||
alpha=0.25,
|
||||
loss_weight=2.0),
|
||||
# loss_bbox=dict(type='L1Loss', loss_weight=0.25),
|
||||
# loss_bbox=dict(type='SmoothL1Loss', loss_weight=0.25),
|
||||
loss_bbox=dict(type='BalancedL1Loss', loss_weight=0.25, gamma=1),
|
||||
loss_iou=dict(type='GIoULoss', loss_weight=0.0)),
|
||||
# model training and testing settings
|
||||
train_cfg=dict(
|
||||
pts=dict(
|
||||
grid_size=[512, 512, 1],
|
||||
voxel_size=voxel_size,
|
||||
point_cloud_range=point_cloud_range,
|
||||
out_size_factor=4,
|
||||
assigner=dict(
|
||||
type='HungarianBBoxAssigner3D',
|
||||
cls_cost=dict(type='FocalLossCost', weight=2.0),
|
||||
reg_cost=dict(type='BBox3DL1Cost', weight=0.25),
|
||||
iou_cost=dict(
|
||||
type='IoUCost', weight=0.0
|
||||
), # Fake cost. This is just to make it compatible with DETR head.
|
||||
pc_range=point_cloud_range))))
|
||||
|
||||
dataset_type = 'NuScenesDataset'
|
||||
data_root = 'data/nuscenes/train-val/'
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='PhotoMetricDistortionMultiViewImage'),
|
||||
|
||||
# dict(type='RandomScaleImageMultiViewImage', scales=[0.8,0.9,1.0,1.1,1.2]),
|
||||
dict(type='RandomHorizontalFlipMultiViewImage'),
|
||||
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
|
||||
dict(type='ObjectNameFilter', classes=CLASSES),
|
||||
dict(type='NormalizeMultiviewImage', **img_norm_cfg),
|
||||
dict(type='PadMultiViewImage', size_divisor=32),
|
||||
dict(type='DefaultFormatBundle3D', class_names=CLASSES),
|
||||
dict(
|
||||
type='Collect3D',
|
||||
keys=['gt_bboxes_3d', 'gt_labels_3d', 'img'],
|
||||
meta_keys=('filename', 'ori_shape', 'img_shape', 'lidar2img',
|
||||
'depth2img', 'cam2img', 'pad_shape', 'scale_factor', 'flip',
|
||||
'pcd_horizontal_flip', 'pcd_vertical_flip', 'box_mode_3d',
|
||||
'box_type_3d', 'img_norm_cfg', 'pcd_trans', 'sample_idx',
|
||||
'prev_idx', 'next_idx', 'pcd_scale_factor', 'pcd_rotation',
|
||||
'pts_filename', 'transformation_3d_flow', 'scene_token',
|
||||
'can_bus'))
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='NormalizeMultiviewImage', **img_norm_cfg),
|
||||
dict(type='PadMultiViewImage', size_divisor=32),
|
||||
dict(
|
||||
type='MultiScaleFlipAug3D',
|
||||
img_scale=(1600, 900),
|
||||
pts_scale_ratio=1,
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(
|
||||
type='DefaultFormatBundle3D',
|
||||
class_names=CLASSES,
|
||||
with_label=False),
|
||||
dict(
|
||||
type='Collect3D',
|
||||
keys=['img'],
|
||||
meta_keys=('filename', 'ori_shape', 'img_shape', 'lidar2img',
|
||||
'depth2img', 'cam2img', 'pad_shape', 'scale_factor',
|
||||
'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip',
|
||||
'box_mode_3d', 'box_type_3d', 'img_norm_cfg',
|
||||
'pcd_trans', 'sample_idx', 'prev_idx', 'next_idx',
|
||||
'pcd_scale_factor', 'pcd_rotation', 'pts_filename',
|
||||
'transformation_3d_flow', 'scene_token', 'can_bus'))
|
||||
])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
imgs_per_gpu=1, # 8gpus, total batch size=8
|
||||
workers_per_gpu=8,
|
||||
pin_memory=True,
|
||||
# shuffler_sampler=dict(type='DistributedGroupSampler'),
|
||||
# nonshuffler_sampler=dict(type='DistributedSampler'),
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_source=dict(
|
||||
type='Det3dSourceNuScenes',
|
||||
data_root=data_root,
|
||||
ann_file=data_root + 'nuscenes_infos_temporal_train.pkl',
|
||||
pipeline=[
|
||||
dict(
|
||||
type='LoadMultiViewImageFromFiles',
|
||||
to_float32=True,
|
||||
backend='turbojpeg'),
|
||||
dict(
|
||||
type='LoadAnnotations3D',
|
||||
with_bbox_3d=True,
|
||||
with_label_3d=True,
|
||||
with_attr_label=False)
|
||||
],
|
||||
classes=CLASSES,
|
||||
modality=input_modality,
|
||||
test_mode=False,
|
||||
use_valid_flag=True,
|
||||
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset
|
||||
# and box_type_3d='Depth' in sunrgbd and scannet dataset.
|
||||
box_type_3d='LiDAR'),
|
||||
pipeline=train_pipeline,
|
||||
queue_length=queue_length,
|
||||
),
|
||||
val=dict(
|
||||
imgs_per_gpu=1,
|
||||
type=dataset_type,
|
||||
data_source=dict(
|
||||
type='Det3dSourceNuScenes',
|
||||
data_root=data_root,
|
||||
ann_file=data_root + 'nuscenes_infos_temporal_val.pkl',
|
||||
pipeline=[
|
||||
dict(
|
||||
type='LoadMultiViewImageFromFiles',
|
||||
to_float32=True,
|
||||
backend='turbojpeg')
|
||||
],
|
||||
classes=CLASSES,
|
||||
modality=input_modality,
|
||||
test_mode=True),
|
||||
pipeline=test_pipeline))
|
||||
|
||||
paramwise_cfg = {'img_backbone': dict(lr_mult=0.1)}
|
||||
optimizer = dict(
|
||||
type='AdamW', lr=2e-4, paramwise_options=paramwise_cfg, weight_decay=0.01)
|
||||
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='CosineAnnealing',
|
||||
warmup='linear',
|
||||
warmup_iters=500,
|
||||
warmup_ratio=1.0 / 3,
|
||||
min_lr_ratio=1e-3)
|
||||
total_epochs = 24
|
||||
|
||||
eval_config = dict(initial=False, interval=1, gpu_collect=False)
|
||||
eval_pipelines = [
|
||||
dict(
|
||||
mode='test',
|
||||
data=data['val'],
|
||||
dist_eval=True,
|
||||
evaluators=[
|
||||
dict(
|
||||
type='NuScenesEvaluator',
|
||||
classes=CLASSES,
|
||||
result_names=['pts_bbox'])
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
load_from = 'https://github.com/zhiqi-li/storage/releases/download/v1.0/r101_dcn_fcos3d_pretrain.pth'
|
||||
log_config = dict(
|
||||
interval=50,
|
||||
hooks=[dict(type='TextLoggerHook'),
|
||||
dict(type='TensorboardLoggerHook')])
|
||||
|
||||
checkpoint_config = dict(interval=1)
|
||||
cudnn_benchmark = True
|
||||
find_unused_parameters = True
|
|
@ -29,18 +29,19 @@ input_modality = dict(
|
|||
use_map=False,
|
||||
use_external=True)
|
||||
|
||||
_dim_ = 256
|
||||
_pos_dim_ = _dim_ // 2
|
||||
_ffn_dim_ = _dim_ * 2
|
||||
_num_levels_ = 1
|
||||
bev_h_ = 50
|
||||
bev_w_ = 50
|
||||
embed_dim = 256
|
||||
pos_dim = embed_dim // 2
|
||||
ffn_dim = embed_dim * 2
|
||||
num_levels = 1
|
||||
bev_h = 50
|
||||
bev_w = 50
|
||||
queue_length = 3 # each sequence contains `queue_length` frames.
|
||||
|
||||
model = dict(
|
||||
type='BEVFormer',
|
||||
use_grid_mask=True,
|
||||
video_test_mode=True,
|
||||
extract_feat_serially=True,
|
||||
pretrained=dict(img='torchvision://resnet50'),
|
||||
img_backbone=dict(
|
||||
type='ResNet',
|
||||
|
@ -56,18 +57,18 @@ model = dict(
|
|||
img_neck=dict(
|
||||
type='FPN',
|
||||
in_channels=[2048],
|
||||
out_channels=_dim_,
|
||||
out_channels=embed_dim,
|
||||
start_level=0,
|
||||
add_extra_convs='on_output',
|
||||
num_outs=_num_levels_,
|
||||
num_outs=num_levels,
|
||||
relu_before_extra_convs=True),
|
||||
pts_bbox_head=dict(
|
||||
type='BEVFormerHead',
|
||||
bev_h=bev_h_,
|
||||
bev_w=bev_w_,
|
||||
bev_h=bev_h,
|
||||
bev_w=bev_w,
|
||||
num_query=900,
|
||||
num_classes=10,
|
||||
in_channels=_dim_,
|
||||
in_channels=embed_dim,
|
||||
sync_cls_avg_factor=True,
|
||||
with_box_refine=True,
|
||||
as_two_stage=False,
|
||||
|
@ -76,7 +77,7 @@ model = dict(
|
|||
rotate_prev_bev=True,
|
||||
use_shift=True,
|
||||
use_can_bus=True,
|
||||
embed_dims=_dim_,
|
||||
embed_dims=embed_dim,
|
||||
encoder=dict(
|
||||
type='BEVFormerEncoder',
|
||||
num_layers=3,
|
||||
|
@ -88,23 +89,23 @@ model = dict(
|
|||
attn_cfgs=[
|
||||
dict(
|
||||
type='TemporalSelfAttention',
|
||||
embed_dims=_dim_,
|
||||
embed_dims=embed_dim,
|
||||
num_levels=1),
|
||||
dict(
|
||||
type='SpatialCrossAttention',
|
||||
pc_range=point_cloud_range,
|
||||
deformable_attention=dict(
|
||||
type='MSDeformableAttention3D',
|
||||
embed_dims=_dim_,
|
||||
embed_dims=embed_dim,
|
||||
num_points=8,
|
||||
num_levels=_num_levels_),
|
||||
embed_dims=_dim_,
|
||||
num_levels=num_levels),
|
||||
embed_dims=embed_dim,
|
||||
)
|
||||
],
|
||||
ffn_cfgs=dict(
|
||||
type='FFN',
|
||||
embed_dims=256,
|
||||
feedforward_channels=_ffn_dim_,
|
||||
feedforward_channels=ffn_dim,
|
||||
num_fcs=2,
|
||||
ffn_drop=0.1,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
|
@ -120,18 +121,18 @@ model = dict(
|
|||
attn_cfgs=[
|
||||
dict(
|
||||
type='MultiheadAttention',
|
||||
embed_dims=_dim_,
|
||||
embed_dims=embed_dim,
|
||||
num_heads=8,
|
||||
dropout=0.1),
|
||||
dict(
|
||||
type='CustomMSDeformableAttention',
|
||||
embed_dims=_dim_,
|
||||
embed_dims=embed_dim,
|
||||
num_levels=1),
|
||||
],
|
||||
ffn_cfgs=dict(
|
||||
type='FFN',
|
||||
embed_dims=256,
|
||||
feedforward_channels=_ffn_dim_,
|
||||
feedforward_channels=ffn_dim,
|
||||
num_fcs=2,
|
||||
ffn_drop=0.1,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
|
@ -147,9 +148,9 @@ model = dict(
|
|||
num_classes=10),
|
||||
positional_encoding=dict(
|
||||
type='LearnedPositionalEncoding',
|
||||
num_feats=_pos_dim_,
|
||||
row_num_embed=bev_h_,
|
||||
col_num_embed=bev_w_,
|
||||
num_feats=pos_dim,
|
||||
row_num_embed=bev_h,
|
||||
col_num_embed=bev_w,
|
||||
),
|
||||
loss_cls=dict(
|
||||
type='FocalLoss',
|
||||
|
@ -179,11 +180,12 @@ dataset_type = 'NuScenesDataset'
|
|||
data_root = 'data/nuScenes/nuscenes-v1.0/'
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='RandomScaleImageMultiViewImage', scales=[0.5]),
|
||||
dict(type='PhotoMetricDistortionMultiViewImage'),
|
||||
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
|
||||
dict(type='ObjectNameFilter', classes=CLASSES),
|
||||
dict(type='NormalizeMultiviewImage', **img_norm_cfg),
|
||||
dict(type='RandomScaleImageMultiViewImage', scales=[0.5]),
|
||||
# dict(type='RandomScaleImageMultiViewImage', scales=[0.5]),
|
||||
dict(type='PadMultiViewImage', size_divisor=32),
|
||||
dict(type='DefaultFormatBundle3D', class_names=CLASSES),
|
||||
dict(
|
||||
|
@ -228,6 +230,7 @@ test_pipeline = [
|
|||
data = dict(
|
||||
imgs_per_gpu=1, # 8gpus, total batch size=8
|
||||
workers_per_gpu=4,
|
||||
pin_memory=True,
|
||||
# shuffler_sampler=dict(type='DistributedGroupSampler'),
|
||||
# nonshuffler_sampler=dict(type='DistributedSampler'),
|
||||
train=dict(
|
||||
|
@ -237,7 +240,10 @@ data = dict(
|
|||
data_root=data_root,
|
||||
ann_file=data_root + 'nuscenes_infos_temporal_train.pkl',
|
||||
pipeline=[
|
||||
dict(type='LoadMultiViewImageFromFiles', to_float32=True),
|
||||
dict(
|
||||
type='LoadMultiViewImageFromFiles',
|
||||
to_float32=True,
|
||||
backend='turbojpeg'),
|
||||
dict(
|
||||
type='LoadAnnotations3D',
|
||||
with_bbox_3d=True,
|
||||
|
@ -262,7 +268,10 @@ data = dict(
|
|||
data_root=data_root,
|
||||
ann_file=data_root + 'nuscenes_infos_temporal_val.pkl',
|
||||
pipeline=[
|
||||
dict(type='LoadMultiViewImageFromFiles', to_float32=True)
|
||||
dict(
|
||||
type='LoadMultiViewImageFromFiles',
|
||||
to_float32=True,
|
||||
backend='turbojpeg')
|
||||
],
|
||||
classes=CLASSES,
|
||||
modality=input_modality,
|
||||
|
@ -305,3 +314,12 @@ log_config = dict(
|
|||
|
||||
checkpoint_config = dict(interval=1)
|
||||
cudnn_benchmark = True
|
||||
export = dict(
|
||||
export_type='blade',
|
||||
blade_config=dict(
|
||||
enable_fp16=True,
|
||||
fp16_fallback_op_ratio=0.0,
|
||||
customize_op_black_list=[
|
||||
'aten::select', 'aten::index', 'aten::slice', 'aten::view',
|
||||
'aten::upsample', 'aten::clamp'
|
||||
]))
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
_base_ = ['./bevformer_tiny_r50_nuscenes.py']
|
||||
|
||||
paramwise_cfg = {'img_backbone': dict(lr_mult=0.1)}
|
||||
optimizer = dict(
|
||||
type='AdamW',
|
||||
lr=2.8e-4,
|
||||
paramwise_options=paramwise_cfg,
|
||||
weight_decay=0.01)
|
||||
|
||||
optimizer_config = dict(
|
||||
grad_clip=dict(max_norm=35, norm_type=2), loss_scale=512.)
|
|
@ -2,22 +2,21 @@
|
|||
import copy
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
from collections import OrderedDict
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms.functional as t_f
|
||||
from mmcv.utils import Config
|
||||
|
||||
from easycv.file import io
|
||||
from easycv.framework.errors import ValueError
|
||||
from easycv.models import (DINO, MOCO, SWAV, YOLOX, Classification, MoBY,
|
||||
build_model)
|
||||
from easycv.framework.errors import NotImplementedError, ValueError
|
||||
from easycv.models import (DINO, MOCO, SWAV, YOLOX, BEVFormer, Classification,
|
||||
MoBY, build_model)
|
||||
from easycv.utils.checkpoint import load_checkpoint
|
||||
from easycv.utils.misc import reparameterize_models
|
||||
from easycv.utils.misc import encode_str_to_tensor
|
||||
|
||||
__all__ = [
|
||||
'export',
|
||||
|
@ -27,7 +26,7 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
def export(cfg, ckpt_path, filename):
|
||||
def export(cfg, ckpt_path, filename, **kwargs):
|
||||
""" export model for inference
|
||||
|
||||
Args:
|
||||
|
@ -42,20 +41,22 @@ def export(cfg, ckpt_path, filename):
|
|||
cfg.model.backbone.pretrained = False
|
||||
|
||||
if isinstance(model, MOCO) or isinstance(model, DINO):
|
||||
_export_moco(model, cfg, filename)
|
||||
_export_moco(model, cfg, filename, **kwargs)
|
||||
elif isinstance(model, MoBY):
|
||||
_export_moby(model, cfg, filename)
|
||||
_export_moby(model, cfg, filename, **kwargs)
|
||||
elif isinstance(model, SWAV):
|
||||
_export_swav(model, cfg, filename)
|
||||
_export_swav(model, cfg, filename, **kwargs)
|
||||
elif isinstance(model, Classification):
|
||||
_export_cls(model, cfg, filename)
|
||||
_export_cls(model, cfg, filename, **kwargs)
|
||||
elif isinstance(model, YOLOX):
|
||||
_export_yolox(model, cfg, filename)
|
||||
_export_yolox(model, cfg, filename, **kwargs)
|
||||
elif isinstance(model, BEVFormer):
|
||||
_export_bevformer(model, cfg, filename, **kwargs)
|
||||
elif hasattr(cfg, 'export') and getattr(cfg.export, 'use_jit', False):
|
||||
export_jit_model(model, cfg, filename)
|
||||
export_jit_model(model, cfg, filename, **kwargs)
|
||||
return
|
||||
else:
|
||||
_export_common(model, cfg, filename)
|
||||
_export_common(model, cfg, filename, **kwargs)
|
||||
|
||||
|
||||
def _export_common(model, cfg, filename):
|
||||
|
@ -179,6 +180,7 @@ def _export_yolox(model, cfg, filename):
|
|||
model.export_type = export_type
|
||||
|
||||
if export_type != 'raw':
|
||||
from easycv.utils.misc import reparameterize_models
|
||||
# only when we use jit or blade, we need to reparameterize_models before export
|
||||
model = reparameterize_models(model)
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
@ -517,6 +519,100 @@ def export_jit_model(model, cfg, filename):
|
|||
torch.jit.save(model_jit, ofile)
|
||||
|
||||
|
||||
def _export_bevformer(model, cfg, filename, fp16=False):
|
||||
if not cfg.adapt_jit:
|
||||
raise ValueError(
|
||||
'"cfg.adapt_jit" must be True when export jit trace or blade model.'
|
||||
)
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model = copy.deepcopy(model)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
def _dummy_inputs():
|
||||
# dummy inputs
|
||||
bacth_size, queue_len, cams_num = 1, 1, 6
|
||||
img_size = (928, 1600)
|
||||
img = torch.rand([cams_num, 3, img_size[0], img_size[1]]).to(device)
|
||||
can_bus = torch.rand([18]).to(device)
|
||||
lidar2img = torch.rand([6, 4, 4]).to(device)
|
||||
img_shape = torch.tensor([[img_size[0], img_size[1], 3]] *
|
||||
cams_num).to(device)
|
||||
dummy_scene_token = 'dummy_scene_token'
|
||||
scene_token = encode_str_to_tensor(dummy_scene_token).to(device)
|
||||
prev_scene_token = scene_token
|
||||
prev_bev = torch.rand([cfg.bev_h * cfg.bev_w, 1,
|
||||
cfg.embed_dim]).to(device)
|
||||
prev_pos = torch.tensor(0)
|
||||
prev_angle = torch.tensor(0)
|
||||
img_metas = {
|
||||
'can_bus': can_bus,
|
||||
'lidar2img': lidar2img,
|
||||
'img_shape': img_shape,
|
||||
'scene_token': scene_token,
|
||||
'prev_bev': prev_bev,
|
||||
'prev_pos': prev_pos,
|
||||
'prev_angle': prev_angle,
|
||||
'prev_scene_token': prev_scene_token
|
||||
}
|
||||
return img, img_metas
|
||||
|
||||
dummy_inputs = _dummy_inputs()
|
||||
|
||||
def _trace_model():
|
||||
with torch.no_grad():
|
||||
model.forward = model.forward_export
|
||||
trace_model = torch.jit.trace(
|
||||
model, copy.deepcopy(dummy_inputs), check_trace=False)
|
||||
return trace_model
|
||||
|
||||
export_type = cfg.export.get('type')
|
||||
if export_type in ['jit', 'blade']:
|
||||
if fp16:
|
||||
with torch.cuda.amp.autocast():
|
||||
trace_model = _trace_model()
|
||||
else:
|
||||
trace_model = _trace_model()
|
||||
torch.jit.save(trace_model, filename + '.jit')
|
||||
else:
|
||||
raise NotImplementedError(f'Not support export type {export_type}!')
|
||||
|
||||
if export_type == 'jit':
|
||||
return
|
||||
|
||||
blade_config = cfg.export.get('blade_config')
|
||||
|
||||
from easycv.toolkit.blade import blade_env_assert, blade_optimize
|
||||
assert blade_env_assert()
|
||||
|
||||
def _get_blade_model():
|
||||
blade_model = blade_optimize(
|
||||
speed_test_model=model,
|
||||
model=trace_model,
|
||||
inputs=copy.deepcopy(dummy_inputs),
|
||||
blade_config=blade_config,
|
||||
static_opt=False,
|
||||
min_num_nodes=None, # 50
|
||||
check_inputs=False,
|
||||
fp16=fp16)
|
||||
return blade_model
|
||||
|
||||
# optimize model with blade
|
||||
if fp16:
|
||||
with torch.cuda.amp.autocast():
|
||||
blade_model = _get_blade_model()
|
||||
else:
|
||||
blade_model = _get_blade_model()
|
||||
|
||||
# save blade code and graph
|
||||
# with io.open(filename + '.blade.code.py', 'w') as ofile:
|
||||
# ofile.write(blade_model.forward.code)
|
||||
# with io.open(filename + '.blade.graph.txt', 'w') as ofile:
|
||||
# ofile.write(blade_model.forward.graph)
|
||||
with io.open(filename + '.blade', 'wb') as ofile:
|
||||
torch.jit.save(blade_model, ofile)
|
||||
|
||||
|
||||
def replace_syncbn(backbone_cfg):
|
||||
if 'norm_cfg' in backbone_cfg.keys():
|
||||
if backbone_cfg['norm_cfg']['type'] == 'SyncBN':
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numba
|
||||
# import numba
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.ops import nms, nms_rotated
|
||||
|
@ -179,7 +179,7 @@ def aligned_3d_nms(boxes, scores, classes, thresh):
|
|||
return indices
|
||||
|
||||
|
||||
@numba.jit(nopython=True)
|
||||
# @numba.jit(nopython=True)
|
||||
def circle_nms(dets, thresh, post_max_size=83):
|
||||
"""Circular NMS.
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import concurrent.futures
|
||||
import copy
|
||||
import logging
|
||||
import random
|
||||
import tempfile
|
||||
from os import path as osp
|
||||
|
@ -14,6 +16,7 @@ from easycv.core.bbox import Box3DMode, Coord3DMode
|
|||
from easycv.datasets.registry import DATASETS
|
||||
from easycv.datasets.shared.base import BaseDataset
|
||||
from easycv.datasets.shared.pipelines import Compose
|
||||
from easycv.datasets.shared.pipelines.format import to_tensor
|
||||
from .utils import extract_result_dict
|
||||
|
||||
|
||||
|
@ -50,6 +53,7 @@ class NuScenesDataset(BaseDataset):
|
|||
self.eval_detection_configs = config_factory(self.eval_version)
|
||||
self.flag = np.zeros(
|
||||
len(self), dtype=np.uint8) # for DistributedGroupSampler
|
||||
self.pipeline_cfg = pipeline
|
||||
|
||||
def _format_bbox(self, results, jsonfile_prefix=None):
|
||||
"""Convert the results to the standard format.
|
||||
|
@ -309,6 +313,9 @@ class NuScenesDataset(BaseDataset):
|
|||
prev_scene_token = None
|
||||
prev_pos = None
|
||||
prev_angle = None
|
||||
|
||||
can_bus_list = []
|
||||
lidar2img_list = []
|
||||
for i, each in enumerate(queue):
|
||||
metas_map[i] = each['img_metas'].data
|
||||
if metas_map[i]['scene_token'] != prev_scene_token:
|
||||
|
@ -326,28 +333,75 @@ class NuScenesDataset(BaseDataset):
|
|||
metas_map[i]['can_bus'][-1] -= prev_angle
|
||||
prev_pos = copy.deepcopy(tmp_pos)
|
||||
prev_angle = copy.deepcopy(tmp_angle)
|
||||
|
||||
can_bus_list.append(to_tensor(metas_map[i]['can_bus']))
|
||||
lidar2img_list.append(to_tensor(metas_map[i]['lidar2img']))
|
||||
|
||||
queue[-1]['img'] = DC(
|
||||
torch.stack(imgs_list), cpu_only=False, stack=True)
|
||||
queue[-1]['img_metas'] = DC(metas_map, cpu_only=True)
|
||||
queue[-1]['can_bus'] = DC(torch.stack(can_bus_list), cpu_only=False)
|
||||
queue[-1]['lidar2img'] = DC(
|
||||
torch.stack(lidar2img_list), cpu_only=False)
|
||||
queue = queue[-1]
|
||||
return queue
|
||||
|
||||
@staticmethod
|
||||
def _get_single_data(i,
|
||||
data_source,
|
||||
pipeline,
|
||||
flip_flag=False,
|
||||
scale=None):
|
||||
i = max(0, i)
|
||||
try:
|
||||
data = data_source[i]
|
||||
data['flip_flag'] = flip_flag
|
||||
if scale:
|
||||
data['resize_scale'] = scale
|
||||
data = pipeline(data)
|
||||
if data is None or ~(data['gt_labels_3d']._data != -1).any():
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
return None
|
||||
return i, data
|
||||
|
||||
def _get_queue_data(self, idx):
|
||||
queue = []
|
||||
idx_list = list(range(idx - self.queue_length, idx))
|
||||
random.shuffle(idx_list)
|
||||
idx_list = sorted(idx_list[1:])
|
||||
idx_list.append(idx)
|
||||
for i in idx_list:
|
||||
i = max(0, i)
|
||||
try:
|
||||
data = self.data_source[i]
|
||||
data = self.pipeline(data)
|
||||
if data is None or ~(data['gt_labels_3d']._data != -1).any():
|
||||
return None
|
||||
except Exception as e:
|
||||
return None
|
||||
queue.append(data)
|
||||
|
||||
flip_flag = False
|
||||
scale = None
|
||||
for member in self.pipeline_cfg:
|
||||
|
||||
if member['type'] == 'RandomScaleImageMultiViewImage':
|
||||
scales = member['scales']
|
||||
rand_ind = np.random.permutation(range(len(scales)))[0]
|
||||
scale = scales[rand_ind]
|
||||
if member['type'] == 'RandomHorizontalFlipMultiViewImage':
|
||||
flip_flag = np.random.rand() >= 0.5
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=len(idx_list)) as executor:
|
||||
threads = []
|
||||
for i in idx_list:
|
||||
future = executor.submit(self._get_single_data, i,
|
||||
self.data_source, self.pipeline,
|
||||
flip_flag, scale)
|
||||
threads.append(future)
|
||||
|
||||
for future in concurrent.futures.as_completed(threads):
|
||||
queue.append(future.result())
|
||||
|
||||
if None in queue:
|
||||
return None
|
||||
|
||||
queue = sorted(queue, key=lambda item: item[0])
|
||||
queue = [item[1] for item in queue]
|
||||
|
||||
return self.union2one(queue)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
@ -358,6 +412,18 @@ class NuScenesDataset(BaseDataset):
|
|||
data_dict = self.data_source[idx]
|
||||
data_dict = self.pipeline(data_dict)
|
||||
|
||||
can_bus_list, lidar2img_list = [], []
|
||||
for i in range(len(data_dict['img_metas'])):
|
||||
can_bus_list.append(
|
||||
to_tensor(data_dict['img_metas'][i]._data['can_bus']))
|
||||
lidar2img_list.append(
|
||||
to_tensor(
|
||||
data_dict['img_metas'][i]._data['lidar2img']))
|
||||
data_dict['can_bus'] = DC(
|
||||
torch.stack(can_bus_list), cpu_only=False)
|
||||
data_dict['lidar2img'] = DC(
|
||||
torch.stack(lidar2img_list), cpu_only=False)
|
||||
|
||||
if data_dict is None:
|
||||
idx = self._rand_another(idx)
|
||||
continue
|
||||
|
|
|
@ -0,0 +1,186 @@
|
|||
######################################################################
|
||||
# Copyright (c) 2022 OpenPerceptionX. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
######################################################################
|
||||
|
||||
######################################################################
|
||||
# This file includes concrete implementation for different data augmentation
|
||||
# methods in transforms.py.
|
||||
######################################################################
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
# Available interpolation modes (opencv)
|
||||
cv2_interp_codes = {
|
||||
'nearest': cv2.INTER_NEAREST,
|
||||
'bilinear': cv2.INTER_LINEAR,
|
||||
'bicubic': cv2.INTER_CUBIC,
|
||||
'area': cv2.INTER_AREA,
|
||||
'lanczos': cv2.INTER_LANCZOS4
|
||||
}
|
||||
|
||||
|
||||
def scale_image_multiple_view(
|
||||
imgs: List[np.ndarray],
|
||||
cam_intrinsics: List[np.ndarray],
|
||||
# cam_extrinsics: List[np.ndarray],
|
||||
lidar2img: List[np.ndarray],
|
||||
rand_scale: float,
|
||||
interpolation='bilinear'
|
||||
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
|
||||
"""Resize the multiple-view images with the same scale selected randomly.
|
||||
Notably used in :class:`.transforms.RandomScaleImageMultiViewImage_naive
|
||||
Args:
|
||||
imgs (list of numpy.array): Multiple-view images to be resized. len(img) is the number of cameras.
|
||||
img shape: [H, W, 3].
|
||||
cam_intrinsics (list of numpy.array): Intrinsic parameters of different cameras. Transformations from camera
|
||||
to image. len(cam_intrinsics) is the number of camera. For each camera, shape is 4 * 4.
|
||||
cam_extrinsics (list of numpy.array): Extrinsic parameters of different cameras. Transformations from
|
||||
lidar to cameras. len(cam_extrinsics) is the number of camera. For each camera, shape is 4 * 4.
|
||||
lidar2img (list of numpy.array): Transformations from lidar to images. len(lidar2img) is the number
|
||||
of camera. For each camera, shape is 4 * 4.
|
||||
rand_scale (float): resize ratio
|
||||
interpolation (string): mode for interpolation in opencv.
|
||||
Returns:
|
||||
imgs_new (list of numpy.array): Updated multiple-view images
|
||||
cam_intrinsics_new (list of numpy.array): Updated intrinsic parameters of different cameras.
|
||||
lidar2img_new (list of numpy.array): Updated Transformations from lidar to images.
|
||||
"""
|
||||
y_size = [int(img.shape[0] * rand_scale) for img in imgs]
|
||||
x_size = [int(img.shape[1] * rand_scale) for img in imgs]
|
||||
scale_factor = np.eye(4)
|
||||
scale_factor[0, 0] *= rand_scale
|
||||
scale_factor[1, 1] *= rand_scale
|
||||
imgs_new = [
|
||||
cv2.resize(
|
||||
img, (x_size[idx], y_size[idx]),
|
||||
interpolation=cv2_interp_codes[interpolation])
|
||||
for idx, img in enumerate(imgs)
|
||||
]
|
||||
cam_intrinsics_new = [
|
||||
scale_factor @ cam_intrinsic for cam_intrinsic in cam_intrinsics
|
||||
]
|
||||
lidar2img_new = [scale_factor @ l2i for l2i in lidar2img]
|
||||
|
||||
return imgs_new, cam_intrinsics_new, lidar2img_new
|
||||
|
||||
|
||||
def horizontal_flip_image_multiview(
|
||||
imgs: List[np.ndarray]) -> List[np.ndarray]:
|
||||
"""Flip every image horizontally.
|
||||
Args:
|
||||
imgs (list of numpy.array): Multiple-view images to be resized. len(img) is the number of cameras.
|
||||
img shape: [H, W, 3].
|
||||
Returns:
|
||||
imgs_new (list of numpy.array): Flippd multiple-view images
|
||||
"""
|
||||
imgs_new = [np.flip(img, axis=1) for img in imgs]
|
||||
return imgs_new
|
||||
|
||||
|
||||
def vertical_flip_image_multiview(imgs: List[np.ndarray]) -> List[np.ndarray]:
|
||||
"""Flip every image vertically.
|
||||
Args:
|
||||
imgs (list of numpy.array): Multiple-view images to be resized. len(img) is the number of cameras.
|
||||
img shape: [H, W, 3].
|
||||
Returns:
|
||||
imgs_new (list of numpy.array): Flippd multiple-view images
|
||||
"""
|
||||
imgs_new = [np.flip(img, axis=0) for img in imgs]
|
||||
return imgs_new
|
||||
|
||||
|
||||
def horizontal_flip_bbox(bboxes_3d: np.ndarray, dataset: str) -> np.ndarray:
|
||||
"""Flip bounding boxes horizontally.
|
||||
Args:
|
||||
bboxes_3d (np.ndarray): bounding boxes of shape [N * 7], N is the number of objects.
|
||||
dataset (string): 'waymo' coordinate system or 'nuscenes' coordinate system.
|
||||
Returns:
|
||||
bboxes_3d (numpy.array): Flippd bounding boxes.
|
||||
"""
|
||||
if dataset == 'nuScenes':
|
||||
bboxes_3d.tensor[:, 0::7] = -bboxes_3d.tensor[:, 0::7]
|
||||
bboxes_3d.tensor[:, 6] = -bboxes_3d.tensor[:, 6] # + np.pi
|
||||
elif dataset == 'waymo':
|
||||
bboxes_3d[:, 1::7] = -bboxes_3d[:, 1::7]
|
||||
bboxes_3d[:, 6] = -bboxes_3d[:, 6] + np.pi
|
||||
return bboxes_3d
|
||||
|
||||
|
||||
def horizontal_flip_cam_params(
|
||||
img_shape: np.ndarray, cam_intrinsics: List[np.ndarray],
|
||||
cam_extrinsics: List[np.ndarray], lidar2imgs: List[np.ndarray],
|
||||
dataset: str
|
||||
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
|
||||
"""Flip camera parameters horizontally.
|
||||
Args:
|
||||
img_shape (numpy.array) of shape [3].
|
||||
cam_intrinsics (list of numpy.array): Intrinsic parameters of different cameras. Transformations from camera
|
||||
to image. len(cam_intrinsics) is the number of camera. For each camera, shape is 4 * 4.
|
||||
cam_extrinsics (list of numpy.array): Extrinsic parameters of different cameras. Transformations from
|
||||
lidar to cameras. len(cam_extrinsics) is the number of camera. For each camera, shape is 4 * 4.
|
||||
lidar2img (list of numpy.array): Transformations from lidar to images. len(lidar2img) is the number
|
||||
of camera. For each camera, shape is 4 * 4.
|
||||
dataset (string): Specify 'waymo' coordinate system or 'nuscenes' coordinate system.
|
||||
Returns:
|
||||
cam_intrinsics (list of numpy.array): Updated intrinsic parameters of different cameras.
|
||||
cam_extrinsics (list of numpy.array): Updated extrinsic parameters of different cameras.
|
||||
lidar2img (list of numpy.array): Updated Transformations from lidar to images.
|
||||
"""
|
||||
flip_factor = np.eye(4)
|
||||
lidar2imgs = []
|
||||
|
||||
w = img_shape[1]
|
||||
if dataset == 'nuScenes':
|
||||
flip_factor[0, 0] = -1
|
||||
cam_extrinsics = [l2c @ flip_factor for l2c in cam_extrinsics]
|
||||
for cam_intrinsic, l2c in zip(cam_intrinsics, cam_extrinsics):
|
||||
cam_intrinsic[0, 0] = -cam_intrinsic[0, 0]
|
||||
cam_intrinsic[0, 2] = w - cam_intrinsic[0, 2]
|
||||
lidar2imgs.append(cam_intrinsic @ l2c)
|
||||
elif dataset == 'waymo':
|
||||
flip_factor[1, 1] = -1
|
||||
cam_extrinsics = [l2c @ flip_factor for l2c in cam_extrinsics]
|
||||
for cam_intrinsic, l2c in zip(cam_intrinsics, cam_extrinsics):
|
||||
cam_intrinsic[0, 0] = -cam_intrinsic[0, 0]
|
||||
cam_intrinsic[0, 2] = w - cam_intrinsic[0, 2]
|
||||
lidar2imgs.append(cam_intrinsic @ l2c)
|
||||
else:
|
||||
assert False
|
||||
|
||||
return cam_intrinsics, cam_extrinsics, lidar2imgs
|
||||
|
||||
|
||||
def horizontal_flip_canbus(canbus: np.ndarray, dataset: str) -> np.ndarray:
|
||||
"""Flip can bus horizontally.
|
||||
Args:
|
||||
canbus (numpy.ndarray) of shape [18,]
|
||||
dataset (string): 'waymo' or 'nuscenes'
|
||||
Returns:
|
||||
canbus_new (list of numpy.array): Flipped canbus.
|
||||
"""
|
||||
if dataset == 'nuScenes':
|
||||
# results['canbus'][1] = -results['canbus'][1] # flip location
|
||||
# results['canbus'][-2] = -results['canbus'][-2] # flip direction
|
||||
canbus[-1] = -canbus[-1] # flip direction
|
||||
elif dataset == 'waymo':
|
||||
# results['canbus'][1] = -results['canbus'][-1] # flip location
|
||||
# results['canbus'][-2] = -results['canbus'][-2] # flip direction
|
||||
canbus[-1] = -canbus[-1] # flip direction
|
||||
else:
|
||||
raise NotImplementedError((f'Not support {dataset} dataset'))
|
||||
return canbus
|
|
@ -1,11 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import concurrent.futures
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
||||
from easycv.core.points import BasePoints, get_points_type
|
||||
from easycv.datasets.detection.pipelines import LoadAnnotations
|
||||
from easycv.datasets.registry import PIPELINES
|
||||
from easycv.file.image import load_image
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
|
@ -17,13 +20,23 @@ class LoadMultiViewImageFromFiles(object):
|
|||
Args:
|
||||
to_float32 (bool, optional): Whether to convert the img to float32.
|
||||
Defaults to False.
|
||||
color_type (str, optional): Color type of the file.
|
||||
Defaults to 'unchanged'.
|
||||
channel_order (str, optional): Channel order.
|
||||
Defaults to 'bgr'.
|
||||
backend (str): The image decoding backend type. Options are `cv2`, `pillow`, `turbojpeg`.
|
||||
"""
|
||||
|
||||
def __init__(self, to_float32=False, color_type='unchanged'):
|
||||
def __init__(self,
|
||||
to_float32=False,
|
||||
channel_order='bgr',
|
||||
backend='pillow'):
|
||||
self.to_float32 = to_float32
|
||||
self.color_type = color_type
|
||||
self.channel_order = channel_order
|
||||
self.backend = backend
|
||||
|
||||
@staticmethod
|
||||
def _load_image(img_path, idx, mode, backend):
|
||||
img = load_image(img_path, mode=mode, backend=backend)
|
||||
return idx, img
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to load multi-view image from files.
|
||||
|
@ -45,8 +58,24 @@ class LoadMultiViewImageFromFiles(object):
|
|||
"""
|
||||
filename = results['img_filename']
|
||||
# img is of shape (h, w, c, num_views)
|
||||
img = np.stack(
|
||||
[mmcv.imread(name, self.color_type) for name in filename], axis=-1)
|
||||
|
||||
img_list = []
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=len(filename)) as executor:
|
||||
threads = []
|
||||
for idx, name in enumerate(filename):
|
||||
future = executor.submit(self._load_image, name, idx,
|
||||
self.channel_order, self.backend)
|
||||
threads.append(future)
|
||||
|
||||
for future in concurrent.futures.as_completed(threads):
|
||||
img_list.append(future.result())
|
||||
|
||||
img_list = sorted(img_list, key=lambda item: item[0])
|
||||
assert len(img_list) == len(filename)
|
||||
img_list = [item[1] for item in img_list]
|
||||
img = np.stack(img_list, axis=-1)
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
results['filename'] = filename
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import List, Tuple
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from numpy import random
|
||||
|
@ -7,6 +9,10 @@ from numpy import random
|
|||
from easycv.core.bbox import (CameraInstance3DBoxes, DepthInstance3DBoxes,
|
||||
LiDARInstance3DBoxes)
|
||||
from easycv.datasets.registry import PIPELINES
|
||||
from .functional import (horizontal_flip_bbox, horizontal_flip_cam_params,
|
||||
horizontal_flip_canbus,
|
||||
horizontal_flip_image_multiview,
|
||||
scale_image_multiple_view)
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
|
@ -298,42 +304,140 @@ class PadMultiViewImage(object):
|
|||
|
||||
@PIPELINES.register_module()
|
||||
class RandomScaleImageMultiViewImage(object):
|
||||
"""Random scale the image.
|
||||
"""Resize the multiple-view images with the same scale selected randomly. .
|
||||
Args:
|
||||
scales (List[float]): List of scales.
|
||||
scales (tuple of float): ratio for resizing the images. Every time, select one ratio
|
||||
randomly.
|
||||
"""
|
||||
|
||||
def __init__(self, scales=[]):
|
||||
def __init__(self, scales=[0.5, 1.0, 1.5]):
|
||||
self.scales = scales
|
||||
assert len(self.scales) == 1
|
||||
self.seed = 0
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to pad images, masks, semantic segmentation maps.
|
||||
Args:
|
||||
results (dict): Result dict from loading pipeline.
|
||||
Returns:
|
||||
dict: Updated result dict.
|
||||
def forward(
|
||||
self,
|
||||
imgs: List[np.ndarray],
|
||||
cam_intrinsics: List[np.ndarray],
|
||||
lidar2img: List[np.ndarray],
|
||||
seed=None,
|
||||
scale=1
|
||||
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
|
||||
"""
|
||||
Args:
|
||||
imgs (list of numpy.array): Multiple-view images to be resized. len(img) is the number of cameras.
|
||||
img shape: [H, W, 3].
|
||||
cam_intrinsics (list of numpy.array): Intrinsic parameters of different cameras. Transformations from camera
|
||||
to image. len(cam_intrinsics) is the number of camera. For each camera, shape is 4 * 4.
|
||||
cam_extrinsics (list of numpy.array): Extrinsic parameters of different cameras. Transformations from
|
||||
lidar to cameras. len(cam_extrinsics) is the number of camera. For each camera, shape is 4 * 4.
|
||||
lidar2img (list of numpy.array): Transformations from lidar to images. len(lidar2img) is the number
|
||||
of camera. For each camera, shape is 4 * 4.
|
||||
seed (int): Seed for generating random number.
|
||||
Returns:
|
||||
imgs_new (list of numpy.array): Updated multiple-view images
|
||||
cam_intrinsics_new (list of numpy.array): Updated intrinsic parameters of different cameras.
|
||||
lidar2img_new (list of numpy.array): Updated Transformations from lidar to images.
|
||||
"""
|
||||
rand_scale = scale
|
||||
imgs_new, cam_intrinsic_new, lidar2img_new = scale_image_multiple_view(
|
||||
imgs, cam_intrinsics, lidar2img, rand_scale)
|
||||
|
||||
return imgs_new, cam_intrinsic_new, lidar2img_new
|
||||
|
||||
def __call__(self, data):
|
||||
imgs = data['img']
|
||||
cam_intrinsics = data['cam_intrinsic']
|
||||
lidar2img = data['lidar2img']
|
||||
|
||||
rand_ind = np.random.permutation(range(len(self.scales)))[0]
|
||||
rand_scale = self.scales[rand_ind]
|
||||
scale = data[
|
||||
'resize_scale'] if 'resize_scale' in data else self.scales[rand_ind]
|
||||
|
||||
y_size = [int(img.shape[0] * rand_scale) for img in results['img']]
|
||||
x_size = [int(img.shape[1] * rand_scale) for img in results['img']]
|
||||
scale_factor = np.eye(4)
|
||||
scale_factor[0, 0] *= rand_scale
|
||||
scale_factor[1, 1] *= rand_scale
|
||||
results['img'] = [
|
||||
mmcv.imresize(img, (x_size[idx], y_size[idx]), return_scale=False)
|
||||
for idx, img in enumerate(results['img'])
|
||||
]
|
||||
lidar2img = [scale_factor @ l2i for l2i in results['lidar2img']]
|
||||
results['lidar2img'] = lidar2img
|
||||
results['img_shape'] = [img.shape for img in results['img']]
|
||||
results['ori_shape'] = [img.shape for img in results['img']]
|
||||
imgs_new, cam_intrinsic_new, lidar2img_new = self.forward(
|
||||
imgs, cam_intrinsics, lidar2img, None, scale)
|
||||
|
||||
return results
|
||||
data['img'] = imgs_new
|
||||
data['cam_intrinsic'] = cam_intrinsic_new
|
||||
data['lidar2img'] = lidar2img_new
|
||||
return data
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(size={self.scales}, '
|
||||
return repr_str
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class RandomHorizontalFlipMultiViewImage(object):
|
||||
"""Horizontally flip the multiple-view images with bounding boxes, camera parameters and can bus randomly. .
|
||||
Support coordinate systems like Waymo (https://waymo.com/open/data/perception/) or Nuscenes (https://www.nuscenes.org/public/images/data.png).
|
||||
Args:
|
||||
flip_ratio (float 0~1): probability of the images being flipped. Default value is 0.5.
|
||||
dataset (string): Specify 'waymo' coordinate system or 'nuscenes' coordinate system.
|
||||
"""
|
||||
|
||||
def __init__(self, flip_ratio=0.5, dataset='nuScenes'):
|
||||
self.flip_ratio = flip_ratio
|
||||
self.seed = 0
|
||||
self.dataset = dataset
|
||||
|
||||
def forward(
|
||||
self,
|
||||
imgs: List[np.ndarray],
|
||||
bboxes_3d: np.ndarray,
|
||||
cam_intrinsics: List[np.ndarray],
|
||||
cam_extrinsics: List[np.ndarray],
|
||||
lidar2imgs: List[np.ndarray],
|
||||
canbus: np.ndarray,
|
||||
seed=None,
|
||||
flip_flag=True
|
||||
) -> Tuple[bool, List[np.ndarray], np.ndarray, List[np.ndarray],
|
||||
List[np.ndarray], List[np.ndarray], np.ndarray]:
|
||||
"""
|
||||
Args:
|
||||
imgs (list of numpy.array): Multiple-view images to be resized. len(img) is the number of cameras.
|
||||
img shape: [H, W, 3].
|
||||
bboxes_3d (np.ndarray): bounding boxes of shape [N * 7], N is the number of objects.
|
||||
cam_intrinsics (list of numpy.array): Intrinsic parameters of different cameras. Transformations from camera
|
||||
to image. len(cam_intrinsics) is the number of camera. For each camera, shape is 4 * 4.
|
||||
cam_extrinsics (list of numpy.array): Extrinsic parameters of different cameras. Transformations from
|
||||
lidar to cameras. len(cam_extrinsics) is the number of camera. For each camera, shape is 4 * 4.
|
||||
lidar2img (list of numpy.array): Transformations from lidar to images. len(lidar2img) is the number
|
||||
of camera. For each camera, shape is 4 * 4.
|
||||
canbus (numpy.array):
|
||||
seed (int): Seed for generating random number.
|
||||
Returns:
|
||||
imgs_new (list of numpy.array): Updated multiple-view images
|
||||
cam_intrinsics_new (list of numpy.array): Updated intrinsic parameters of different cameras.
|
||||
lidar2img_new (list of numpy.array): Updated Transformations from lidar to images.
|
||||
"""
|
||||
|
||||
if flip_flag == False:
|
||||
return flip_flag, imgs, bboxes_3d, cam_intrinsics, cam_extrinsics, lidar2imgs, canbus
|
||||
else:
|
||||
# flip_flag = True
|
||||
imgs_flip = horizontal_flip_image_multiview(imgs)
|
||||
bboxes_3d_flip = horizontal_flip_bbox(bboxes_3d, self.dataset)
|
||||
img_shape = imgs[0].shape
|
||||
cam_intrinsics_flip, cam_extrinsics_flip, lidar2imgs_flip = horizontal_flip_cam_params(
|
||||
img_shape, cam_intrinsics, cam_extrinsics, lidar2imgs,
|
||||
self.dataset)
|
||||
canbus_flip = horizontal_flip_canbus(canbus, self.dataset)
|
||||
return flip_flag, imgs_flip, bboxes_3d_flip, cam_intrinsics_flip, cam_extrinsics_flip, lidar2imgs_flip, canbus_flip
|
||||
|
||||
def __call__(self, data):
|
||||
|
||||
imgs = data['img']
|
||||
bboxes_3d = data['gt_bboxes_3d']
|
||||
cam_intrinsics = data['cam_intrinsic']
|
||||
lidar2imgs = data['lidar2img']
|
||||
canbus = data['can_bus']
|
||||
cam_extrinsics = data['lidar2cam']
|
||||
flip_flag = data['flip_flag']
|
||||
|
||||
flip_flag, imgs_flip, bboxes_3d_flip, cam_intrinsics_flip, cam_extrinsics_flip, lidar2imgs_flip, canbus_flip = self.forward(
|
||||
imgs, bboxes_3d, cam_intrinsics, cam_extrinsics, lidar2imgs,
|
||||
canbus, None, flip_flag)
|
||||
|
||||
data['img'] = imgs_flip
|
||||
data['gt_bboxes_3d'] = bboxes_3d_flip
|
||||
data['cam_intrinsic'] = cam_intrinsics_flip
|
||||
data['lidar2img'] = lidar2imgs_flip
|
||||
data['can_bus'] = canbus_flip
|
||||
data['lidar2cam'] = cam_extrinsics_flip
|
||||
return data
|
||||
|
|
|
@ -5,61 +5,105 @@ import time
|
|||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from cv2 import IMREAD_COLOR
|
||||
from PIL import Image
|
||||
|
||||
from easycv import file
|
||||
from easycv.framework.errors import IOError
|
||||
from easycv.framework.errors import IOError, KeyError, ValueError
|
||||
from easycv.utils.constant import MAX_READ_IMAGE_TRY_TIMES
|
||||
from .utils import is_oss_path, is_url_path
|
||||
|
||||
try:
|
||||
from turbojpeg import TurboJPEG, TJCS_RGB, TJPF_BGR
|
||||
turbo_jpeg = TurboJPEG()
|
||||
turbo_jpeg_mode = {'RGB': TJCS_RGB, 'BGR': TJPF_BGR}
|
||||
except:
|
||||
turbo_jpeg = None
|
||||
turbo_jpeg_mode = None
|
||||
|
||||
def load_image(img_path, mode='BGR', max_try_times=MAX_READ_IMAGE_TRY_TIMES):
|
||||
"""Return np.ndarray[unit8]
|
||||
"""
|
||||
# TODO: functions of multi tries should be in the `io.open`
|
||||
try_cnt = 0
|
||||
img = None
|
||||
while try_cnt < max_try_times:
|
||||
try:
|
||||
if is_url_path(img_path):
|
||||
from mmcv.fileio.file_client import HTTPBackend
|
||||
client = HTTPBackend()
|
||||
img_bytes = client.get(img_path)
|
||||
buff = io.BytesIO(img_bytes)
|
||||
image = Image.open(buff)
|
||||
if mode.upper() != 'BGR' and image.mode.upper() != mode.upper(
|
||||
):
|
||||
image = image.convert(mode.upper())
|
||||
img = np.asarray(image, dtype=np.uint8)
|
||||
else:
|
||||
with file.io.open(img_path, 'rb') as infile:
|
||||
# cv2.imdecode may corrupt when the img is broken
|
||||
image = Image.open(infile)
|
||||
if mode.upper() != 'BGR' and image.mode.upper(
|
||||
) != mode.upper():
|
||||
image = image.convert(mode.upper())
|
||||
img = np.asarray(image, dtype=np.uint8)
|
||||
|
||||
if mode.upper() == 'BGR':
|
||||
if image.mode.upper() != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||
assert img is not None
|
||||
break
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
logging.warning('Read file {} fault, try count : {}'.format(
|
||||
img_path, try_cnt))
|
||||
# frequent access to oss will cause error, sleep can aviod it
|
||||
if is_oss_path(img_path):
|
||||
sleep_time = 1
|
||||
logging.warning(
|
||||
'Sleep {}s, frequent access to oss file may cause error.'.
|
||||
format(sleep_time))
|
||||
time.sleep(sleep_time)
|
||||
try_cnt += 1
|
||||
def load_image_with_pillow(content, mode='BGR', dtype=np.uint8):
|
||||
with io.BytesIO(content) as buff:
|
||||
image = Image.open(buff)
|
||||
|
||||
if img is None:
|
||||
raise IOError('Read Image Error: ' + img_path)
|
||||
if mode.upper() != 'BGR':
|
||||
if image.mode.upper() != mode.upper():
|
||||
image = image.convert(mode.upper())
|
||||
img = np.asarray(image, dtype=dtype)
|
||||
else:
|
||||
if image.mode.upper() != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
img = np.asarray(image, dtype=dtype)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def load_image_with_turbojpeg(content, mode='BGR', dtype=np.uint8):
|
||||
assert mode.upper() in turbo_jpeg_mode
|
||||
if turbo_jpeg is None or turbo_jpeg_mode is None:
|
||||
raise ValueError(
|
||||
'Please install turbojpeg by "pip install PyTurboJPEG" !')
|
||||
|
||||
img = turbo_jpeg.decode(
|
||||
content, pixel_format=turbo_jpeg_mode[mode.upper()])
|
||||
|
||||
if img.dtype != dtype:
|
||||
img = img.astype(dtype)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def load_image_with_cv2(content, mode='BGR', dtype=np.uint8):
|
||||
assert mode.upper() in ['BGR', 'RGB']
|
||||
|
||||
img_np = np.frombuffer(content, np.uint8)
|
||||
img = cv2.imdecode(img_np, flags=IMREAD_COLOR)
|
||||
|
||||
if mode.upper() == 'RGB':
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
if img.dtype != dtype:
|
||||
img = img.astype(dtype)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def _load_image(fp, mode='BGR', dtype=np.uint8, backend='pillow'):
|
||||
if backend == 'pillow':
|
||||
img = load_image_with_pillow(fp, mode=mode, dtype=dtype)
|
||||
elif backend == 'turbojpeg':
|
||||
img = load_image_with_turbojpeg(fp, mode=mode, dtype=dtype)
|
||||
elif backend == 'cv2':
|
||||
img = load_image_with_cv2(fp, mode=mode, dtype=dtype)
|
||||
else:
|
||||
raise KeyError(
|
||||
'Only support backend in ["pillow", "turbojpeg", "cv2"]')
|
||||
return img
|
||||
|
||||
|
||||
def load_image(img_path,
|
||||
mode='BGR',
|
||||
dtype=np.uint8,
|
||||
backend='pillow',
|
||||
max_try_times=MAX_READ_IMAGE_TRY_TIMES):
|
||||
"""Load image file, return np.ndarray.
|
||||
|
||||
Args:
|
||||
img_path (str): Image file path.
|
||||
mode (str): Order of channel, candidates are `bgr` and `rgb`.
|
||||
dtype : Output data type.
|
||||
backend (str): The image decoding backend type. Options are `cv2`, `pillow`, `turbojpeg`.
|
||||
"""
|
||||
# TODO: functions of multi tries should be in the `io.open`
|
||||
img = None
|
||||
if is_url_path(img_path):
|
||||
from mmcv.fileio.file_client import HTTPBackend
|
||||
client = HTTPBackend()
|
||||
img_bytes = client.get(img_path)
|
||||
img = _load_image(img_bytes, mode=mode, dtype=dtype, backend=backend)
|
||||
else:
|
||||
with file.io.open(img_path, 'rb') as infile:
|
||||
img = _load_image(
|
||||
infile.read(), mode=mode, dtype=dtype, backend=backend)
|
||||
|
||||
return img
|
||||
|
|
|
@ -103,7 +103,6 @@ class AnchorFreeHead(BaseDenseHead, BBoxTestMixin):
|
|||
self.test_cfg = test_cfg
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.fp16_enabled = False
|
||||
|
||||
self._init_layers()
|
||||
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
from . import detectors, utils
|
||||
from . import utils
|
||||
from .detectors import *
|
||||
|
|
|
@ -3,15 +3,16 @@
|
|||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import constant_init, xavier_init
|
||||
from mmcv.ops.multi_scale_deform_attn import \
|
||||
multi_scale_deformable_attn_pytorch
|
||||
from mmcv.runner.base_module import BaseModule
|
||||
|
||||
from easycv.models.registry import ATTENTION
|
||||
from easycv.thirdparty.deformable_attention.functions import \
|
||||
MSDeformAttnFunction
|
||||
|
||||
|
||||
@ATTENTION.register_module()
|
||||
|
@ -36,6 +37,8 @@ class CustomMSDeformableAttention(BaseModule):
|
|||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Default to False.
|
||||
add_identity (bool, optional): Whether to add the
|
||||
identity connection. Default: `True`.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: None.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
|
@ -50,8 +53,10 @@ class CustomMSDeformableAttention(BaseModule):
|
|||
im2col_step=64,
|
||||
dropout=0.1,
|
||||
batch_first=False,
|
||||
add_identity=True,
|
||||
norm_cfg=None,
|
||||
init_cfg=None):
|
||||
init_cfg=None,
|
||||
adapt_jit=False):
|
||||
super().__init__(init_cfg)
|
||||
if embed_dims % num_heads != 0:
|
||||
raise ValueError(f'embed_dims must be divisible by num_heads, '
|
||||
|
@ -60,7 +65,7 @@ class CustomMSDeformableAttention(BaseModule):
|
|||
self.norm_cfg = norm_cfg
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.batch_first = batch_first
|
||||
self.fp16_enabled = False
|
||||
self.add_identity = add_identity
|
||||
|
||||
# you'd better set dim_per_head to a power of 2
|
||||
# which is more efficient in the CUDA implementation
|
||||
|
@ -90,6 +95,11 @@ class CustomMSDeformableAttention(BaseModule):
|
|||
self.value_proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.output_proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.init_weights()
|
||||
self.adapt_jit = adapt_jit
|
||||
if self.adapt_jit:
|
||||
self.ms_deform_attn_op = torch.ops.custom.ms_deform_attn
|
||||
else:
|
||||
self.ms_deform_attn_op = MSDeformAttnFunction.apply
|
||||
|
||||
def init_weights(self):
|
||||
"""Default initialization for Parameters of Module."""
|
||||
|
@ -130,19 +140,23 @@ class CustomMSDeformableAttention(BaseModule):
|
|||
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
|
||||
return sampling_locations
|
||||
|
||||
def forward(self,
|
||||
query,
|
||||
key=None,
|
||||
value=None,
|
||||
identity=None,
|
||||
query_pos=None,
|
||||
key_padding_mask=None,
|
||||
reference_points=None,
|
||||
spatial_shapes=None,
|
||||
level_start_index=None,
|
||||
add_identity=True,
|
||||
flag='decoder',
|
||||
**kwargs):
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
spatial_shapes: torch.Tensor,
|
||||
reference_points: torch.Tensor,
|
||||
level_start_index: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
value: Optional[torch.Tensor] = None,
|
||||
identity: Optional[torch.Tensor] = None,
|
||||
query_pos: Optional[torch.Tensor] = None,
|
||||
key_padding_mask: Optional[torch.Tensor] = None,
|
||||
flag: Optional[str] = 'decoder',
|
||||
key_pos: Optional[torch.Tensor] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
cls_branches: Optional[torch.Tensor] = None,
|
||||
img_metas: Optional[str] = None,
|
||||
):
|
||||
"""Forward Function of MultiScaleDeformAttention.
|
||||
|
||||
Args:
|
||||
|
@ -212,29 +226,29 @@ class CustomMSDeformableAttention(BaseModule):
|
|||
sampling_locations = self._get_sampling_locations(
|
||||
reference_points, spatial_shapes, sampling_offsets)
|
||||
|
||||
if torch.cuda.is_available() and value.is_cuda:
|
||||
from easycv.thirdparty.deformable_attention.functions import MSDeformAttnFunction
|
||||
if value.dtype == torch.float16:
|
||||
# for mixed precision
|
||||
output = MSDeformAttnFunction.apply(
|
||||
value.to(torch.float32), spatial_shapes, level_start_index,
|
||||
sampling_locations.to(torch.float32), attention_weights,
|
||||
self.im2col_step)
|
||||
output = output.to(torch.float16)
|
||||
else:
|
||||
output = MSDeformAttnFunction.apply(value, spatial_shapes,
|
||||
level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
self.im2col_step)
|
||||
if value.dtype == torch.float16:
|
||||
# for mixed precision
|
||||
assert value.size(0) % min(value.size(0), self.im2col_step) == 0
|
||||
output = self.ms_deform_attn_op(
|
||||
value.to(torch.float32), spatial_shapes, level_start_index,
|
||||
sampling_locations.to(torch.float32), attention_weights,
|
||||
self.im2col_step)
|
||||
output = output.to(torch.float16)
|
||||
else:
|
||||
output = multi_scale_deformable_attn_pytorch(
|
||||
value, spatial_shapes, sampling_locations, attention_weights)
|
||||
output = self.ms_deform_attn_op(value, spatial_shapes,
|
||||
level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
self.im2col_step)
|
||||
# cpu
|
||||
# from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch
|
||||
# output = multi_scale_deformable_attn_pytorch(
|
||||
# value, spatial_shapes, sampling_locations, attention_weights)
|
||||
|
||||
output = self.output_proj(output)
|
||||
if not self.batch_first:
|
||||
output = output.permute(1, 0, 2)
|
||||
if add_identity:
|
||||
if self.add_identity:
|
||||
return self.dropout(output) + identity
|
||||
else:
|
||||
return self.dropout(output)
|
||||
|
@ -276,7 +290,8 @@ class MSDeformableAttention3D(CustomMSDeformableAttention):
|
|||
dropout=0.,
|
||||
batch_first=True,
|
||||
norm_cfg=None,
|
||||
init_cfg=None):
|
||||
init_cfg=None,
|
||||
adapt_jit=False):
|
||||
super(MSDeformableAttention3D, self).__init__(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
|
@ -285,8 +300,10 @@ class MSDeformableAttention3D(CustomMSDeformableAttention):
|
|||
im2col_step=im2col_step,
|
||||
dropout=dropout,
|
||||
batch_first=batch_first,
|
||||
add_identity=False,
|
||||
norm_cfg=norm_cfg,
|
||||
init_cfg=init_cfg)
|
||||
init_cfg=init_cfg,
|
||||
adapt_jit=adapt_jit)
|
||||
|
||||
self.output_proj = nn.Identity()
|
||||
|
||||
|
@ -321,62 +338,3 @@ class MSDeformableAttention3D(CustomMSDeformableAttention):
|
|||
f'Last dim of reference_points must be'
|
||||
f' 2, but get {reference_points.shape[-1]} instead.')
|
||||
return sampling_locations
|
||||
|
||||
def forward(self,
|
||||
query,
|
||||
key=None,
|
||||
value=None,
|
||||
identity=None,
|
||||
query_pos=None,
|
||||
key_padding_mask=None,
|
||||
reference_points=None,
|
||||
spatial_shapes=None,
|
||||
level_start_index=None,
|
||||
**kwargs):
|
||||
"""Forward Function of MultiScaleDeformAttention.
|
||||
|
||||
Args:
|
||||
query (Tensor): Query of Transformer with shape
|
||||
( bs, num_query, embed_dims).
|
||||
key (Tensor): The key tensor with shape
|
||||
`(bs, num_key, embed_dims)`.
|
||||
value (Tensor): The value tensor with shape
|
||||
`(bs, num_key, embed_dims)`.
|
||||
identity (Tensor): The tensor used for addition, with the
|
||||
same shape as `query`. Default None. If None,
|
||||
`query` will be used.
|
||||
query_pos (Tensor): The positional encoding for `query`.
|
||||
Default: None.
|
||||
key_pos (Tensor): The positional encoding for `key`. Default
|
||||
None.
|
||||
reference_points (Tensor): The normalized reference
|
||||
points with shape (bs, num_query, num_levels, 2),
|
||||
all elements is range in [0, 1], top-left (0,0),
|
||||
bottom-right (1, 1), including padding area.
|
||||
or (N, Length_{query}, num_levels, 4), add
|
||||
additional two dimensions is (w, h) to
|
||||
form reference boxes.
|
||||
key_padding_mask (Tensor): ByteTensor for `query`, with
|
||||
shape [bs, num_key].
|
||||
spatial_shapes (Tensor): Spatial shape of features in
|
||||
different levels. With shape (num_levels, 2),
|
||||
last dimension represents (h, w).
|
||||
level_start_index (Tensor): The start index of each level.
|
||||
A tensor has shape ``(num_levels, )`` and can be represented
|
||||
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
|
||||
|
||||
Returns:
|
||||
Tensor: forwarded results with shape [num_query, bs, embed_dims].
|
||||
"""
|
||||
return super().forward(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
identity=identity,
|
||||
query_pos=query_pos,
|
||||
key_padding_mask=key_padding_mask,
|
||||
reference_points=reference_points,
|
||||
spatial_shapes=spatial_shapes,
|
||||
level_start_index=level_start_index,
|
||||
add_identity=False,
|
||||
**kwargs)
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Modified from https://github.com/fundamentalvision/BEVFormer.
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import xavier_init
|
||||
|
@ -41,7 +43,6 @@ class SpatialCrossAttention(BaseModule):
|
|||
self.init_cfg = init_cfg
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.pc_range = pc_range
|
||||
self.fp16_enabled = False
|
||||
self.deformable_attention = build_attention(deformable_attention)
|
||||
self.embed_dims = embed_dims
|
||||
self.num_cams = num_cams
|
||||
|
@ -56,20 +57,21 @@ class SpatialCrossAttention(BaseModule):
|
|||
@force_fp32(
|
||||
apply_to=('query', 'key', 'value', 'query_pos', 'reference_points_cam')
|
||||
)
|
||||
def forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
residual=None,
|
||||
query_pos=None,
|
||||
key_padding_mask=None,
|
||||
reference_points=None,
|
||||
spatial_shapes=None,
|
||||
reference_points_cam=None,
|
||||
bev_mask=None,
|
||||
level_start_index=None,
|
||||
flag='encoder',
|
||||
**kwargs):
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
reference_points_cam: torch.Tensor,
|
||||
bev_mask: torch.Tensor,
|
||||
spatial_shapes: torch.Tensor,
|
||||
level_start_index: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
query_pos: Optional[torch.Tensor] = None,
|
||||
key_padding_mask: Optional[torch.Tensor] = None,
|
||||
reference_points: Optional[torch.Tensor] = None,
|
||||
flag: Optional[str] = 'encoder',
|
||||
):
|
||||
"""Forward Function of Detr3DCrossAtten.
|
||||
Args:
|
||||
query (Tensor): Query of Transformer with shape
|
||||
|
@ -108,9 +110,13 @@ class SpatialCrossAttention(BaseModule):
|
|||
if value is None:
|
||||
value = key
|
||||
|
||||
if residual is None:
|
||||
inp_residual = query
|
||||
slots = torch.zeros_like(query)
|
||||
# if residual is None:
|
||||
# inp_residual = query
|
||||
# slots = torch.zeros_like(query)
|
||||
assert residual is None
|
||||
inp_residual = query
|
||||
slots = torch.zeros_like(query)
|
||||
|
||||
if query_pos is not None:
|
||||
query = query + query_pos
|
||||
|
||||
|
|
|
@ -63,7 +63,6 @@ class TemporalSelfAttention(BaseModule):
|
|||
self.norm_cfg = norm_cfg
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.batch_first = batch_first
|
||||
self.fp16_enabled = False
|
||||
|
||||
# you'd better set dim_per_head to a power of 2
|
||||
# which is more efficient in the CUDA implementation
|
||||
|
@ -129,8 +128,7 @@ class TemporalSelfAttention(BaseModule):
|
|||
reference_points=None,
|
||||
spatial_shapes=None,
|
||||
level_start_index=None,
|
||||
flag='decoder',
|
||||
**kwargs):
|
||||
flag='decoder'):
|
||||
"""Forward Function of MultiScaleDeformAttention.
|
||||
|
||||
Args:
|
||||
|
@ -235,19 +233,20 @@ class TemporalSelfAttention(BaseModule):
|
|||
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
|
||||
if torch.cuda.is_available() and value.is_cuda:
|
||||
from easycv.thirdparty.deformable_attention.functions import MSDeformAttnFunction
|
||||
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
op = MSDeformAttnFunction.apply
|
||||
else:
|
||||
op = torch.ops.custom.ms_deform_attn
|
||||
if value.dtype == torch.float16:
|
||||
output = MSDeformAttnFunction.apply(
|
||||
output = op(
|
||||
value.to(torch.float32), spatial_shapes, level_start_index,
|
||||
sampling_locations.to(torch.float32), attention_weights,
|
||||
self.im2col_step)
|
||||
output = output.to(torch.float16)
|
||||
else:
|
||||
output = MSDeformAttnFunction.apply(value, spatial_shapes,
|
||||
level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
self.im2col_step)
|
||||
output = op(value, spatial_shapes, level_start_index,
|
||||
sampling_locations, attention_weights,
|
||||
self.im2col_step)
|
||||
else:
|
||||
output = multi_scale_deformable_attn_pytorch(
|
||||
value, spatial_shapes, sampling_locations, attention_weights)
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
# Modified from https://github.com/fundamentalvision/BEVFormer.
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import copy
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from easycv.core.bbox import get_box_type
|
||||
from easycv.core.bbox.bbox_util import bbox3d2result
|
||||
from easycv.models.detection3d.detectors.mvx_two_stage import \
|
||||
MVXTwoStageDetector
|
||||
from easycv.models.detection3d.utils.grid_mask import GridMask
|
||||
from easycv.models.registry import MODELS
|
||||
from easycv.utils.misc import decode_tensor_to_str, encode_str_to_tensor
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
@ -16,6 +20,8 @@ class BEVFormer(MVXTwoStageDetector):
|
|||
"""BEVFormer.
|
||||
Args:
|
||||
video_test_mode (bool): Decide whether to use temporal information during inference.
|
||||
extract_feat_serially (bool): Whether extract history features one by one,
|
||||
to solve the problem of batchnorm corrupt when shape N is too large.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -34,7 +40,8 @@ class BEVFormer(MVXTwoStageDetector):
|
|||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
pretrained=None,
|
||||
video_test_mode=False):
|
||||
video_test_mode=False,
|
||||
extract_feat_serially=False):
|
||||
|
||||
super(BEVFormer,
|
||||
self).__init__(pts_voxel_layer, pts_voxel_encoder,
|
||||
|
@ -45,18 +52,18 @@ class BEVFormer(MVXTwoStageDetector):
|
|||
self.grid_mask = GridMask(
|
||||
True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7)
|
||||
self.use_grid_mask = use_grid_mask
|
||||
self.fp16_enabled = False
|
||||
self.extract_feat_serially = extract_feat_serially
|
||||
|
||||
# temporal
|
||||
self.video_test_mode = video_test_mode
|
||||
self.prev_frame_info = {
|
||||
'prev_bev': None,
|
||||
'scene_token': None,
|
||||
'prev_scene_token': None,
|
||||
'prev_pos': 0,
|
||||
'prev_angle': 0,
|
||||
}
|
||||
|
||||
def extract_img_feat(self, img, img_metas, len_queue=None):
|
||||
def extract_img_feat(self, img, len_queue=None):
|
||||
"""Extract features of images."""
|
||||
B = img.size(0)
|
||||
if img is not None:
|
||||
|
@ -94,10 +101,10 @@ class BEVFormer(MVXTwoStageDetector):
|
|||
img_feat.view(B, int(BN / B), C, H, W))
|
||||
return img_feats_reshaped
|
||||
|
||||
def extract_feat(self, img, img_metas=None, len_queue=None):
|
||||
def extract_feat(self, img, len_queue=None):
|
||||
"""Extract features from images and points."""
|
||||
|
||||
img_feats = self.extract_img_feat(img, img_metas, len_queue=len_queue)
|
||||
img_feats = self.extract_img_feat(img, len_queue=len_queue)
|
||||
|
||||
return img_feats
|
||||
|
||||
|
@ -132,6 +139,27 @@ class BEVFormer(MVXTwoStageDetector):
|
|||
dummy_metas = None
|
||||
return self.forward_test(img=img, img_metas=[[dummy_metas]])
|
||||
|
||||
def obtain_history_bev_serially(self, imgs_queue, img_metas_list):
|
||||
"""Obtain history BEV features iteratively.
|
||||
Extract feature one by one to solve the problem of batchnorm corrupt when shape N is too large.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
prev_bev = None
|
||||
bs, len_queue, num_cams, C, H, W = imgs_queue.shape
|
||||
for i in range(len_queue):
|
||||
img_feats = self.extract_feat(
|
||||
img=imgs_queue[:, i, ...], len_queue=None)
|
||||
img_metas = [each[i] for each in img_metas_list]
|
||||
if not img_metas[0]['prev_bev_exists']:
|
||||
prev_bev = None
|
||||
prev_bev = self.pts_bbox_head(
|
||||
img_feats, img_metas, prev_bev, only_bev=True)
|
||||
self.train()
|
||||
|
||||
return prev_bev
|
||||
|
||||
def obtain_history_bev(self, imgs_queue, img_metas_list):
|
||||
"""Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
|
||||
"""
|
||||
|
@ -147,21 +175,19 @@ class BEVFormer(MVXTwoStageDetector):
|
|||
img_metas = [each[i] for each in img_metas_list]
|
||||
if not img_metas[0]['prev_bev_exists']:
|
||||
prev_bev = None
|
||||
# img_feats = self.extract_feat(img=img, img_metas=img_metas)
|
||||
img_feats = [each_scale[:, i] for each_scale in img_feats_list]
|
||||
prev_bev = self.pts_bbox_head(
|
||||
img_feats, img_metas, prev_bev, only_bev=True)
|
||||
self.train()
|
||||
return prev_bev
|
||||
|
||||
def forward_train(
|
||||
self,
|
||||
img_metas=None,
|
||||
gt_bboxes_3d=None,
|
||||
gt_labels_3d=None,
|
||||
img=None,
|
||||
gt_bboxes_ignore=None,
|
||||
):
|
||||
def forward_train(self,
|
||||
img_metas=None,
|
||||
gt_bboxes_3d=None,
|
||||
gt_labels_3d=None,
|
||||
img=None,
|
||||
gt_bboxes_ignore=None,
|
||||
**kwargs):
|
||||
"""Forward training function.
|
||||
Args:
|
||||
points (list[torch.Tensor], optional): Points of each sample.
|
||||
|
@ -185,18 +211,24 @@ class BEVFormer(MVXTwoStageDetector):
|
|||
Returns:
|
||||
dict: Losses of different branches.
|
||||
"""
|
||||
self._check_inputs(img_metas, img, kwargs)
|
||||
|
||||
len_queue = img.size(1)
|
||||
prev_img = img[:, :-1, ...]
|
||||
img = img[:, -1, ...]
|
||||
|
||||
prev_img_metas = copy.deepcopy(img_metas)
|
||||
prev_bev = self.obtain_history_bev(prev_img, prev_img_metas)
|
||||
|
||||
if self.extract_feat_serially:
|
||||
prev_bev = self.obtain_history_bev_serially(
|
||||
prev_img, prev_img_metas)
|
||||
else:
|
||||
prev_bev = self.obtain_history_bev(prev_img, prev_img_metas)
|
||||
|
||||
img_metas = [each[len_queue - 1] for each in img_metas]
|
||||
if not img_metas[0]['prev_bev_exists']:
|
||||
prev_bev = None
|
||||
img_feats = self.extract_feat(img=img, img_metas=img_metas)
|
||||
img_feats = self.extract_feat(img=img)
|
||||
losses = dict()
|
||||
losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d,
|
||||
gt_labels_3d, img_metas,
|
||||
|
@ -205,7 +237,34 @@ class BEVFormer(MVXTwoStageDetector):
|
|||
losses.update(losses_pts)
|
||||
return losses
|
||||
|
||||
def _check_inputs(self, img_metas, img, kwargs):
|
||||
can_bus_in_kwargs = kwargs.get('can_bus', None) is not None
|
||||
lidar2img_in_kwargs = kwargs.get('lidar2img', None) is not None
|
||||
for batch_i in range(len(img_metas)):
|
||||
for i in range(len(img_metas[batch_i])):
|
||||
if can_bus_in_kwargs:
|
||||
img_metas[batch_i][i]['can_bus'] = kwargs['can_bus'][
|
||||
batch_i][i]
|
||||
else:
|
||||
if isinstance(img_metas[batch_i][i]['can_bus'],
|
||||
np.ndarray):
|
||||
img_metas[batch_i][i]['can_bus'] = torch.from_numpy(
|
||||
img_metas[batch_i][i]['can_bus']).to(img.device)
|
||||
if lidar2img_in_kwargs:
|
||||
img_metas[batch_i][i]['lidar2img'] = kwargs['lidar2img'][
|
||||
batch_i][i]
|
||||
else:
|
||||
if isinstance(img_metas[batch_i][i]['lidar2img'],
|
||||
np.ndarray):
|
||||
img_metas[batch_i][i]['lidar2img'] = torch.from_numpy(
|
||||
np.array(img_metas[batch_i][i]['lidar2img'])).to(
|
||||
img.device)
|
||||
kwargs.pop('can_bus', None)
|
||||
kwargs.pop('lidar2img', None)
|
||||
|
||||
def forward_test(self, img_metas, img=None, rescale=True, **kwargs):
|
||||
self._check_inputs(img_metas, img, kwargs)
|
||||
|
||||
for var, name in [(img_metas, 'img_metas')]:
|
||||
if not isinstance(var, list):
|
||||
raise TypeError('{} must be a list, but got {}'.format(
|
||||
|
@ -213,20 +272,25 @@ class BEVFormer(MVXTwoStageDetector):
|
|||
img = [img] if img is None else img
|
||||
|
||||
if img_metas[0][0]['scene_token'] != self.prev_frame_info[
|
||||
'scene_token']:
|
||||
'prev_scene_token']:
|
||||
# the first sample of each scene is truncated
|
||||
self.prev_frame_info['prev_bev'] = None
|
||||
# update idx
|
||||
self.prev_frame_info['scene_token'] = img_metas[0][0]['scene_token']
|
||||
self.prev_frame_info['prev_scene_token'] = img_metas[0][0][
|
||||
'scene_token']
|
||||
|
||||
# do not use temporal information
|
||||
if not self.video_test_mode:
|
||||
self.prev_frame_info['prev_bev'] = None
|
||||
|
||||
# Get the delta of ego position and angle between two timestamps.
|
||||
tmp_pos = copy.deepcopy(img_metas[0][0]['can_bus'][:3])
|
||||
tmp_angle = copy.deepcopy(img_metas[0][0]['can_bus'][-1])
|
||||
if self.prev_frame_info['prev_bev'] is not None:
|
||||
tmp_pos = img_metas[0][0]['can_bus'][:3].clone()
|
||||
tmp_angle = img_metas[0][0]['can_bus'][-1].clone()
|
||||
# skip init dummy prev_bev
|
||||
if self.prev_frame_info['prev_bev'] is not None and not torch.equal(
|
||||
self.prev_frame_info['prev_bev'],
|
||||
self.prev_frame_info['prev_bev'].new_zeros(
|
||||
self.prev_frame_info['prev_bev'].size())):
|
||||
img_metas[0][0]['can_bus'][:3] -= self.prev_frame_info['prev_pos']
|
||||
img_metas[0][0]['can_bus'][-1] -= self.prev_frame_info[
|
||||
'prev_angle']
|
||||
|
@ -268,7 +332,7 @@ class BEVFormer(MVXTwoStageDetector):
|
|||
|
||||
def simple_test(self, img_metas, img=None, prev_bev=None, rescale=False):
|
||||
"""Test function without augmentaiton."""
|
||||
img_feats = self.extract_feat(img=img, img_metas=img_metas)
|
||||
img_feats = self.extract_feat(img=img)
|
||||
|
||||
bbox_list = [dict() for i in range(len(img_metas))]
|
||||
new_prev_bev, bbox_pts = self.simple_test_pts(
|
||||
|
@ -276,3 +340,102 @@ class BEVFormer(MVXTwoStageDetector):
|
|||
for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
|
||||
result_dict['pts_bbox'] = pts_bbox
|
||||
return new_prev_bev, bbox_list
|
||||
|
||||
def forward_export(self, img, img_metas):
|
||||
error_str = 'Only support batch_size=1 and queue_length=1, please remove axis of batch_size and queue_length!'
|
||||
if len(img.shape) > 4:
|
||||
raise ValueError(error_str)
|
||||
elif len(img.shape) < 4:
|
||||
raise ValueError(
|
||||
'The length of img size must be equal to 4: [num_cameras, img_channel, img_height, img_width]'
|
||||
)
|
||||
|
||||
assert len(
|
||||
img_metas['can_bus'].shape) == 1, error_str # torch.Size([18])
|
||||
assert len(img_metas['lidar2img'].shape
|
||||
) == 3, error_str # torch.Size([6, 4, 4])
|
||||
assert len(
|
||||
img_metas['img_shape'].shape) == 2, error_str # torch.Size([6, 3])
|
||||
assert len(img_metas['prev_bev'].shape
|
||||
) == 3, error_str # torch.Size([40000, 1, 256])
|
||||
|
||||
img = img[
|
||||
None, None,
|
||||
...] # torch.Size([6, 3, 928, 1600]) -> torch.Size([1, 1, 6, 3, 928, 1600])
|
||||
|
||||
box_type_3d = img_metas.get('box_type_3d', 'LiDAR')
|
||||
if isinstance(box_type_3d, torch.Tensor):
|
||||
box_type_3d = pickle.loads(box_type_3d.cpu().numpy().tobytes())
|
||||
img_metas['box_type_3d'] = get_box_type(box_type_3d)[0]
|
||||
img_metas['scene_token'] = decode_tensor_to_str(
|
||||
img_metas['scene_token'])
|
||||
|
||||
# previous frame info
|
||||
self.prev_frame_info['prev_scene_token'] = decode_tensor_to_str(
|
||||
img_metas.pop('prev_scene_token', None))
|
||||
self.prev_frame_info['prev_bev'] = img_metas.pop('prev_bev', None)
|
||||
self.prev_frame_info['prev_pos'] = img_metas.pop('prev_pos', None)
|
||||
self.prev_frame_info['prev_angle'] = img_metas.pop('prev_angle', None)
|
||||
|
||||
img_metas = [[img_metas]]
|
||||
outputs = self.forward_test(img_metas, img=img)
|
||||
scores_3d = outputs['pts_bbox'][0]['scores_3d']
|
||||
labels_3d = outputs['pts_bbox'][0]['labels_3d']
|
||||
boxes_3d = outputs['pts_bbox'][0]['boxes_3d'].tensor.cpu()
|
||||
|
||||
# info has been updated to the current frame
|
||||
prev_bev = self.prev_frame_info['prev_bev']
|
||||
prev_pos = self.prev_frame_info['prev_pos']
|
||||
prev_angle = self.prev_frame_info['prev_angle']
|
||||
prev_scene_token = encode_str_to_tensor(
|
||||
self.prev_frame_info['prev_scene_token'])
|
||||
|
||||
return scores_3d, labels_3d, boxes_3d, [
|
||||
prev_bev, prev_pos, prev_angle, prev_scene_token
|
||||
]
|
||||
|
||||
def forward_history_bev(self,
|
||||
img,
|
||||
can_bus,
|
||||
lidar2img,
|
||||
img_shape,
|
||||
scene_token,
|
||||
box_type_3d='LiDAR'):
|
||||
"""Experimental api, for export jit model to obtain history bev.
|
||||
"""
|
||||
if isinstance(box_type_3d, torch.Tensor):
|
||||
box_type_3d = pickle.loads(box_type_3d.cpu().numpy().tobytes())
|
||||
|
||||
batch_size, len_queue = img.size()[:2]
|
||||
img_metas = []
|
||||
for b_i in range(batch_size):
|
||||
img_metas.append([])
|
||||
for i in range(len_queue):
|
||||
scene_token_str = pickle.loads(
|
||||
scene_token[b_i][i].cpu().numpy().tobytes())
|
||||
img_metas[b_i].append({
|
||||
'scene_token':
|
||||
scene_token_str,
|
||||
'can_bus':
|
||||
can_bus[b_i][i],
|
||||
'lidar2img':
|
||||
lidar2img[b_i][i],
|
||||
'img_shape':
|
||||
img_shape[b_i][i],
|
||||
'box_type_3d':
|
||||
get_box_type(box_type_3d)[0],
|
||||
'prev_bev_exists':
|
||||
False
|
||||
})
|
||||
|
||||
prev_img = img[:, :-1, ...]
|
||||
img = img[:, -1, ...]
|
||||
|
||||
prev_img_metas = copy.deepcopy(img_metas)
|
||||
|
||||
if self.extract_feat_serially:
|
||||
prev_bev = self.obtain_history_bev_serially(
|
||||
prev_img, prev_img_metas)
|
||||
else:
|
||||
prev_bev = self.obtain_history_bev(prev_img, prev_img_metas)
|
||||
return prev_bev
|
||||
|
|
|
@ -36,6 +36,8 @@ class BEVFormerHead(AnchorFreeHead):
|
|||
num_classes,
|
||||
in_channels,
|
||||
num_query=100,
|
||||
num_query_one2many=0,
|
||||
one2many_gt_mul=None,
|
||||
num_reg_fcs=2,
|
||||
with_box_refine=False,
|
||||
as_two_stage=False,
|
||||
|
@ -71,7 +73,6 @@ class BEVFormerHead(AnchorFreeHead):
|
|||
|
||||
self.bev_h = bev_h
|
||||
self.bev_w = bev_w
|
||||
self.fp16_enabled = False
|
||||
self.with_box_refine = with_box_refine
|
||||
self.as_two_stage = as_two_stage
|
||||
if self.as_two_stage:
|
||||
|
@ -133,13 +134,17 @@ class BEVFormerHead(AnchorFreeHead):
|
|||
sampler_cfg = dict(type='PseudoBBoxSampler')
|
||||
self.sampler = build_bbox_sampler(sampler_cfg, context=self)
|
||||
|
||||
self.num_query = num_query
|
||||
# for one2many task
|
||||
self.num_query_one2many = num_query_one2many
|
||||
self.num_query_one2one = num_query
|
||||
self.one2many_gt_mul = one2many_gt_mul
|
||||
|
||||
self.num_query = num_query + num_query_one2many if num_query_one2many > 0 else num_query
|
||||
self.num_classes = num_classes
|
||||
self.in_channels = in_channels
|
||||
self.num_reg_fcs = num_reg_fcs
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
self.fp16_enabled = False
|
||||
self.loss_cls = build_loss(loss_cls)
|
||||
self.loss_bbox = build_loss(loss_bbox)
|
||||
self.loss_iou = build_loss(loss_iou)
|
||||
|
@ -279,6 +284,16 @@ class BEVFormerHead(AnchorFreeHead):
|
|||
prev_bev=prev_bev,
|
||||
)
|
||||
else:
|
||||
# make attn mask for one2many task
|
||||
self_attn_mask = torch.zeros([
|
||||
self.num_query,
|
||||
self.num_query,
|
||||
]).bool().to(bev_queries.device)
|
||||
self_attn_mask[self.num_query_one2one:,
|
||||
0:self.num_query_one2one, ] = True
|
||||
self_attn_mask[0:self.num_query_one2one,
|
||||
self.num_query_one2one:, ] = True
|
||||
|
||||
outputs = self.transformer(
|
||||
mlvl_feats,
|
||||
bev_queries,
|
||||
|
@ -292,7 +307,8 @@ class BEVFormerHead(AnchorFreeHead):
|
|||
if self.with_box_refine else None, # noqa:E501
|
||||
cls_branches=self.cls_branches if self.as_two_stage else None,
|
||||
img_metas=img_metas,
|
||||
prev_bev=prev_bev)
|
||||
prev_bev=prev_bev,
|
||||
attn_mask=self_attn_mask)
|
||||
|
||||
bev_embed, hs, init_reference, inter_references = outputs
|
||||
hs = hs.permute(0, 2, 1, 3)
|
||||
|
@ -309,20 +325,47 @@ class BEVFormerHead(AnchorFreeHead):
|
|||
|
||||
# TODO: check the shape of reference
|
||||
assert reference.shape[-1] == 3
|
||||
tmp[..., 0:2] += reference[..., 0:2]
|
||||
tmp[..., 0:2] = tmp[..., 0:2].sigmoid()
|
||||
tmp[..., 4:5] += reference[..., 2:3]
|
||||
tmp[..., 4:5] = tmp[..., 4:5].sigmoid()
|
||||
tmp[..., 0:1] = (
|
||||
tmp[..., 0:1] * (self.pc_range[3] - self.pc_range[0]) +
|
||||
# tmp: torch.Size([1, 900, 10])
|
||||
# tmp[..., 0:2] += reference[..., 0:2]
|
||||
# tmp[..., 0:2] = tmp[..., 0:2].sigmoid()
|
||||
# tmp[..., 4:5] += reference[..., 2:3]
|
||||
# tmp[..., 4:5] = tmp[..., 4:5].sigmoid()
|
||||
|
||||
# tmp[..., 0:1] = (
|
||||
# tmp[..., 0:1] * (self.pc_range[3] - self.pc_range[0]) +
|
||||
# self.pc_range[0])
|
||||
# tmp[..., 1:2] = (
|
||||
# tmp[..., 1:2] * (self.pc_range[4] - self.pc_range[1]) +
|
||||
# self.pc_range[1])
|
||||
# tmp[..., 4:5] = (
|
||||
# tmp[..., 4:5] * (self.pc_range[5] - self.pc_range[2]) +
|
||||
# self.pc_range[2])
|
||||
|
||||
# remove inplace operation, metric may incorrect when using blade
|
||||
tmp_0_2 = tmp[..., 0:2]
|
||||
tmp_0_2_add_reference = tmp_0_2 + reference[..., 0:2]
|
||||
tmp_0_2_add_reference = tmp_0_2_add_reference.sigmoid()
|
||||
tmp_4_5 = tmp[..., 4:5]
|
||||
tmp_4_5_add_reference = tmp_4_5 + reference[..., 2:3]
|
||||
tmp_4_5_add_reference = tmp_4_5_add_reference.sigmoid()
|
||||
tmp_0_1 = tmp_0_2_add_reference[..., 0:1]
|
||||
tmp_0_1_new = (
|
||||
tmp_0_1 * (self.pc_range[3] - self.pc_range[0]) +
|
||||
self.pc_range[0])
|
||||
tmp[..., 1:2] = (
|
||||
tmp[..., 1:2] * (self.pc_range[4] - self.pc_range[1]) +
|
||||
tmp_1_2 = tmp_0_2_add_reference[..., 1:2]
|
||||
tmp_1_2_new = (
|
||||
tmp_1_2 * (self.pc_range[4] - self.pc_range[1]) +
|
||||
self.pc_range[1])
|
||||
tmp[..., 4:5] = (
|
||||
tmp[..., 4:5] * (self.pc_range[5] - self.pc_range[2]) +
|
||||
tmp_4_5_new = (
|
||||
tmp_4_5_add_reference * (self.pc_range[5] - self.pc_range[2]) +
|
||||
self.pc_range[2])
|
||||
|
||||
tmp_2_4 = tmp[..., 2:4]
|
||||
tmp_5_10 = tmp[..., 5:10]
|
||||
tmp = torch.cat(
|
||||
[tmp_0_1_new, tmp_1_2_new, tmp_2_4, tmp_4_5_new, tmp_5_10],
|
||||
dim=-1)
|
||||
|
||||
# TODO: check if using sigmoid
|
||||
outputs_coord = tmp
|
||||
outputs_classes.append(outputs_class)
|
||||
|
@ -333,12 +376,19 @@ class BEVFormerHead(AnchorFreeHead):
|
|||
|
||||
outs = {
|
||||
'bev_embed': bev_embed,
|
||||
'all_cls_scores': outputs_classes,
|
||||
'all_bbox_preds': outputs_coords,
|
||||
'all_cls_scores':
|
||||
outputs_classes[:, :, :self.num_query_one2one, :],
|
||||
'all_bbox_preds': outputs_coords[:, :, :self.num_query_one2one, :],
|
||||
'enc_cls_scores': None,
|
||||
'enc_bbox_preds': None,
|
||||
}
|
||||
|
||||
if self.num_query_one2many > 0:
|
||||
outs['all_cls_scores_aux'] = outputs_classes[:, :, self.
|
||||
num_query_one2one:, :]
|
||||
outs['all_bbox_preds_aux'] = outputs_coords[:, :, self.
|
||||
num_query_one2one:, :]
|
||||
|
||||
return outs
|
||||
|
||||
def _get_target_single(self,
|
||||
|
@ -396,6 +446,8 @@ class BEVFormerHead(AnchorFreeHead):
|
|||
bbox_weights[pos_inds] = 1.0
|
||||
|
||||
# DETR
|
||||
sampling_result.pos_gt_bboxes = sampling_result.pos_gt_bboxes.type_as(
|
||||
bbox_targets)
|
||||
bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes
|
||||
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
|
||||
neg_inds)
|
||||
|
@ -586,6 +638,47 @@ class BEVFormerHead(AnchorFreeHead):
|
|||
all_gt_bboxes_ignore_list)
|
||||
|
||||
loss_dict = dict()
|
||||
|
||||
# for one2many task
|
||||
if 'all_cls_scores_aux' in preds_dicts and self.one2many_gt_mul:
|
||||
all_cls_scores_aux = preds_dicts['all_cls_scores_aux']
|
||||
all_bbox_preds_aux = preds_dicts['all_bbox_preds_aux']
|
||||
|
||||
gt_bboxes_list_aux = []
|
||||
gt_labels_list_aux = []
|
||||
for gt_bboxes, gt_labels in zip(gt_bboxes_list, gt_labels_list):
|
||||
gt_bboxes_list_aux.append(
|
||||
gt_bboxes.repeat(self.one2many_gt_mul, 1))
|
||||
gt_labels_list_aux.append(
|
||||
gt_labels.repeat(self.one2many_gt_mul))
|
||||
# for classwise multiply
|
||||
# for gt_bboxes, gt_labels in zip(gt_bboxes_list,gt_labels_list):
|
||||
# gt_bboxes_aux = []
|
||||
# gt_labels_aux = []
|
||||
# for gt_bbox, gt_label in zip(gt_bboxes, gt_labels):
|
||||
# gt_bboxes_aux += [gt_bbox]*self.one2many_gt_mul[gt_label]
|
||||
# gt_labels_aux += [gt_label]*self.one2many_gt_mul[gt_label]
|
||||
# gt_bboxes_list_aux.append(torch.stack(gt_bboxes_aux))
|
||||
# gt_labels_list_aux.append(torch.stack(gt_labels_aux))
|
||||
all_gt_bboxes_list_aux = [
|
||||
gt_bboxes_list_aux for _ in range(num_dec_layers)
|
||||
]
|
||||
all_gt_labels_list_aux = [
|
||||
gt_labels_list_aux for _ in range(num_dec_layers)
|
||||
]
|
||||
losses_cls_aux, losses_bbox_aux = multi_apply(
|
||||
self.loss_single, all_cls_scores_aux, all_bbox_preds_aux,
|
||||
all_gt_bboxes_list_aux, all_gt_labels_list_aux,
|
||||
all_gt_bboxes_ignore_list)
|
||||
loss_dict['loss_cls_aux'] = losses_cls_aux[-1]
|
||||
loss_dict['loss_bbox_aux'] = losses_bbox_aux[-1]
|
||||
num_dec_layer = 0
|
||||
for loss_cls_i, loss_bbox_i in zip(losses_cls_aux[:-1],
|
||||
losses_bbox_aux[:-1]):
|
||||
loss_dict[f'd{num_dec_layer}.loss_cls_aux'] = loss_cls_i
|
||||
loss_dict[f'd{num_dec_layer}.loss_bbox_aux'] = loss_bbox_i
|
||||
num_dec_layer += 1
|
||||
|
||||
# loss of proposal generated from encode feature map.
|
||||
if enc_cls_scores is not None:
|
||||
binary_labels_list = [
|
||||
|
|
|
@ -24,6 +24,19 @@ from easycv.models.utils.transformer import (BaseTransformerLayer,
|
|||
TransformerLayerSequence)
|
||||
from . import (CustomMSDeformableAttention, MSDeformableAttention3D,
|
||||
TemporalSelfAttention)
|
||||
from .attentions.spatial_cross_attention import SpatialCrossAttention
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _rotate(img: torch.Tensor, angle: torch.Tensor, center: torch.Tensor):
|
||||
"""torch.jit.trace does not support torchvision.rotate"""
|
||||
|
||||
img = rotate(
|
||||
img,
|
||||
float(angle.item()),
|
||||
center=[int(center[0].item()),
|
||||
int(center[1].item())])
|
||||
return img
|
||||
|
||||
|
||||
@TRANSFORMER_LAYER.register_module()
|
||||
|
@ -107,6 +120,7 @@ class BEVFormerLayer(BaseModule):
|
|||
),
|
||||
batch_first=True,
|
||||
init_cfg=None,
|
||||
adapt_jit=False,
|
||||
**kwargs):
|
||||
super(BEVFormerLayer, self).__init__(init_cfg)
|
||||
|
||||
|
@ -135,6 +149,7 @@ class BEVFormerLayer(BaseModule):
|
|||
self.attentions = ModuleList()
|
||||
|
||||
index = 0
|
||||
self.adapt_jit = adapt_jit
|
||||
for operation_name in operation_order:
|
||||
if operation_name in ['self_attn', 'cross_attn']:
|
||||
if 'batch_first' in attn_cfgs[index]:
|
||||
|
@ -142,6 +157,10 @@ class BEVFormerLayer(BaseModule):
|
|||
else:
|
||||
attn_cfgs[index]['batch_first'] = self.batch_first
|
||||
attention = build_attention(attn_cfgs[index])
|
||||
# for export jit model
|
||||
if self.adapt_jit and isinstance(attention,
|
||||
SpatialCrossAttention):
|
||||
attention = torch.jit.script(attention)
|
||||
# Some custom attentions used as `self_attn`
|
||||
# or `cross_attn` can have different behavior.
|
||||
attention.operation_name = operation_name
|
||||
|
@ -170,7 +189,6 @@ class BEVFormerLayer(BaseModule):
|
|||
for _ in range(num_norms):
|
||||
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
|
||||
|
||||
self.fp16_enabled = False
|
||||
assert len(operation_order) == 6
|
||||
assert set(operation_order) == set(
|
||||
['self_attn', 'norm', 'cross_attn', 'ffn'])
|
||||
|
@ -249,43 +267,42 @@ class BEVFormerLayer(BaseModule):
|
|||
if layer == 'self_attn':
|
||||
|
||||
query = self.attentions[attn_index](
|
||||
query,
|
||||
prev_bev,
|
||||
prev_bev,
|
||||
identity if self.pre_norm else None,
|
||||
query=query,
|
||||
key=prev_bev,
|
||||
value=prev_bev,
|
||||
identity=identity if self.pre_norm else None,
|
||||
query_pos=bev_pos,
|
||||
key_pos=bev_pos,
|
||||
attn_mask=attn_masks[attn_index],
|
||||
key_padding_mask=query_key_padding_mask,
|
||||
reference_points=ref_2d,
|
||||
spatial_shapes=torch.tensor([[bev_h, bev_w]],
|
||||
device=query.device),
|
||||
level_start_index=torch.tensor([0], device=query.device),
|
||||
**kwargs)
|
||||
)
|
||||
attn_index += 1
|
||||
identity = query
|
||||
|
||||
elif layer == 'norm':
|
||||
# fix fp16
|
||||
dtype = query.dtype
|
||||
query = self.norms[norm_index](query)
|
||||
query = query.to(dtype)
|
||||
norm_index += 1
|
||||
|
||||
# spaital cross attention
|
||||
elif layer == 'cross_attn':
|
||||
query = self.attentions[attn_index](
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
identity if self.pre_norm else None,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
residual=identity if self.pre_norm else None,
|
||||
query_pos=query_pos,
|
||||
key_pos=key_pos,
|
||||
reference_points=ref_3d,
|
||||
reference_points_cam=reference_points_cam,
|
||||
mask=mask,
|
||||
attn_mask=attn_masks[attn_index],
|
||||
bev_mask=kwargs.get('bev_mask'),
|
||||
key_padding_mask=key_padding_mask,
|
||||
spatial_shapes=spatial_shapes,
|
||||
level_start_index=level_start_index,
|
||||
**kwargs)
|
||||
)
|
||||
attn_index += 1
|
||||
identity = query
|
||||
|
||||
|
@ -309,7 +326,6 @@ class Detr3DTransformerDecoder(TransformerLayerSequence):
|
|||
def __init__(self, *args, return_intermediate=False, **kwargs):
|
||||
super(Detr3DTransformerDecoder, self).__init__(*args, **kwargs)
|
||||
self.return_intermediate = return_intermediate
|
||||
self.fp16_enabled = False
|
||||
|
||||
def forward(self,
|
||||
query,
|
||||
|
@ -317,6 +333,7 @@ class Detr3DTransformerDecoder(TransformerLayerSequence):
|
|||
reference_points=None,
|
||||
reg_branches=None,
|
||||
key_padding_mask=None,
|
||||
attn_mask=None,
|
||||
**kwargs):
|
||||
"""Forward function for `Detr3DTransformerDecoder`.
|
||||
Args:
|
||||
|
@ -346,6 +363,7 @@ class Detr3DTransformerDecoder(TransformerLayerSequence):
|
|||
output,
|
||||
*args,
|
||||
reference_points=reference_points_input,
|
||||
attn_masks=[attn_mask] * layer.num_attn,
|
||||
key_padding_mask=key_padding_mask,
|
||||
**kwargs)
|
||||
output = output.permute(1, 0, 2)
|
||||
|
@ -355,13 +373,26 @@ class Detr3DTransformerDecoder(TransformerLayerSequence):
|
|||
|
||||
assert reference_points.shape[-1] == 3
|
||||
|
||||
new_reference_points = torch.zeros_like(reference_points)
|
||||
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(
|
||||
reference_points[..., :2], eps=1e-5)
|
||||
new_reference_points[...,
|
||||
2:3] = tmp[..., 4:5] + inverse_sigmoid(
|
||||
reference_points[..., 2:3], eps=1e-5)
|
||||
# new_reference_points = torch.zeros_like(
|
||||
# reference_points) # torch.Size([1, 900, 3])
|
||||
# new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(
|
||||
# reference_points[..., :2], eps=1e-5)
|
||||
# new_reference_points[...,
|
||||
# 2:3] = tmp[..., 4:5] + inverse_sigmoid(
|
||||
# reference_points[..., 2:3], eps=1e-5)
|
||||
|
||||
# new_reference_points = new_reference_points.sigmoid()
|
||||
|
||||
# reference_points = new_reference_points.detach()
|
||||
|
||||
# remove inplace operation, metric may incorrect when using blade
|
||||
new_reference_points_0_2 = tmp[..., :2] + inverse_sigmoid(
|
||||
reference_points[..., :2], eps=1e-5)
|
||||
new_reference_points_2_3 = tmp[..., 4:5] + inverse_sigmoid(
|
||||
reference_points[..., 2:3], eps=1e-5)
|
||||
new_reference_points = torch.cat(
|
||||
[new_reference_points_0_2, new_reference_points_2_3],
|
||||
dim=-1)
|
||||
new_reference_points = new_reference_points.sigmoid()
|
||||
|
||||
reference_points = new_reference_points.detach()
|
||||
|
@ -402,7 +433,6 @@ class BEVFormerEncoder(TransformerLayerSequence):
|
|||
|
||||
self.num_points_in_pillar = num_points_in_pillar
|
||||
self.pc_range = pc_range
|
||||
self.fp16_enabled = False
|
||||
|
||||
@staticmethod
|
||||
def get_reference_points(H,
|
||||
|
@ -456,12 +486,9 @@ class BEVFormerEncoder(TransformerLayerSequence):
|
|||
# This function must use fp32!!!
|
||||
@force_fp32(apply_to=('reference_points', 'img_metas'))
|
||||
def point_sampling(self, reference_points, pc_range, img_metas):
|
||||
lidar2img = torch.stack([meta['lidar2img'] for meta in img_metas
|
||||
]).to(reference_points.dtype) # (B, N, 4, 4)
|
||||
|
||||
lidar2img = []
|
||||
for img_meta in img_metas:
|
||||
lidar2img.append(img_meta['lidar2img'])
|
||||
lidar2img = np.asarray(lidar2img)
|
||||
lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4)
|
||||
reference_points = reference_points.clone()
|
||||
|
||||
reference_points[..., 0:1] = reference_points[..., 0:1] * \
|
||||
|
@ -650,7 +677,6 @@ class PerceptionTransformer(BaseModule):
|
|||
self.embed_dims = embed_dims
|
||||
self.num_feature_levels = num_feature_levels
|
||||
self.num_cams = num_cams
|
||||
self.fp16_enabled = False
|
||||
|
||||
self.rotate_prev_bev = rotate_prev_bev
|
||||
self.use_shift = use_shift
|
||||
|
@ -711,26 +737,28 @@ class PerceptionTransformer(BaseModule):
|
|||
bev_pos = bev_pos.flatten(2).permute(2, 0, 1)
|
||||
|
||||
# obtain rotation angle and shift with ego motion
|
||||
delta_x = np.array(
|
||||
delta_x = torch.stack(
|
||||
[each['can_bus'][0] for each in kwargs['img_metas']])
|
||||
delta_y = np.array(
|
||||
delta_y = torch.stack(
|
||||
[each['can_bus'][1] for each in kwargs['img_metas']])
|
||||
ego_angle = np.array([
|
||||
ego_angle = torch.stack([
|
||||
each['can_bus'][-2] / np.pi * 180 for each in kwargs['img_metas']
|
||||
])
|
||||
grid_length_y = grid_length[0]
|
||||
grid_length_x = grid_length[1]
|
||||
translation_length = np.sqrt(delta_x**2 + delta_y**2)
|
||||
translation_angle = np.arctan2(delta_y, delta_x) / np.pi * 180
|
||||
translation_length = torch.sqrt(delta_x**2 + delta_y**2)
|
||||
translation_angle = torch.atan2(delta_y, delta_x) / np.pi * 180
|
||||
bev_angle = ego_angle - translation_angle
|
||||
shift_y = translation_length * \
|
||||
np.cos(bev_angle / 180 * np.pi) / grid_length_y / bev_h
|
||||
torch.cos(bev_angle / 180 * np.pi) / grid_length_y / bev_h
|
||||
shift_x = translation_length * \
|
||||
np.sin(bev_angle / 180 * np.pi) / grid_length_x / bev_w
|
||||
shift_y = shift_y * self.use_shift
|
||||
shift_x = shift_x * self.use_shift
|
||||
shift = bev_queries.new_tensor([shift_x, shift_y
|
||||
]).permute(1, 0) # xy, bs -> bs, xy
|
||||
torch.sin(bev_angle / 180 * np.pi) / grid_length_x / bev_w
|
||||
|
||||
if not self.use_shift:
|
||||
shift_y = shift_y.new_zeros(shift_y.size())
|
||||
shift_x = shift_x.new_zeros(shift_y.size())
|
||||
shift = torch.stack([shift_x,
|
||||
shift_y]).permute(1, 0).to(bev_queries.dtype)
|
||||
|
||||
if prev_bev is not None:
|
||||
if prev_bev.shape[1] == bev_h * bev_w:
|
||||
|
@ -741,19 +769,23 @@ class PerceptionTransformer(BaseModule):
|
|||
rotation_angle = kwargs['img_metas'][i]['can_bus'][-1]
|
||||
tmp_prev_bev = prev_bev[:, i].reshape(bev_h, bev_w,
|
||||
-1).permute(2, 0, 1)
|
||||
tmp_prev_bev = rotate(
|
||||
tmp_prev_bev = _rotate(
|
||||
tmp_prev_bev,
|
||||
rotation_angle,
|
||||
center=self.rotate_center)
|
||||
center=torch.tensor(self.rotate_center))
|
||||
tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape(
|
||||
bev_h * bev_w, 1, -1)
|
||||
prev_bev[:, i] = tmp_prev_bev[:, 0]
|
||||
|
||||
# add can bus signals
|
||||
can_bus = bev_queries.new_tensor(
|
||||
[each['can_bus'] for each in kwargs['img_metas']]) # [:, :]
|
||||
can_bus = torch.stack([
|
||||
each['can_bus'] for each in kwargs['img_metas']
|
||||
]).to(bev_queries.dtype)
|
||||
can_bus = self.can_bus_mlp(can_bus)[None, :, :]
|
||||
bev_queries = bev_queries + can_bus * self.use_can_bus
|
||||
# fix fp16
|
||||
can_bus = can_bus.to(bev_queries.dtype)
|
||||
if self.use_can_bus:
|
||||
bev_queries = bev_queries + can_bus
|
||||
|
||||
feat_flatten = []
|
||||
spatial_shapes = []
|
||||
|
@ -806,6 +838,7 @@ class PerceptionTransformer(BaseModule):
|
|||
reg_branches=None,
|
||||
cls_branches=None,
|
||||
prev_bev=None,
|
||||
attn_mask=None,
|
||||
**kwargs):
|
||||
"""Forward function for `Detr3DTransformer`.
|
||||
Args:
|
||||
|
@ -873,6 +906,7 @@ class PerceptionTransformer(BaseModule):
|
|||
value=bev_embed,
|
||||
query_pos=query_pos,
|
||||
reference_points=reference_points,
|
||||
attn_mask=attn_mask,
|
||||
reg_branches=reg_branches,
|
||||
cls_branches=cls_branches,
|
||||
spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device),
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmcv.runner import auto_fp16
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import rotate
|
||||
|
||||
|
||||
class Grid(object):
|
||||
|
@ -113,7 +114,7 @@ class GridMask(nn.Module):
|
|||
ww = int(1.5 * w)
|
||||
d = np.random.randint(2, h)
|
||||
self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
|
||||
mask = np.ones((hh, ww), np.float32)
|
||||
mask = torch.ones((hh, ww), dtype=torch.uint8, device=x.device)
|
||||
st_h = np.random.randint(d)
|
||||
st_w = np.random.randint(d)
|
||||
if self.use_h:
|
||||
|
@ -128,19 +129,16 @@ class GridMask(nn.Module):
|
|||
mask[:, s:t] *= 0
|
||||
|
||||
r = np.random.randint(self.rotate)
|
||||
mask = Image.fromarray(np.uint8(mask))
|
||||
mask = mask.rotate(r)
|
||||
mask = np.asarray(mask)
|
||||
mask = rotate(mask.unsqueeze(0), r)[0]
|
||||
mask = mask[(hh - h) // 2:(hh - h) // 2 + h,
|
||||
(ww - w) // 2:(ww - w) // 2 + w]
|
||||
|
||||
mask = torch.from_numpy(mask).to(x.dtype).cuda()
|
||||
mask = mask.to(x.dtype)
|
||||
if self.mode == 1:
|
||||
mask = 1 - mask
|
||||
mask = mask.expand_as(x)
|
||||
if self.offset:
|
||||
offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).to(
|
||||
x.dtype).cuda()
|
||||
offset = (2 * torch.rand(
|
||||
(h, w), device=x.device) - 0.5).to(x.dtype)
|
||||
x = x * mask + offset * (1 - mask)
|
||||
else:
|
||||
x = x * mask
|
||||
|
|
|
@ -4,7 +4,7 @@ from .det_db_loss import DBLoss
|
|||
from .face_keypoint_loss import FacePoseLoss, WingLossWithPose
|
||||
from .focal_loss import FocalLoss, VarifocalLoss
|
||||
from .iou_loss import GIoULoss, IoULoss, YOLOX_IOULoss
|
||||
from .l1_loss import L1Loss
|
||||
from .l1_loss import L1Loss, SmoothL1Loss
|
||||
from .mse_loss import JointsMSELoss
|
||||
from .ocr_rec_multi_loss import MultiLoss
|
||||
from .pytorch_metric_learning import (AMSoftmaxLoss,
|
||||
|
@ -22,5 +22,5 @@ __all__ = [
|
|||
'FocalLoss2d', 'DistributeMSELoss', 'CrossEntropyLossWithLabelSmooth',
|
||||
'AMSoftmaxLoss', 'ModelParallelSoftmaxLoss', 'ModelParallelAMSoftmaxLoss',
|
||||
'SoftTargetCrossEntropy', 'CDNCriterion', 'DNCriterion', 'DBLoss',
|
||||
'HungarianMatcher', 'SetCriterion', 'L1Loss', 'MultiLoss'
|
||||
'HungarianMatcher', 'SetCriterion', 'L1Loss', 'MultiLoss', 'SmoothL1Loss'
|
||||
]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -66,3 +67,185 @@ class L1Loss(nn.Module):
|
|||
loss_bbox = self.loss_weight * l1_loss(
|
||||
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
|
||||
return loss_bbox
|
||||
|
||||
|
||||
# @mmcv.jit(derivate=True, coderize=True)
|
||||
@weighted_loss
|
||||
def smooth_l1_loss(pred, target, beta=1.0):
|
||||
"""Smooth L1 loss.
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction.
|
||||
target (torch.Tensor): The learning target of the prediction.
|
||||
beta (float, optional): The threshold in the piecewise function.
|
||||
Defaults to 1.0.
|
||||
Returns:
|
||||
torch.Tensor: Calculated loss
|
||||
"""
|
||||
assert beta > 0
|
||||
if target.numel() == 0:
|
||||
return pred.sum() * 0
|
||||
|
||||
assert pred.size() == target.size()
|
||||
diff = torch.abs(pred - target)
|
||||
loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
|
||||
diff - 0.5 * beta)
|
||||
return loss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class SmoothL1Loss(nn.Module):
|
||||
"""Smooth L1 loss.
|
||||
Args:
|
||||
beta (float, optional): The threshold in the piecewise function.
|
||||
Defaults to 1.0.
|
||||
reduction (str, optional): The method to reduce the loss.
|
||||
Options are "none", "mean" and "sum". Defaults to "mean".
|
||||
loss_weight (float, optional): The weight of loss.
|
||||
"""
|
||||
|
||||
def __init__(self, beta=1.0, reduction='mean', loss_weight=1.0):
|
||||
super(SmoothL1Loss, self).__init__()
|
||||
self.beta = beta
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
|
||||
def forward(self,
|
||||
pred,
|
||||
target,
|
||||
weight=None,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
**kwargs):
|
||||
"""Forward function.
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction.
|
||||
target (torch.Tensor): The learning target of the prediction.
|
||||
weight (torch.Tensor, optional): The weight of loss for each
|
||||
prediction. Defaults to None.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
reduction_override (str, optional): The reduction method used to
|
||||
override the original reduction method of the loss.
|
||||
Defaults to None.
|
||||
"""
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
loss_bbox = self.loss_weight * smooth_l1_loss(
|
||||
pred,
|
||||
target,
|
||||
weight,
|
||||
beta=self.beta,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor,
|
||||
**kwargs)
|
||||
return loss_bbox
|
||||
|
||||
|
||||
@mmcv.jit(derivate=True, coderize=True)
|
||||
@weighted_loss
|
||||
def balanced_l1_loss(pred,
|
||||
target,
|
||||
beta=1.0,
|
||||
alpha=0.5,
|
||||
gamma=1.5,
|
||||
reduction='mean'):
|
||||
"""Calculate balanced L1 loss.
|
||||
Please see the `Libra R-CNN <https://arxiv.org/pdf/1904.02701.pdf>`_
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, 4).
|
||||
target (torch.Tensor): The learning target of the prediction with
|
||||
shape (N, 4).
|
||||
beta (float): The loss is a piecewise function of prediction and target
|
||||
and ``beta`` serves as a threshold for the difference between the
|
||||
prediction and target. Defaults to 1.0.
|
||||
alpha (float): The denominator ``alpha`` in the balanced L1 loss.
|
||||
Defaults to 0.5.
|
||||
gamma (float): The ``gamma`` in the balanced L1 loss.
|
||||
Defaults to 1.5.
|
||||
reduction (str, optional): The method that reduces the loss to a
|
||||
scalar. Options are "none", "mean" and "sum".
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
assert beta > 0
|
||||
if target.numel() == 0:
|
||||
return pred.sum() * 0
|
||||
|
||||
assert pred.size() == target.size()
|
||||
|
||||
diff = torch.abs(pred - target)
|
||||
b = np.e**(gamma / alpha) - 1
|
||||
loss = torch.where(
|
||||
diff < beta, alpha / b *
|
||||
(b * diff + 1) * torch.log(b * diff / beta + 1) - alpha * diff,
|
||||
gamma * diff + gamma / b - alpha * beta)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class BalancedL1Loss(nn.Module):
|
||||
"""Balanced L1 Loss.
|
||||
arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019)
|
||||
Args:
|
||||
alpha (float): The denominator ``alpha`` in the balanced L1 loss.
|
||||
Defaults to 0.5.
|
||||
gamma (float): The ``gamma`` in the balanced L1 loss. Defaults to 1.5.
|
||||
beta (float, optional): The loss is a piecewise function of prediction
|
||||
and target. ``beta`` serves as a threshold for the difference
|
||||
between the prediction and target. Defaults to 1.0.
|
||||
reduction (str, optional): The method that reduces the loss to a
|
||||
scalar. Options are "none", "mean" and "sum".
|
||||
loss_weight (float, optional): The weight of the loss. Defaults to 1.0
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
alpha=0.5,
|
||||
gamma=1.5,
|
||||
beta=1.0,
|
||||
reduction='mean',
|
||||
loss_weight=1.0):
|
||||
super(BalancedL1Loss, self).__init__()
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.beta = beta
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
|
||||
def forward(self,
|
||||
pred,
|
||||
target,
|
||||
weight=None,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
**kwargs):
|
||||
"""Forward function of loss.
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, 4).
|
||||
target (torch.Tensor): The learning target of the prediction with
|
||||
shape (N, 4).
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight with
|
||||
shape (N, ).
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
reduction_override (str, optional): The reduction method used to
|
||||
override the original reduction method of the loss.
|
||||
Options are "none", "mean" and "sum".
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
loss_bbox = self.loss_weight * balanced_l1_loss(
|
||||
pred,
|
||||
target,
|
||||
weight,
|
||||
alpha=self.alpha,
|
||||
gamma=self.gamma,
|
||||
beta=self.beta,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor,
|
||||
**kwargs)
|
||||
return loss_bbox
|
||||
|
|
|
@ -510,10 +510,10 @@ class BaseTransformerLayer(BaseModule):
|
|||
|
||||
elif layer == 'cross_attn':
|
||||
query = self.attentions[attn_index](
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
identity if self.pre_norm else None,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
identity=identity if self.pre_norm else None,
|
||||
query_pos=query_pos,
|
||||
key_pos=key_pos,
|
||||
attn_mask=attn_masks[attn_index],
|
||||
|
|
|
@ -1,15 +1,21 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.parallel import DataContainer as DC
|
||||
|
||||
from easycv.core.bbox import get_box_type
|
||||
from easycv.datasets.registry import PIPELINES
|
||||
from easycv.datasets.shared.pipelines.format import to_tensor
|
||||
from easycv.datasets.shared.pipelines.transforms import Compose
|
||||
from easycv.framework.errors import ValueError
|
||||
from easycv.predictors.base import PredictorV2
|
||||
from easycv.predictors.builder import PREDICTORS
|
||||
from easycv.utils.misc import encode_str_to_tensor
|
||||
from easycv.utils.registry import build_from_cfg
|
||||
from .base import PredictorV2
|
||||
from .builder import PREDICTORS
|
||||
|
||||
|
||||
@PREDICTORS.register_module()
|
||||
|
@ -40,8 +46,21 @@ class BEVFormerPredictor(PredictorV2):
|
|||
box_type_3d='LiDAR',
|
||||
use_camera=True,
|
||||
score_threshold=0.1,
|
||||
model_type=None,
|
||||
*arg,
|
||||
**kwargs):
|
||||
if batch_size > 1:
|
||||
raise ValueError(
|
||||
f'Only support batch_size=1 now, but get batch_size={batch_size}'
|
||||
)
|
||||
self.model_type = model_type
|
||||
if self.model_type is None:
|
||||
if model_path.endswith('jit'):
|
||||
self.model_type = 'jit'
|
||||
elif model_path.endswith('blade'):
|
||||
self.model_type = 'blade'
|
||||
self.is_jit_model = self.model_type in ['jit', 'blade']
|
||||
|
||||
super(BEVFormerPredictor, self).__init__(
|
||||
model_path,
|
||||
config_file=config_file,
|
||||
|
@ -58,6 +77,20 @@ class BEVFormerPredictor(PredictorV2):
|
|||
self.score_threshold = score_threshold
|
||||
self.result_key = 'pts_bbox'
|
||||
|
||||
# The initial prev_bev should be the weight of self.model.pts_bbox_head.bev_embedding, but the weight cannot be taken out from the blade model.
|
||||
# So we using the dummy data as the the initial value, and it will not be used, just to adapt to jit and blade models.
|
||||
# init_prev_bev = self.model.pts_bbox_head.bev_embedding.weight.clone().detach()
|
||||
# init_prev_bev = init_prev_bev[:, None, :], # [40000, 256] -> [40000, 1, 256]
|
||||
dummy_prev_bev = torch.rand(
|
||||
[self.cfg.bev_h * self.cfg.bev_w, 1,
|
||||
self.cfg.embed_dim]).to(self.device)
|
||||
self.prev_frame_info = {
|
||||
'prev_bev': dummy_prev_bev.to(self.device),
|
||||
'prev_scene_token': encode_str_to_tensor('dummy_prev_scene_token'),
|
||||
'prev_pos': torch.tensor(0),
|
||||
'prev_angle': torch.tensor(0),
|
||||
}
|
||||
|
||||
def _prepare_input_dict(self, data_info):
|
||||
from nuscenes.eval.common.utils import Quaternion, quaternion_yaw
|
||||
|
||||
|
@ -133,13 +166,85 @@ class BEVFormerPredictor(PredictorV2):
|
|||
Args:
|
||||
input (str): Pickle file path, the content format is the same with the infos file of nusences.
|
||||
"""
|
||||
data_info = mmcv.load(input)
|
||||
data_info = mmcv.load(input) if isinstance(input, str) else input
|
||||
result = self._prepare_input_dict(data_info)
|
||||
return self.processor(result)
|
||||
result = self.processor(result)
|
||||
|
||||
if self.is_jit_model:
|
||||
result['can_bus'] = DC(
|
||||
to_tensor(result['img_metas'][0]._data['can_bus']),
|
||||
cpu_only=False)
|
||||
result['lidar2img'] = DC(
|
||||
to_tensor(result['img_metas'][0]._data['lidar2img']),
|
||||
cpu_only=False)
|
||||
result['scene_token'] = DC(
|
||||
torch.tensor(
|
||||
bytearray(
|
||||
pickle.dumps(
|
||||
result['img_metas'][0]._data['scene_token'])),
|
||||
dtype=torch.uint8),
|
||||
cpu_only=False)
|
||||
result['img_shape'] = DC(
|
||||
to_tensor(result['img_metas'][0]._data['img_shape']),
|
||||
cpu_only=False)
|
||||
else:
|
||||
result['can_bus'] = DC(
|
||||
torch.stack(
|
||||
[to_tensor(result['img_metas'][0]._data['can_bus'])]),
|
||||
cpu_only=False)
|
||||
result['lidar2img'] = DC(
|
||||
torch.stack(
|
||||
[to_tensor(result['img_metas'][0]._data['lidar2img'])]),
|
||||
cpu_only=False)
|
||||
|
||||
return result
|
||||
|
||||
def postprocess_single(self, inputs, *args, **kwargs):
|
||||
# TODO: filter results by score_threshold
|
||||
return super().postprocess_single(inputs, *args, **kwargs)
|
||||
|
||||
def prepare_model(self):
|
||||
if self.is_jit_model:
|
||||
model = torch.jit.load(self.model_path, map_location=self.device)
|
||||
return model
|
||||
return super().prepare_model()
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.is_jit_model:
|
||||
with torch.no_grad():
|
||||
img = inputs['img'][0][0]
|
||||
img_metas = {
|
||||
'can_bus': inputs['can_bus'][0],
|
||||
'lidar2img': inputs['lidar2img'][0],
|
||||
'img_shape': inputs['img_shape'][0],
|
||||
'scene_token': inputs['scene_token'][0],
|
||||
'prev_bev': self.prev_frame_info['prev_bev'],
|
||||
'prev_pos': self.prev_frame_info['prev_pos'],
|
||||
'prev_angle': self.prev_frame_info['prev_angle'],
|
||||
'prev_scene_token':
|
||||
self.prev_frame_info['prev_scene_token']
|
||||
}
|
||||
inputs = (img, img_metas)
|
||||
outputs = self.model(*inputs)
|
||||
|
||||
# update prev_frame_info
|
||||
self.prev_frame_info['prev_bev'] = outputs[3][0]
|
||||
self.prev_frame_info['prev_pos'] = outputs[3][1]
|
||||
self.prev_frame_info['prev_angle'] = outputs[3][2]
|
||||
self.prev_frame_info['prev_scene_token'] = outputs[3][3]
|
||||
|
||||
outputs = {
|
||||
'pts_bbox': [{
|
||||
'scores_3d':
|
||||
outputs[0],
|
||||
'labels_3d':
|
||||
outputs[1],
|
||||
'boxes_3d':
|
||||
self.box_type_3d(outputs[2].cpu(), outputs[2].size()[-1])
|
||||
}],
|
||||
}
|
||||
return outputs
|
||||
return super().forward(inputs)
|
||||
|
||||
def visualize(self, inputs, results, out_dir, show=False, pipeline=None):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -7,7 +7,6 @@ import numpy as np
|
|||
import torch
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from easycv.apis.export import reparameterize_models
|
||||
from easycv.core.visualization import imshow_bboxes
|
||||
from easycv.datasets.registry import PIPELINES
|
||||
from easycv.datasets.utils import replace_ImageToTensor
|
||||
|
@ -198,6 +197,7 @@ class YoloXPredictor(DetectionPredictor):
|
|||
with io.open(self.model_path, 'rb') as infile:
|
||||
model = torch.jit.load(infile, self.device)
|
||||
else:
|
||||
from easycv.utils.misc import reparameterize_models
|
||||
model = super()._build_model()
|
||||
model = reparameterize_models(model)
|
||||
return model
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import time
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import torch
|
||||
|
@ -94,7 +93,6 @@ class EVRunner(EpochBasedRunner):
|
|||
self.data_loader = data_loader
|
||||
self._max_iters = self._max_epochs * len(self.data_loader)
|
||||
self.call_hook('before_train_epoch')
|
||||
time.sleep(2) # Prevent possible deadlock during epoch transition
|
||||
|
||||
for i, data_batch in enumerate(self.data_loader):
|
||||
self._inner_iter = i
|
||||
|
@ -122,7 +120,7 @@ class EVRunner(EpochBasedRunner):
|
|||
self.mode = 'val'
|
||||
self.data_loader = data_loader
|
||||
self.call_hook('before_val_epoch')
|
||||
time.sleep(2) # Prevent possible deadlock during epoch transition
|
||||
|
||||
for i, data_batch in enumerate(self.data_loader):
|
||||
self._inner_iter = i
|
||||
self.call_hook('before_val_iter')
|
||||
|
|
|
@ -10,20 +10,31 @@
|
|||
# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
def _auto_compile():
|
||||
cur_dir= os.getcwd()
|
||||
target_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
os.chdir(target_dir)
|
||||
res = subprocess.call('python setup.py build install', shell=True)
|
||||
os.chdir(cur_dir)
|
||||
return res
|
||||
|
||||
try:
|
||||
import MultiScaleDeformableAttention as MSDA
|
||||
except ModuleNotFoundError as e:
|
||||
info_string = (
|
||||
'\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n'
|
||||
'\t`cd thirdparty/deformable_attention`\n'
|
||||
'\t`python setup.py build install`\n')
|
||||
raise ModuleNotFoundError(info_string)
|
||||
res = _auto_compile()
|
||||
if res != 0:
|
||||
info_string = (
|
||||
'\n\nAuto compile failed! Please compile MultiScaleDeformableAttention CUDA op with the following commands :\n'
|
||||
'\t`cd easycv/thirdparty/deformable_attention`\n'
|
||||
'\t`python setup.py build install`\n')
|
||||
raise ModuleNotFoundError(info_string)
|
||||
|
||||
|
||||
class MSDeformAttnFunction(Function):
|
||||
|
|
|
@ -14,8 +14,21 @@
|
|||
*/
|
||||
|
||||
#include "ms_deform_attn.h"
|
||||
#include <torch/script.h>
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
|
||||
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
|
||||
}
|
||||
|
||||
inline at::Tensor ms_deform_attn(
|
||||
const at::Tensor &value,
|
||||
const at::Tensor &spatial_shapes,
|
||||
const at::Tensor &level_start_index,
|
||||
const at::Tensor &sampling_loc,
|
||||
const at::Tensor &attn_weight,
|
||||
const int64_t im2col_step) {
|
||||
return ms_deform_attn_forward(value, spatial_shapes, level_start_index,
|
||||
sampling_loc, attn_weight, im2col_step);
|
||||
}
|
||||
static auto registry = torch::RegisterOperators().op("custom::ms_deform_attn", &ms_deform_attn);
|
|
@ -0,0 +1,260 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include "pytorch_cpp_helper.hpp"
|
||||
#include "pytorch_device_registry.hpp"
|
||||
#include <torch/script.h>
|
||||
|
||||
void modulated_deformable_im2col_impl(
|
||||
const Tensor data_im, const Tensor data_offset, const Tensor data_mask,
|
||||
const int batch_size, const int channels, const int height_im,
|
||||
const int width_im, const int height_col, const int width_col,
|
||||
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w, const int dilation_h,
|
||||
const int dilation_w, const int deformable_group, Tensor data_col) {
|
||||
DISPATCH_DEVICE_IMPL(modulated_deformable_im2col_impl, data_im, data_offset,
|
||||
data_mask, batch_size, channels, height_im, width_im,
|
||||
height_col, width_col, kernel_h, kernel_w, pad_h, pad_w,
|
||||
stride_h, stride_w, dilation_h, dilation_w,
|
||||
deformable_group, data_col);
|
||||
}
|
||||
|
||||
void modulated_deformable_col2im_impl(
|
||||
const Tensor data_col, const Tensor data_offset, const Tensor data_mask,
|
||||
const int batch_size, const int channels, const int height_im,
|
||||
const int width_im, const int height_col, const int width_col,
|
||||
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
|
||||
const int stride_h, const int stride_w, const int dilation_h,
|
||||
const int dilation_w, const int deformable_group, Tensor grad_im) {
|
||||
DISPATCH_DEVICE_IMPL(modulated_deformable_col2im_impl, data_col, data_offset,
|
||||
data_mask, batch_size, channels, height_im, width_im,
|
||||
height_col, width_col, kernel_h, kernel_w, pad_h, pad_w,
|
||||
stride_h, stride_w, dilation_h, dilation_w,
|
||||
deformable_group, grad_im);
|
||||
}
|
||||
|
||||
void modulated_deformable_col2im_coord_impl(
|
||||
const Tensor data_col, const Tensor data_im, const Tensor data_offset,
|
||||
const Tensor data_mask, const int batch_size, const int channels,
|
||||
const int height_im, const int width_im, const int height_col,
|
||||
const int width_col, const int kernel_h, const int kernel_w,
|
||||
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
||||
const int dilation_h, const int dilation_w, const int deformable_group,
|
||||
Tensor grad_offset, Tensor grad_mask) {
|
||||
DISPATCH_DEVICE_IMPL(modulated_deformable_col2im_coord_impl, data_col,
|
||||
data_im, data_offset, data_mask, batch_size, channels,
|
||||
height_im, width_im, height_col, width_col, kernel_h,
|
||||
kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h,
|
||||
dilation_w, deformable_group, grad_offset, grad_mask);
|
||||
}
|
||||
|
||||
void modulated_deform_conv_forward(
|
||||
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
|
||||
Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w,
|
||||
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
|
||||
const int dilation_h, const int dilation_w, const int group,
|
||||
const int deformable_group, const bool with_bias) {
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
const int batch = input.size(0);
|
||||
const int channels = input.size(1);
|
||||
const int height = input.size(2);
|
||||
const int width = input.size(3);
|
||||
|
||||
const int channels_out = weight.size(0);
|
||||
const int channels_kernel = weight.size(1);
|
||||
const int kernel_h_ = weight.size(2);
|
||||
const int kernel_w_ = weight.size(3);
|
||||
|
||||
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
||||
AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).",
|
||||
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
||||
if (channels != channels_kernel * group)
|
||||
AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).",
|
||||
channels, channels_kernel * group);
|
||||
|
||||
const int height_out =
|
||||
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
||||
const int width_out =
|
||||
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
||||
|
||||
if (ones.ndimension() != 2 ||
|
||||
ones.size(0) * ones.size(1) < height_out * width_out) {
|
||||
// Resize plane and fill with ones...
|
||||
ones = at::ones({height_out, width_out}, input.options());
|
||||
}
|
||||
|
||||
// resize output
|
||||
output = output.view({batch, channels_out, height_out, width_out}).zero_();
|
||||
// resize temporary columns
|
||||
columns =
|
||||
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
|
||||
input.options());
|
||||
|
||||
output = output.view({output.size(0), group, output.size(1) / group,
|
||||
output.size(2), output.size(3)});
|
||||
|
||||
for (int b = 0; b < batch; b++) {
|
||||
modulated_deformable_im2col_impl(
|
||||
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
||||
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, deformable_group, columns);
|
||||
|
||||
// divide into group
|
||||
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||
weight.size(2), weight.size(3)});
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
output[b][g] = output[b][g]
|
||||
.flatten(1)
|
||||
.addmm_(weight[g].flatten(1), columns[g])
|
||||
.view_as(output[b][g]);
|
||||
}
|
||||
|
||||
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||
weight.size(3), weight.size(4)});
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
}
|
||||
|
||||
output = output.view({output.size(0), output.size(1) * output.size(2),
|
||||
output.size(3), output.size(4)});
|
||||
|
||||
if (with_bias) {
|
||||
output += bias.view({1, bias.size(0), 1, 1});
|
||||
}
|
||||
}
|
||||
|
||||
void modulated_deform_conv_backward(
|
||||
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
|
||||
Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight,
|
||||
Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output,
|
||||
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
||||
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
||||
const bool with_bias) {
|
||||
at::DeviceGuard guard(input.device());
|
||||
|
||||
const int batch = input.size(0);
|
||||
const int channels = input.size(1);
|
||||
const int height = input.size(2);
|
||||
const int width = input.size(3);
|
||||
|
||||
const int channels_kernel = weight.size(1);
|
||||
const int kernel_h_ = weight.size(2);
|
||||
const int kernel_w_ = weight.size(3);
|
||||
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
||||
AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).",
|
||||
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
||||
if (channels != channels_kernel * group)
|
||||
AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).",
|
||||
channels, channels_kernel * group);
|
||||
|
||||
const int height_out =
|
||||
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
||||
const int width_out =
|
||||
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
||||
|
||||
if (ones.ndimension() != 2 ||
|
||||
ones.size(0) * ones.size(1) < height_out * width_out) {
|
||||
// Resize plane and fill with ones...
|
||||
ones = at::ones({height_out, width_out}, input.options());
|
||||
}
|
||||
|
||||
grad_input = grad_input.view({batch, channels, height, width});
|
||||
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
|
||||
input.options());
|
||||
|
||||
grad_output =
|
||||
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
|
||||
grad_output.size(2), grad_output.size(3)});
|
||||
|
||||
for (int b = 0; b < batch; b++) {
|
||||
// divide int group
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
||||
weight.size(2), weight.size(3)});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
||||
grad_output[b][g].flatten(1), 0.0f, 1.0f);
|
||||
}
|
||||
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
||||
weight.size(3), weight.size(4)});
|
||||
|
||||
// gradient w.r.t. input coordinate data
|
||||
modulated_deformable_col2im_coord_impl(
|
||||
columns, input[b], offset[b], mask[b], 1, channels, height, width,
|
||||
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
|
||||
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
|
||||
grad_mask[b]);
|
||||
// gradient w.r.t. input data
|
||||
modulated_deformable_col2im_impl(
|
||||
columns, offset[b], mask[b], 1, channels, height, width, height_out,
|
||||
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, deformable_group, grad_input[b]);
|
||||
|
||||
// gradient w.r.t. weight, dWeight should accumulate across the batch and
|
||||
// group
|
||||
modulated_deformable_im2col_impl(
|
||||
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
||||
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
||||
dilation_h, dilation_w, deformable_group, columns);
|
||||
|
||||
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
||||
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
|
||||
grad_weight.size(1), grad_weight.size(2),
|
||||
grad_weight.size(3)});
|
||||
if (with_bias)
|
||||
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
|
||||
|
||||
for (int g = 0; g < group; g++) {
|
||||
grad_weight[g] =
|
||||
grad_weight[g]
|
||||
.flatten(1)
|
||||
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
|
||||
.view_as(grad_weight[g]);
|
||||
if (with_bias) {
|
||||
grad_bias[g] =
|
||||
grad_bias[g]
|
||||
.view({-1, 1})
|
||||
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
|
||||
.view(-1);
|
||||
}
|
||||
}
|
||||
|
||||
columns =
|
||||
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
||||
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
|
||||
grad_weight.size(2), grad_weight.size(3),
|
||||
grad_weight.size(4)});
|
||||
if (with_bias)
|
||||
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
|
||||
}
|
||||
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
|
||||
grad_output.size(2), grad_output.size(3),
|
||||
grad_output.size(4)});
|
||||
}
|
||||
|
||||
at::Tensor modulated_deform_conv(
|
||||
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
|
||||
Tensor mask, Tensor output, Tensor columns, int64_t kernel_h, int64_t kernel_w,
|
||||
const int64_t stride_h, const int64_t stride_w, const int64_t pad_h, const int64_t pad_w,
|
||||
const int64_t dilation_h, const int64_t dilation_w, const int64_t group,
|
||||
const int64_t deformable_group, const bool with_bias) {
|
||||
modulated_deform_conv_forward(input, weight, bias, ones, offset, mask, output, columns,
|
||||
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w,
|
||||
group, deformable_group, with_bias);
|
||||
return output;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(mmcv, m) {
|
||||
m.def(R"SIG(modulated_deform_conv(Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
|
||||
Tensor mask, Tensor(a!) output, Tensor columns, *, int kernel_h, int kernel_w,
|
||||
int stride_h, int stride_w, int pad_h, int pad_w,
|
||||
int dilation_h, int dilation_w, int group,
|
||||
int deformable_group, bool with_bias) -> Tensor(a!))SIG", modulated_deform_conv);
|
||||
}
|
||||
|
||||
// static auto registry = torch::RegisterOperators().op("mmcv::modulated_deform_conv", &modulated_deform_conv);
|
|
@ -0,0 +1,389 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
from torch.nn.modules.utils import _pair, _single
|
||||
|
||||
from mmcv.utils import deprecated_api_warning
|
||||
from ..cnn import CONV_LAYERS
|
||||
from ..utils import ext_loader, print_log
|
||||
|
||||
ext_module = ext_loader.load_ext(
|
||||
'_ext',
|
||||
['modulated_deform_conv_forward', 'modulated_deform_conv_backward'])
|
||||
|
||||
|
||||
class ModulatedDeformConv2dFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, input, offset, mask, weight, bias, stride, padding,
|
||||
dilation, groups, deform_groups):
|
||||
input_tensors = [input, offset, mask, weight]
|
||||
if bias is not None:
|
||||
input_tensors.append(bias)
|
||||
return g.op(
|
||||
'mmcv::MMCVModulatedDeformConv2d',
|
||||
*input_tensors,
|
||||
stride_i=stride,
|
||||
padding_i=padding,
|
||||
dilation_i=dilation,
|
||||
groups_i=groups,
|
||||
deform_groups_i=deform_groups)
|
||||
|
||||
@staticmethod
|
||||
def _jit_forward(
|
||||
input,
|
||||
offset,
|
||||
mask,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
deform_groups=1):
|
||||
if input is not None and input.dim() != 4:
|
||||
raise ValueError(
|
||||
f'Expected 4D tensor as input, got {input.dim()}D tensor \
|
||||
instead.')
|
||||
with_bias = bias is not None
|
||||
if not bias:
|
||||
bias = input.new_empty(0) # fake tensor
|
||||
# When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
|
||||
# amp won't cast the type of model (float32), but "offset" is cast
|
||||
# to float16 by nn.Conv2d automatically, leading to the type
|
||||
# mismatch with input (when it is float32) or weight.
|
||||
# The flag for whether to use fp16 or amp is the type of "offset",
|
||||
# we cast weight and input to temporarily support fp16 and amp
|
||||
# whatever the pytorch version is.
|
||||
|
||||
def _output_size(input, weight):
|
||||
channels = weight.size(0)
|
||||
output_size = (input.size(0), channels)
|
||||
for d in range(input.dim() - 2):
|
||||
in_size = input.size(d + 2)
|
||||
pad = padding[d]
|
||||
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
|
||||
stride_ = stride[d]
|
||||
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
|
||||
if not all(map(lambda s: s > 0, output_size)):
|
||||
raise ValueError(
|
||||
'convolution input is too small (output would be ' +
|
||||
'x'.join(map(str, output_size)) + ')')
|
||||
return output_size
|
||||
|
||||
input = input.type_as(offset)
|
||||
weight = weight.type_as(input)
|
||||
output = input.new_empty(
|
||||
_output_size(input, weight))
|
||||
_bufs = [input.new_empty(0), input.new_empty(0)]
|
||||
if weight.dtype == torch.float16:
|
||||
output = torch.ops.mmcv.modulated_deform_conv(
|
||||
input.to(torch.float32),
|
||||
weight.to(torch.float32),
|
||||
bias.to(torch.float32),
|
||||
_bufs[0].to(torch.float32),
|
||||
offset.to(torch.float32),
|
||||
mask.to(torch.float32),
|
||||
output.to(torch.float32),
|
||||
_bufs[1].to(torch.float32),
|
||||
kernel_h=weight.size(2),
|
||||
kernel_w=weight.size(3),
|
||||
stride_h=stride[0],
|
||||
stride_w=stride[1],
|
||||
pad_h=padding[0],
|
||||
pad_w=padding[1],
|
||||
dilation_h=dilation[0],
|
||||
dilation_w=dilation[1],
|
||||
group=groups,
|
||||
deformable_group=deform_groups,
|
||||
with_bias=with_bias)
|
||||
output = output.to(torch.float16)
|
||||
else:
|
||||
output = torch.ops.mmcv.modulated_deform_conv(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
_bufs[0],
|
||||
offset,
|
||||
mask,
|
||||
output,
|
||||
_bufs[1],
|
||||
kernel_h=weight.size(2),
|
||||
kernel_w=weight.size(3),
|
||||
stride_h=stride[0],
|
||||
stride_w=stride[1],
|
||||
pad_h=padding[0],
|
||||
pad_w=padding[1],
|
||||
dilation_h=dilation[0],
|
||||
dilation_w=dilation[1],
|
||||
group=groups,
|
||||
deformable_group=deform_groups,
|
||||
with_bias=with_bias)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
input: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
weight: nn.Parameter,
|
||||
bias: Optional[nn.Parameter] = None,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
deform_groups: int = 1) -> torch.Tensor:
|
||||
if input is not None and input.dim() != 4:
|
||||
raise ValueError(
|
||||
f'Expected 4D tensor as input, got {input.dim()}D tensor \
|
||||
instead.')
|
||||
ctx.stride = _pair(stride)
|
||||
ctx.padding = _pair(padding)
|
||||
ctx.dilation = _pair(dilation)
|
||||
ctx.groups = groups
|
||||
ctx.deform_groups = deform_groups
|
||||
ctx.with_bias = bias is not None
|
||||
if not ctx.with_bias:
|
||||
bias = input.new_empty(0) # fake tensor
|
||||
# When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
|
||||
# amp won't cast the type of model (float32), but "offset" is cast
|
||||
# to float16 by nn.Conv2d automatically, leading to the type
|
||||
# mismatch with input (when it is float32) or weight.
|
||||
# The flag for whether to use fp16 or amp is the type of "offset",
|
||||
# we cast weight and input to temporarily support fp16 and amp
|
||||
# whatever the pytorch version is.
|
||||
input = input.type_as(offset)
|
||||
weight = weight.type_as(input)
|
||||
bias = bias.type_as(input) # type: ignore
|
||||
ctx.save_for_backward(input, offset, mask, weight, bias)
|
||||
output = input.new_empty(
|
||||
ModulatedDeformConv2dFunction._output_size(ctx, input, weight))
|
||||
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
|
||||
ext_module.modulated_deform_conv_forward(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
ctx._bufs[0],
|
||||
offset,
|
||||
mask,
|
||||
output,
|
||||
ctx._bufs[1],
|
||||
kernel_h=weight.size(2),
|
||||
kernel_w=weight.size(3),
|
||||
stride_h=ctx.stride[0],
|
||||
stride_w=ctx.stride[1],
|
||||
pad_h=ctx.padding[0],
|
||||
pad_w=ctx.padding[1],
|
||||
dilation_h=ctx.dilation[0],
|
||||
dilation_w=ctx.dilation[1],
|
||||
group=ctx.groups,
|
||||
deformable_group=ctx.deform_groups,
|
||||
with_bias=ctx.with_bias)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_output: torch.Tensor) -> tuple:
|
||||
input, offset, mask, weight, bias = ctx.saved_tensors
|
||||
grad_input = torch.zeros_like(input)
|
||||
grad_offset = torch.zeros_like(offset)
|
||||
grad_mask = torch.zeros_like(mask)
|
||||
grad_weight = torch.zeros_like(weight)
|
||||
grad_bias = torch.zeros_like(bias)
|
||||
grad_output = grad_output.contiguous()
|
||||
ext_module.modulated_deform_conv_backward(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
ctx._bufs[0],
|
||||
offset,
|
||||
mask,
|
||||
ctx._bufs[1],
|
||||
grad_input,
|
||||
grad_weight,
|
||||
grad_bias,
|
||||
grad_offset,
|
||||
grad_mask,
|
||||
grad_output,
|
||||
kernel_h=weight.size(2),
|
||||
kernel_w=weight.size(3),
|
||||
stride_h=ctx.stride[0],
|
||||
stride_w=ctx.stride[1],
|
||||
pad_h=ctx.padding[0],
|
||||
pad_w=ctx.padding[1],
|
||||
dilation_h=ctx.dilation[0],
|
||||
dilation_w=ctx.dilation[1],
|
||||
group=ctx.groups,
|
||||
deformable_group=ctx.deform_groups,
|
||||
with_bias=ctx.with_bias)
|
||||
if not ctx.with_bias:
|
||||
grad_bias = None
|
||||
|
||||
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
|
||||
None, None, None, None, None)
|
||||
|
||||
@staticmethod
|
||||
def _output_size(ctx, input, weight):
|
||||
channels = weight.size(0)
|
||||
output_size = (input.size(0), channels)
|
||||
for d in range(input.dim() - 2):
|
||||
in_size = input.size(d + 2)
|
||||
pad = ctx.padding[d]
|
||||
kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1
|
||||
stride_ = ctx.stride[d]
|
||||
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
|
||||
if not all(map(lambda s: s > 0, output_size)):
|
||||
raise ValueError(
|
||||
'convolution input is too small (output would be ' +
|
||||
'x'.join(map(str, output_size)) + ')')
|
||||
return output_size
|
||||
|
||||
|
||||
modulated_deform_conv2d = ModulatedDeformConv2dFunction.apply
|
||||
|
||||
|
||||
class ModulatedDeformConv2d(nn.Module):
|
||||
|
||||
@deprecated_api_warning({'deformable_groups': 'deform_groups'},
|
||||
cls_name='ModulatedDeformConv2d')
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int]],
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
deform_groups: int = 1,
|
||||
bias: Union[bool, str] = True):
|
||||
super(ModulatedDeformConv2d, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = _pair(kernel_size)
|
||||
self.stride = _pair(stride)
|
||||
self.padding = _pair(padding)
|
||||
self.dilation = _pair(dilation)
|
||||
self.groups = groups
|
||||
self.deform_groups = deform_groups
|
||||
# enable compatibility with nn.Conv2d
|
||||
self.transposed = False
|
||||
self.output_padding = _single(0)
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.Tensor(out_channels, in_channels // groups,
|
||||
*self.kernel_size))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
n = self.in_channels
|
||||
for k in self.kernel_size:
|
||||
n *= k
|
||||
stdv = 1. / math.sqrt(n)
|
||||
self.weight.data.uniform_(-stdv, stdv)
|
||||
if self.bias is not None:
|
||||
self.bias.data.zero_()
|
||||
|
||||
def forward(self, x: torch.Tensor, offset: torch.Tensor,
|
||||
mask: torch.Tensor) -> torch.Tensor:
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return ModulatedDeformConv2dFunction._jit_forward(
|
||||
x, offset, mask, self.weight, self.bias,
|
||||
self.stride, self.padding,
|
||||
self.dilation, self.groups,
|
||||
self.deform_groups)
|
||||
return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
|
||||
self.stride, self.padding,
|
||||
self.dilation, self.groups,
|
||||
self.deform_groups)
|
||||
|
||||
|
||||
@CONV_LAYERS.register_module('DCNv2')
|
||||
class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
|
||||
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv
|
||||
layers.
|
||||
|
||||
Args:
|
||||
in_channels (int): Same as nn.Conv2d.
|
||||
out_channels (int): Same as nn.Conv2d.
|
||||
kernel_size (int or tuple[int]): Same as nn.Conv2d.
|
||||
stride (int): Same as nn.Conv2d, while tuple is not supported.
|
||||
padding (int): Same as nn.Conv2d, while tuple is not supported.
|
||||
dilation (int): Same as nn.Conv2d, while tuple is not supported.
|
||||
groups (int): Same as nn.Conv2d.
|
||||
bias (bool or str): If specified as `auto`, it will be decided by the
|
||||
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
|
||||
False.
|
||||
"""
|
||||
|
||||
_version = 2
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ModulatedDeformConv2dPack, self).__init__(*args, **kwargs)
|
||||
self.conv_offset = nn.Conv2d(
|
||||
self.in_channels,
|
||||
self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
|
||||
kernel_size=self.kernel_size,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.dilation,
|
||||
bias=True)
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self) -> None:
|
||||
super(ModulatedDeformConv2dPack, self).init_weights()
|
||||
if hasattr(self, 'conv_offset'):
|
||||
self.conv_offset.weight.data.zero_()
|
||||
self.conv_offset.bias.data.zero_()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
|
||||
out = self.conv_offset(x)
|
||||
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
||||
offset = torch.cat((o1, o2), dim=1)
|
||||
mask = torch.sigmoid(mask)
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return ModulatedDeformConv2dFunction._jit_forward(
|
||||
x, offset, mask, self.weight, self.bias,
|
||||
self.stride, self.padding,
|
||||
self.dilation, self.groups,
|
||||
self.deform_groups)
|
||||
return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
|
||||
self.stride, self.padding,
|
||||
self.dilation, self.groups,
|
||||
self.deform_groups)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
version = local_metadata.get('version', None)
|
||||
|
||||
if version is None or version < 2:
|
||||
# the key is different in early versions
|
||||
# In version < 2, ModulatedDeformConvPack
|
||||
# loads previous benchmark models.
|
||||
if (prefix + 'conv_offset.weight' not in state_dict
|
||||
and prefix[:-1] + '_offset.weight' in state_dict):
|
||||
state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
|
||||
prefix[:-1] + '_offset.weight')
|
||||
if (prefix + 'conv_offset.bias' not in state_dict
|
||||
and prefix[:-1] + '_offset.bias' in state_dict):
|
||||
state_dict[prefix +
|
||||
'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
|
||||
'_offset.bias')
|
||||
|
||||
if version is not None and version > 1:
|
||||
print_log(
|
||||
f'ModulatedDeformConvPack {prefix.rstrip(".")} is upgraded to '
|
||||
'version 2.',
|
||||
logger='root')
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys,
|
||||
error_msgs)
|
|
@ -1,11 +1,10 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import ctypes
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import timeit
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
@ -14,7 +13,6 @@ import pandas as pd
|
|||
import torch
|
||||
import torch_blade
|
||||
import torch_blade.tensorrt
|
||||
import torchvision
|
||||
from torch_blade import optimize
|
||||
|
||||
from easycv.framework.errors import RuntimeError
|
||||
|
@ -80,6 +78,7 @@ def opt_trt_config(
|
|||
# 'aten::select', 'aten::index', 'aten::slice', 'aten::view', 'aten::upsample'
|
||||
],
|
||||
fp16_fallback_op_ratio=0.05,
|
||||
preserved_attributes=[],
|
||||
)
|
||||
BLADE_CONFIG_KEYS = list(BLADE_CONFIG_DEFAULT.keys())
|
||||
|
||||
|
@ -185,24 +184,41 @@ def computeStats(backend, timings, batch_size=1, model_name='default'):
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark(model, inp, backend, batch_size, model_name='default', num=200):
|
||||
def benchmark(model,
|
||||
inputs,
|
||||
backend,
|
||||
batch_size,
|
||||
model_name='default',
|
||||
num_iters=200,
|
||||
warmup_iters=5,
|
||||
fp16=False):
|
||||
"""
|
||||
evaluate the time and speed of different models
|
||||
|
||||
Args:
|
||||
model: input model
|
||||
inp: input of the model
|
||||
inputs: input of the model
|
||||
backend (str): backend name
|
||||
batch_size (int): image batch
|
||||
model_name (str): tested model name
|
||||
num: test forward times
|
||||
num_iters: test forward times
|
||||
"""
|
||||
for _ in range(warmup_iters):
|
||||
if fp16:
|
||||
with torch.cuda.amp.autocast():
|
||||
model(*copy.deepcopy(inputs))
|
||||
else:
|
||||
model(*copy.deepcopy(inputs))
|
||||
|
||||
torch.cuda.synchronize()
|
||||
timings = []
|
||||
for i in range(num):
|
||||
for i in range(num_iters):
|
||||
start_time = timeit.default_timer()
|
||||
model(*inp)
|
||||
if fp16:
|
||||
with torch.cuda.amp.autocast():
|
||||
model(*copy.deepcopy(inputs))
|
||||
else:
|
||||
model(*copy.deepcopy(inputs))
|
||||
torch.cuda.synchronize()
|
||||
end_time = timeit.default_timer()
|
||||
meas_time = end_time - start_time
|
||||
|
@ -246,40 +262,49 @@ def blade_optimize(speed_test_model,
|
|||
enable_fp16=True, fp16_fallback_op_ratio=0.05),
|
||||
backend='TensorRT',
|
||||
batch=1,
|
||||
warm_up_time=10,
|
||||
warmup_iters=10,
|
||||
compute_cost=True,
|
||||
use_profile=False,
|
||||
check_result=False,
|
||||
static_opt=True):
|
||||
static_opt=True,
|
||||
min_num_nodes=None,
|
||||
check_inputs=True,
|
||||
fp16=False):
|
||||
|
||||
if not static_opt:
|
||||
logging.info(
|
||||
'PAI-Blade use dynamic optimize for input model, export model is build for dynamic shape input'
|
||||
)
|
||||
with opt_trt_config(blade_config):
|
||||
opt_model = optimize(
|
||||
model,
|
||||
allow_tracing=True,
|
||||
model_inputs=tuple(inputs),
|
||||
)
|
||||
optimize_op = optimize
|
||||
else:
|
||||
logging.info(
|
||||
'PAI-Blade use static optimize for input model, export model must be used as static shape input'
|
||||
)
|
||||
from torch_blade.optimization import _static_optimize
|
||||
optimize_op = _static_optimize
|
||||
if min_num_nodes is not None:
|
||||
import torch_blade.clustering.support_fusion_group as blade_fusion
|
||||
with blade_fusion.min_group_nodes(min_num_nodes=min_num_nodes):
|
||||
with opt_trt_config(blade_config):
|
||||
opt_model = optimize_op(
|
||||
model,
|
||||
allow_tracing=True,
|
||||
model_inputs=tuple(copy.deepcopy(inputs)),
|
||||
)
|
||||
else:
|
||||
with opt_trt_config(blade_config):
|
||||
opt_model = _static_optimize(
|
||||
opt_model = optimize_op(
|
||||
model,
|
||||
allow_tracing=True,
|
||||
model_inputs=tuple(inputs),
|
||||
model_inputs=tuple(copy.deepcopy(inputs)),
|
||||
)
|
||||
|
||||
if compute_cost:
|
||||
logging.info('Running benchmark...')
|
||||
results = []
|
||||
inputs_t = inputs
|
||||
inputs_t = copy.deepcopy(inputs)
|
||||
|
||||
# end2end model and scripts needs different channel purmulate, encounter this problem only when we use end2end export
|
||||
if (inputs_t[0].shape[-1] == 3):
|
||||
if check_inputs and (inputs_t[0].shape[-1] == 3):
|
||||
shape_length = len(inputs_t[0].shape)
|
||||
if shape_length == 4:
|
||||
inputs_t = inputs_t[0].permute(0, 3, 1, 2)
|
||||
|
@ -290,45 +315,67 @@ def blade_optimize(speed_test_model,
|
|||
inputs_t = (torch.unsqueeze(inputs_t, 0), )
|
||||
|
||||
results.append(
|
||||
benchmark(speed_test_model, inputs_t, backend, batch, 'easycv'))
|
||||
benchmark(
|
||||
speed_test_model,
|
||||
inputs_t,
|
||||
backend,
|
||||
batch,
|
||||
'easycv',
|
||||
warmup_iters=warmup_iters,
|
||||
fp16=fp16))
|
||||
results.append(
|
||||
benchmark(model, inputs, backend, batch, 'easycv script'))
|
||||
results.append(benchmark(opt_model, inputs, backend, batch, 'blade'))
|
||||
benchmark(
|
||||
model,
|
||||
copy.deepcopy(inputs),
|
||||
backend,
|
||||
batch,
|
||||
'easycv script',
|
||||
warmup_iters=warmup_iters,
|
||||
fp16=fp16))
|
||||
results.append(
|
||||
benchmark(
|
||||
opt_model,
|
||||
copy.deepcopy(inputs),
|
||||
backend,
|
||||
batch,
|
||||
'blade',
|
||||
warmup_iters=warmup_iters,
|
||||
fp16=fp16))
|
||||
|
||||
logging.info('Model Summary:')
|
||||
summary = pd.DataFrame(results)
|
||||
logging.warning(summary.to_markdown())
|
||||
print(summary.to_markdown())
|
||||
|
||||
if use_profile:
|
||||
torch.cuda.empty_cache()
|
||||
# warm-up
|
||||
for k in range(warm_up_time):
|
||||
test_result = opt_model(*inputs)
|
||||
for k in range(warmup_iters):
|
||||
test_result = opt_model(*copy.deepcopy(inputs))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
cu_prof_start()
|
||||
for k in range(warm_up_time):
|
||||
test_result = opt_model(*inputs)
|
||||
for k in range(warmup_iters):
|
||||
test_result = opt_model(*copy.deepcopy(inputs))
|
||||
torch.cuda.synchronize()
|
||||
cu_prof_stop()
|
||||
import torch.autograd.profiler as profiler
|
||||
with profiler.profile(use_cuda=True) as prof:
|
||||
for k in range(warm_up_time):
|
||||
test_result = opt_model(*inputs)
|
||||
for k in range(warmup_iters):
|
||||
test_result = opt_model(*copy.deepcopy(inputs))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with profiler.profile(use_cuda=True) as prof:
|
||||
for k in range(warm_up_time):
|
||||
test_result = opt_model(*inputs)
|
||||
for k in range(warmup_iters):
|
||||
test_result = opt_model(*copy.deepcopy(inputs))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
prof_str = prof.key_averages().table(sort_by='cuda_time_total')
|
||||
print(f'{prof_str}')
|
||||
|
||||
if check_result:
|
||||
output = model(*inputs)
|
||||
test_result = opt_model(*inputs)
|
||||
output = model(*copy.deepcopy(inputs))
|
||||
test_result = opt_model(*copy.deepcopy(inputs))
|
||||
check_results(output, test_result)
|
||||
|
||||
return opt_model
|
||||
|
|
|
@ -2,10 +2,14 @@
|
|||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import pickle
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from easycv.framework.errors import ValueError
|
||||
|
||||
|
||||
def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
|
||||
|
@ -99,3 +103,21 @@ def deprecated(reason):
|
|||
return new_func1
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def encode_str_to_tensor(obj):
|
||||
if isinstance(obj, str):
|
||||
return torch.tensor(bytearray(pickle.dumps(obj)), dtype=torch.uint8)
|
||||
elif isinstance(obj, torch.Tensor):
|
||||
return obj
|
||||
else:
|
||||
raise ValueError(f'Not support type {type(obj)}')
|
||||
|
||||
|
||||
def decode_tensor_to_str(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return pickle.loads(obj.cpu().numpy().tobytes())
|
||||
elif isinstance(obj, str):
|
||||
return obj
|
||||
else:
|
||||
raise ValueError(f'Not support type {type(obj)}')
|
||||
|
|
|
@ -373,3 +373,17 @@ def remove_adapt_for_mmlab(cfg):
|
|||
mmlab_modules_cfg = cfg.get('mmlab_modules', [])
|
||||
adapter = MMAdapter(mmlab_modules_cfg)
|
||||
adapter.reset_mm_registry()
|
||||
|
||||
|
||||
def fix_dc_pin_memory():
|
||||
"""Fix pin memory for DataContainer."""
|
||||
from mmcv.parallel import DataContainer as DC
|
||||
from torch.utils.data._utils.pin_memory import pin_memory
|
||||
|
||||
def data_container_pin_memory(self):
|
||||
if self.cpu_only:
|
||||
return self
|
||||
self._data = pin_memory(self._data)
|
||||
return self
|
||||
|
||||
setattr(DC, 'pin_memory', data_container_pin_memory)
|
||||
|
|
|
@ -1,17 +1,21 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tests.ut_config import (IMAGENET_LABEL_TXT, PRETRAINED_MODEL_MOCO,
|
||||
PRETRAINED_MODEL_RESNET50,
|
||||
from tests.ut_config import (IMAGENET_LABEL_TXT,
|
||||
PRETRAINED_MODEL_BEVFORMER_BASE,
|
||||
PRETRAINED_MODEL_MOCO, PRETRAINED_MODEL_RESNET50,
|
||||
PRETRAINED_MODEL_YOLOXS_EXPORT)
|
||||
|
||||
import easycv
|
||||
from easycv.apis.export import export
|
||||
from easycv.file import io
|
||||
from easycv.utils.config_tools import mmcv_config_fromfile
|
||||
from easycv.utils.test_util import clean_up, get_tmp_dir
|
||||
|
||||
|
@ -126,6 +130,41 @@ class ModelExportTest(unittest.TestCase):
|
|||
self.assertTrue(
|
||||
export_config['model']['backbone']['norm_cfg']['type'] == 'BN')
|
||||
|
||||
@unittest.skipIf(torch.__version__ != '1.8.1+cu102',
|
||||
'need another environment where mmcv has been recompiled')
|
||||
def test_export_bevformer_jit(self):
|
||||
ckpt_path = PRETRAINED_MODEL_BEVFORMER_BASE
|
||||
|
||||
easycv_dir = os.path.dirname(easycv.__file__)
|
||||
if os.path.exists(os.path.join(easycv_dir, 'configs')):
|
||||
config_dir = os.path.join(easycv_dir, 'configs')
|
||||
else:
|
||||
config_dir = os.path.join(os.path.dirname(easycv_dir), 'configs')
|
||||
config_file = os.path.join(
|
||||
config_dir,
|
||||
'detection3d/bevformer/bevformer_base_r101_dcn_nuscenes.py')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with io.open(config_file, 'r') as f:
|
||||
cfg_str = f.read()
|
||||
new_config_path = os.path.join(tmpdir, 'new_config.py')
|
||||
# find first adapt_jit and replace value
|
||||
res = re.search(r'adapt_jit(\s*)=(\s*)False', cfg_str)
|
||||
if res is not None:
|
||||
cfg_str_list = list(cfg_str)
|
||||
cfg_str_list[res.span()[0]:res.span()[1]] = 'adapt_jit = True'
|
||||
cfg_str = ''.join(cfg_str_list)
|
||||
with io.open(new_config_path, 'w') as f:
|
||||
f.write(cfg_str)
|
||||
|
||||
cfg = mmcv_config_fromfile(new_config_path)
|
||||
cfg.export.type = 'jit'
|
||||
|
||||
filename = os.path.join(tmpdir, 'model.pth')
|
||||
export(cfg, ckpt_path, filename, fp16=False)
|
||||
|
||||
self.assertTrue(os.path.exists(filename + '.jit'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from tests.ut_config import TEST_IMAGES_DIR
|
||||
|
||||
from easycv.file.image import load_image
|
||||
|
||||
|
||||
class LoadImageTest(unittest.TestCase):
|
||||
img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
|
||||
img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/unittest/local_backup/easycv_nfs/data/test_images/000000289059.jpg'
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def test_backend_pillow(self):
|
||||
img = load_image(
|
||||
self.img_path, mode='BGR', dtype=np.float32, backend='pillow')
|
||||
self.assertEqual(img.shape, (480, 640, 3))
|
||||
self.assertEqual(img.dtype, np.float32)
|
||||
self.assertEqual(list(img[0][0]), [145, 92, 59])
|
||||
|
||||
def test_backend_cv2(self):
|
||||
img = load_image(self.img_path, mode='RGB', backend='cv2')
|
||||
self.assertEqual(img.shape, (480, 640, 3))
|
||||
self.assertEqual(img.dtype, np.uint8)
|
||||
self.assertEqual(list(img[0][0]), [59, 92, 145])
|
||||
|
||||
def test_backend_turbojpeg(self):
|
||||
img = load_image(
|
||||
self.img_path, mode='RGB', dtype=np.float32, backend='turbojpeg')
|
||||
self.assertEqual(img.shape, (480, 640, 3))
|
||||
self.assertEqual(img.dtype, np.float32)
|
||||
self.assertEqual(list(img[0][0]), [59, 92, 145])
|
||||
|
||||
def test_url_path_cv2(self):
|
||||
img = load_image(self.img_url, mode='BGR', backend='cv2')
|
||||
self.assertEqual(img.shape, (480, 640, 3))
|
||||
self.assertEqual(img.dtype, np.uint8)
|
||||
self.assertEqual(list(img[0][0]), [145, 92, 59])
|
||||
|
||||
def test_url_path_pillow(self):
|
||||
img = load_image(self.img_url, mode='RGB', backend='pillow')
|
||||
self.assertEqual(img.shape, (480, 640, 3))
|
||||
self.assertEqual(img.dtype, np.uint8)
|
||||
self.assertEqual(list(img[0][0]), [59, 92, 145])
|
||||
|
||||
def test_url_path_turbojpeg(self):
|
||||
img = load_image(self.img_url, mode='BGR', backend='turbojpeg')
|
||||
self.assertEqual(img.shape, (480, 640, 3))
|
||||
self.assertEqual(img.dtype, np.uint8)
|
||||
self.assertEqual(list(img[0][0]), [145, 92, 59])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,7 +1,10 @@
|
|||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.testing import assert_array_almost_equal
|
||||
|
@ -9,7 +12,12 @@ from tests.ut_config import (PRETRAINED_MODEL_BEVFORMER_BASE,
|
|||
SMALL_NUSCENES_PATH)
|
||||
|
||||
import easycv
|
||||
from easycv.apis.export import export
|
||||
from easycv.core.evaluation.builder import build_evaluator
|
||||
from easycv.datasets import build_dataset
|
||||
from easycv.file import io
|
||||
from easycv.predictors import BEVFormerPredictor
|
||||
from easycv.utils.config_tools import mmcv_config_fromfile
|
||||
|
||||
|
||||
class BEVFormerPredictorTest(unittest.TestCase):
|
||||
|
@ -17,7 +25,7 @@ class BEVFormerPredictorTest(unittest.TestCase):
|
|||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def _assert_results(self, results, assert_value=True):
|
||||
def _assert_results(self, results):
|
||||
res = results['pts_bbox']
|
||||
self.assertEqual(res['scores_3d'].shape, torch.Size([300]))
|
||||
self.assertEqual(res['labels_3d'].shape, torch.Size([300]))
|
||||
|
@ -40,88 +48,7 @@ class BEVFormerPredictorTest(unittest.TestCase):
|
|||
self.assertEqual(res['boxes_3d'].volume.shape, torch.Size([300]))
|
||||
self.assertEqual(res['boxes_3d'].yaw.shape, torch.Size([300]))
|
||||
|
||||
if not assert_value:
|
||||
return
|
||||
|
||||
assert_array_almost_equal(
|
||||
res['scores_3d'][:5].numpy(),
|
||||
np.array([0.982, 0.982, 0.982, 0.982, 0.981], dtype=np.float32),
|
||||
decimal=3)
|
||||
assert_array_almost_equal(res['labels_3d'][:10].numpy(),
|
||||
np.array([5, 5, 5, 5, 5, 5, 5, 5, 5, 5]))
|
||||
assert_array_almost_equal(
|
||||
res['boxes_3d'].bev[:2].numpy(),
|
||||
np.array([[9.341, -2.664, 2.034, 0.657, 1.819],
|
||||
[6.945, -18.833, 2.047, 0.661, 1.694]],
|
||||
dtype=np.float32),
|
||||
decimal=3)
|
||||
assert_array_almost_equal(
|
||||
res['boxes_3d'].bottom_center[:2].numpy(),
|
||||
np.array([[9.341, -2.664, -1.849], [6.945, -18.833, -2.295]],
|
||||
dtype=np.float32),
|
||||
decimal=3)
|
||||
assert_array_almost_equal(
|
||||
res['boxes_3d'].bottom_height[:5].numpy(),
|
||||
np.array([-1.849, -2.332, -2.295, -1.508, -1.204],
|
||||
dtype=np.float32),
|
||||
decimal=1)
|
||||
assert_array_almost_equal(
|
||||
res['boxes_3d'].center[:2].numpy(),
|
||||
np.array([[9.341, -2.664, -1.849], [6.945, -18.833, -2.295]],
|
||||
dtype=np.float32),
|
||||
decimal=3)
|
||||
assert_array_almost_equal(
|
||||
res['boxes_3d'].corners[:1][0][:3].numpy(),
|
||||
np.array([[9.91, -3.569, -1.849], [9.91, -3.569, -0.742],
|
||||
[9.273, -3.73, -0.742]],
|
||||
dtype=np.float32),
|
||||
decimal=3)
|
||||
assert_array_almost_equal(
|
||||
res['boxes_3d'].dims[:2].numpy(),
|
||||
np.array([[2.034, 0.657, 1.107], [2.047, 0.661, 1.101]],
|
||||
dtype=np.float32),
|
||||
decimal=3)
|
||||
assert_array_almost_equal(
|
||||
res['boxes_3d'].gravity_center[:2].numpy(),
|
||||
np.array([[9.341, -2.664, -1.295], [6.945, -18.833, -1.745]],
|
||||
dtype=np.float32),
|
||||
decimal=3)
|
||||
assert_array_almost_equal(
|
||||
res['boxes_3d'].height[:5].numpy(),
|
||||
np.array([1.107, 1.101, 1.082, 1.098, 1.073], dtype=np.float32),
|
||||
decimal=3)
|
||||
assert_array_almost_equal(
|
||||
res['boxes_3d'].nearest_bev[:2].numpy(),
|
||||
np.array([[9.013, -3.681, 9.67, -1.647],
|
||||
[6.615, -19.857, 7.276, -17.81]],
|
||||
dtype=np.float32),
|
||||
decimal=3)
|
||||
assert_array_almost_equal(
|
||||
res['boxes_3d'].tensor[:1].numpy(),
|
||||
np.array([[
|
||||
9.340, -2.664, -1.849, 2.0343, 6.568e-01, 1.107, 1.819,
|
||||
-8.636e-06, 2.034e-05
|
||||
]],
|
||||
dtype=np.float32),
|
||||
decimal=3)
|
||||
assert_array_almost_equal(
|
||||
res['boxes_3d'].top_height[:5].numpy(),
|
||||
np.array([-0.742, -1.194, -1.25, -0.411, -0.132],
|
||||
dtype=np.float32),
|
||||
decimal=3)
|
||||
assert_array_almost_equal(
|
||||
res['boxes_3d'].volume[:5].numpy(),
|
||||
np.array([1.478, 1.49, 1.435, 1.495, 1.47], dtype=np.float32),
|
||||
decimal=3)
|
||||
assert_array_almost_equal(
|
||||
res['boxes_3d'].yaw[:5].numpy(),
|
||||
np.array([1.819, 1.694, 1.659, 1.62, 1.641], dtype=np.float32),
|
||||
decimal=3)
|
||||
|
||||
def test_single(self):
|
||||
model_path = PRETRAINED_MODEL_BEVFORMER_BASE
|
||||
single_ann_file = os.path.join(SMALL_NUSCENES_PATH,
|
||||
'inference/single_sample.pkl')
|
||||
def _get_config_file(self):
|
||||
easycv_dir = os.path.dirname(easycv.__file__)
|
||||
if os.path.exists(os.path.join(easycv_dir, 'configs')):
|
||||
config_dir = os.path.join(easycv_dir, 'configs')
|
||||
|
@ -130,7 +57,13 @@ class BEVFormerPredictorTest(unittest.TestCase):
|
|||
config_file = os.path.join(
|
||||
config_dir,
|
||||
'detection3d/bevformer/bevformer_base_r101_dcn_nuscenes.py')
|
||||
return config_file
|
||||
|
||||
def test_single(self):
|
||||
model_path = PRETRAINED_MODEL_BEVFORMER_BASE
|
||||
single_ann_file = os.path.join(SMALL_NUSCENES_PATH,
|
||||
'inference/single_sample.pkl')
|
||||
config_file = self._get_config_file()
|
||||
predictor = BEVFormerPredictor(
|
||||
model_path=model_path,
|
||||
config_file=config_file,
|
||||
|
@ -140,10 +73,86 @@ class BEVFormerPredictorTest(unittest.TestCase):
|
|||
for result in results:
|
||||
self._assert_results(result)
|
||||
|
||||
@unittest.skipIf(True, 'Not support batch yet')
|
||||
def test_batch(self):
|
||||
model_path = PRETRAINED_MODEL_BEVFORMER_BASE
|
||||
single_ann_file = os.path.join(SMALL_NUSCENES_PATH,
|
||||
'inference/single_sample.pkl')
|
||||
config_file = self._get_config_file()
|
||||
predictor = BEVFormerPredictor(
|
||||
model_path=model_path, config_file=config_file, batch_size=2)
|
||||
results = predictor([single_ann_file, single_ann_file])
|
||||
self.assertEqual(len(results), 2)
|
||||
# Input the same sample continuously, the output value is different,
|
||||
# because the model will record the features of the previous sample to infer the next sample
|
||||
self._assert_results(results[0])
|
||||
self._assert_results(results[1])
|
||||
|
||||
def test_metric(self):
|
||||
model_path = PRETRAINED_MODEL_BEVFORMER_BASE
|
||||
inputs_file = os.path.join(SMALL_NUSCENES_PATH,
|
||||
'nuscenes_infos_temporal_val.pkl')
|
||||
config_file = self._get_config_file()
|
||||
cfg = mmcv_config_fromfile(config_file)
|
||||
cfg.data.val.data_source.data_root = SMALL_NUSCENES_PATH
|
||||
cfg.data.val.data_source.ann_file = os.path.join(
|
||||
SMALL_NUSCENES_PATH, 'nuscenes_infos_temporal_val.pkl')
|
||||
cfg.data.val.pop('imgs_per_gpu', None)
|
||||
val_dataset = build_dataset(cfg.data.val)
|
||||
evaluators = build_evaluator(cfg.eval_pipelines[0]['evaluators'][0])
|
||||
predictor = BEVFormerPredictor(
|
||||
model_path=model_path, config_file=config_file)
|
||||
inputs = mmcv.load(inputs_file)['infos']
|
||||
for i in range(len(inputs)):
|
||||
for k in list(inputs[i]['cams'].keys()):
|
||||
inputs[i]['cams'][k]['data_path'] = os.path.join(
|
||||
SMALL_NUSCENES_PATH, inputs[i]['cams'][k]['data_path'])
|
||||
predict_results = predictor(inputs)
|
||||
|
||||
results = {'pts_bbox': [i['pts_bbox'] for i in predict_results]}
|
||||
val_results = val_dataset.evaluate(results, evaluators)
|
||||
self.assertAlmostEqual(
|
||||
val_results['pts_bbox_NuScenes/NDS'], 0.460, delta=0.01)
|
||||
self.assertAlmostEqual(
|
||||
val_results['pts_bbox_NuScenes/mAP'], 0.41, delta=0.01)
|
||||
|
||||
|
||||
@unittest.skipIf(torch.__version__ != '1.8.1+cu102',
|
||||
'need another environment where mmcv has been recompiled')
|
||||
class BEVFormerBladePredictorTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.mkdir(self.tmp_dir)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
io.remove(self.tmp_dir)
|
||||
return super().tearDown()
|
||||
|
||||
def _replace_config(self, cfg_file):
|
||||
with io.open(cfg_file, 'r') as f:
|
||||
cfg_str = f.read()
|
||||
|
||||
new_config_path = os.path.join(self.tmp_dir, 'new_config.py')
|
||||
|
||||
# find first adapt_jit and replace value
|
||||
res = re.search(r'adapt_jit(\s*)=(\s*)False', cfg_str)
|
||||
if res is not None:
|
||||
cfg_str_list = list(cfg_str)
|
||||
cfg_str_list[res.span()[0]:res.span()[1]] = 'adapt_jit = True'
|
||||
cfg_str = ''.join(cfg_str_list)
|
||||
with io.open(new_config_path, 'w') as f:
|
||||
f.write(cfg_str)
|
||||
return new_config_path
|
||||
|
||||
def test_single(self):
|
||||
# test export blade model and bevformer predictor
|
||||
ori_ckpt = PRETRAINED_MODEL_BEVFORMER_BASE
|
||||
inputs_file = os.path.join(SMALL_NUSCENES_PATH,
|
||||
'nuscenes_infos_temporal_val.pkl')
|
||||
|
||||
easycv_dir = os.path.dirname(easycv.__file__)
|
||||
if os.path.exists(os.path.join(easycv_dir, 'configs')):
|
||||
config_dir = os.path.join(easycv_dir, 'configs')
|
||||
|
@ -152,15 +161,41 @@ class BEVFormerPredictorTest(unittest.TestCase):
|
|||
config_file = os.path.join(
|
||||
config_dir,
|
||||
'detection3d/bevformer/bevformer_base_r101_dcn_nuscenes.py')
|
||||
config_file = self._replace_config(config_file)
|
||||
cfg = mmcv_config_fromfile(config_file)
|
||||
|
||||
filename = os.path.join(self.tmp_dir, 'model.pth')
|
||||
export(cfg, ori_ckpt, filename, fp16=False)
|
||||
blade_filename = filename + '.blade'
|
||||
|
||||
self.assertTrue(blade_filename)
|
||||
|
||||
cfg.data.val.data_source.data_root = SMALL_NUSCENES_PATH
|
||||
cfg.data.val.data_source.ann_file = os.path.join(
|
||||
SMALL_NUSCENES_PATH, 'nuscenes_infos_temporal_val.pkl')
|
||||
cfg.data.val.pop('imgs_per_gpu', None)
|
||||
val_dataset = build_dataset(cfg.data.val)
|
||||
evaluators = build_evaluator(cfg.eval_pipelines[0]['evaluators'][0])
|
||||
|
||||
predictor = BEVFormerPredictor(
|
||||
model_path=model_path, config_file=config_file, batch_size=2)
|
||||
results = predictor([single_ann_file, single_ann_file])
|
||||
self.assertEqual(len(results), 2)
|
||||
# Input the same sample continuously, the output value is different,
|
||||
# because the model will record the features of the previous sample to infer the next sample
|
||||
self._assert_results(results[0])
|
||||
self._assert_results(results[1], assert_value=False)
|
||||
model_path=blade_filename,
|
||||
config_file=config_file,
|
||||
model_type='blade',
|
||||
)
|
||||
|
||||
inputs = mmcv.load(inputs_file)['infos']
|
||||
predict_results = predictor(inputs)
|
||||
|
||||
results = {'pts_bbox': [i['pts_bbox'] for i in predict_results]}
|
||||
val_results = val_dataset.evaluate(results, evaluators)
|
||||
self.assertAlmostEqual(
|
||||
val_results['pts_bbox_NuScenes/NDS'], 0.460, delta=0.01)
|
||||
self.assertAlmostEqual(
|
||||
val_results['pts_bbox_NuScenes/mAP'], 0.41, delta=0.01)
|
||||
|
||||
@unittest.skipIf(True, 'Not support batch yet')
|
||||
def test_batch(self):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -34,11 +34,11 @@ from easycv.file import io
|
|||
from easycv.models import build_model
|
||||
from easycv.utils.collect_env import collect_env
|
||||
from easycv.utils.logger import get_root_logger
|
||||
from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab
|
||||
from easycv.utils import mmlab_utils
|
||||
from easycv.utils.config_tools import traverse_replace
|
||||
from easycv.utils.config_tools import (CONFIG_TEMPLATE_ZOO,
|
||||
mmcv_config_fromfile, rebuild_config)
|
||||
from easycv.utils.dist_utils import get_device
|
||||
from easycv.utils.dist_utils import get_device, is_master
|
||||
from easycv.utils.setup_env import setup_multi_processes
|
||||
|
||||
|
||||
|
@ -161,7 +161,7 @@ def main():
|
|||
cfg.load_from = args.load_from
|
||||
|
||||
# dynamic adapt mmdet models
|
||||
dynamic_adapt_for_mmlab(cfg)
|
||||
mmlab_utils.dynamic_adapt_for_mmlab(cfg)
|
||||
|
||||
cfg.gpus = args.gpus
|
||||
|
||||
|
@ -230,7 +230,9 @@ def main():
|
|||
assert isinstance(args.pretrained, str)
|
||||
cfg.model.pretrained = args.pretrained
|
||||
model = build_model(cfg.model)
|
||||
print(model)
|
||||
|
||||
if is_master():
|
||||
print(model)
|
||||
|
||||
if 'stage' in cfg.model and cfg.model['stage'] == 'EDGE':
|
||||
from easycv.utils.flops_counter import get_model_info
|
||||
|
@ -259,6 +261,8 @@ def main():
|
|||
), 'odps config must be set in cfg file / cfg.data.train.data_source !!'
|
||||
shuffle = False
|
||||
|
||||
if getattr(cfg.data, 'pin_memory', False):
|
||||
mmlab_utils.fix_dc_pin_memory()
|
||||
datasets = [build_dataset(cfg.data.train)]
|
||||
data_loaders = [
|
||||
build_dataloader(
|
||||
|
@ -268,6 +272,7 @@ def main():
|
|||
cfg.gpus,
|
||||
dist=distributed,
|
||||
shuffle=shuffle,
|
||||
pin_memory=getattr(cfg.data, 'pin_memory', False),
|
||||
replace=getattr(cfg.data, 'sampling_replace', False),
|
||||
seed=cfg.seed,
|
||||
drop_last=getattr(cfg.data, 'drop_last', False),
|
||||
|
|
Loading…
Reference in New Issue