Improve the performance of bevformer (#224)

Improve the performance of bevformer

* add hybrid brach (#232)

Co-authored-by: yhq <yanhaiqiang.yhq@alibaba-inc.com>
pull/240/head
Cathy0908 2022-11-23 21:32:08 +08:00 committed by GitHub
parent a36e0e32a4
commit f8c9a9a1c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 2847 additions and 530 deletions

View File

@ -20,14 +20,16 @@ input_modality = dict(
use_map=False,
use_external=True)
_dim_ = 256
_pos_dim_ = _dim_ // 2
_ffn_dim_ = _dim_ * 2
_num_levels_ = 4
bev_h_ = 200
bev_w_ = 200
embed_dim = 256
pos_dim = embed_dim // 2
ffn_dim = embed_dim * 2
num_levels = 4
bev_h = 200
bev_w = 200
queue_length = 4 # each sequence contains `queue_length` frames.
adapt_jit = False # set True when export jit trace model or blade model
model = dict(
type='BEVFormer',
use_grid_mask=True,
@ -47,18 +49,18 @@ model = dict(
img_neck=dict(
type='FPN',
in_channels=[512, 1024, 2048],
out_channels=_dim_,
out_channels=embed_dim,
start_level=0,
add_extra_convs='on_output',
num_outs=_num_levels_,
num_outs=num_levels,
relu_before_extra_convs=True),
pts_bbox_head=dict(
type='BEVFormerHead',
bev_h=bev_h_,
bev_w=bev_w_,
bev_h=bev_h,
bev_w=bev_w,
num_query=900,
num_classes=10,
in_channels=_dim_,
in_channels=embed_dim,
sync_cls_avg_factor=True,
with_box_refine=True,
as_two_stage=False,
@ -67,7 +69,7 @@ model = dict(
rotate_prev_bev=True,
use_shift=True,
use_can_bus=True,
embed_dims=_dim_,
embed_dims=embed_dim,
encoder=dict(
type='BEVFormerEncoder',
num_layers=6,
@ -76,26 +78,28 @@ model = dict(
return_intermediate=False,
transformerlayers=dict(
type='BEVFormerLayer',
adapt_jit=adapt_jit,
attn_cfgs=[
dict(
type='TemporalSelfAttention',
embed_dims=_dim_,
embed_dims=embed_dim,
num_levels=1),
dict(
type='SpatialCrossAttention',
pc_range=point_cloud_range,
deformable_attention=dict(
type='MSDeformableAttention3D',
embed_dims=_dim_,
embed_dims=embed_dim,
num_points=8,
num_levels=_num_levels_),
embed_dims=_dim_,
num_levels=num_levels,
adapt_jit=adapt_jit),
embed_dims=embed_dim,
)
],
ffn_cfgs=dict(
type='FFN',
embed_dims=256,
feedforward_channels=_ffn_dim_,
feedforward_channels=ffn_dim,
num_fcs=2,
ffn_drop=0.1,
act_cfg=dict(type='ReLU', inplace=True),
@ -111,18 +115,19 @@ model = dict(
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=_dim_,
embed_dims=embed_dim,
num_heads=8,
dropout=0.1),
dict(
type='CustomMSDeformableAttention',
embed_dims=_dim_,
num_levels=1),
embed_dims=embed_dim,
num_levels=1,
adapt_jit=adapt_jit),
],
ffn_cfgs=dict(
type='FFN',
embed_dims=256,
feedforward_channels=_ffn_dim_,
feedforward_channels=ffn_dim,
num_fcs=2,
ffn_drop=0.1,
act_cfg=dict(type='ReLU', inplace=True),
@ -138,9 +143,9 @@ model = dict(
num_classes=10),
positional_encoding=dict(
type='LearnedPositionalEncoding',
num_feats=_pos_dim_,
row_num_embed=bev_h_,
col_num_embed=bev_w_,
num_feats=pos_dim,
row_num_embed=bev_h,
col_num_embed=bev_w,
),
loss_cls=dict(
type='FocalLoss',
@ -217,6 +222,7 @@ test_pipeline = [
data = dict(
imgs_per_gpu=1, # 8gpus, total batch size=8
workers_per_gpu=4,
pin_memory=True,
# shuffler_sampler=dict(type='DistributedGroupSampler'),
# nonshuffler_sampler=dict(type='DistributedSampler'),
train=dict(
@ -226,7 +232,10 @@ data = dict(
data_root=data_root,
ann_file=data_root + 'nuscenes_infos_temporal_train.pkl',
pipeline=[
dict(type='LoadMultiViewImageFromFiles', to_float32=True),
dict(
type='LoadMultiViewImageFromFiles',
to_float32=True,
backend='turbojpeg'),
dict(
type='LoadAnnotations3D',
with_bbox_3d=True,
@ -251,7 +260,10 @@ data = dict(
data_root=data_root,
ann_file=data_root + 'nuscenes_infos_temporal_val.pkl',
pipeline=[
dict(type='LoadMultiViewImageFromFiles', to_float32=True)
dict(
type='LoadMultiViewImageFromFiles',
to_float32=True,
backend='turbojpeg')
],
classes=CLASSES,
modality=input_modality,
@ -295,3 +307,12 @@ log_config = dict(
checkpoint_config = dict(interval=1)
cudnn_benchmark = True
export = dict(
type='blade',
blade_config=dict(
enable_fp16=True,
fp16_fallback_op_ratio=0.0,
customize_op_black_list=[
'aten::select', 'aten::index', 'aten::slice', 'aten::view',
'aten::upsample', 'aten::clamp'
]))

View File

@ -0,0 +1,311 @@
_base_ = ['configs/base.py']
# If point cloud range is changed, the models should also change their point
# cloud range accordingly
point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
voxel_size = [0.2, 0.2, 8]
img_norm_cfg = dict(
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
# For nuScenes we usually do 10-class detection
CLASSES = [
'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
]
input_modality = dict(
use_lidar=False,
use_camera=True,
use_radar=False,
use_map=False,
use_external=True)
embed_dim = 256
pos_dim = embed_dim // 2
ffn_dim = embed_dim * 2
num_levels = 4
bev_h = 200
bev_w = 200
queue_length = 4 # each sequence contains `queue_length` frames.
model = dict(
type='BEVFormer',
use_grid_mask=True,
video_test_mode=True,
img_backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(2, 3, 4),
frozen_stages=-1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='caffe',
dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False),
stage_with_dcn=(False, False, True, True),
zero_init_residual=True),
img_neck=dict(
type='FPN',
in_channels=[512, 1024, 2048],
out_channels=embed_dim,
start_level=0,
add_extra_convs='on_output',
num_outs=num_levels,
relu_before_extra_convs=True),
pts_bbox_head=dict(
type='BEVFormerHead',
bev_h=bev_h,
bev_w=bev_w,
num_query=900,
num_query_one2many=1800,
one2many_gt_mul=4,
num_classes=10,
in_channels=embed_dim,
sync_cls_avg_factor=True,
with_box_refine=True,
as_two_stage=False,
transformer=dict(
type='PerceptionTransformer',
rotate_prev_bev=True,
use_shift=True,
use_can_bus=True,
embed_dims=embed_dim,
encoder=dict(
type='BEVFormerEncoder',
num_layers=6,
pc_range=point_cloud_range,
num_points_in_pillar=4,
return_intermediate=False,
transformerlayers=dict(
type='BEVFormerLayer',
attn_cfgs=[
dict(
type='TemporalSelfAttention',
embed_dims=embed_dim,
num_levels=1),
dict(
type='SpatialCrossAttention',
pc_range=point_cloud_range,
deformable_attention=dict(
type='MSDeformableAttention3D',
embed_dims=embed_dim,
num_points=8,
num_levels=num_levels),
embed_dims=embed_dim,
)
],
ffn_cfgs=dict(
type='FFN',
embed_dims=256,
feedforward_channels=ffn_dim,
num_fcs=2,
ffn_drop=0.1,
act_cfg=dict(type='ReLU', inplace=True),
),
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm'))),
decoder=dict(
type='Detr3DTransformerDecoder',
num_layers=6,
return_intermediate=True,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=embed_dim,
num_heads=8,
dropout=0.1),
dict(
type='CustomMSDeformableAttention',
embed_dims=embed_dim,
num_levels=1),
],
ffn_cfgs=dict(
type='FFN',
embed_dims=256,
feedforward_channels=ffn_dim,
num_fcs=2,
ffn_drop=0.1,
act_cfg=dict(type='ReLU', inplace=True),
),
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm')))),
bbox_coder=dict(
type='NMSFreeBBoxCoder',
post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
pc_range=point_cloud_range,
max_num=300,
voxel_size=voxel_size,
num_classes=10),
positional_encoding=dict(
type='LearnedPositionalEncoding',
num_feats=pos_dim,
row_num_embed=bev_h,
col_num_embed=bev_w,
),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=2.0),
# loss_bbox=dict(type='L1Loss', loss_weight=0.25),
# loss_bbox=dict(type='SmoothL1Loss', loss_weight=0.25),
loss_bbox=dict(type='BalancedL1Loss', loss_weight=0.25, gamma=1),
loss_iou=dict(type='GIoULoss', loss_weight=0.0)),
# model training and testing settings
train_cfg=dict(
pts=dict(
grid_size=[512, 512, 1],
voxel_size=voxel_size,
point_cloud_range=point_cloud_range,
out_size_factor=4,
assigner=dict(
type='HungarianBBoxAssigner3D',
cls_cost=dict(type='FocalLossCost', weight=2.0),
reg_cost=dict(type='BBox3DL1Cost', weight=0.25),
iou_cost=dict(
type='IoUCost', weight=0.0
), # Fake cost. This is just to make it compatible with DETR head.
pc_range=point_cloud_range))))
dataset_type = 'NuScenesDataset'
data_root = 'data/nuscenes/train-val/'
train_pipeline = [
dict(type='PhotoMetricDistortionMultiViewImage'),
# dict(type='RandomScaleImageMultiViewImage', scales=[0.8,0.9,1.0,1.1,1.2]),
dict(type='RandomHorizontalFlipMultiViewImage'),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=CLASSES),
dict(type='NormalizeMultiviewImage', **img_norm_cfg),
dict(type='PadMultiViewImage', size_divisor=32),
dict(type='DefaultFormatBundle3D', class_names=CLASSES),
dict(
type='Collect3D',
keys=['gt_bboxes_3d', 'gt_labels_3d', 'img'],
meta_keys=('filename', 'ori_shape', 'img_shape', 'lidar2img',
'depth2img', 'cam2img', 'pad_shape', 'scale_factor', 'flip',
'pcd_horizontal_flip', 'pcd_vertical_flip', 'box_mode_3d',
'box_type_3d', 'img_norm_cfg', 'pcd_trans', 'sample_idx',
'prev_idx', 'next_idx', 'pcd_scale_factor', 'pcd_rotation',
'pts_filename', 'transformation_3d_flow', 'scene_token',
'can_bus'))
]
test_pipeline = [
dict(type='NormalizeMultiviewImage', **img_norm_cfg),
dict(type='PadMultiViewImage', size_divisor=32),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1600, 900),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='DefaultFormatBundle3D',
class_names=CLASSES,
with_label=False),
dict(
type='Collect3D',
keys=['img'],
meta_keys=('filename', 'ori_shape', 'img_shape', 'lidar2img',
'depth2img', 'cam2img', 'pad_shape', 'scale_factor',
'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip',
'box_mode_3d', 'box_type_3d', 'img_norm_cfg',
'pcd_trans', 'sample_idx', 'prev_idx', 'next_idx',
'pcd_scale_factor', 'pcd_rotation', 'pts_filename',
'transformation_3d_flow', 'scene_token', 'can_bus'))
])
]
data = dict(
imgs_per_gpu=1, # 8gpus, total batch size=8
workers_per_gpu=8,
pin_memory=True,
# shuffler_sampler=dict(type='DistributedGroupSampler'),
# nonshuffler_sampler=dict(type='DistributedSampler'),
train=dict(
type=dataset_type,
data_source=dict(
type='Det3dSourceNuScenes',
data_root=data_root,
ann_file=data_root + 'nuscenes_infos_temporal_train.pkl',
pipeline=[
dict(
type='LoadMultiViewImageFromFiles',
to_float32=True,
backend='turbojpeg'),
dict(
type='LoadAnnotations3D',
with_bbox_3d=True,
with_label_3d=True,
with_attr_label=False)
],
classes=CLASSES,
modality=input_modality,
test_mode=False,
use_valid_flag=True,
# we use box_type_3d='LiDAR' in kitti and nuscenes dataset
# and box_type_3d='Depth' in sunrgbd and scannet dataset.
box_type_3d='LiDAR'),
pipeline=train_pipeline,
queue_length=queue_length,
),
val=dict(
imgs_per_gpu=1,
type=dataset_type,
data_source=dict(
type='Det3dSourceNuScenes',
data_root=data_root,
ann_file=data_root + 'nuscenes_infos_temporal_val.pkl',
pipeline=[
dict(
type='LoadMultiViewImageFromFiles',
to_float32=True,
backend='turbojpeg')
],
classes=CLASSES,
modality=input_modality,
test_mode=True),
pipeline=test_pipeline))
paramwise_cfg = {'img_backbone': dict(lr_mult=0.1)}
optimizer = dict(
type='AdamW', lr=2e-4, paramwise_options=paramwise_cfg, weight_decay=0.01)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# learning policy
lr_config = dict(
policy='CosineAnnealing',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
min_lr_ratio=1e-3)
total_epochs = 24
eval_config = dict(initial=False, interval=1, gpu_collect=False)
eval_pipelines = [
dict(
mode='test',
data=data['val'],
dist_eval=True,
evaluators=[
dict(
type='NuScenesEvaluator',
classes=CLASSES,
result_names=['pts_bbox'])
],
)
]
load_from = 'https://github.com/zhiqi-li/storage/releases/download/v1.0/r101_dcn_fcos3d_pretrain.pth'
log_config = dict(
interval=50,
hooks=[dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')])
checkpoint_config = dict(interval=1)
cudnn_benchmark = True
find_unused_parameters = True

View File

@ -29,18 +29,19 @@ input_modality = dict(
use_map=False,
use_external=True)
_dim_ = 256
_pos_dim_ = _dim_ // 2
_ffn_dim_ = _dim_ * 2
_num_levels_ = 1
bev_h_ = 50
bev_w_ = 50
embed_dim = 256
pos_dim = embed_dim // 2
ffn_dim = embed_dim * 2
num_levels = 1
bev_h = 50
bev_w = 50
queue_length = 3 # each sequence contains `queue_length` frames.
model = dict(
type='BEVFormer',
use_grid_mask=True,
video_test_mode=True,
extract_feat_serially=True,
pretrained=dict(img='torchvision://resnet50'),
img_backbone=dict(
type='ResNet',
@ -56,18 +57,18 @@ model = dict(
img_neck=dict(
type='FPN',
in_channels=[2048],
out_channels=_dim_,
out_channels=embed_dim,
start_level=0,
add_extra_convs='on_output',
num_outs=_num_levels_,
num_outs=num_levels,
relu_before_extra_convs=True),
pts_bbox_head=dict(
type='BEVFormerHead',
bev_h=bev_h_,
bev_w=bev_w_,
bev_h=bev_h,
bev_w=bev_w,
num_query=900,
num_classes=10,
in_channels=_dim_,
in_channels=embed_dim,
sync_cls_avg_factor=True,
with_box_refine=True,
as_two_stage=False,
@ -76,7 +77,7 @@ model = dict(
rotate_prev_bev=True,
use_shift=True,
use_can_bus=True,
embed_dims=_dim_,
embed_dims=embed_dim,
encoder=dict(
type='BEVFormerEncoder',
num_layers=3,
@ -88,23 +89,23 @@ model = dict(
attn_cfgs=[
dict(
type='TemporalSelfAttention',
embed_dims=_dim_,
embed_dims=embed_dim,
num_levels=1),
dict(
type='SpatialCrossAttention',
pc_range=point_cloud_range,
deformable_attention=dict(
type='MSDeformableAttention3D',
embed_dims=_dim_,
embed_dims=embed_dim,
num_points=8,
num_levels=_num_levels_),
embed_dims=_dim_,
num_levels=num_levels),
embed_dims=embed_dim,
)
],
ffn_cfgs=dict(
type='FFN',
embed_dims=256,
feedforward_channels=_ffn_dim_,
feedforward_channels=ffn_dim,
num_fcs=2,
ffn_drop=0.1,
act_cfg=dict(type='ReLU', inplace=True),
@ -120,18 +121,18 @@ model = dict(
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=_dim_,
embed_dims=embed_dim,
num_heads=8,
dropout=0.1),
dict(
type='CustomMSDeformableAttention',
embed_dims=_dim_,
embed_dims=embed_dim,
num_levels=1),
],
ffn_cfgs=dict(
type='FFN',
embed_dims=256,
feedforward_channels=_ffn_dim_,
feedforward_channels=ffn_dim,
num_fcs=2,
ffn_drop=0.1,
act_cfg=dict(type='ReLU', inplace=True),
@ -147,9 +148,9 @@ model = dict(
num_classes=10),
positional_encoding=dict(
type='LearnedPositionalEncoding',
num_feats=_pos_dim_,
row_num_embed=bev_h_,
col_num_embed=bev_w_,
num_feats=pos_dim,
row_num_embed=bev_h,
col_num_embed=bev_w,
),
loss_cls=dict(
type='FocalLoss',
@ -179,11 +180,12 @@ dataset_type = 'NuScenesDataset'
data_root = 'data/nuScenes/nuscenes-v1.0/'
train_pipeline = [
dict(type='RandomScaleImageMultiViewImage', scales=[0.5]),
dict(type='PhotoMetricDistortionMultiViewImage'),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=CLASSES),
dict(type='NormalizeMultiviewImage', **img_norm_cfg),
dict(type='RandomScaleImageMultiViewImage', scales=[0.5]),
# dict(type='RandomScaleImageMultiViewImage', scales=[0.5]),
dict(type='PadMultiViewImage', size_divisor=32),
dict(type='DefaultFormatBundle3D', class_names=CLASSES),
dict(
@ -228,6 +230,7 @@ test_pipeline = [
data = dict(
imgs_per_gpu=1, # 8gpus, total batch size=8
workers_per_gpu=4,
pin_memory=True,
# shuffler_sampler=dict(type='DistributedGroupSampler'),
# nonshuffler_sampler=dict(type='DistributedSampler'),
train=dict(
@ -237,7 +240,10 @@ data = dict(
data_root=data_root,
ann_file=data_root + 'nuscenes_infos_temporal_train.pkl',
pipeline=[
dict(type='LoadMultiViewImageFromFiles', to_float32=True),
dict(
type='LoadMultiViewImageFromFiles',
to_float32=True,
backend='turbojpeg'),
dict(
type='LoadAnnotations3D',
with_bbox_3d=True,
@ -262,7 +268,10 @@ data = dict(
data_root=data_root,
ann_file=data_root + 'nuscenes_infos_temporal_val.pkl',
pipeline=[
dict(type='LoadMultiViewImageFromFiles', to_float32=True)
dict(
type='LoadMultiViewImageFromFiles',
to_float32=True,
backend='turbojpeg')
],
classes=CLASSES,
modality=input_modality,
@ -305,3 +314,12 @@ log_config = dict(
checkpoint_config = dict(interval=1)
cudnn_benchmark = True
export = dict(
export_type='blade',
blade_config=dict(
enable_fp16=True,
fp16_fallback_op_ratio=0.0,
customize_op_black_list=[
'aten::select', 'aten::index', 'aten::slice', 'aten::view',
'aten::upsample', 'aten::clamp'
]))

View File

@ -0,0 +1,11 @@
_base_ = ['./bevformer_tiny_r50_nuscenes.py']
paramwise_cfg = {'img_backbone': dict(lr_mult=0.1)}
optimizer = dict(
type='AdamW',
lr=2.8e-4,
paramwise_options=paramwise_cfg,
weight_decay=0.01)
optimizer_config = dict(
grad_clip=dict(max_norm=35, norm_type=2), loss_scale=512.)

View File

@ -2,22 +2,21 @@
import copy
import json
import logging
import pickle
from collections import OrderedDict
from distutils.version import LooseVersion
from typing import Callable, Dict, List, Optional, Tuple
import cv2
import torch
import torchvision
import torchvision.transforms.functional as t_f
from mmcv.utils import Config
from easycv.file import io
from easycv.framework.errors import ValueError
from easycv.models import (DINO, MOCO, SWAV, YOLOX, Classification, MoBY,
build_model)
from easycv.framework.errors import NotImplementedError, ValueError
from easycv.models import (DINO, MOCO, SWAV, YOLOX, BEVFormer, Classification,
MoBY, build_model)
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.misc import reparameterize_models
from easycv.utils.misc import encode_str_to_tensor
__all__ = [
'export',
@ -27,7 +26,7 @@ __all__ = [
]
def export(cfg, ckpt_path, filename):
def export(cfg, ckpt_path, filename, **kwargs):
""" export model for inference
Args:
@ -42,20 +41,22 @@ def export(cfg, ckpt_path, filename):
cfg.model.backbone.pretrained = False
if isinstance(model, MOCO) or isinstance(model, DINO):
_export_moco(model, cfg, filename)
_export_moco(model, cfg, filename, **kwargs)
elif isinstance(model, MoBY):
_export_moby(model, cfg, filename)
_export_moby(model, cfg, filename, **kwargs)
elif isinstance(model, SWAV):
_export_swav(model, cfg, filename)
_export_swav(model, cfg, filename, **kwargs)
elif isinstance(model, Classification):
_export_cls(model, cfg, filename)
_export_cls(model, cfg, filename, **kwargs)
elif isinstance(model, YOLOX):
_export_yolox(model, cfg, filename)
_export_yolox(model, cfg, filename, **kwargs)
elif isinstance(model, BEVFormer):
_export_bevformer(model, cfg, filename, **kwargs)
elif hasattr(cfg, 'export') and getattr(cfg.export, 'use_jit', False):
export_jit_model(model, cfg, filename)
export_jit_model(model, cfg, filename, **kwargs)
return
else:
_export_common(model, cfg, filename)
_export_common(model, cfg, filename, **kwargs)
def _export_common(model, cfg, filename):
@ -179,6 +180,7 @@ def _export_yolox(model, cfg, filename):
model.export_type = export_type
if export_type != 'raw':
from easycv.utils.misc import reparameterize_models
# only when we use jit or blade, we need to reparameterize_models before export
model = reparameterize_models(model)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@ -517,6 +519,100 @@ def export_jit_model(model, cfg, filename):
torch.jit.save(model_jit, ofile)
def _export_bevformer(model, cfg, filename, fp16=False):
if not cfg.adapt_jit:
raise ValueError(
'"cfg.adapt_jit" must be True when export jit trace or blade model.'
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = copy.deepcopy(model)
model.eval()
model.to(device)
def _dummy_inputs():
# dummy inputs
bacth_size, queue_len, cams_num = 1, 1, 6
img_size = (928, 1600)
img = torch.rand([cams_num, 3, img_size[0], img_size[1]]).to(device)
can_bus = torch.rand([18]).to(device)
lidar2img = torch.rand([6, 4, 4]).to(device)
img_shape = torch.tensor([[img_size[0], img_size[1], 3]] *
cams_num).to(device)
dummy_scene_token = 'dummy_scene_token'
scene_token = encode_str_to_tensor(dummy_scene_token).to(device)
prev_scene_token = scene_token
prev_bev = torch.rand([cfg.bev_h * cfg.bev_w, 1,
cfg.embed_dim]).to(device)
prev_pos = torch.tensor(0)
prev_angle = torch.tensor(0)
img_metas = {
'can_bus': can_bus,
'lidar2img': lidar2img,
'img_shape': img_shape,
'scene_token': scene_token,
'prev_bev': prev_bev,
'prev_pos': prev_pos,
'prev_angle': prev_angle,
'prev_scene_token': prev_scene_token
}
return img, img_metas
dummy_inputs = _dummy_inputs()
def _trace_model():
with torch.no_grad():
model.forward = model.forward_export
trace_model = torch.jit.trace(
model, copy.deepcopy(dummy_inputs), check_trace=False)
return trace_model
export_type = cfg.export.get('type')
if export_type in ['jit', 'blade']:
if fp16:
with torch.cuda.amp.autocast():
trace_model = _trace_model()
else:
trace_model = _trace_model()
torch.jit.save(trace_model, filename + '.jit')
else:
raise NotImplementedError(f'Not support export type {export_type}!')
if export_type == 'jit':
return
blade_config = cfg.export.get('blade_config')
from easycv.toolkit.blade import blade_env_assert, blade_optimize
assert blade_env_assert()
def _get_blade_model():
blade_model = blade_optimize(
speed_test_model=model,
model=trace_model,
inputs=copy.deepcopy(dummy_inputs),
blade_config=blade_config,
static_opt=False,
min_num_nodes=None, # 50
check_inputs=False,
fp16=fp16)
return blade_model
# optimize model with blade
if fp16:
with torch.cuda.amp.autocast():
blade_model = _get_blade_model()
else:
blade_model = _get_blade_model()
# save blade code and graph
# with io.open(filename + '.blade.code.py', 'w') as ofile:
# ofile.write(blade_model.forward.code)
# with io.open(filename + '.blade.graph.txt', 'w') as ofile:
# ofile.write(blade_model.forward.graph)
with io.open(filename + '.blade', 'wb') as ofile:
torch.jit.save(blade_model, ofile)
def replace_syncbn(backbone_cfg):
if 'norm_cfg' in backbone_cfg.keys():
if backbone_cfg['norm_cfg']['type'] == 'SyncBN':

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Alibaba, Inc. and its affiliates.
import numba
# import numba
import numpy as np
import torch
from mmcv.ops import nms, nms_rotated
@ -179,7 +179,7 @@ def aligned_3d_nms(boxes, scores, classes, thresh):
return indices
@numba.jit(nopython=True)
# @numba.jit(nopython=True)
def circle_nms(dets, thresh, post_max_size=83):
"""Circular NMS.

View File

@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Alibaba, Inc. and its affiliates.
import concurrent.futures
import copy
import logging
import random
import tempfile
from os import path as osp
@ -14,6 +16,7 @@ from easycv.core.bbox import Box3DMode, Coord3DMode
from easycv.datasets.registry import DATASETS
from easycv.datasets.shared.base import BaseDataset
from easycv.datasets.shared.pipelines import Compose
from easycv.datasets.shared.pipelines.format import to_tensor
from .utils import extract_result_dict
@ -50,6 +53,7 @@ class NuScenesDataset(BaseDataset):
self.eval_detection_configs = config_factory(self.eval_version)
self.flag = np.zeros(
len(self), dtype=np.uint8) # for DistributedGroupSampler
self.pipeline_cfg = pipeline
def _format_bbox(self, results, jsonfile_prefix=None):
"""Convert the results to the standard format.
@ -309,6 +313,9 @@ class NuScenesDataset(BaseDataset):
prev_scene_token = None
prev_pos = None
prev_angle = None
can_bus_list = []
lidar2img_list = []
for i, each in enumerate(queue):
metas_map[i] = each['img_metas'].data
if metas_map[i]['scene_token'] != prev_scene_token:
@ -326,28 +333,75 @@ class NuScenesDataset(BaseDataset):
metas_map[i]['can_bus'][-1] -= prev_angle
prev_pos = copy.deepcopy(tmp_pos)
prev_angle = copy.deepcopy(tmp_angle)
can_bus_list.append(to_tensor(metas_map[i]['can_bus']))
lidar2img_list.append(to_tensor(metas_map[i]['lidar2img']))
queue[-1]['img'] = DC(
torch.stack(imgs_list), cpu_only=False, stack=True)
queue[-1]['img_metas'] = DC(metas_map, cpu_only=True)
queue[-1]['can_bus'] = DC(torch.stack(can_bus_list), cpu_only=False)
queue[-1]['lidar2img'] = DC(
torch.stack(lidar2img_list), cpu_only=False)
queue = queue[-1]
return queue
@staticmethod
def _get_single_data(i,
data_source,
pipeline,
flip_flag=False,
scale=None):
i = max(0, i)
try:
data = data_source[i]
data['flip_flag'] = flip_flag
if scale:
data['resize_scale'] = scale
data = pipeline(data)
if data is None or ~(data['gt_labels_3d']._data != -1).any():
return None
except Exception as e:
logging.error(e)
return None
return i, data
def _get_queue_data(self, idx):
queue = []
idx_list = list(range(idx - self.queue_length, idx))
random.shuffle(idx_list)
idx_list = sorted(idx_list[1:])
idx_list.append(idx)
for i in idx_list:
i = max(0, i)
try:
data = self.data_source[i]
data = self.pipeline(data)
if data is None or ~(data['gt_labels_3d']._data != -1).any():
return None
except Exception as e:
return None
queue.append(data)
flip_flag = False
scale = None
for member in self.pipeline_cfg:
if member['type'] == 'RandomScaleImageMultiViewImage':
scales = member['scales']
rand_ind = np.random.permutation(range(len(scales)))[0]
scale = scales[rand_ind]
if member['type'] == 'RandomHorizontalFlipMultiViewImage':
flip_flag = np.random.rand() >= 0.5
with concurrent.futures.ThreadPoolExecutor(
max_workers=len(idx_list)) as executor:
threads = []
for i in idx_list:
future = executor.submit(self._get_single_data, i,
self.data_source, self.pipeline,
flip_flag, scale)
threads.append(future)
for future in concurrent.futures.as_completed(threads):
queue.append(future.result())
if None in queue:
return None
queue = sorted(queue, key=lambda item: item[0])
queue = [item[1] for item in queue]
return self.union2one(queue)
def __getitem__(self, idx):
@ -358,6 +412,18 @@ class NuScenesDataset(BaseDataset):
data_dict = self.data_source[idx]
data_dict = self.pipeline(data_dict)
can_bus_list, lidar2img_list = [], []
for i in range(len(data_dict['img_metas'])):
can_bus_list.append(
to_tensor(data_dict['img_metas'][i]._data['can_bus']))
lidar2img_list.append(
to_tensor(
data_dict['img_metas'][i]._data['lidar2img']))
data_dict['can_bus'] = DC(
torch.stack(can_bus_list), cpu_only=False)
data_dict['lidar2img'] = DC(
torch.stack(lidar2img_list), cpu_only=False)
if data_dict is None:
idx = self._rand_another(idx)
continue

View File

@ -0,0 +1,186 @@
######################################################################
# Copyright (c) 2022 OpenPerceptionX. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
######################################################################
# This file includes concrete implementation for different data augmentation
# methods in transforms.py.
######################################################################
from typing import List, Tuple
import cv2
import numpy as np
# Available interpolation modes (opencv)
cv2_interp_codes = {
'nearest': cv2.INTER_NEAREST,
'bilinear': cv2.INTER_LINEAR,
'bicubic': cv2.INTER_CUBIC,
'area': cv2.INTER_AREA,
'lanczos': cv2.INTER_LANCZOS4
}
def scale_image_multiple_view(
imgs: List[np.ndarray],
cam_intrinsics: List[np.ndarray],
# cam_extrinsics: List[np.ndarray],
lidar2img: List[np.ndarray],
rand_scale: float,
interpolation='bilinear'
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
"""Resize the multiple-view images with the same scale selected randomly.
Notably used in :class:`.transforms.RandomScaleImageMultiViewImage_naive
Args:
imgs (list of numpy.array): Multiple-view images to be resized. len(img) is the number of cameras.
img shape: [H, W, 3].
cam_intrinsics (list of numpy.array): Intrinsic parameters of different cameras. Transformations from camera
to image. len(cam_intrinsics) is the number of camera. For each camera, shape is 4 * 4.
cam_extrinsics (list of numpy.array): Extrinsic parameters of different cameras. Transformations from
lidar to cameras. len(cam_extrinsics) is the number of camera. For each camera, shape is 4 * 4.
lidar2img (list of numpy.array): Transformations from lidar to images. len(lidar2img) is the number
of camera. For each camera, shape is 4 * 4.
rand_scale (float): resize ratio
interpolation (string): mode for interpolation in opencv.
Returns:
imgs_new (list of numpy.array): Updated multiple-view images
cam_intrinsics_new (list of numpy.array): Updated intrinsic parameters of different cameras.
lidar2img_new (list of numpy.array): Updated Transformations from lidar to images.
"""
y_size = [int(img.shape[0] * rand_scale) for img in imgs]
x_size = [int(img.shape[1] * rand_scale) for img in imgs]
scale_factor = np.eye(4)
scale_factor[0, 0] *= rand_scale
scale_factor[1, 1] *= rand_scale
imgs_new = [
cv2.resize(
img, (x_size[idx], y_size[idx]),
interpolation=cv2_interp_codes[interpolation])
for idx, img in enumerate(imgs)
]
cam_intrinsics_new = [
scale_factor @ cam_intrinsic for cam_intrinsic in cam_intrinsics
]
lidar2img_new = [scale_factor @ l2i for l2i in lidar2img]
return imgs_new, cam_intrinsics_new, lidar2img_new
def horizontal_flip_image_multiview(
imgs: List[np.ndarray]) -> List[np.ndarray]:
"""Flip every image horizontally.
Args:
imgs (list of numpy.array): Multiple-view images to be resized. len(img) is the number of cameras.
img shape: [H, W, 3].
Returns:
imgs_new (list of numpy.array): Flippd multiple-view images
"""
imgs_new = [np.flip(img, axis=1) for img in imgs]
return imgs_new
def vertical_flip_image_multiview(imgs: List[np.ndarray]) -> List[np.ndarray]:
"""Flip every image vertically.
Args:
imgs (list of numpy.array): Multiple-view images to be resized. len(img) is the number of cameras.
img shape: [H, W, 3].
Returns:
imgs_new (list of numpy.array): Flippd multiple-view images
"""
imgs_new = [np.flip(img, axis=0) for img in imgs]
return imgs_new
def horizontal_flip_bbox(bboxes_3d: np.ndarray, dataset: str) -> np.ndarray:
"""Flip bounding boxes horizontally.
Args:
bboxes_3d (np.ndarray): bounding boxes of shape [N * 7], N is the number of objects.
dataset (string): 'waymo' coordinate system or 'nuscenes' coordinate system.
Returns:
bboxes_3d (numpy.array): Flippd bounding boxes.
"""
if dataset == 'nuScenes':
bboxes_3d.tensor[:, 0::7] = -bboxes_3d.tensor[:, 0::7]
bboxes_3d.tensor[:, 6] = -bboxes_3d.tensor[:, 6] # + np.pi
elif dataset == 'waymo':
bboxes_3d[:, 1::7] = -bboxes_3d[:, 1::7]
bboxes_3d[:, 6] = -bboxes_3d[:, 6] + np.pi
return bboxes_3d
def horizontal_flip_cam_params(
img_shape: np.ndarray, cam_intrinsics: List[np.ndarray],
cam_extrinsics: List[np.ndarray], lidar2imgs: List[np.ndarray],
dataset: str
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
"""Flip camera parameters horizontally.
Args:
img_shape (numpy.array) of shape [3].
cam_intrinsics (list of numpy.array): Intrinsic parameters of different cameras. Transformations from camera
to image. len(cam_intrinsics) is the number of camera. For each camera, shape is 4 * 4.
cam_extrinsics (list of numpy.array): Extrinsic parameters of different cameras. Transformations from
lidar to cameras. len(cam_extrinsics) is the number of camera. For each camera, shape is 4 * 4.
lidar2img (list of numpy.array): Transformations from lidar to images. len(lidar2img) is the number
of camera. For each camera, shape is 4 * 4.
dataset (string): Specify 'waymo' coordinate system or 'nuscenes' coordinate system.
Returns:
cam_intrinsics (list of numpy.array): Updated intrinsic parameters of different cameras.
cam_extrinsics (list of numpy.array): Updated extrinsic parameters of different cameras.
lidar2img (list of numpy.array): Updated Transformations from lidar to images.
"""
flip_factor = np.eye(4)
lidar2imgs = []
w = img_shape[1]
if dataset == 'nuScenes':
flip_factor[0, 0] = -1
cam_extrinsics = [l2c @ flip_factor for l2c in cam_extrinsics]
for cam_intrinsic, l2c in zip(cam_intrinsics, cam_extrinsics):
cam_intrinsic[0, 0] = -cam_intrinsic[0, 0]
cam_intrinsic[0, 2] = w - cam_intrinsic[0, 2]
lidar2imgs.append(cam_intrinsic @ l2c)
elif dataset == 'waymo':
flip_factor[1, 1] = -1
cam_extrinsics = [l2c @ flip_factor for l2c in cam_extrinsics]
for cam_intrinsic, l2c in zip(cam_intrinsics, cam_extrinsics):
cam_intrinsic[0, 0] = -cam_intrinsic[0, 0]
cam_intrinsic[0, 2] = w - cam_intrinsic[0, 2]
lidar2imgs.append(cam_intrinsic @ l2c)
else:
assert False
return cam_intrinsics, cam_extrinsics, lidar2imgs
def horizontal_flip_canbus(canbus: np.ndarray, dataset: str) -> np.ndarray:
"""Flip can bus horizontally.
Args:
canbus (numpy.ndarray) of shape [18,]
dataset (string): 'waymo' or 'nuscenes'
Returns:
canbus_new (list of numpy.array): Flipped canbus.
"""
if dataset == 'nuScenes':
# results['canbus'][1] = -results['canbus'][1] # flip location
# results['canbus'][-2] = -results['canbus'][-2] # flip direction
canbus[-1] = -canbus[-1] # flip direction
elif dataset == 'waymo':
# results['canbus'][1] = -results['canbus'][-1] # flip location
# results['canbus'][-2] = -results['canbus'][-2] # flip direction
canbus[-1] = -canbus[-1] # flip direction
else:
raise NotImplementedError((f'Not support {dataset} dataset'))
return canbus

View File

@ -1,11 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Alibaba, Inc. and its affiliates.
import concurrent.futures
import mmcv
import numpy as np
from easycv.core.points import BasePoints, get_points_type
from easycv.datasets.detection.pipelines import LoadAnnotations
from easycv.datasets.registry import PIPELINES
from easycv.file.image import load_image
@PIPELINES.register_module()
@ -17,13 +20,23 @@ class LoadMultiViewImageFromFiles(object):
Args:
to_float32 (bool, optional): Whether to convert the img to float32.
Defaults to False.
color_type (str, optional): Color type of the file.
Defaults to 'unchanged'.
channel_order (str, optional): Channel order.
Defaults to 'bgr'.
backend (str): The image decoding backend type. Options are `cv2`, `pillow`, `turbojpeg`.
"""
def __init__(self, to_float32=False, color_type='unchanged'):
def __init__(self,
to_float32=False,
channel_order='bgr',
backend='pillow'):
self.to_float32 = to_float32
self.color_type = color_type
self.channel_order = channel_order
self.backend = backend
@staticmethod
def _load_image(img_path, idx, mode, backend):
img = load_image(img_path, mode=mode, backend=backend)
return idx, img
def __call__(self, results):
"""Call function to load multi-view image from files.
@ -45,8 +58,24 @@ class LoadMultiViewImageFromFiles(object):
"""
filename = results['img_filename']
# img is of shape (h, w, c, num_views)
img = np.stack(
[mmcv.imread(name, self.color_type) for name in filename], axis=-1)
img_list = []
with concurrent.futures.ThreadPoolExecutor(
max_workers=len(filename)) as executor:
threads = []
for idx, name in enumerate(filename):
future = executor.submit(self._load_image, name, idx,
self.channel_order, self.backend)
threads.append(future)
for future in concurrent.futures.as_completed(threads):
img_list.append(future.result())
img_list = sorted(img_list, key=lambda item: item[0])
assert len(img_list) == len(filename)
img_list = [item[1] for item in img_list]
img = np.stack(img_list, axis=-1)
if self.to_float32:
img = img.astype(np.float32)
results['filename'] = filename

View File

@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import List, Tuple
import mmcv
import numpy as np
from numpy import random
@ -7,6 +9,10 @@ from numpy import random
from easycv.core.bbox import (CameraInstance3DBoxes, DepthInstance3DBoxes,
LiDARInstance3DBoxes)
from easycv.datasets.registry import PIPELINES
from .functional import (horizontal_flip_bbox, horizontal_flip_cam_params,
horizontal_flip_canbus,
horizontal_flip_image_multiview,
scale_image_multiple_view)
@PIPELINES.register_module()
@ -298,42 +304,140 @@ class PadMultiViewImage(object):
@PIPELINES.register_module()
class RandomScaleImageMultiViewImage(object):
"""Random scale the image.
"""Resize the multiple-view images with the same scale selected randomly. .
Args:
scales (List[float]): List of scales.
scales (tuple of float): ratio for resizing the images. Every time, select one ratio
randomly.
"""
def __init__(self, scales=[]):
def __init__(self, scales=[0.5, 1.0, 1.5]):
self.scales = scales
assert len(self.scales) == 1
self.seed = 0
def __call__(self, results):
"""Call function to pad images, masks, semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Updated result dict.
def forward(
self,
imgs: List[np.ndarray],
cam_intrinsics: List[np.ndarray],
lidar2img: List[np.ndarray],
seed=None,
scale=1
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
"""
Args:
imgs (list of numpy.array): Multiple-view images to be resized. len(img) is the number of cameras.
img shape: [H, W, 3].
cam_intrinsics (list of numpy.array): Intrinsic parameters of different cameras. Transformations from camera
to image. len(cam_intrinsics) is the number of camera. For each camera, shape is 4 * 4.
cam_extrinsics (list of numpy.array): Extrinsic parameters of different cameras. Transformations from
lidar to cameras. len(cam_extrinsics) is the number of camera. For each camera, shape is 4 * 4.
lidar2img (list of numpy.array): Transformations from lidar to images. len(lidar2img) is the number
of camera. For each camera, shape is 4 * 4.
seed (int): Seed for generating random number.
Returns:
imgs_new (list of numpy.array): Updated multiple-view images
cam_intrinsics_new (list of numpy.array): Updated intrinsic parameters of different cameras.
lidar2img_new (list of numpy.array): Updated Transformations from lidar to images.
"""
rand_scale = scale
imgs_new, cam_intrinsic_new, lidar2img_new = scale_image_multiple_view(
imgs, cam_intrinsics, lidar2img, rand_scale)
return imgs_new, cam_intrinsic_new, lidar2img_new
def __call__(self, data):
imgs = data['img']
cam_intrinsics = data['cam_intrinsic']
lidar2img = data['lidar2img']
rand_ind = np.random.permutation(range(len(self.scales)))[0]
rand_scale = self.scales[rand_ind]
scale = data[
'resize_scale'] if 'resize_scale' in data else self.scales[rand_ind]
y_size = [int(img.shape[0] * rand_scale) for img in results['img']]
x_size = [int(img.shape[1] * rand_scale) for img in results['img']]
scale_factor = np.eye(4)
scale_factor[0, 0] *= rand_scale
scale_factor[1, 1] *= rand_scale
results['img'] = [
mmcv.imresize(img, (x_size[idx], y_size[idx]), return_scale=False)
for idx, img in enumerate(results['img'])
]
lidar2img = [scale_factor @ l2i for l2i in results['lidar2img']]
results['lidar2img'] = lidar2img
results['img_shape'] = [img.shape for img in results['img']]
results['ori_shape'] = [img.shape for img in results['img']]
imgs_new, cam_intrinsic_new, lidar2img_new = self.forward(
imgs, cam_intrinsics, lidar2img, None, scale)
return results
data['img'] = imgs_new
data['cam_intrinsic'] = cam_intrinsic_new
data['lidar2img'] = lidar2img_new
return data
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(size={self.scales}, '
return repr_str
@PIPELINES.register_module()
class RandomHorizontalFlipMultiViewImage(object):
"""Horizontally flip the multiple-view images with bounding boxes, camera parameters and can bus randomly. .
Support coordinate systems like Waymo (https://waymo.com/open/data/perception/) or Nuscenes (https://www.nuscenes.org/public/images/data.png).
Args:
flip_ratio (float 0~1): probability of the images being flipped. Default value is 0.5.
dataset (string): Specify 'waymo' coordinate system or 'nuscenes' coordinate system.
"""
def __init__(self, flip_ratio=0.5, dataset='nuScenes'):
self.flip_ratio = flip_ratio
self.seed = 0
self.dataset = dataset
def forward(
self,
imgs: List[np.ndarray],
bboxes_3d: np.ndarray,
cam_intrinsics: List[np.ndarray],
cam_extrinsics: List[np.ndarray],
lidar2imgs: List[np.ndarray],
canbus: np.ndarray,
seed=None,
flip_flag=True
) -> Tuple[bool, List[np.ndarray], np.ndarray, List[np.ndarray],
List[np.ndarray], List[np.ndarray], np.ndarray]:
"""
Args:
imgs (list of numpy.array): Multiple-view images to be resized. len(img) is the number of cameras.
img shape: [H, W, 3].
bboxes_3d (np.ndarray): bounding boxes of shape [N * 7], N is the number of objects.
cam_intrinsics (list of numpy.array): Intrinsic parameters of different cameras. Transformations from camera
to image. len(cam_intrinsics) is the number of camera. For each camera, shape is 4 * 4.
cam_extrinsics (list of numpy.array): Extrinsic parameters of different cameras. Transformations from
lidar to cameras. len(cam_extrinsics) is the number of camera. For each camera, shape is 4 * 4.
lidar2img (list of numpy.array): Transformations from lidar to images. len(lidar2img) is the number
of camera. For each camera, shape is 4 * 4.
canbus (numpy.array):
seed (int): Seed for generating random number.
Returns:
imgs_new (list of numpy.array): Updated multiple-view images
cam_intrinsics_new (list of numpy.array): Updated intrinsic parameters of different cameras.
lidar2img_new (list of numpy.array): Updated Transformations from lidar to images.
"""
if flip_flag == False:
return flip_flag, imgs, bboxes_3d, cam_intrinsics, cam_extrinsics, lidar2imgs, canbus
else:
# flip_flag = True
imgs_flip = horizontal_flip_image_multiview(imgs)
bboxes_3d_flip = horizontal_flip_bbox(bboxes_3d, self.dataset)
img_shape = imgs[0].shape
cam_intrinsics_flip, cam_extrinsics_flip, lidar2imgs_flip = horizontal_flip_cam_params(
img_shape, cam_intrinsics, cam_extrinsics, lidar2imgs,
self.dataset)
canbus_flip = horizontal_flip_canbus(canbus, self.dataset)
return flip_flag, imgs_flip, bboxes_3d_flip, cam_intrinsics_flip, cam_extrinsics_flip, lidar2imgs_flip, canbus_flip
def __call__(self, data):
imgs = data['img']
bboxes_3d = data['gt_bboxes_3d']
cam_intrinsics = data['cam_intrinsic']
lidar2imgs = data['lidar2img']
canbus = data['can_bus']
cam_extrinsics = data['lidar2cam']
flip_flag = data['flip_flag']
flip_flag, imgs_flip, bboxes_3d_flip, cam_intrinsics_flip, cam_extrinsics_flip, lidar2imgs_flip, canbus_flip = self.forward(
imgs, bboxes_3d, cam_intrinsics, cam_extrinsics, lidar2imgs,
canbus, None, flip_flag)
data['img'] = imgs_flip
data['gt_bboxes_3d'] = bboxes_3d_flip
data['cam_intrinsic'] = cam_intrinsics_flip
data['lidar2img'] = lidar2imgs_flip
data['can_bus'] = canbus_flip
data['lidar2cam'] = cam_extrinsics_flip
return data

View File

@ -5,61 +5,105 @@ import time
import cv2
import numpy as np
from cv2 import IMREAD_COLOR
from PIL import Image
from easycv import file
from easycv.framework.errors import IOError
from easycv.framework.errors import IOError, KeyError, ValueError
from easycv.utils.constant import MAX_READ_IMAGE_TRY_TIMES
from .utils import is_oss_path, is_url_path
try:
from turbojpeg import TurboJPEG, TJCS_RGB, TJPF_BGR
turbo_jpeg = TurboJPEG()
turbo_jpeg_mode = {'RGB': TJCS_RGB, 'BGR': TJPF_BGR}
except:
turbo_jpeg = None
turbo_jpeg_mode = None
def load_image(img_path, mode='BGR', max_try_times=MAX_READ_IMAGE_TRY_TIMES):
"""Return np.ndarray[unit8]
"""
# TODO: functions of multi tries should be in the `io.open`
try_cnt = 0
img = None
while try_cnt < max_try_times:
try:
if is_url_path(img_path):
from mmcv.fileio.file_client import HTTPBackend
client = HTTPBackend()
img_bytes = client.get(img_path)
buff = io.BytesIO(img_bytes)
image = Image.open(buff)
if mode.upper() != 'BGR' and image.mode.upper() != mode.upper(
):
image = image.convert(mode.upper())
img = np.asarray(image, dtype=np.uint8)
else:
with file.io.open(img_path, 'rb') as infile:
# cv2.imdecode may corrupt when the img is broken
image = Image.open(infile)
if mode.upper() != 'BGR' and image.mode.upper(
) != mode.upper():
image = image.convert(mode.upper())
img = np.asarray(image, dtype=np.uint8)
if mode.upper() == 'BGR':
if image.mode.upper() != 'RGB':
image = image.convert('RGB')
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
assert img is not None
break
except Exception as e:
logging.error(e)
logging.warning('Read file {} fault, try count : {}'.format(
img_path, try_cnt))
# frequent access to oss will cause error, sleep can aviod it
if is_oss_path(img_path):
sleep_time = 1
logging.warning(
'Sleep {}s, frequent access to oss file may cause error.'.
format(sleep_time))
time.sleep(sleep_time)
try_cnt += 1
def load_image_with_pillow(content, mode='BGR', dtype=np.uint8):
with io.BytesIO(content) as buff:
image = Image.open(buff)
if img is None:
raise IOError('Read Image Error: ' + img_path)
if mode.upper() != 'BGR':
if image.mode.upper() != mode.upper():
image = image.convert(mode.upper())
img = np.asarray(image, dtype=dtype)
else:
if image.mode.upper() != 'RGB':
image = image.convert('RGB')
img = np.asarray(image, dtype=dtype)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img
def load_image_with_turbojpeg(content, mode='BGR', dtype=np.uint8):
assert mode.upper() in turbo_jpeg_mode
if turbo_jpeg is None or turbo_jpeg_mode is None:
raise ValueError(
'Please install turbojpeg by "pip install PyTurboJPEG" !')
img = turbo_jpeg.decode(
content, pixel_format=turbo_jpeg_mode[mode.upper()])
if img.dtype != dtype:
img = img.astype(dtype)
return img
def load_image_with_cv2(content, mode='BGR', dtype=np.uint8):
assert mode.upper() in ['BGR', 'RGB']
img_np = np.frombuffer(content, np.uint8)
img = cv2.imdecode(img_np, flags=IMREAD_COLOR)
if mode.upper() == 'RGB':
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if img.dtype != dtype:
img = img.astype(dtype)
return img
def _load_image(fp, mode='BGR', dtype=np.uint8, backend='pillow'):
if backend == 'pillow':
img = load_image_with_pillow(fp, mode=mode, dtype=dtype)
elif backend == 'turbojpeg':
img = load_image_with_turbojpeg(fp, mode=mode, dtype=dtype)
elif backend == 'cv2':
img = load_image_with_cv2(fp, mode=mode, dtype=dtype)
else:
raise KeyError(
'Only support backend in ["pillow", "turbojpeg", "cv2"]')
return img
def load_image(img_path,
mode='BGR',
dtype=np.uint8,
backend='pillow',
max_try_times=MAX_READ_IMAGE_TRY_TIMES):
"""Load image file, return np.ndarray.
Args:
img_path (str): Image file path.
mode (str): Order of channel, candidates are `bgr` and `rgb`.
dtype : Output data type.
backend (str): The image decoding backend type. Options are `cv2`, `pillow`, `turbojpeg`.
"""
# TODO: functions of multi tries should be in the `io.open`
img = None
if is_url_path(img_path):
from mmcv.fileio.file_client import HTTPBackend
client = HTTPBackend()
img_bytes = client.get(img_path)
img = _load_image(img_bytes, mode=mode, dtype=dtype, backend=backend)
else:
with file.io.open(img_path, 'rb') as infile:
img = _load_image(
infile.read(), mode=mode, dtype=dtype, backend=backend)
return img

View File

@ -103,7 +103,6 @@ class AnchorFreeHead(BaseDenseHead, BBoxTestMixin):
self.test_cfg = test_cfg
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.fp16_enabled = False
self._init_layers()

View File

@ -1 +1,2 @@
from . import detectors, utils
from . import utils
from .detectors import *

View File

@ -3,15 +3,16 @@
import math
import warnings
from typing import Optional
import torch
import torch.nn as nn
from mmcv.cnn import constant_init, xavier_init
from mmcv.ops.multi_scale_deform_attn import \
multi_scale_deformable_attn_pytorch
from mmcv.runner.base_module import BaseModule
from easycv.models.registry import ATTENTION
from easycv.thirdparty.deformable_attention.functions import \
MSDeformAttnFunction
@ATTENTION.register_module()
@ -36,6 +37,8 @@ class CustomMSDeformableAttention(BaseModule):
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
add_identity (bool, optional): Whether to add the
identity connection. Default: `True`.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
@ -50,8 +53,10 @@ class CustomMSDeformableAttention(BaseModule):
im2col_step=64,
dropout=0.1,
batch_first=False,
add_identity=True,
norm_cfg=None,
init_cfg=None):
init_cfg=None,
adapt_jit=False):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
@ -60,7 +65,7 @@ class CustomMSDeformableAttention(BaseModule):
self.norm_cfg = norm_cfg
self.dropout = nn.Dropout(dropout)
self.batch_first = batch_first
self.fp16_enabled = False
self.add_identity = add_identity
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
@ -90,6 +95,11 @@ class CustomMSDeformableAttention(BaseModule):
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weights()
self.adapt_jit = adapt_jit
if self.adapt_jit:
self.ms_deform_attn_op = torch.ops.custom.ms_deform_attn
else:
self.ms_deform_attn_op = MSDeformAttnFunction.apply
def init_weights(self):
"""Default initialization for Parameters of Module."""
@ -130,19 +140,23 @@ class CustomMSDeformableAttention(BaseModule):
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
return sampling_locations
def forward(self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
add_identity=True,
flag='decoder',
**kwargs):
def forward(
self,
query: torch.Tensor,
spatial_shapes: torch.Tensor,
reference_points: torch.Tensor,
level_start_index: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
identity: Optional[torch.Tensor] = None,
query_pos: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
flag: Optional[str] = 'decoder',
key_pos: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
cls_branches: Optional[torch.Tensor] = None,
img_metas: Optional[str] = None,
):
"""Forward Function of MultiScaleDeformAttention.
Args:
@ -212,29 +226,29 @@ class CustomMSDeformableAttention(BaseModule):
sampling_locations = self._get_sampling_locations(
reference_points, spatial_shapes, sampling_offsets)
if torch.cuda.is_available() and value.is_cuda:
from easycv.thirdparty.deformable_attention.functions import MSDeformAttnFunction
if value.dtype == torch.float16:
# for mixed precision
output = MSDeformAttnFunction.apply(
value.to(torch.float32), spatial_shapes, level_start_index,
sampling_locations.to(torch.float32), attention_weights,
self.im2col_step)
output = output.to(torch.float16)
else:
output = MSDeformAttnFunction.apply(value, spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step)
if value.dtype == torch.float16:
# for mixed precision
assert value.size(0) % min(value.size(0), self.im2col_step) == 0
output = self.ms_deform_attn_op(
value.to(torch.float32), spatial_shapes, level_start_index,
sampling_locations.to(torch.float32), attention_weights,
self.im2col_step)
output = output.to(torch.float16)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights)
output = self.ms_deform_attn_op(value, spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step)
# cpu
# from mmcv.ops.multi_scale_deform_attn import multi_scale_deformable_attn_pytorch
# output = multi_scale_deformable_attn_pytorch(
# value, spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
if not self.batch_first:
output = output.permute(1, 0, 2)
if add_identity:
if self.add_identity:
return self.dropout(output) + identity
else:
return self.dropout(output)
@ -276,7 +290,8 @@ class MSDeformableAttention3D(CustomMSDeformableAttention):
dropout=0.,
batch_first=True,
norm_cfg=None,
init_cfg=None):
init_cfg=None,
adapt_jit=False):
super(MSDeformableAttention3D, self).__init__(
embed_dims=embed_dims,
num_heads=num_heads,
@ -285,8 +300,10 @@ class MSDeformableAttention3D(CustomMSDeformableAttention):
im2col_step=im2col_step,
dropout=dropout,
batch_first=batch_first,
add_identity=False,
norm_cfg=norm_cfg,
init_cfg=init_cfg)
init_cfg=init_cfg,
adapt_jit=adapt_jit)
self.output_proj = nn.Identity()
@ -321,62 +338,3 @@ class MSDeformableAttention3D(CustomMSDeformableAttention):
f'Last dim of reference_points must be'
f' 2, but get {reference_points.shape[-1]} instead.')
return sampling_locations
def forward(self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
( bs, num_query, embed_dims).
key (Tensor): The key tensor with shape
`(bs, num_key, embed_dims)`.
value (Tensor): The value tensor with shape
`(bs, num_key, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
return super().forward(
query=query,
key=key,
value=value,
identity=identity,
query_pos=query_pos,
key_padding_mask=key_padding_mask,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
add_identity=False,
**kwargs)

View File

@ -1,5 +1,7 @@
# Modified from https://github.com/fundamentalvision/BEVFormer.
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Optional
import torch
import torch.nn as nn
from mmcv.cnn import xavier_init
@ -41,7 +43,6 @@ class SpatialCrossAttention(BaseModule):
self.init_cfg = init_cfg
self.dropout = nn.Dropout(dropout)
self.pc_range = pc_range
self.fp16_enabled = False
self.deformable_attention = build_attention(deformable_attention)
self.embed_dims = embed_dims
self.num_cams = num_cams
@ -56,20 +57,21 @@ class SpatialCrossAttention(BaseModule):
@force_fp32(
apply_to=('query', 'key', 'value', 'query_pos', 'reference_points_cam')
)
def forward(self,
query,
key,
value,
residual=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
reference_points_cam=None,
bev_mask=None,
level_start_index=None,
flag='encoder',
**kwargs):
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
reference_points_cam: torch.Tensor,
bev_mask: torch.Tensor,
spatial_shapes: torch.Tensor,
level_start_index: torch.Tensor,
residual: Optional[torch.Tensor] = None,
query_pos: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
reference_points: Optional[torch.Tensor] = None,
flag: Optional[str] = 'encoder',
):
"""Forward Function of Detr3DCrossAtten.
Args:
query (Tensor): Query of Transformer with shape
@ -108,9 +110,13 @@ class SpatialCrossAttention(BaseModule):
if value is None:
value = key
if residual is None:
inp_residual = query
slots = torch.zeros_like(query)
# if residual is None:
# inp_residual = query
# slots = torch.zeros_like(query)
assert residual is None
inp_residual = query
slots = torch.zeros_like(query)
if query_pos is not None:
query = query + query_pos

View File

@ -63,7 +63,6 @@ class TemporalSelfAttention(BaseModule):
self.norm_cfg = norm_cfg
self.dropout = nn.Dropout(dropout)
self.batch_first = batch_first
self.fp16_enabled = False
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
@ -129,8 +128,7 @@ class TemporalSelfAttention(BaseModule):
reference_points=None,
spatial_shapes=None,
level_start_index=None,
flag='decoder',
**kwargs):
flag='decoder'):
"""Forward Function of MultiScaleDeformAttention.
Args:
@ -235,19 +233,20 @@ class TemporalSelfAttention(BaseModule):
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available() and value.is_cuda:
from easycv.thirdparty.deformable_attention.functions import MSDeformAttnFunction
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
op = MSDeformAttnFunction.apply
else:
op = torch.ops.custom.ms_deform_attn
if value.dtype == torch.float16:
output = MSDeformAttnFunction.apply(
output = op(
value.to(torch.float32), spatial_shapes, level_start_index,
sampling_locations.to(torch.float32), attention_weights,
self.im2col_step)
output = output.to(torch.float16)
else:
output = MSDeformAttnFunction.apply(value, spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step)
output = op(value, spatial_shapes, level_start_index,
sampling_locations, attention_weights,
self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights)

View File

@ -1,14 +1,18 @@
# Modified from https://github.com/fundamentalvision/BEVFormer.
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import pickle
import numpy as np
import torch
from easycv.core.bbox import get_box_type
from easycv.core.bbox.bbox_util import bbox3d2result
from easycv.models.detection3d.detectors.mvx_two_stage import \
MVXTwoStageDetector
from easycv.models.detection3d.utils.grid_mask import GridMask
from easycv.models.registry import MODELS
from easycv.utils.misc import decode_tensor_to_str, encode_str_to_tensor
@MODELS.register_module()
@ -16,6 +20,8 @@ class BEVFormer(MVXTwoStageDetector):
"""BEVFormer.
Args:
video_test_mode (bool): Decide whether to use temporal information during inference.
extract_feat_serially (bool): Whether extract history features one by one,
to solve the problem of batchnorm corrupt when shape N is too large.
"""
def __init__(self,
@ -34,7 +40,8 @@ class BEVFormer(MVXTwoStageDetector):
train_cfg=None,
test_cfg=None,
pretrained=None,
video_test_mode=False):
video_test_mode=False,
extract_feat_serially=False):
super(BEVFormer,
self).__init__(pts_voxel_layer, pts_voxel_encoder,
@ -45,18 +52,18 @@ class BEVFormer(MVXTwoStageDetector):
self.grid_mask = GridMask(
True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7)
self.use_grid_mask = use_grid_mask
self.fp16_enabled = False
self.extract_feat_serially = extract_feat_serially
# temporal
self.video_test_mode = video_test_mode
self.prev_frame_info = {
'prev_bev': None,
'scene_token': None,
'prev_scene_token': None,
'prev_pos': 0,
'prev_angle': 0,
}
def extract_img_feat(self, img, img_metas, len_queue=None):
def extract_img_feat(self, img, len_queue=None):
"""Extract features of images."""
B = img.size(0)
if img is not None:
@ -94,10 +101,10 @@ class BEVFormer(MVXTwoStageDetector):
img_feat.view(B, int(BN / B), C, H, W))
return img_feats_reshaped
def extract_feat(self, img, img_metas=None, len_queue=None):
def extract_feat(self, img, len_queue=None):
"""Extract features from images and points."""
img_feats = self.extract_img_feat(img, img_metas, len_queue=len_queue)
img_feats = self.extract_img_feat(img, len_queue=len_queue)
return img_feats
@ -132,6 +139,27 @@ class BEVFormer(MVXTwoStageDetector):
dummy_metas = None
return self.forward_test(img=img, img_metas=[[dummy_metas]])
def obtain_history_bev_serially(self, imgs_queue, img_metas_list):
"""Obtain history BEV features iteratively.
Extract feature one by one to solve the problem of batchnorm corrupt when shape N is too large.
"""
self.eval()
with torch.no_grad():
prev_bev = None
bs, len_queue, num_cams, C, H, W = imgs_queue.shape
for i in range(len_queue):
img_feats = self.extract_feat(
img=imgs_queue[:, i, ...], len_queue=None)
img_metas = [each[i] for each in img_metas_list]
if not img_metas[0]['prev_bev_exists']:
prev_bev = None
prev_bev = self.pts_bbox_head(
img_feats, img_metas, prev_bev, only_bev=True)
self.train()
return prev_bev
def obtain_history_bev(self, imgs_queue, img_metas_list):
"""Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
"""
@ -147,21 +175,19 @@ class BEVFormer(MVXTwoStageDetector):
img_metas = [each[i] for each in img_metas_list]
if not img_metas[0]['prev_bev_exists']:
prev_bev = None
# img_feats = self.extract_feat(img=img, img_metas=img_metas)
img_feats = [each_scale[:, i] for each_scale in img_feats_list]
prev_bev = self.pts_bbox_head(
img_feats, img_metas, prev_bev, only_bev=True)
self.train()
return prev_bev
def forward_train(
self,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
img=None,
gt_bboxes_ignore=None,
):
def forward_train(self,
img_metas=None,
gt_bboxes_3d=None,
gt_labels_3d=None,
img=None,
gt_bboxes_ignore=None,
**kwargs):
"""Forward training function.
Args:
points (list[torch.Tensor], optional): Points of each sample.
@ -185,18 +211,24 @@ class BEVFormer(MVXTwoStageDetector):
Returns:
dict: Losses of different branches.
"""
self._check_inputs(img_metas, img, kwargs)
len_queue = img.size(1)
prev_img = img[:, :-1, ...]
img = img[:, -1, ...]
prev_img_metas = copy.deepcopy(img_metas)
prev_bev = self.obtain_history_bev(prev_img, prev_img_metas)
if self.extract_feat_serially:
prev_bev = self.obtain_history_bev_serially(
prev_img, prev_img_metas)
else:
prev_bev = self.obtain_history_bev(prev_img, prev_img_metas)
img_metas = [each[len_queue - 1] for each in img_metas]
if not img_metas[0]['prev_bev_exists']:
prev_bev = None
img_feats = self.extract_feat(img=img, img_metas=img_metas)
img_feats = self.extract_feat(img=img)
losses = dict()
losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d,
gt_labels_3d, img_metas,
@ -205,7 +237,34 @@ class BEVFormer(MVXTwoStageDetector):
losses.update(losses_pts)
return losses
def _check_inputs(self, img_metas, img, kwargs):
can_bus_in_kwargs = kwargs.get('can_bus', None) is not None
lidar2img_in_kwargs = kwargs.get('lidar2img', None) is not None
for batch_i in range(len(img_metas)):
for i in range(len(img_metas[batch_i])):
if can_bus_in_kwargs:
img_metas[batch_i][i]['can_bus'] = kwargs['can_bus'][
batch_i][i]
else:
if isinstance(img_metas[batch_i][i]['can_bus'],
np.ndarray):
img_metas[batch_i][i]['can_bus'] = torch.from_numpy(
img_metas[batch_i][i]['can_bus']).to(img.device)
if lidar2img_in_kwargs:
img_metas[batch_i][i]['lidar2img'] = kwargs['lidar2img'][
batch_i][i]
else:
if isinstance(img_metas[batch_i][i]['lidar2img'],
np.ndarray):
img_metas[batch_i][i]['lidar2img'] = torch.from_numpy(
np.array(img_metas[batch_i][i]['lidar2img'])).to(
img.device)
kwargs.pop('can_bus', None)
kwargs.pop('lidar2img', None)
def forward_test(self, img_metas, img=None, rescale=True, **kwargs):
self._check_inputs(img_metas, img, kwargs)
for var, name in [(img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError('{} must be a list, but got {}'.format(
@ -213,20 +272,25 @@ class BEVFormer(MVXTwoStageDetector):
img = [img] if img is None else img
if img_metas[0][0]['scene_token'] != self.prev_frame_info[
'scene_token']:
'prev_scene_token']:
# the first sample of each scene is truncated
self.prev_frame_info['prev_bev'] = None
# update idx
self.prev_frame_info['scene_token'] = img_metas[0][0]['scene_token']
self.prev_frame_info['prev_scene_token'] = img_metas[0][0][
'scene_token']
# do not use temporal information
if not self.video_test_mode:
self.prev_frame_info['prev_bev'] = None
# Get the delta of ego position and angle between two timestamps.
tmp_pos = copy.deepcopy(img_metas[0][0]['can_bus'][:3])
tmp_angle = copy.deepcopy(img_metas[0][0]['can_bus'][-1])
if self.prev_frame_info['prev_bev'] is not None:
tmp_pos = img_metas[0][0]['can_bus'][:3].clone()
tmp_angle = img_metas[0][0]['can_bus'][-1].clone()
# skip init dummy prev_bev
if self.prev_frame_info['prev_bev'] is not None and not torch.equal(
self.prev_frame_info['prev_bev'],
self.prev_frame_info['prev_bev'].new_zeros(
self.prev_frame_info['prev_bev'].size())):
img_metas[0][0]['can_bus'][:3] -= self.prev_frame_info['prev_pos']
img_metas[0][0]['can_bus'][-1] -= self.prev_frame_info[
'prev_angle']
@ -268,7 +332,7 @@ class BEVFormer(MVXTwoStageDetector):
def simple_test(self, img_metas, img=None, prev_bev=None, rescale=False):
"""Test function without augmentaiton."""
img_feats = self.extract_feat(img=img, img_metas=img_metas)
img_feats = self.extract_feat(img=img)
bbox_list = [dict() for i in range(len(img_metas))]
new_prev_bev, bbox_pts = self.simple_test_pts(
@ -276,3 +340,102 @@ class BEVFormer(MVXTwoStageDetector):
for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
result_dict['pts_bbox'] = pts_bbox
return new_prev_bev, bbox_list
def forward_export(self, img, img_metas):
error_str = 'Only support batch_size=1 and queue_length=1, please remove axis of batch_size and queue_length!'
if len(img.shape) > 4:
raise ValueError(error_str)
elif len(img.shape) < 4:
raise ValueError(
'The length of img size must be equal to 4: [num_cameras, img_channel, img_height, img_width]'
)
assert len(
img_metas['can_bus'].shape) == 1, error_str # torch.Size([18])
assert len(img_metas['lidar2img'].shape
) == 3, error_str # torch.Size([6, 4, 4])
assert len(
img_metas['img_shape'].shape) == 2, error_str # torch.Size([6, 3])
assert len(img_metas['prev_bev'].shape
) == 3, error_str # torch.Size([40000, 1, 256])
img = img[
None, None,
...] # torch.Size([6, 3, 928, 1600]) -> torch.Size([1, 1, 6, 3, 928, 1600])
box_type_3d = img_metas.get('box_type_3d', 'LiDAR')
if isinstance(box_type_3d, torch.Tensor):
box_type_3d = pickle.loads(box_type_3d.cpu().numpy().tobytes())
img_metas['box_type_3d'] = get_box_type(box_type_3d)[0]
img_metas['scene_token'] = decode_tensor_to_str(
img_metas['scene_token'])
# previous frame info
self.prev_frame_info['prev_scene_token'] = decode_tensor_to_str(
img_metas.pop('prev_scene_token', None))
self.prev_frame_info['prev_bev'] = img_metas.pop('prev_bev', None)
self.prev_frame_info['prev_pos'] = img_metas.pop('prev_pos', None)
self.prev_frame_info['prev_angle'] = img_metas.pop('prev_angle', None)
img_metas = [[img_metas]]
outputs = self.forward_test(img_metas, img=img)
scores_3d = outputs['pts_bbox'][0]['scores_3d']
labels_3d = outputs['pts_bbox'][0]['labels_3d']
boxes_3d = outputs['pts_bbox'][0]['boxes_3d'].tensor.cpu()
# info has been updated to the current frame
prev_bev = self.prev_frame_info['prev_bev']
prev_pos = self.prev_frame_info['prev_pos']
prev_angle = self.prev_frame_info['prev_angle']
prev_scene_token = encode_str_to_tensor(
self.prev_frame_info['prev_scene_token'])
return scores_3d, labels_3d, boxes_3d, [
prev_bev, prev_pos, prev_angle, prev_scene_token
]
def forward_history_bev(self,
img,
can_bus,
lidar2img,
img_shape,
scene_token,
box_type_3d='LiDAR'):
"""Experimental api, for export jit model to obtain history bev.
"""
if isinstance(box_type_3d, torch.Tensor):
box_type_3d = pickle.loads(box_type_3d.cpu().numpy().tobytes())
batch_size, len_queue = img.size()[:2]
img_metas = []
for b_i in range(batch_size):
img_metas.append([])
for i in range(len_queue):
scene_token_str = pickle.loads(
scene_token[b_i][i].cpu().numpy().tobytes())
img_metas[b_i].append({
'scene_token':
scene_token_str,
'can_bus':
can_bus[b_i][i],
'lidar2img':
lidar2img[b_i][i],
'img_shape':
img_shape[b_i][i],
'box_type_3d':
get_box_type(box_type_3d)[0],
'prev_bev_exists':
False
})
prev_img = img[:, :-1, ...]
img = img[:, -1, ...]
prev_img_metas = copy.deepcopy(img_metas)
if self.extract_feat_serially:
prev_bev = self.obtain_history_bev_serially(
prev_img, prev_img_metas)
else:
prev_bev = self.obtain_history_bev(prev_img, prev_img_metas)
return prev_bev

View File

@ -36,6 +36,8 @@ class BEVFormerHead(AnchorFreeHead):
num_classes,
in_channels,
num_query=100,
num_query_one2many=0,
one2many_gt_mul=None,
num_reg_fcs=2,
with_box_refine=False,
as_two_stage=False,
@ -71,7 +73,6 @@ class BEVFormerHead(AnchorFreeHead):
self.bev_h = bev_h
self.bev_w = bev_w
self.fp16_enabled = False
self.with_box_refine = with_box_refine
self.as_two_stage = as_two_stage
if self.as_two_stage:
@ -133,13 +134,17 @@ class BEVFormerHead(AnchorFreeHead):
sampler_cfg = dict(type='PseudoBBoxSampler')
self.sampler = build_bbox_sampler(sampler_cfg, context=self)
self.num_query = num_query
# for one2many task
self.num_query_one2many = num_query_one2many
self.num_query_one2one = num_query
self.one2many_gt_mul = one2many_gt_mul
self.num_query = num_query + num_query_one2many if num_query_one2many > 0 else num_query
self.num_classes = num_classes
self.in_channels = in_channels
self.num_reg_fcs = num_reg_fcs
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.fp16_enabled = False
self.loss_cls = build_loss(loss_cls)
self.loss_bbox = build_loss(loss_bbox)
self.loss_iou = build_loss(loss_iou)
@ -279,6 +284,16 @@ class BEVFormerHead(AnchorFreeHead):
prev_bev=prev_bev,
)
else:
# make attn mask for one2many task
self_attn_mask = torch.zeros([
self.num_query,
self.num_query,
]).bool().to(bev_queries.device)
self_attn_mask[self.num_query_one2one:,
0:self.num_query_one2one, ] = True
self_attn_mask[0:self.num_query_one2one,
self.num_query_one2one:, ] = True
outputs = self.transformer(
mlvl_feats,
bev_queries,
@ -292,7 +307,8 @@ class BEVFormerHead(AnchorFreeHead):
if self.with_box_refine else None, # noqa:E501
cls_branches=self.cls_branches if self.as_two_stage else None,
img_metas=img_metas,
prev_bev=prev_bev)
prev_bev=prev_bev,
attn_mask=self_attn_mask)
bev_embed, hs, init_reference, inter_references = outputs
hs = hs.permute(0, 2, 1, 3)
@ -309,20 +325,47 @@ class BEVFormerHead(AnchorFreeHead):
# TODO: check the shape of reference
assert reference.shape[-1] == 3
tmp[..., 0:2] += reference[..., 0:2]
tmp[..., 0:2] = tmp[..., 0:2].sigmoid()
tmp[..., 4:5] += reference[..., 2:3]
tmp[..., 4:5] = tmp[..., 4:5].sigmoid()
tmp[..., 0:1] = (
tmp[..., 0:1] * (self.pc_range[3] - self.pc_range[0]) +
# tmp: torch.Size([1, 900, 10])
# tmp[..., 0:2] += reference[..., 0:2]
# tmp[..., 0:2] = tmp[..., 0:2].sigmoid()
# tmp[..., 4:5] += reference[..., 2:3]
# tmp[..., 4:5] = tmp[..., 4:5].sigmoid()
# tmp[..., 0:1] = (
# tmp[..., 0:1] * (self.pc_range[3] - self.pc_range[0]) +
# self.pc_range[0])
# tmp[..., 1:2] = (
# tmp[..., 1:2] * (self.pc_range[4] - self.pc_range[1]) +
# self.pc_range[1])
# tmp[..., 4:5] = (
# tmp[..., 4:5] * (self.pc_range[5] - self.pc_range[2]) +
# self.pc_range[2])
# remove inplace operation, metric may incorrect when using blade
tmp_0_2 = tmp[..., 0:2]
tmp_0_2_add_reference = tmp_0_2 + reference[..., 0:2]
tmp_0_2_add_reference = tmp_0_2_add_reference.sigmoid()
tmp_4_5 = tmp[..., 4:5]
tmp_4_5_add_reference = tmp_4_5 + reference[..., 2:3]
tmp_4_5_add_reference = tmp_4_5_add_reference.sigmoid()
tmp_0_1 = tmp_0_2_add_reference[..., 0:1]
tmp_0_1_new = (
tmp_0_1 * (self.pc_range[3] - self.pc_range[0]) +
self.pc_range[0])
tmp[..., 1:2] = (
tmp[..., 1:2] * (self.pc_range[4] - self.pc_range[1]) +
tmp_1_2 = tmp_0_2_add_reference[..., 1:2]
tmp_1_2_new = (
tmp_1_2 * (self.pc_range[4] - self.pc_range[1]) +
self.pc_range[1])
tmp[..., 4:5] = (
tmp[..., 4:5] * (self.pc_range[5] - self.pc_range[2]) +
tmp_4_5_new = (
tmp_4_5_add_reference * (self.pc_range[5] - self.pc_range[2]) +
self.pc_range[2])
tmp_2_4 = tmp[..., 2:4]
tmp_5_10 = tmp[..., 5:10]
tmp = torch.cat(
[tmp_0_1_new, tmp_1_2_new, tmp_2_4, tmp_4_5_new, tmp_5_10],
dim=-1)
# TODO: check if using sigmoid
outputs_coord = tmp
outputs_classes.append(outputs_class)
@ -333,12 +376,19 @@ class BEVFormerHead(AnchorFreeHead):
outs = {
'bev_embed': bev_embed,
'all_cls_scores': outputs_classes,
'all_bbox_preds': outputs_coords,
'all_cls_scores':
outputs_classes[:, :, :self.num_query_one2one, :],
'all_bbox_preds': outputs_coords[:, :, :self.num_query_one2one, :],
'enc_cls_scores': None,
'enc_bbox_preds': None,
}
if self.num_query_one2many > 0:
outs['all_cls_scores_aux'] = outputs_classes[:, :, self.
num_query_one2one:, :]
outs['all_bbox_preds_aux'] = outputs_coords[:, :, self.
num_query_one2one:, :]
return outs
def _get_target_single(self,
@ -396,6 +446,8 @@ class BEVFormerHead(AnchorFreeHead):
bbox_weights[pos_inds] = 1.0
# DETR
sampling_result.pos_gt_bboxes = sampling_result.pos_gt_bboxes.type_as(
bbox_targets)
bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
neg_inds)
@ -586,6 +638,47 @@ class BEVFormerHead(AnchorFreeHead):
all_gt_bboxes_ignore_list)
loss_dict = dict()
# for one2many task
if 'all_cls_scores_aux' in preds_dicts and self.one2many_gt_mul:
all_cls_scores_aux = preds_dicts['all_cls_scores_aux']
all_bbox_preds_aux = preds_dicts['all_bbox_preds_aux']
gt_bboxes_list_aux = []
gt_labels_list_aux = []
for gt_bboxes, gt_labels in zip(gt_bboxes_list, gt_labels_list):
gt_bboxes_list_aux.append(
gt_bboxes.repeat(self.one2many_gt_mul, 1))
gt_labels_list_aux.append(
gt_labels.repeat(self.one2many_gt_mul))
# for classwise multiply
# for gt_bboxes, gt_labels in zip(gt_bboxes_list,gt_labels_list):
# gt_bboxes_aux = []
# gt_labels_aux = []
# for gt_bbox, gt_label in zip(gt_bboxes, gt_labels):
# gt_bboxes_aux += [gt_bbox]*self.one2many_gt_mul[gt_label]
# gt_labels_aux += [gt_label]*self.one2many_gt_mul[gt_label]
# gt_bboxes_list_aux.append(torch.stack(gt_bboxes_aux))
# gt_labels_list_aux.append(torch.stack(gt_labels_aux))
all_gt_bboxes_list_aux = [
gt_bboxes_list_aux for _ in range(num_dec_layers)
]
all_gt_labels_list_aux = [
gt_labels_list_aux for _ in range(num_dec_layers)
]
losses_cls_aux, losses_bbox_aux = multi_apply(
self.loss_single, all_cls_scores_aux, all_bbox_preds_aux,
all_gt_bboxes_list_aux, all_gt_labels_list_aux,
all_gt_bboxes_ignore_list)
loss_dict['loss_cls_aux'] = losses_cls_aux[-1]
loss_dict['loss_bbox_aux'] = losses_bbox_aux[-1]
num_dec_layer = 0
for loss_cls_i, loss_bbox_i in zip(losses_cls_aux[:-1],
losses_bbox_aux[:-1]):
loss_dict[f'd{num_dec_layer}.loss_cls_aux'] = loss_cls_i
loss_dict[f'd{num_dec_layer}.loss_bbox_aux'] = loss_bbox_i
num_dec_layer += 1
# loss of proposal generated from encode feature map.
if enc_cls_scores is not None:
binary_labels_list = [

View File

@ -24,6 +24,19 @@ from easycv.models.utils.transformer import (BaseTransformerLayer,
TransformerLayerSequence)
from . import (CustomMSDeformableAttention, MSDeformableAttention3D,
TemporalSelfAttention)
from .attentions.spatial_cross_attention import SpatialCrossAttention
@torch.jit.script
def _rotate(img: torch.Tensor, angle: torch.Tensor, center: torch.Tensor):
"""torch.jit.trace does not support torchvision.rotate"""
img = rotate(
img,
float(angle.item()),
center=[int(center[0].item()),
int(center[1].item())])
return img
@TRANSFORMER_LAYER.register_module()
@ -107,6 +120,7 @@ class BEVFormerLayer(BaseModule):
),
batch_first=True,
init_cfg=None,
adapt_jit=False,
**kwargs):
super(BEVFormerLayer, self).__init__(init_cfg)
@ -135,6 +149,7 @@ class BEVFormerLayer(BaseModule):
self.attentions = ModuleList()
index = 0
self.adapt_jit = adapt_jit
for operation_name in operation_order:
if operation_name in ['self_attn', 'cross_attn']:
if 'batch_first' in attn_cfgs[index]:
@ -142,6 +157,10 @@ class BEVFormerLayer(BaseModule):
else:
attn_cfgs[index]['batch_first'] = self.batch_first
attention = build_attention(attn_cfgs[index])
# for export jit model
if self.adapt_jit and isinstance(attention,
SpatialCrossAttention):
attention = torch.jit.script(attention)
# Some custom attentions used as `self_attn`
# or `cross_attn` can have different behavior.
attention.operation_name = operation_name
@ -170,7 +189,6 @@ class BEVFormerLayer(BaseModule):
for _ in range(num_norms):
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
self.fp16_enabled = False
assert len(operation_order) == 6
assert set(operation_order) == set(
['self_attn', 'norm', 'cross_attn', 'ffn'])
@ -249,43 +267,42 @@ class BEVFormerLayer(BaseModule):
if layer == 'self_attn':
query = self.attentions[attn_index](
query,
prev_bev,
prev_bev,
identity if self.pre_norm else None,
query=query,
key=prev_bev,
value=prev_bev,
identity=identity if self.pre_norm else None,
query_pos=bev_pos,
key_pos=bev_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=query_key_padding_mask,
reference_points=ref_2d,
spatial_shapes=torch.tensor([[bev_h, bev_w]],
device=query.device),
level_start_index=torch.tensor([0], device=query.device),
**kwargs)
)
attn_index += 1
identity = query
elif layer == 'norm':
# fix fp16
dtype = query.dtype
query = self.norms[norm_index](query)
query = query.to(dtype)
norm_index += 1
# spaital cross attention
elif layer == 'cross_attn':
query = self.attentions[attn_index](
query,
key,
value,
identity if self.pre_norm else None,
query=query,
key=key,
value=value,
residual=identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=key_pos,
reference_points=ref_3d,
reference_points_cam=reference_points_cam,
mask=mask,
attn_mask=attn_masks[attn_index],
bev_mask=kwargs.get('bev_mask'),
key_padding_mask=key_padding_mask,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
**kwargs)
)
attn_index += 1
identity = query
@ -309,7 +326,6 @@ class Detr3DTransformerDecoder(TransformerLayerSequence):
def __init__(self, *args, return_intermediate=False, **kwargs):
super(Detr3DTransformerDecoder, self).__init__(*args, **kwargs)
self.return_intermediate = return_intermediate
self.fp16_enabled = False
def forward(self,
query,
@ -317,6 +333,7 @@ class Detr3DTransformerDecoder(TransformerLayerSequence):
reference_points=None,
reg_branches=None,
key_padding_mask=None,
attn_mask=None,
**kwargs):
"""Forward function for `Detr3DTransformerDecoder`.
Args:
@ -346,6 +363,7 @@ class Detr3DTransformerDecoder(TransformerLayerSequence):
output,
*args,
reference_points=reference_points_input,
attn_masks=[attn_mask] * layer.num_attn,
key_padding_mask=key_padding_mask,
**kwargs)
output = output.permute(1, 0, 2)
@ -355,13 +373,26 @@ class Detr3DTransformerDecoder(TransformerLayerSequence):
assert reference_points.shape[-1] == 3
new_reference_points = torch.zeros_like(reference_points)
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(
reference_points[..., :2], eps=1e-5)
new_reference_points[...,
2:3] = tmp[..., 4:5] + inverse_sigmoid(
reference_points[..., 2:3], eps=1e-5)
# new_reference_points = torch.zeros_like(
# reference_points) # torch.Size([1, 900, 3])
# new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(
# reference_points[..., :2], eps=1e-5)
# new_reference_points[...,
# 2:3] = tmp[..., 4:5] + inverse_sigmoid(
# reference_points[..., 2:3], eps=1e-5)
# new_reference_points = new_reference_points.sigmoid()
# reference_points = new_reference_points.detach()
# remove inplace operation, metric may incorrect when using blade
new_reference_points_0_2 = tmp[..., :2] + inverse_sigmoid(
reference_points[..., :2], eps=1e-5)
new_reference_points_2_3 = tmp[..., 4:5] + inverse_sigmoid(
reference_points[..., 2:3], eps=1e-5)
new_reference_points = torch.cat(
[new_reference_points_0_2, new_reference_points_2_3],
dim=-1)
new_reference_points = new_reference_points.sigmoid()
reference_points = new_reference_points.detach()
@ -402,7 +433,6 @@ class BEVFormerEncoder(TransformerLayerSequence):
self.num_points_in_pillar = num_points_in_pillar
self.pc_range = pc_range
self.fp16_enabled = False
@staticmethod
def get_reference_points(H,
@ -456,12 +486,9 @@ class BEVFormerEncoder(TransformerLayerSequence):
# This function must use fp32!!!
@force_fp32(apply_to=('reference_points', 'img_metas'))
def point_sampling(self, reference_points, pc_range, img_metas):
lidar2img = torch.stack([meta['lidar2img'] for meta in img_metas
]).to(reference_points.dtype) # (B, N, 4, 4)
lidar2img = []
for img_meta in img_metas:
lidar2img.append(img_meta['lidar2img'])
lidar2img = np.asarray(lidar2img)
lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4)
reference_points = reference_points.clone()
reference_points[..., 0:1] = reference_points[..., 0:1] * \
@ -650,7 +677,6 @@ class PerceptionTransformer(BaseModule):
self.embed_dims = embed_dims
self.num_feature_levels = num_feature_levels
self.num_cams = num_cams
self.fp16_enabled = False
self.rotate_prev_bev = rotate_prev_bev
self.use_shift = use_shift
@ -711,26 +737,28 @@ class PerceptionTransformer(BaseModule):
bev_pos = bev_pos.flatten(2).permute(2, 0, 1)
# obtain rotation angle and shift with ego motion
delta_x = np.array(
delta_x = torch.stack(
[each['can_bus'][0] for each in kwargs['img_metas']])
delta_y = np.array(
delta_y = torch.stack(
[each['can_bus'][1] for each in kwargs['img_metas']])
ego_angle = np.array([
ego_angle = torch.stack([
each['can_bus'][-2] / np.pi * 180 for each in kwargs['img_metas']
])
grid_length_y = grid_length[0]
grid_length_x = grid_length[1]
translation_length = np.sqrt(delta_x**2 + delta_y**2)
translation_angle = np.arctan2(delta_y, delta_x) / np.pi * 180
translation_length = torch.sqrt(delta_x**2 + delta_y**2)
translation_angle = torch.atan2(delta_y, delta_x) / np.pi * 180
bev_angle = ego_angle - translation_angle
shift_y = translation_length * \
np.cos(bev_angle / 180 * np.pi) / grid_length_y / bev_h
torch.cos(bev_angle / 180 * np.pi) / grid_length_y / bev_h
shift_x = translation_length * \
np.sin(bev_angle / 180 * np.pi) / grid_length_x / bev_w
shift_y = shift_y * self.use_shift
shift_x = shift_x * self.use_shift
shift = bev_queries.new_tensor([shift_x, shift_y
]).permute(1, 0) # xy, bs -> bs, xy
torch.sin(bev_angle / 180 * np.pi) / grid_length_x / bev_w
if not self.use_shift:
shift_y = shift_y.new_zeros(shift_y.size())
shift_x = shift_x.new_zeros(shift_y.size())
shift = torch.stack([shift_x,
shift_y]).permute(1, 0).to(bev_queries.dtype)
if prev_bev is not None:
if prev_bev.shape[1] == bev_h * bev_w:
@ -741,19 +769,23 @@ class PerceptionTransformer(BaseModule):
rotation_angle = kwargs['img_metas'][i]['can_bus'][-1]
tmp_prev_bev = prev_bev[:, i].reshape(bev_h, bev_w,
-1).permute(2, 0, 1)
tmp_prev_bev = rotate(
tmp_prev_bev = _rotate(
tmp_prev_bev,
rotation_angle,
center=self.rotate_center)
center=torch.tensor(self.rotate_center))
tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape(
bev_h * bev_w, 1, -1)
prev_bev[:, i] = tmp_prev_bev[:, 0]
# add can bus signals
can_bus = bev_queries.new_tensor(
[each['can_bus'] for each in kwargs['img_metas']]) # [:, :]
can_bus = torch.stack([
each['can_bus'] for each in kwargs['img_metas']
]).to(bev_queries.dtype)
can_bus = self.can_bus_mlp(can_bus)[None, :, :]
bev_queries = bev_queries + can_bus * self.use_can_bus
# fix fp16
can_bus = can_bus.to(bev_queries.dtype)
if self.use_can_bus:
bev_queries = bev_queries + can_bus
feat_flatten = []
spatial_shapes = []
@ -806,6 +838,7 @@ class PerceptionTransformer(BaseModule):
reg_branches=None,
cls_branches=None,
prev_bev=None,
attn_mask=None,
**kwargs):
"""Forward function for `Detr3DTransformer`.
Args:
@ -873,6 +906,7 @@ class PerceptionTransformer(BaseModule):
value=bev_embed,
query_pos=query_pos,
reference_points=reference_points,
attn_mask=attn_mask,
reg_branches=reg_branches,
cls_branches=cls_branches,
spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device),

View File

@ -5,6 +5,7 @@ import torch
import torch.nn as nn
from mmcv.runner import auto_fp16
from PIL import Image
from torchvision.transforms.functional import rotate
class Grid(object):
@ -113,7 +114,7 @@ class GridMask(nn.Module):
ww = int(1.5 * w)
d = np.random.randint(2, h)
self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
mask = np.ones((hh, ww), np.float32)
mask = torch.ones((hh, ww), dtype=torch.uint8, device=x.device)
st_h = np.random.randint(d)
st_w = np.random.randint(d)
if self.use_h:
@ -128,19 +129,16 @@ class GridMask(nn.Module):
mask[:, s:t] *= 0
r = np.random.randint(self.rotate)
mask = Image.fromarray(np.uint8(mask))
mask = mask.rotate(r)
mask = np.asarray(mask)
mask = rotate(mask.unsqueeze(0), r)[0]
mask = mask[(hh - h) // 2:(hh - h) // 2 + h,
(ww - w) // 2:(ww - w) // 2 + w]
mask = torch.from_numpy(mask).to(x.dtype).cuda()
mask = mask.to(x.dtype)
if self.mode == 1:
mask = 1 - mask
mask = mask.expand_as(x)
if self.offset:
offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).to(
x.dtype).cuda()
offset = (2 * torch.rand(
(h, w), device=x.device) - 0.5).to(x.dtype)
x = x * mask + offset * (1 - mask)
else:
x = x * mask

View File

@ -4,7 +4,7 @@ from .det_db_loss import DBLoss
from .face_keypoint_loss import FacePoseLoss, WingLossWithPose
from .focal_loss import FocalLoss, VarifocalLoss
from .iou_loss import GIoULoss, IoULoss, YOLOX_IOULoss
from .l1_loss import L1Loss
from .l1_loss import L1Loss, SmoothL1Loss
from .mse_loss import JointsMSELoss
from .ocr_rec_multi_loss import MultiLoss
from .pytorch_metric_learning import (AMSoftmaxLoss,
@ -22,5 +22,5 @@ __all__ = [
'FocalLoss2d', 'DistributeMSELoss', 'CrossEntropyLossWithLabelSmooth',
'AMSoftmaxLoss', 'ModelParallelSoftmaxLoss', 'ModelParallelAMSoftmaxLoss',
'SoftTargetCrossEntropy', 'CDNCriterion', 'DNCriterion', 'DBLoss',
'HungarianMatcher', 'SetCriterion', 'L1Loss', 'MultiLoss'
'HungarianMatcher', 'SetCriterion', 'L1Loss', 'MultiLoss', 'SmoothL1Loss'
]

View File

@ -1,4 +1,5 @@
import mmcv
import numpy as np
import torch
import torch.nn as nn
@ -66,3 +67,185 @@ class L1Loss(nn.Module):
loss_bbox = self.loss_weight * l1_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss_bbox
# @mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def smooth_l1_loss(pred, target, beta=1.0):
"""Smooth L1 loss.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
beta (float, optional): The threshold in the piecewise function.
Defaults to 1.0.
Returns:
torch.Tensor: Calculated loss
"""
assert beta > 0
if target.numel() == 0:
return pred.sum() * 0
assert pred.size() == target.size()
diff = torch.abs(pred - target)
loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
diff - 0.5 * beta)
return loss
@LOSSES.register_module()
class SmoothL1Loss(nn.Module):
"""Smooth L1 loss.
Args:
beta (float, optional): The threshold in the piecewise function.
Defaults to 1.0.
reduction (str, optional): The method to reduce the loss.
Options are "none", "mean" and "sum". Defaults to "mean".
loss_weight (float, optional): The weight of loss.
"""
def __init__(self, beta=1.0, reduction='mean', loss_weight=1.0):
super(SmoothL1Loss, self).__init__()
self.beta = beta
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
"""Forward function.
Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * smooth_l1_loss(
pred,
target,
weight,
beta=self.beta,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_bbox
@mmcv.jit(derivate=True, coderize=True)
@weighted_loss
def balanced_l1_loss(pred,
target,
beta=1.0,
alpha=0.5,
gamma=1.5,
reduction='mean'):
"""Calculate balanced L1 loss.
Please see the `Libra R-CNN <https://arxiv.org/pdf/1904.02701.pdf>`_
Args:
pred (torch.Tensor): The prediction with shape (N, 4).
target (torch.Tensor): The learning target of the prediction with
shape (N, 4).
beta (float): The loss is a piecewise function of prediction and target
and ``beta`` serves as a threshold for the difference between the
prediction and target. Defaults to 1.0.
alpha (float): The denominator ``alpha`` in the balanced L1 loss.
Defaults to 0.5.
gamma (float): The ``gamma`` in the balanced L1 loss.
Defaults to 1.5.
reduction (str, optional): The method that reduces the loss to a
scalar. Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
assert beta > 0
if target.numel() == 0:
return pred.sum() * 0
assert pred.size() == target.size()
diff = torch.abs(pred - target)
b = np.e**(gamma / alpha) - 1
loss = torch.where(
diff < beta, alpha / b *
(b * diff + 1) * torch.log(b * diff / beta + 1) - alpha * diff,
gamma * diff + gamma / b - alpha * beta)
return loss
@LOSSES.register_module()
class BalancedL1Loss(nn.Module):
"""Balanced L1 Loss.
arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019)
Args:
alpha (float): The denominator ``alpha`` in the balanced L1 loss.
Defaults to 0.5.
gamma (float): The ``gamma`` in the balanced L1 loss. Defaults to 1.5.
beta (float, optional): The loss is a piecewise function of prediction
and target. ``beta`` serves as a threshold for the difference
between the prediction and target. Defaults to 1.0.
reduction (str, optional): The method that reduces the loss to a
scalar. Options are "none", "mean" and "sum".
loss_weight (float, optional): The weight of the loss. Defaults to 1.0
"""
def __init__(self,
alpha=0.5,
gamma=1.5,
beta=1.0,
reduction='mean',
loss_weight=1.0):
super(BalancedL1Loss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.beta = beta
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
"""Forward function of loss.
Args:
pred (torch.Tensor): The prediction with shape (N, 4).
target (torch.Tensor): The learning target of the prediction with
shape (N, 4).
weight (torch.Tensor, optional): Sample-wise loss weight with
shape (N, ).
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * balanced_l1_loss(
pred,
target,
weight,
alpha=self.alpha,
gamma=self.gamma,
beta=self.beta,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_bbox

View File

@ -510,10 +510,10 @@ class BaseTransformerLayer(BaseModule):
elif layer == 'cross_attn':
query = self.attentions[attn_index](
query,
key,
value,
identity if self.pre_norm else None,
query=query,
key=key,
value=value,
identity=identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=key_pos,
attn_mask=attn_masks[attn_index],

View File

@ -1,15 +1,21 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import pickle
import mmcv
import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from easycv.core.bbox import get_box_type
from easycv.datasets.registry import PIPELINES
from easycv.datasets.shared.pipelines.format import to_tensor
from easycv.datasets.shared.pipelines.transforms import Compose
from easycv.framework.errors import ValueError
from easycv.predictors.base import PredictorV2
from easycv.predictors.builder import PREDICTORS
from easycv.utils.misc import encode_str_to_tensor
from easycv.utils.registry import build_from_cfg
from .base import PredictorV2
from .builder import PREDICTORS
@PREDICTORS.register_module()
@ -40,8 +46,21 @@ class BEVFormerPredictor(PredictorV2):
box_type_3d='LiDAR',
use_camera=True,
score_threshold=0.1,
model_type=None,
*arg,
**kwargs):
if batch_size > 1:
raise ValueError(
f'Only support batch_size=1 now, but get batch_size={batch_size}'
)
self.model_type = model_type
if self.model_type is None:
if model_path.endswith('jit'):
self.model_type = 'jit'
elif model_path.endswith('blade'):
self.model_type = 'blade'
self.is_jit_model = self.model_type in ['jit', 'blade']
super(BEVFormerPredictor, self).__init__(
model_path,
config_file=config_file,
@ -58,6 +77,20 @@ class BEVFormerPredictor(PredictorV2):
self.score_threshold = score_threshold
self.result_key = 'pts_bbox'
# The initial prev_bev should be the weight of self.model.pts_bbox_head.bev_embedding, but the weight cannot be taken out from the blade model.
# So we using the dummy data as the the initial value, and it will not be used, just to adapt to jit and blade models.
# init_prev_bev = self.model.pts_bbox_head.bev_embedding.weight.clone().detach()
# init_prev_bev = init_prev_bev[:, None, :], # [40000, 256] -> [40000, 1, 256]
dummy_prev_bev = torch.rand(
[self.cfg.bev_h * self.cfg.bev_w, 1,
self.cfg.embed_dim]).to(self.device)
self.prev_frame_info = {
'prev_bev': dummy_prev_bev.to(self.device),
'prev_scene_token': encode_str_to_tensor('dummy_prev_scene_token'),
'prev_pos': torch.tensor(0),
'prev_angle': torch.tensor(0),
}
def _prepare_input_dict(self, data_info):
from nuscenes.eval.common.utils import Quaternion, quaternion_yaw
@ -133,13 +166,85 @@ class BEVFormerPredictor(PredictorV2):
Args:
input (str): Pickle file path, the content format is the same with the infos file of nusences.
"""
data_info = mmcv.load(input)
data_info = mmcv.load(input) if isinstance(input, str) else input
result = self._prepare_input_dict(data_info)
return self.processor(result)
result = self.processor(result)
if self.is_jit_model:
result['can_bus'] = DC(
to_tensor(result['img_metas'][0]._data['can_bus']),
cpu_only=False)
result['lidar2img'] = DC(
to_tensor(result['img_metas'][0]._data['lidar2img']),
cpu_only=False)
result['scene_token'] = DC(
torch.tensor(
bytearray(
pickle.dumps(
result['img_metas'][0]._data['scene_token'])),
dtype=torch.uint8),
cpu_only=False)
result['img_shape'] = DC(
to_tensor(result['img_metas'][0]._data['img_shape']),
cpu_only=False)
else:
result['can_bus'] = DC(
torch.stack(
[to_tensor(result['img_metas'][0]._data['can_bus'])]),
cpu_only=False)
result['lidar2img'] = DC(
torch.stack(
[to_tensor(result['img_metas'][0]._data['lidar2img'])]),
cpu_only=False)
return result
def postprocess_single(self, inputs, *args, **kwargs):
# TODO: filter results by score_threshold
return super().postprocess_single(inputs, *args, **kwargs)
def prepare_model(self):
if self.is_jit_model:
model = torch.jit.load(self.model_path, map_location=self.device)
return model
return super().prepare_model()
def forward(self, inputs):
if self.is_jit_model:
with torch.no_grad():
img = inputs['img'][0][0]
img_metas = {
'can_bus': inputs['can_bus'][0],
'lidar2img': inputs['lidar2img'][0],
'img_shape': inputs['img_shape'][0],
'scene_token': inputs['scene_token'][0],
'prev_bev': self.prev_frame_info['prev_bev'],
'prev_pos': self.prev_frame_info['prev_pos'],
'prev_angle': self.prev_frame_info['prev_angle'],
'prev_scene_token':
self.prev_frame_info['prev_scene_token']
}
inputs = (img, img_metas)
outputs = self.model(*inputs)
# update prev_frame_info
self.prev_frame_info['prev_bev'] = outputs[3][0]
self.prev_frame_info['prev_pos'] = outputs[3][1]
self.prev_frame_info['prev_angle'] = outputs[3][2]
self.prev_frame_info['prev_scene_token'] = outputs[3][3]
outputs = {
'pts_bbox': [{
'scores_3d':
outputs[0],
'labels_3d':
outputs[1],
'boxes_3d':
self.box_type_3d(outputs[2].cpu(), outputs[2].size()[-1])
}],
}
return outputs
return super().forward(inputs)
def visualize(self, inputs, results, out_dir, show=False, pipeline=None):
raise NotImplementedError

View File

@ -7,7 +7,6 @@ import numpy as np
import torch
from torchvision.transforms import Compose
from easycv.apis.export import reparameterize_models
from easycv.core.visualization import imshow_bboxes
from easycv.datasets.registry import PIPELINES
from easycv.datasets.utils import replace_ImageToTensor
@ -198,6 +197,7 @@ class YoloXPredictor(DetectionPredictor):
with io.open(self.model_path, 'rb') as infile:
model = torch.jit.load(infile, self.device)
else:
from easycv.utils.misc import reparameterize_models
model = super()._build_model()
model = reparameterize_models(model)
return model

View File

@ -1,6 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import time
from distutils.version import LooseVersion
import torch
@ -94,7 +93,6 @@ class EVRunner(EpochBasedRunner):
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
@ -122,7 +120,7 @@ class EVRunner(EpochBasedRunner):
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')

View File

@ -10,20 +10,31 @@
# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
from __future__ import absolute_import, division, print_function
import os
import subprocess
import torch
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd.function import once_differentiable
def _auto_compile():
cur_dir= os.getcwd()
target_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
os.chdir(target_dir)
res = subprocess.call('python setup.py build install', shell=True)
os.chdir(cur_dir)
return res
try:
import MultiScaleDeformableAttention as MSDA
except ModuleNotFoundError as e:
info_string = (
'\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n'
'\t`cd thirdparty/deformable_attention`\n'
'\t`python setup.py build install`\n')
raise ModuleNotFoundError(info_string)
res = _auto_compile()
if res != 0:
info_string = (
'\n\nAuto compile failed! Please compile MultiScaleDeformableAttention CUDA op with the following commands :\n'
'\t`cd easycv/thirdparty/deformable_attention`\n'
'\t`python setup.py build install`\n')
raise ModuleNotFoundError(info_string)
class MSDeformAttnFunction(Function):

View File

@ -14,8 +14,21 @@
*/
#include "ms_deform_attn.h"
#include <torch/script.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
}
inline at::Tensor ms_deform_attn(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
const at::Tensor &level_start_index,
const at::Tensor &sampling_loc,
const at::Tensor &attn_weight,
const int64_t im2col_step) {
return ms_deform_attn_forward(value, spatial_shapes, level_start_index,
sampling_loc, attn_weight, im2col_step);
}
static auto registry = torch::RegisterOperators().op("custom::ms_deform_attn", &ms_deform_attn);

View File

@ -0,0 +1,260 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
#include <torch/script.h>
void modulated_deformable_im2col_impl(
const Tensor data_im, const Tensor data_offset, const Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int deformable_group, Tensor data_col) {
DISPATCH_DEVICE_IMPL(modulated_deformable_im2col_impl, data_im, data_offset,
data_mask, batch_size, channels, height_im, width_im,
height_col, width_col, kernel_h, kernel_w, pad_h, pad_w,
stride_h, stride_w, dilation_h, dilation_w,
deformable_group, data_col);
}
void modulated_deformable_col2im_impl(
const Tensor data_col, const Tensor data_offset, const Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int deformable_group, Tensor grad_im) {
DISPATCH_DEVICE_IMPL(modulated_deformable_col2im_impl, data_col, data_offset,
data_mask, batch_size, channels, height_im, width_im,
height_col, width_col, kernel_h, kernel_w, pad_h, pad_w,
stride_h, stride_w, dilation_h, dilation_w,
deformable_group, grad_im);
}
void modulated_deformable_col2im_coord_impl(
const Tensor data_col, const Tensor data_im, const Tensor data_offset,
const Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int deformable_group,
Tensor grad_offset, Tensor grad_mask) {
DISPATCH_DEVICE_IMPL(modulated_deformable_col2im_coord_impl, data_col,
data_im, data_offset, data_mask, batch_size, channels,
height_im, width_im, height_col, width_col, kernel_h,
kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h,
dilation_w, deformable_group, grad_offset, grad_mask);
}
void modulated_deform_conv_forward(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
Tensor mask, Tensor output, Tensor columns, int kernel_h, int kernel_w,
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
const int dilation_h, const int dilation_w, const int group,
const int deformable_group, const bool with_bias) {
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
// resize output
output = output.view({batch, channels_out, height_out, width_out}).zero_();
// resize temporary columns
columns =
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
input.options());
output = output.view({output.size(0), group, output.size(1) / group,
output.size(2), output.size(3)});
for (int b = 0; b < batch; b++) {
modulated_deformable_im2col_impl(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
// divide into group
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
for (int g = 0; g < group; g++) {
output[b][g] = output[b][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output[b][g]);
}
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
output = output.view({output.size(0), output.size(1) * output.size(2),
output.size(3), output.size(4)});
if (with_bias) {
output += bias.view({1, bias.size(0), 1, 1});
}
}
void modulated_deform_conv_backward(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
Tensor mask, Tensor columns, Tensor grad_input, Tensor grad_weight,
Tensor grad_bias, Tensor grad_offset, Tensor grad_mask, Tensor grad_output,
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias) {
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
grad_input = grad_input.view({batch, channels, height, width});
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
input.options());
grad_output =
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
grad_output.size(2), grad_output.size(3)});
for (int b = 0; b < batch; b++) {
// divide int group
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++) {
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
grad_output[b][g].flatten(1), 0.0f, 1.0f);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_impl(
columns, input[b], offset[b], mask[b], 1, channels, height, width,
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
grad_mask[b]);
// gradient w.r.t. input data
modulated_deformable_col2im_impl(
columns, offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, grad_input[b]);
// gradient w.r.t. weight, dWeight should accumulate across the batch and
// group
modulated_deformable_im2col_impl(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
grad_weight.size(1), grad_weight.size(2),
grad_weight.size(3)});
if (with_bias)
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
for (int g = 0; g < group; g++) {
grad_weight[g] =
grad_weight[g]
.flatten(1)
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
.view_as(grad_weight[g]);
if (with_bias) {
grad_bias[g] =
grad_bias[g]
.view({-1, 1})
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
.view(-1);
}
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2), grad_weight.size(3),
grad_weight.size(4)});
if (with_bias)
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
}
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
grad_output.size(2), grad_output.size(3),
grad_output.size(4)});
}
at::Tensor modulated_deform_conv(
Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
Tensor mask, Tensor output, Tensor columns, int64_t kernel_h, int64_t kernel_w,
const int64_t stride_h, const int64_t stride_w, const int64_t pad_h, const int64_t pad_w,
const int64_t dilation_h, const int64_t dilation_w, const int64_t group,
const int64_t deformable_group, const bool with_bias) {
modulated_deform_conv_forward(input, weight, bias, ones, offset, mask, output, columns,
kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w,
group, deformable_group, with_bias);
return output;
}
TORCH_LIBRARY(mmcv, m) {
m.def(R"SIG(modulated_deform_conv(Tensor input, Tensor weight, Tensor bias, Tensor ones, Tensor offset,
Tensor mask, Tensor(a!) output, Tensor columns, *, int kernel_h, int kernel_w,
int stride_h, int stride_w, int pad_h, int pad_w,
int dilation_h, int dilation_w, int group,
int deformable_group, bool with_bias) -> Tensor(a!))SIG", modulated_deform_conv);
}
// static auto registry = torch::RegisterOperators().op("mmcv::modulated_deform_conv", &modulated_deform_conv);

View File

@ -0,0 +1,389 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair, _single
from mmcv.utils import deprecated_api_warning
from ..cnn import CONV_LAYERS
from ..utils import ext_loader, print_log
ext_module = ext_loader.load_ext(
'_ext',
['modulated_deform_conv_forward', 'modulated_deform_conv_backward'])
class ModulatedDeformConv2dFunction(Function):
@staticmethod
def symbolic(g, input, offset, mask, weight, bias, stride, padding,
dilation, groups, deform_groups):
input_tensors = [input, offset, mask, weight]
if bias is not None:
input_tensors.append(bias)
return g.op(
'mmcv::MMCVModulatedDeformConv2d',
*input_tensors,
stride_i=stride,
padding_i=padding,
dilation_i=dilation,
groups_i=groups,
deform_groups_i=deform_groups)
@staticmethod
def _jit_forward(
input,
offset,
mask,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
deform_groups=1):
if input is not None and input.dim() != 4:
raise ValueError(
f'Expected 4D tensor as input, got {input.dim()}D tensor \
instead.')
with_bias = bias is not None
if not bias:
bias = input.new_empty(0) # fake tensor
# When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
# amp won't cast the type of model (float32), but "offset" is cast
# to float16 by nn.Conv2d automatically, leading to the type
# mismatch with input (when it is float32) or weight.
# The flag for whether to use fp16 or amp is the type of "offset",
# we cast weight and input to temporarily support fp16 and amp
# whatever the pytorch version is.
def _output_size(input, weight):
channels = weight.size(0)
output_size = (input.size(0), channels)
for d in range(input.dim() - 2):
in_size = input.size(d + 2)
pad = padding[d]
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
stride_ = stride[d]
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
if not all(map(lambda s: s > 0, output_size)):
raise ValueError(
'convolution input is too small (output would be ' +
'x'.join(map(str, output_size)) + ')')
return output_size
input = input.type_as(offset)
weight = weight.type_as(input)
output = input.new_empty(
_output_size(input, weight))
_bufs = [input.new_empty(0), input.new_empty(0)]
if weight.dtype == torch.float16:
output = torch.ops.mmcv.modulated_deform_conv(
input.to(torch.float32),
weight.to(torch.float32),
bias.to(torch.float32),
_bufs[0].to(torch.float32),
offset.to(torch.float32),
mask.to(torch.float32),
output.to(torch.float32),
_bufs[1].to(torch.float32),
kernel_h=weight.size(2),
kernel_w=weight.size(3),
stride_h=stride[0],
stride_w=stride[1],
pad_h=padding[0],
pad_w=padding[1],
dilation_h=dilation[0],
dilation_w=dilation[1],
group=groups,
deformable_group=deform_groups,
with_bias=with_bias)
output = output.to(torch.float16)
else:
output = torch.ops.mmcv.modulated_deform_conv(
input,
weight,
bias,
_bufs[0],
offset,
mask,
output,
_bufs[1],
kernel_h=weight.size(2),
kernel_w=weight.size(3),
stride_h=stride[0],
stride_w=stride[1],
pad_h=padding[0],
pad_w=padding[1],
dilation_h=dilation[0],
dilation_w=dilation[1],
group=groups,
deformable_group=deform_groups,
with_bias=with_bias)
return output
@staticmethod
def forward(ctx,
input: torch.Tensor,
offset: torch.Tensor,
mask: torch.Tensor,
weight: nn.Parameter,
bias: Optional[nn.Parameter] = None,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
deform_groups: int = 1) -> torch.Tensor:
if input is not None and input.dim() != 4:
raise ValueError(
f'Expected 4D tensor as input, got {input.dim()}D tensor \
instead.')
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.groups = groups
ctx.deform_groups = deform_groups
ctx.with_bias = bias is not None
if not ctx.with_bias:
bias = input.new_empty(0) # fake tensor
# When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
# amp won't cast the type of model (float32), but "offset" is cast
# to float16 by nn.Conv2d automatically, leading to the type
# mismatch with input (when it is float32) or weight.
# The flag for whether to use fp16 or amp is the type of "offset",
# we cast weight and input to temporarily support fp16 and amp
# whatever the pytorch version is.
input = input.type_as(offset)
weight = weight.type_as(input)
bias = bias.type_as(input) # type: ignore
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(
ModulatedDeformConv2dFunction._output_size(ctx, input, weight))
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
ext_module.modulated_deform_conv_forward(
input,
weight,
bias,
ctx._bufs[0],
offset,
mask,
output,
ctx._bufs[1],
kernel_h=weight.size(2),
kernel_w=weight.size(3),
stride_h=ctx.stride[0],
stride_w=ctx.stride[1],
pad_h=ctx.padding[0],
pad_w=ctx.padding[1],
dilation_h=ctx.dilation[0],
dilation_w=ctx.dilation[1],
group=ctx.groups,
deformable_group=ctx.deform_groups,
with_bias=ctx.with_bias)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output: torch.Tensor) -> tuple:
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
grad_output = grad_output.contiguous()
ext_module.modulated_deform_conv_backward(
input,
weight,
bias,
ctx._bufs[0],
offset,
mask,
ctx._bufs[1],
grad_input,
grad_weight,
grad_bias,
grad_offset,
grad_mask,
grad_output,
kernel_h=weight.size(2),
kernel_w=weight.size(3),
stride_h=ctx.stride[0],
stride_w=ctx.stride[1],
pad_h=ctx.padding[0],
pad_w=ctx.padding[1],
dilation_h=ctx.dilation[0],
dilation_w=ctx.dilation[1],
group=ctx.groups,
deformable_group=ctx.deform_groups,
with_bias=ctx.with_bias)
if not ctx.with_bias:
grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
None, None, None, None, None)
@staticmethod
def _output_size(ctx, input, weight):
channels = weight.size(0)
output_size = (input.size(0), channels)
for d in range(input.dim() - 2):
in_size = input.size(d + 2)
pad = ctx.padding[d]
kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1
stride_ = ctx.stride[d]
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
if not all(map(lambda s: s > 0, output_size)):
raise ValueError(
'convolution input is too small (output would be ' +
'x'.join(map(str, output_size)) + ')')
return output_size
modulated_deform_conv2d = ModulatedDeformConv2dFunction.apply
class ModulatedDeformConv2d(nn.Module):
@deprecated_api_warning({'deformable_groups': 'deform_groups'},
cls_name='ModulatedDeformConv2d')
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]],
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
deform_groups: int = 1,
bias: Union[bool, str] = True):
super(ModulatedDeformConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deform_groups = deform_groups
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels // groups,
*self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.init_weights()
def init_weights(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, x: torch.Tensor, offset: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
if torch.jit.is_scripting() or torch.jit.is_tracing():
return ModulatedDeformConv2dFunction._jit_forward(
x, offset, mask, self.weight, self.bias,
self.stride, self.padding,
self.dilation, self.groups,
self.deform_groups)
return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
self.stride, self.padding,
self.dilation, self.groups,
self.deform_groups)
@CONV_LAYERS.register_module('DCNv2')
class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv
layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int): Same as nn.Conv2d, while tuple is not supported.
padding (int): Same as nn.Conv2d, while tuple is not supported.
dilation (int): Same as nn.Conv2d, while tuple is not supported.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version = 2
def __init__(self, *args, **kwargs):
super(ModulatedDeformConv2dPack, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
bias=True)
self.init_weights()
def init_weights(self) -> None:
super(ModulatedDeformConv2dPack, self).init_weights()
if hasattr(self, 'conv_offset'):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
if torch.jit.is_scripting() or torch.jit.is_tracing():
return ModulatedDeformConv2dFunction._jit_forward(
x, offset, mask, self.weight, self.bias,
self.stride, self.padding,
self.dilation, self.groups,
self.deform_groups)
return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
self.stride, self.padding,
self.dilation, self.groups,
self.deform_groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if version is None or version < 2:
# the key is different in early versions
# In version < 2, ModulatedDeformConvPack
# loads previous benchmark models.
if (prefix + 'conv_offset.weight' not in state_dict
and prefix[:-1] + '_offset.weight' in state_dict):
state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
prefix[:-1] + '_offset.weight')
if (prefix + 'conv_offset.bias' not in state_dict
and prefix[:-1] + '_offset.bias' in state_dict):
state_dict[prefix +
'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
'_offset.bias')
if version is not None and version > 1:
print_log(
f'ModulatedDeformConvPack {prefix.rstrip(".")} is upgraded to '
'version 2.',
logger='root')
super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys,
error_msgs)

View File

@ -1,11 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import copy
import ctypes
import itertools
import logging
import os
import time
import timeit
from contextlib import contextmanager
@ -14,7 +13,6 @@ import pandas as pd
import torch
import torch_blade
import torch_blade.tensorrt
import torchvision
from torch_blade import optimize
from easycv.framework.errors import RuntimeError
@ -80,6 +78,7 @@ def opt_trt_config(
# 'aten::select', 'aten::index', 'aten::slice', 'aten::view', 'aten::upsample'
],
fp16_fallback_op_ratio=0.05,
preserved_attributes=[],
)
BLADE_CONFIG_KEYS = list(BLADE_CONFIG_DEFAULT.keys())
@ -185,24 +184,41 @@ def computeStats(backend, timings, batch_size=1, model_name='default'):
@torch.no_grad()
def benchmark(model, inp, backend, batch_size, model_name='default', num=200):
def benchmark(model,
inputs,
backend,
batch_size,
model_name='default',
num_iters=200,
warmup_iters=5,
fp16=False):
"""
evaluate the time and speed of different models
Args:
model: input model
inp: input of the model
inputs: input of the model
backend (str): backend name
batch_size (int) image batch
model_name (str): tested model name
num: test forward times
num_iters: test forward times
"""
for _ in range(warmup_iters):
if fp16:
with torch.cuda.amp.autocast():
model(*copy.deepcopy(inputs))
else:
model(*copy.deepcopy(inputs))
torch.cuda.synchronize()
timings = []
for i in range(num):
for i in range(num_iters):
start_time = timeit.default_timer()
model(*inp)
if fp16:
with torch.cuda.amp.autocast():
model(*copy.deepcopy(inputs))
else:
model(*copy.deepcopy(inputs))
torch.cuda.synchronize()
end_time = timeit.default_timer()
meas_time = end_time - start_time
@ -246,40 +262,49 @@ def blade_optimize(speed_test_model,
enable_fp16=True, fp16_fallback_op_ratio=0.05),
backend='TensorRT',
batch=1,
warm_up_time=10,
warmup_iters=10,
compute_cost=True,
use_profile=False,
check_result=False,
static_opt=True):
static_opt=True,
min_num_nodes=None,
check_inputs=True,
fp16=False):
if not static_opt:
logging.info(
'PAI-Blade use dynamic optimize for input model, export model is build for dynamic shape input'
)
with opt_trt_config(blade_config):
opt_model = optimize(
model,
allow_tracing=True,
model_inputs=tuple(inputs),
)
optimize_op = optimize
else:
logging.info(
'PAI-Blade use static optimize for input model, export model must be used as static shape input'
)
from torch_blade.optimization import _static_optimize
optimize_op = _static_optimize
if min_num_nodes is not None:
import torch_blade.clustering.support_fusion_group as blade_fusion
with blade_fusion.min_group_nodes(min_num_nodes=min_num_nodes):
with opt_trt_config(blade_config):
opt_model = optimize_op(
model,
allow_tracing=True,
model_inputs=tuple(copy.deepcopy(inputs)),
)
else:
with opt_trt_config(blade_config):
opt_model = _static_optimize(
opt_model = optimize_op(
model,
allow_tracing=True,
model_inputs=tuple(inputs),
model_inputs=tuple(copy.deepcopy(inputs)),
)
if compute_cost:
logging.info('Running benchmark...')
results = []
inputs_t = inputs
inputs_t = copy.deepcopy(inputs)
# end2end model and scripts needs different channel purmulate, encounter this problem only when we use end2end export
if (inputs_t[0].shape[-1] == 3):
if check_inputs and (inputs_t[0].shape[-1] == 3):
shape_length = len(inputs_t[0].shape)
if shape_length == 4:
inputs_t = inputs_t[0].permute(0, 3, 1, 2)
@ -290,45 +315,67 @@ def blade_optimize(speed_test_model,
inputs_t = (torch.unsqueeze(inputs_t, 0), )
results.append(
benchmark(speed_test_model, inputs_t, backend, batch, 'easycv'))
benchmark(
speed_test_model,
inputs_t,
backend,
batch,
'easycv',
warmup_iters=warmup_iters,
fp16=fp16))
results.append(
benchmark(model, inputs, backend, batch, 'easycv script'))
results.append(benchmark(opt_model, inputs, backend, batch, 'blade'))
benchmark(
model,
copy.deepcopy(inputs),
backend,
batch,
'easycv script',
warmup_iters=warmup_iters,
fp16=fp16))
results.append(
benchmark(
opt_model,
copy.deepcopy(inputs),
backend,
batch,
'blade',
warmup_iters=warmup_iters,
fp16=fp16))
logging.info('Model Summary:')
summary = pd.DataFrame(results)
logging.warning(summary.to_markdown())
print(summary.to_markdown())
if use_profile:
torch.cuda.empty_cache()
# warm-up
for k in range(warm_up_time):
test_result = opt_model(*inputs)
for k in range(warmup_iters):
test_result = opt_model(*copy.deepcopy(inputs))
torch.cuda.synchronize()
torch.cuda.synchronize()
cu_prof_start()
for k in range(warm_up_time):
test_result = opt_model(*inputs)
for k in range(warmup_iters):
test_result = opt_model(*copy.deepcopy(inputs))
torch.cuda.synchronize()
cu_prof_stop()
import torch.autograd.profiler as profiler
with profiler.profile(use_cuda=True) as prof:
for k in range(warm_up_time):
test_result = opt_model(*inputs)
for k in range(warmup_iters):
test_result = opt_model(*copy.deepcopy(inputs))
torch.cuda.synchronize()
with profiler.profile(use_cuda=True) as prof:
for k in range(warm_up_time):
test_result = opt_model(*inputs)
for k in range(warmup_iters):
test_result = opt_model(*copy.deepcopy(inputs))
torch.cuda.synchronize()
prof_str = prof.key_averages().table(sort_by='cuda_time_total')
print(f'{prof_str}')
if check_result:
output = model(*inputs)
test_result = opt_model(*inputs)
output = model(*copy.deepcopy(inputs))
test_result = opt_model(*copy.deepcopy(inputs))
check_results(output, test_result)
return opt_model

View File

@ -2,10 +2,14 @@
import functools
import inspect
import logging
import pickle
import warnings
import mmcv
import numpy as np
import torch
from easycv.framework.errors import ValueError
def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
@ -99,3 +103,21 @@ def deprecated(reason):
return new_func1
return decorator
def encode_str_to_tensor(obj):
if isinstance(obj, str):
return torch.tensor(bytearray(pickle.dumps(obj)), dtype=torch.uint8)
elif isinstance(obj, torch.Tensor):
return obj
else:
raise ValueError(f'Not support type {type(obj)}')
def decode_tensor_to_str(obj):
if isinstance(obj, torch.Tensor):
return pickle.loads(obj.cpu().numpy().tobytes())
elif isinstance(obj, str):
return obj
else:
raise ValueError(f'Not support type {type(obj)}')

View File

@ -373,3 +373,17 @@ def remove_adapt_for_mmlab(cfg):
mmlab_modules_cfg = cfg.get('mmlab_modules', [])
adapter = MMAdapter(mmlab_modules_cfg)
adapter.reset_mm_registry()
def fix_dc_pin_memory():
"""Fix pin memory for DataContainer."""
from mmcv.parallel import DataContainer as DC
from torch.utils.data._utils.pin_memory import pin_memory
def data_container_pin_memory(self):
if self.cpu_only:
return self
self._data = pin_memory(self._data)
return self
setattr(DC, 'pin_memory', data_container_pin_memory)

View File

@ -1,17 +1,21 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import os
import re
import subprocess
import tempfile
import unittest
import numpy as np
import torch
from tests.ut_config import (IMAGENET_LABEL_TXT, PRETRAINED_MODEL_MOCO,
PRETRAINED_MODEL_RESNET50,
from tests.ut_config import (IMAGENET_LABEL_TXT,
PRETRAINED_MODEL_BEVFORMER_BASE,
PRETRAINED_MODEL_MOCO, PRETRAINED_MODEL_RESNET50,
PRETRAINED_MODEL_YOLOXS_EXPORT)
import easycv
from easycv.apis.export import export
from easycv.file import io
from easycv.utils.config_tools import mmcv_config_fromfile
from easycv.utils.test_util import clean_up, get_tmp_dir
@ -126,6 +130,41 @@ class ModelExportTest(unittest.TestCase):
self.assertTrue(
export_config['model']['backbone']['norm_cfg']['type'] == 'BN')
@unittest.skipIf(torch.__version__ != '1.8.1+cu102',
'need another environment where mmcv has been recompiled')
def test_export_bevformer_jit(self):
ckpt_path = PRETRAINED_MODEL_BEVFORMER_BASE
easycv_dir = os.path.dirname(easycv.__file__)
if os.path.exists(os.path.join(easycv_dir, 'configs')):
config_dir = os.path.join(easycv_dir, 'configs')
else:
config_dir = os.path.join(os.path.dirname(easycv_dir), 'configs')
config_file = os.path.join(
config_dir,
'detection3d/bevformer/bevformer_base_r101_dcn_nuscenes.py')
with tempfile.TemporaryDirectory() as tmpdir:
with io.open(config_file, 'r') as f:
cfg_str = f.read()
new_config_path = os.path.join(tmpdir, 'new_config.py')
# find first adapt_jit and replace value
res = re.search(r'adapt_jit(\s*)=(\s*)False', cfg_str)
if res is not None:
cfg_str_list = list(cfg_str)
cfg_str_list[res.span()[0]:res.span()[1]] = 'adapt_jit = True'
cfg_str = ''.join(cfg_str_list)
with io.open(new_config_path, 'w') as f:
f.write(cfg_str)
cfg = mmcv_config_fromfile(new_config_path)
cfg.export.type = 'jit'
filename = os.path.join(tmpdir, 'model.pth')
export(cfg, ckpt_path, filename, fp16=False)
self.assertTrue(os.path.exists(filename + '.jit'))
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,59 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
import numpy as np
from tests.ut_config import TEST_IMAGES_DIR
from easycv.file.image import load_image
class LoadImageTest(unittest.TestCase):
img_path = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/unittest/local_backup/easycv_nfs/data/test_images/000000289059.jpg'
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_backend_pillow(self):
img = load_image(
self.img_path, mode='BGR', dtype=np.float32, backend='pillow')
self.assertEqual(img.shape, (480, 640, 3))
self.assertEqual(img.dtype, np.float32)
self.assertEqual(list(img[0][0]), [145, 92, 59])
def test_backend_cv2(self):
img = load_image(self.img_path, mode='RGB', backend='cv2')
self.assertEqual(img.shape, (480, 640, 3))
self.assertEqual(img.dtype, np.uint8)
self.assertEqual(list(img[0][0]), [59, 92, 145])
def test_backend_turbojpeg(self):
img = load_image(
self.img_path, mode='RGB', dtype=np.float32, backend='turbojpeg')
self.assertEqual(img.shape, (480, 640, 3))
self.assertEqual(img.dtype, np.float32)
self.assertEqual(list(img[0][0]), [59, 92, 145])
def test_url_path_cv2(self):
img = load_image(self.img_url, mode='BGR', backend='cv2')
self.assertEqual(img.shape, (480, 640, 3))
self.assertEqual(img.dtype, np.uint8)
self.assertEqual(list(img[0][0]), [145, 92, 59])
def test_url_path_pillow(self):
img = load_image(self.img_url, mode='RGB', backend='pillow')
self.assertEqual(img.shape, (480, 640, 3))
self.assertEqual(img.dtype, np.uint8)
self.assertEqual(list(img[0][0]), [59, 92, 145])
def test_url_path_turbojpeg(self):
img = load_image(self.img_url, mode='BGR', backend='turbojpeg')
self.assertEqual(img.shape, (480, 640, 3))
self.assertEqual(img.dtype, np.uint8)
self.assertEqual(list(img[0][0]), [145, 92, 59])
if __name__ == '__main__':
unittest.main()

View File

@ -1,7 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import re
import tempfile
import unittest
import mmcv
import numpy as np
import torch
from numpy.testing import assert_array_almost_equal
@ -9,7 +12,12 @@ from tests.ut_config import (PRETRAINED_MODEL_BEVFORMER_BASE,
SMALL_NUSCENES_PATH)
import easycv
from easycv.apis.export import export
from easycv.core.evaluation.builder import build_evaluator
from easycv.datasets import build_dataset
from easycv.file import io
from easycv.predictors import BEVFormerPredictor
from easycv.utils.config_tools import mmcv_config_fromfile
class BEVFormerPredictorTest(unittest.TestCase):
@ -17,7 +25,7 @@ class BEVFormerPredictorTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def _assert_results(self, results, assert_value=True):
def _assert_results(self, results):
res = results['pts_bbox']
self.assertEqual(res['scores_3d'].shape, torch.Size([300]))
self.assertEqual(res['labels_3d'].shape, torch.Size([300]))
@ -40,88 +48,7 @@ class BEVFormerPredictorTest(unittest.TestCase):
self.assertEqual(res['boxes_3d'].volume.shape, torch.Size([300]))
self.assertEqual(res['boxes_3d'].yaw.shape, torch.Size([300]))
if not assert_value:
return
assert_array_almost_equal(
res['scores_3d'][:5].numpy(),
np.array([0.982, 0.982, 0.982, 0.982, 0.981], dtype=np.float32),
decimal=3)
assert_array_almost_equal(res['labels_3d'][:10].numpy(),
np.array([5, 5, 5, 5, 5, 5, 5, 5, 5, 5]))
assert_array_almost_equal(
res['boxes_3d'].bev[:2].numpy(),
np.array([[9.341, -2.664, 2.034, 0.657, 1.819],
[6.945, -18.833, 2.047, 0.661, 1.694]],
dtype=np.float32),
decimal=3)
assert_array_almost_equal(
res['boxes_3d'].bottom_center[:2].numpy(),
np.array([[9.341, -2.664, -1.849], [6.945, -18.833, -2.295]],
dtype=np.float32),
decimal=3)
assert_array_almost_equal(
res['boxes_3d'].bottom_height[:5].numpy(),
np.array([-1.849, -2.332, -2.295, -1.508, -1.204],
dtype=np.float32),
decimal=1)
assert_array_almost_equal(
res['boxes_3d'].center[:2].numpy(),
np.array([[9.341, -2.664, -1.849], [6.945, -18.833, -2.295]],
dtype=np.float32),
decimal=3)
assert_array_almost_equal(
res['boxes_3d'].corners[:1][0][:3].numpy(),
np.array([[9.91, -3.569, -1.849], [9.91, -3.569, -0.742],
[9.273, -3.73, -0.742]],
dtype=np.float32),
decimal=3)
assert_array_almost_equal(
res['boxes_3d'].dims[:2].numpy(),
np.array([[2.034, 0.657, 1.107], [2.047, 0.661, 1.101]],
dtype=np.float32),
decimal=3)
assert_array_almost_equal(
res['boxes_3d'].gravity_center[:2].numpy(),
np.array([[9.341, -2.664, -1.295], [6.945, -18.833, -1.745]],
dtype=np.float32),
decimal=3)
assert_array_almost_equal(
res['boxes_3d'].height[:5].numpy(),
np.array([1.107, 1.101, 1.082, 1.098, 1.073], dtype=np.float32),
decimal=3)
assert_array_almost_equal(
res['boxes_3d'].nearest_bev[:2].numpy(),
np.array([[9.013, -3.681, 9.67, -1.647],
[6.615, -19.857, 7.276, -17.81]],
dtype=np.float32),
decimal=3)
assert_array_almost_equal(
res['boxes_3d'].tensor[:1].numpy(),
np.array([[
9.340, -2.664, -1.849, 2.0343, 6.568e-01, 1.107, 1.819,
-8.636e-06, 2.034e-05
]],
dtype=np.float32),
decimal=3)
assert_array_almost_equal(
res['boxes_3d'].top_height[:5].numpy(),
np.array([-0.742, -1.194, -1.25, -0.411, -0.132],
dtype=np.float32),
decimal=3)
assert_array_almost_equal(
res['boxes_3d'].volume[:5].numpy(),
np.array([1.478, 1.49, 1.435, 1.495, 1.47], dtype=np.float32),
decimal=3)
assert_array_almost_equal(
res['boxes_3d'].yaw[:5].numpy(),
np.array([1.819, 1.694, 1.659, 1.62, 1.641], dtype=np.float32),
decimal=3)
def test_single(self):
model_path = PRETRAINED_MODEL_BEVFORMER_BASE
single_ann_file = os.path.join(SMALL_NUSCENES_PATH,
'inference/single_sample.pkl')
def _get_config_file(self):
easycv_dir = os.path.dirname(easycv.__file__)
if os.path.exists(os.path.join(easycv_dir, 'configs')):
config_dir = os.path.join(easycv_dir, 'configs')
@ -130,7 +57,13 @@ class BEVFormerPredictorTest(unittest.TestCase):
config_file = os.path.join(
config_dir,
'detection3d/bevformer/bevformer_base_r101_dcn_nuscenes.py')
return config_file
def test_single(self):
model_path = PRETRAINED_MODEL_BEVFORMER_BASE
single_ann_file = os.path.join(SMALL_NUSCENES_PATH,
'inference/single_sample.pkl')
config_file = self._get_config_file()
predictor = BEVFormerPredictor(
model_path=model_path,
config_file=config_file,
@ -140,10 +73,86 @@ class BEVFormerPredictorTest(unittest.TestCase):
for result in results:
self._assert_results(result)
@unittest.skipIf(True, 'Not support batch yet')
def test_batch(self):
model_path = PRETRAINED_MODEL_BEVFORMER_BASE
single_ann_file = os.path.join(SMALL_NUSCENES_PATH,
'inference/single_sample.pkl')
config_file = self._get_config_file()
predictor = BEVFormerPredictor(
model_path=model_path, config_file=config_file, batch_size=2)
results = predictor([single_ann_file, single_ann_file])
self.assertEqual(len(results), 2)
# Input the same sample continuously, the output value is different,
# because the model will record the features of the previous sample to infer the next sample
self._assert_results(results[0])
self._assert_results(results[1])
def test_metric(self):
model_path = PRETRAINED_MODEL_BEVFORMER_BASE
inputs_file = os.path.join(SMALL_NUSCENES_PATH,
'nuscenes_infos_temporal_val.pkl')
config_file = self._get_config_file()
cfg = mmcv_config_fromfile(config_file)
cfg.data.val.data_source.data_root = SMALL_NUSCENES_PATH
cfg.data.val.data_source.ann_file = os.path.join(
SMALL_NUSCENES_PATH, 'nuscenes_infos_temporal_val.pkl')
cfg.data.val.pop('imgs_per_gpu', None)
val_dataset = build_dataset(cfg.data.val)
evaluators = build_evaluator(cfg.eval_pipelines[0]['evaluators'][0])
predictor = BEVFormerPredictor(
model_path=model_path, config_file=config_file)
inputs = mmcv.load(inputs_file)['infos']
for i in range(len(inputs)):
for k in list(inputs[i]['cams'].keys()):
inputs[i]['cams'][k]['data_path'] = os.path.join(
SMALL_NUSCENES_PATH, inputs[i]['cams'][k]['data_path'])
predict_results = predictor(inputs)
results = {'pts_bbox': [i['pts_bbox'] for i in predict_results]}
val_results = val_dataset.evaluate(results, evaluators)
self.assertAlmostEqual(
val_results['pts_bbox_NuScenes/NDS'], 0.460, delta=0.01)
self.assertAlmostEqual(
val_results['pts_bbox_NuScenes/mAP'], 0.41, delta=0.01)
@unittest.skipIf(torch.__version__ != '1.8.1+cu102',
'need another environment where mmcv has been recompiled')
class BEVFormerBladePredictorTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.mkdir(self.tmp_dir)
def tearDown(self) -> None:
io.remove(self.tmp_dir)
return super().tearDown()
def _replace_config(self, cfg_file):
with io.open(cfg_file, 'r') as f:
cfg_str = f.read()
new_config_path = os.path.join(self.tmp_dir, 'new_config.py')
# find first adapt_jit and replace value
res = re.search(r'adapt_jit(\s*)=(\s*)False', cfg_str)
if res is not None:
cfg_str_list = list(cfg_str)
cfg_str_list[res.span()[0]:res.span()[1]] = 'adapt_jit = True'
cfg_str = ''.join(cfg_str_list)
with io.open(new_config_path, 'w') as f:
f.write(cfg_str)
return new_config_path
def test_single(self):
# test export blade model and bevformer predictor
ori_ckpt = PRETRAINED_MODEL_BEVFORMER_BASE
inputs_file = os.path.join(SMALL_NUSCENES_PATH,
'nuscenes_infos_temporal_val.pkl')
easycv_dir = os.path.dirname(easycv.__file__)
if os.path.exists(os.path.join(easycv_dir, 'configs')):
config_dir = os.path.join(easycv_dir, 'configs')
@ -152,15 +161,41 @@ class BEVFormerPredictorTest(unittest.TestCase):
config_file = os.path.join(
config_dir,
'detection3d/bevformer/bevformer_base_r101_dcn_nuscenes.py')
config_file = self._replace_config(config_file)
cfg = mmcv_config_fromfile(config_file)
filename = os.path.join(self.tmp_dir, 'model.pth')
export(cfg, ori_ckpt, filename, fp16=False)
blade_filename = filename + '.blade'
self.assertTrue(blade_filename)
cfg.data.val.data_source.data_root = SMALL_NUSCENES_PATH
cfg.data.val.data_source.ann_file = os.path.join(
SMALL_NUSCENES_PATH, 'nuscenes_infos_temporal_val.pkl')
cfg.data.val.pop('imgs_per_gpu', None)
val_dataset = build_dataset(cfg.data.val)
evaluators = build_evaluator(cfg.eval_pipelines[0]['evaluators'][0])
predictor = BEVFormerPredictor(
model_path=model_path, config_file=config_file, batch_size=2)
results = predictor([single_ann_file, single_ann_file])
self.assertEqual(len(results), 2)
# Input the same sample continuously, the output value is different,
# because the model will record the features of the previous sample to infer the next sample
self._assert_results(results[0])
self._assert_results(results[1], assert_value=False)
model_path=blade_filename,
config_file=config_file,
model_type='blade',
)
inputs = mmcv.load(inputs_file)['infos']
predict_results = predictor(inputs)
results = {'pts_bbox': [i['pts_bbox'] for i in predict_results]}
val_results = val_dataset.evaluate(results, evaluators)
self.assertAlmostEqual(
val_results['pts_bbox_NuScenes/NDS'], 0.460, delta=0.01)
self.assertAlmostEqual(
val_results['pts_bbox_NuScenes/mAP'], 0.41, delta=0.01)
@unittest.skipIf(True, 'Not support batch yet')
def test_batch(self):
pass
if __name__ == '__main__':

View File

@ -34,11 +34,11 @@ from easycv.file import io
from easycv.models import build_model
from easycv.utils.collect_env import collect_env
from easycv.utils.logger import get_root_logger
from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab
from easycv.utils import mmlab_utils
from easycv.utils.config_tools import traverse_replace
from easycv.utils.config_tools import (CONFIG_TEMPLATE_ZOO,
mmcv_config_fromfile, rebuild_config)
from easycv.utils.dist_utils import get_device
from easycv.utils.dist_utils import get_device, is_master
from easycv.utils.setup_env import setup_multi_processes
@ -161,7 +161,7 @@ def main():
cfg.load_from = args.load_from
# dynamic adapt mmdet models
dynamic_adapt_for_mmlab(cfg)
mmlab_utils.dynamic_adapt_for_mmlab(cfg)
cfg.gpus = args.gpus
@ -230,7 +230,9 @@ def main():
assert isinstance(args.pretrained, str)
cfg.model.pretrained = args.pretrained
model = build_model(cfg.model)
print(model)
if is_master():
print(model)
if 'stage' in cfg.model and cfg.model['stage'] == 'EDGE':
from easycv.utils.flops_counter import get_model_info
@ -259,6 +261,8 @@ def main():
), 'odps config must be set in cfg file / cfg.data.train.data_source !!'
shuffle = False
if getattr(cfg.data, 'pin_memory', False):
mmlab_utils.fix_dc_pin_memory()
datasets = [build_dataset(cfg.data.train)]
data_loaders = [
build_dataloader(
@ -268,6 +272,7 @@ def main():
cfg.gpus,
dist=distributed,
shuffle=shuffle,
pin_memory=getattr(cfg.data, 'pin_memory', False),
replace=getattr(cfg.data, 'sampling_replace', False),
seed=cfg.seed,
drop_last=getattr(cfg.data, 'drop_last', False),