mirror of https://github.com/alibaba/EasyCV.git
add bevformer benchmark and fix classification predict bug (#240)
parent
f8c9a9a1c9
commit
a9b67f0509
|
@ -7,3 +7,4 @@ Pretrained on [nuScenes](https://www.nuscenes.org/) dataset.
|
|||
| Algorithm | Config | Params<br/> | Train memory<br/>(GB) | NDS | mAP | Download |
|
||||
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||
| BEVFormer-base | [bevformer_base_r101_dcn_nuscenes](https://github.com/alibaba/EasyCV/tree/master/configs/detection3d/bevformer/bevformer_base_r101_dcn_nuscenes.py) | 69M | 23.9 | 52.46 | 41.83 | [model](http://pai-vision-data-hz.oss-accelerate.aliyuncs.com/EasyCV/modelzoo/detection3d/bevformer/epoch_24.pth) |
|
||||
| BEVFormer-base-hybrid | [bevformer_base_r101_dcn_nuscenes_hybrid](https://github.com/alibaba/EasyCV/blob/master/configs/detection3d/bevformer/bevformer_base_r101_dcn_nuscenes_hybrid.py) | 69M | 46.1 | 53.02 | 42.48 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection3d/bevformer_base_hybrid2/epoch_23.pth) |
|
||||
|
|
|
@ -11,8 +11,6 @@ from mmcv.cnn import constant_init, xavier_init
|
|||
from mmcv.runner.base_module import BaseModule
|
||||
|
||||
from easycv.models.registry import ATTENTION
|
||||
from easycv.thirdparty.deformable_attention.functions import \
|
||||
MSDeformAttnFunction
|
||||
|
||||
|
||||
@ATTENTION.register_module()
|
||||
|
@ -99,6 +97,7 @@ class CustomMSDeformableAttention(BaseModule):
|
|||
if self.adapt_jit:
|
||||
self.ms_deform_attn_op = torch.ops.custom.ms_deform_attn
|
||||
else:
|
||||
from easycv.thirdparty.deformable_attention.functions import MSDeformAttnFunction
|
||||
self.ms_deform_attn_op = MSDeformAttnFunction.apply
|
||||
|
||||
def init_weights(self):
|
||||
|
|
|
@ -35,10 +35,10 @@ class ClassificationPredictor(PredictorV2):
|
|||
device=None,
|
||||
save_results=False,
|
||||
save_path=None,
|
||||
pipelines=[],
|
||||
pipelines=None,
|
||||
topk=1,
|
||||
pil_input=True,
|
||||
label_map_path=[],
|
||||
label_map_path=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
super(ClassificationPredictor, self).__init__(
|
||||
|
@ -59,7 +59,12 @@ class ClassificationPredictor(PredictorV2):
|
|||
self.INPUT_IMAGE_MODE = 'RGB'
|
||||
|
||||
if label_map_path is None:
|
||||
class_list = self.cfg.get('CLASSES', [])
|
||||
if 'CLASSES' in self.cfg:
|
||||
class_list = self.cfg.get('CLASSES', [])
|
||||
elif 'class_list' in self.cfg:
|
||||
class_list = self.cfg.get('class_list', [])
|
||||
else:
|
||||
class_list = []
|
||||
else:
|
||||
with io.open(label_map_path, 'r') as f:
|
||||
class_list = f.readlines()
|
||||
|
@ -85,7 +90,9 @@ class ClassificationPredictor(PredictorV2):
|
|||
img = img.convert(self.INPUT_IMAGE_MODE.upper())
|
||||
results['filename'] = input
|
||||
else:
|
||||
assert isinstance(input, ImageFile.ImageFile)
|
||||
if isinstance(input, np.ndarray):
|
||||
input = Image.fromarray(input)
|
||||
# assert isinstance(input, ImageFile.ImageFile)
|
||||
img = input
|
||||
results['filename'] = None
|
||||
results['img'] = img
|
||||
|
|
Loading…
Reference in New Issue