Merge pull request #19 from wenmengzhou/fix/cls_predictor

Fix: classifier run inference error
pull/24/head
Chen Jiayu 2022-04-22 15:37:47 +08:00 committed by GitHub
commit cdb2366725
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 25 additions and 5 deletions

View File

@ -48,7 +48,7 @@
You can simply install easycv with the following command:
```shell
pip install easycv
pip install pai-easycv
```
or clone the repository and then install it:

View File

@ -81,7 +81,7 @@ def _export_cls(model, cfg, filename):
else:
export_cfg = dict(export_neck=False)
export_neck = export_cfg.get('export_neck', False)
export_neck = export_cfg.get('export_neck', True)
label_map_path = cfg.get('label_map_path', None)
class_list = None
if label_map_path is not None:
@ -103,7 +103,7 @@ def _export_cls(model, cfg, filename):
print("this cls model doesn't contain cls head, we add a dummy head!")
model_config['head'] = head = dict(
type='ClsHead',
with_avg_pool=False,
with_avg_pool=True,
in_channels=model_config['backbone'].get('num_classes', 2048),
num_classes=1000,
)
@ -112,11 +112,15 @@ def _export_cls(model, cfg, filename):
if hasattr(cfg, 'test_pipeline'):
test_pipeline = cfg.test_pipeline
for pipe in test_pipeline:
if pipe['type'] == 'Collect':
pipe['keys'] = ['img']
else:
test_pipeline = [
dict(type='Resize', size=[224, 224]),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Collect', keys=['img'])
]
config = dict(

View File

@ -79,7 +79,6 @@ class ClsHead(nn.Module):
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
x1 = x[self.input_feature_index[0]]
if self.with_avg_pool and x1.dim() > 2:
assert x1.dim() == 4, \
'Tensor must has 4 dims, got: {}'.format(x1.dim())

View File

@ -5,8 +5,10 @@ import subprocess
import tempfile
import unittest
import numpy as np
import torch
from tests.ut_config import PRETRAINED_MODEL_MOCO, PRETRAINED_MODEL_RESNET50
from tests.ut_config import (IMAGENET_LABEL_TXT, PRETRAINED_MODEL_MOCO,
PRETRAINED_MODEL_RESNET50)
from easycv.apis.export import export
from easycv.utils.config_tools import mmcv_config_fromfile
@ -47,6 +49,21 @@ class ModelExportTest(unittest.TestCase):
export(cfg, ori_ckpt, target_ckpt)
self.assertTrue(os.path.exists(target_ckpt))
def test_export_classification_and_inference(self):
config_file = 'configs/classification/imagenet/imagenet_rn50_jpg.py'
cfg = mmcv_config_fromfile(config_file)
cfg.export = dict(use_jit=False)
ori_ckpt = PRETRAINED_MODEL_RESNET50
target_ckpt = f'{self.tmp_dir}/classification_export.pth'
export(cfg, ori_ckpt, target_ckpt)
self.assertTrue(os.path.exists(target_ckpt))
from easycv.predictors.classifier import TorchClassifier
classifier = TorchClassifier(
target_ckpt, label_map_path=IMAGENET_LABEL_TXT)
img = np.random.randint(0, 255, (256, 256, 3))
r = classifier.predict([img])
def test_export_cls_syncbn(self):
config_file = 'configs/classification/imagenet/imagenet_rn50_jpg.py'
cfg = mmcv_config_fromfile(config_file)