mirror of https://github.com/alibaba/EasyCV.git
Merge pull request #19 from wenmengzhou/fix/cls_predictor
Fix: classifier run inference errorpull/24/head
commit
cdb2366725
|
@ -48,7 +48,7 @@
|
|||
You can simply install easycv with the following command:
|
||||
|
||||
```shell
|
||||
pip install easycv
|
||||
pip install pai-easycv
|
||||
```
|
||||
|
||||
or clone the repository and then install it:
|
||||
|
|
|
@ -81,7 +81,7 @@ def _export_cls(model, cfg, filename):
|
|||
else:
|
||||
export_cfg = dict(export_neck=False)
|
||||
|
||||
export_neck = export_cfg.get('export_neck', False)
|
||||
export_neck = export_cfg.get('export_neck', True)
|
||||
label_map_path = cfg.get('label_map_path', None)
|
||||
class_list = None
|
||||
if label_map_path is not None:
|
||||
|
@ -103,7 +103,7 @@ def _export_cls(model, cfg, filename):
|
|||
print("this cls model doesn't contain cls head, we add a dummy head!")
|
||||
model_config['head'] = head = dict(
|
||||
type='ClsHead',
|
||||
with_avg_pool=False,
|
||||
with_avg_pool=True,
|
||||
in_channels=model_config['backbone'].get('num_classes', 2048),
|
||||
num_classes=1000,
|
||||
)
|
||||
|
@ -112,11 +112,15 @@ def _export_cls(model, cfg, filename):
|
|||
|
||||
if hasattr(cfg, 'test_pipeline'):
|
||||
test_pipeline = cfg.test_pipeline
|
||||
for pipe in test_pipeline:
|
||||
if pipe['type'] == 'Collect':
|
||||
pipe['keys'] = ['img']
|
||||
else:
|
||||
test_pipeline = [
|
||||
dict(type='Resize', size=[224, 224]),
|
||||
dict(type='ToTensor'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
|
||||
config = dict(
|
||||
|
|
|
@ -79,7 +79,6 @@ class ClsHead(nn.Module):
|
|||
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
|
||||
x1 = x[self.input_feature_index[0]]
|
||||
|
||||
if self.with_avg_pool and x1.dim() > 2:
|
||||
assert x1.dim() == 4, \
|
||||
'Tensor must has 4 dims, got: {}'.format(x1.dim())
|
||||
|
|
|
@ -5,8 +5,10 @@ import subprocess
|
|||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tests.ut_config import PRETRAINED_MODEL_MOCO, PRETRAINED_MODEL_RESNET50
|
||||
from tests.ut_config import (IMAGENET_LABEL_TXT, PRETRAINED_MODEL_MOCO,
|
||||
PRETRAINED_MODEL_RESNET50)
|
||||
|
||||
from easycv.apis.export import export
|
||||
from easycv.utils.config_tools import mmcv_config_fromfile
|
||||
|
@ -47,6 +49,21 @@ class ModelExportTest(unittest.TestCase):
|
|||
export(cfg, ori_ckpt, target_ckpt)
|
||||
self.assertTrue(os.path.exists(target_ckpt))
|
||||
|
||||
def test_export_classification_and_inference(self):
|
||||
config_file = 'configs/classification/imagenet/imagenet_rn50_jpg.py'
|
||||
cfg = mmcv_config_fromfile(config_file)
|
||||
cfg.export = dict(use_jit=False)
|
||||
ori_ckpt = PRETRAINED_MODEL_RESNET50
|
||||
target_ckpt = f'{self.tmp_dir}/classification_export.pth'
|
||||
export(cfg, ori_ckpt, target_ckpt)
|
||||
self.assertTrue(os.path.exists(target_ckpt))
|
||||
|
||||
from easycv.predictors.classifier import TorchClassifier
|
||||
classifier = TorchClassifier(
|
||||
target_ckpt, label_map_path=IMAGENET_LABEL_TXT)
|
||||
img = np.random.randint(0, 255, (256, 256, 3))
|
||||
r = classifier.predict([img])
|
||||
|
||||
def test_export_cls_syncbn(self):
|
||||
config_file = 'configs/classification/imagenet/imagenet_rn50_jpg.py'
|
||||
cfg = mmcv_config_fromfile(config_file)
|
||||
|
|
Loading…
Reference in New Issue