add bevformer benchmark and fix classification predict bug ()

pull/243/head
yhq 2022-11-24 18:25:27 +08:00 committed by GitHub
parent f8c9a9a1c9
commit a9b67f0509
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 6 deletions
easycv
models/detection3d/detectors/bevformer/attentions
predictors

View File

@ -7,3 +7,4 @@ Pretrained on [nuScenes](https://www.nuscenes.org/) dataset.
| Algorithm | Config | Params<br/> | Train memory<br/>(GB) | NDS | mAP | Download | | Algorithm | Config | Params<br/> | Train memory<br/>(GB) | NDS | mAP | Download |
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | | ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| BEVFormer-base | [bevformer_base_r101_dcn_nuscenes](https://github.com/alibaba/EasyCV/tree/master/configs/detection3d/bevformer/bevformer_base_r101_dcn_nuscenes.py) | 69M | 23.9 | 52.46 | 41.83 | [model](http://pai-vision-data-hz.oss-accelerate.aliyuncs.com/EasyCV/modelzoo/detection3d/bevformer/epoch_24.pth) | | BEVFormer-base | [bevformer_base_r101_dcn_nuscenes](https://github.com/alibaba/EasyCV/tree/master/configs/detection3d/bevformer/bevformer_base_r101_dcn_nuscenes.py) | 69M | 23.9 | 52.46 | 41.83 | [model](http://pai-vision-data-hz.oss-accelerate.aliyuncs.com/EasyCV/modelzoo/detection3d/bevformer/epoch_24.pth) |
| BEVFormer-base-hybrid | [bevformer_base_r101_dcn_nuscenes_hybrid](https://github.com/alibaba/EasyCV/blob/master/configs/detection3d/bevformer/bevformer_base_r101_dcn_nuscenes_hybrid.py) | 69M | 46.1 | 53.02 | 42.48 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection3d/bevformer_base_hybrid2/epoch_23.pth) |

View File

@ -11,8 +11,6 @@ from mmcv.cnn import constant_init, xavier_init
from mmcv.runner.base_module import BaseModule from mmcv.runner.base_module import BaseModule
from easycv.models.registry import ATTENTION from easycv.models.registry import ATTENTION
from easycv.thirdparty.deformable_attention.functions import \
MSDeformAttnFunction
@ATTENTION.register_module() @ATTENTION.register_module()
@ -99,6 +97,7 @@ class CustomMSDeformableAttention(BaseModule):
if self.adapt_jit: if self.adapt_jit:
self.ms_deform_attn_op = torch.ops.custom.ms_deform_attn self.ms_deform_attn_op = torch.ops.custom.ms_deform_attn
else: else:
from easycv.thirdparty.deformable_attention.functions import MSDeformAttnFunction
self.ms_deform_attn_op = MSDeformAttnFunction.apply self.ms_deform_attn_op = MSDeformAttnFunction.apply
def init_weights(self): def init_weights(self):

View File

@ -35,10 +35,10 @@ class ClassificationPredictor(PredictorV2):
device=None, device=None,
save_results=False, save_results=False,
save_path=None, save_path=None,
pipelines=[], pipelines=None,
topk=1, topk=1,
pil_input=True, pil_input=True,
label_map_path=[], label_map_path=None,
*args, *args,
**kwargs): **kwargs):
super(ClassificationPredictor, self).__init__( super(ClassificationPredictor, self).__init__(
@ -59,7 +59,12 @@ class ClassificationPredictor(PredictorV2):
self.INPUT_IMAGE_MODE = 'RGB' self.INPUT_IMAGE_MODE = 'RGB'
if label_map_path is None: if label_map_path is None:
class_list = self.cfg.get('CLASSES', []) if 'CLASSES' in self.cfg:
class_list = self.cfg.get('CLASSES', [])
elif 'class_list' in self.cfg:
class_list = self.cfg.get('class_list', [])
else:
class_list = []
else: else:
with io.open(label_map_path, 'r') as f: with io.open(label_map_path, 'r') as f:
class_list = f.readlines() class_list = f.readlines()
@ -85,7 +90,9 @@ class ClassificationPredictor(PredictorV2):
img = img.convert(self.INPUT_IMAGE_MODE.upper()) img = img.convert(self.INPUT_IMAGE_MODE.upper())
results['filename'] = input results['filename'] = input
else: else:
assert isinstance(input, ImageFile.ImageFile) if isinstance(input, np.ndarray):
input = Image.fromarray(input)
# assert isinstance(input, ImageFile.ImageFile)
img = input img = input
results['filename'] = None results['filename'] = None
results['img'] = img results['img'] = img