add bevformer benchmark and fix classification predict bug (#240)

pull/243/head
yhq 2022-11-24 18:25:27 +08:00 committed by GitHub
parent f8c9a9a1c9
commit a9b67f0509
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 6 deletions

View File

@ -7,3 +7,4 @@ Pretrained on [nuScenes](https://www.nuscenes.org/) dataset.
| Algorithm | Config | Params<br/> | Train memory<br/>(GB) | NDS | mAP | Download |
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| BEVFormer-base | [bevformer_base_r101_dcn_nuscenes](https://github.com/alibaba/EasyCV/tree/master/configs/detection3d/bevformer/bevformer_base_r101_dcn_nuscenes.py) | 69M | 23.9 | 52.46 | 41.83 | [model](http://pai-vision-data-hz.oss-accelerate.aliyuncs.com/EasyCV/modelzoo/detection3d/bevformer/epoch_24.pth) |
| BEVFormer-base-hybrid | [bevformer_base_r101_dcn_nuscenes_hybrid](https://github.com/alibaba/EasyCV/blob/master/configs/detection3d/bevformer/bevformer_base_r101_dcn_nuscenes_hybrid.py) | 69M | 46.1 | 53.02 | 42.48 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection3d/bevformer_base_hybrid2/epoch_23.pth) |

View File

@ -11,8 +11,6 @@ from mmcv.cnn import constant_init, xavier_init
from mmcv.runner.base_module import BaseModule
from easycv.models.registry import ATTENTION
from easycv.thirdparty.deformable_attention.functions import \
MSDeformAttnFunction
@ATTENTION.register_module()
@ -99,6 +97,7 @@ class CustomMSDeformableAttention(BaseModule):
if self.adapt_jit:
self.ms_deform_attn_op = torch.ops.custom.ms_deform_attn
else:
from easycv.thirdparty.deformable_attention.functions import MSDeformAttnFunction
self.ms_deform_attn_op = MSDeformAttnFunction.apply
def init_weights(self):

View File

@ -35,10 +35,10 @@ class ClassificationPredictor(PredictorV2):
device=None,
save_results=False,
save_path=None,
pipelines=[],
pipelines=None,
topk=1,
pil_input=True,
label_map_path=[],
label_map_path=None,
*args,
**kwargs):
super(ClassificationPredictor, self).__init__(
@ -59,7 +59,12 @@ class ClassificationPredictor(PredictorV2):
self.INPUT_IMAGE_MODE = 'RGB'
if label_map_path is None:
class_list = self.cfg.get('CLASSES', [])
if 'CLASSES' in self.cfg:
class_list = self.cfg.get('CLASSES', [])
elif 'class_list' in self.cfg:
class_list = self.cfg.get('class_list', [])
else:
class_list = []
else:
with io.open(label_map_path, 'r') as f:
class_list = f.readlines()
@ -85,7 +90,9 @@ class ClassificationPredictor(PredictorV2):
img = img.convert(self.INPUT_IMAGE_MODE.upper())
results['filename'] = input
else:
assert isinstance(input, ImageFile.ImageFile)
if isinstance(input, np.ndarray):
input = Image.fromarray(input)
# assert isinstance(input, ImageFile.ImageFile)
img = input
results['filename'] = None
results['img'] = img