mirror of https://github.com/alibaba/EasyCV.git
95 lines
3.0 KiB
Python
95 lines
3.0 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import os
|
|
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
from torchvision.transforms import Compose
|
|
|
|
from easycv.datasets.registry import PIPELINES
|
|
from easycv.file import io
|
|
from easycv.models import build_model
|
|
from easycv.utils.checkpoint import load_checkpoint
|
|
# from mmcv import Config
|
|
from easycv.utils.config_tools import mmcv_config_fromfile
|
|
from easycv.utils.constant import CACHE_DIR
|
|
from easycv.utils.registry import build_from_cfg
|
|
|
|
|
|
class NumpyToPIL(object):
|
|
|
|
def __call__(self, results):
|
|
img = results['img']
|
|
results['img'] = Image.fromarray(np.uint8(img)).convert('RGB')
|
|
return results
|
|
|
|
|
|
class Predictor(object):
|
|
|
|
def __init__(self, model_path, numpy_to_pil=True):
|
|
self.model_path = model_path
|
|
self.numpy_to_pil = numpy_to_pil
|
|
assert io.exists(self.model_path), f'{self.model_path} does not exists'
|
|
|
|
with io.open(self.model_path, 'rb') as infile:
|
|
checkpoint = torch.load(infile, map_location='cpu')
|
|
|
|
assert 'meta' in checkpoint and 'config' in checkpoint[
|
|
'meta'], 'meta.config is missing from checkpoint'
|
|
|
|
config_str = checkpoint['meta']['config']
|
|
# get config
|
|
basename = os.path.basename(self.model_path)
|
|
fname, _ = os.path.splitext(basename)
|
|
self.local_config_file = os.path.join(CACHE_DIR,
|
|
f'{fname}_config.json')
|
|
if not os.path.exists(CACHE_DIR):
|
|
os.makedirs(CACHE_DIR)
|
|
with open(self.local_config_file, 'w') as ofile:
|
|
ofile.write(config_str)
|
|
self.cfg = mmcv_config_fromfile(self.local_config_file)
|
|
|
|
# build model
|
|
self.model = build_model(self.cfg.model)
|
|
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
map_location = 'cpu' if self.device == 'cpu' else 'cuda'
|
|
self.ckpt = load_checkpoint(
|
|
self.model, self.model_path, map_location=map_location)
|
|
|
|
self.model.to(self.device)
|
|
self.model.eval()
|
|
|
|
# build pipeline
|
|
pipeline = [
|
|
build_from_cfg(p, PIPELINES) for p in self.cfg.test_pipeline
|
|
]
|
|
if self.numpy_to_pil:
|
|
pipeline = [NumpyToPIL()] + pipeline
|
|
self.pipeline = Compose(pipeline)
|
|
|
|
def preprocess(self, image_list):
|
|
# only perform transform to img
|
|
output_imgs_list = []
|
|
for img in image_list:
|
|
tmp_input = {'img': img}
|
|
tmp_results = self.pipeline(tmp_input)
|
|
output_imgs_list.append(tmp_results['img'])
|
|
|
|
return output_imgs_list
|
|
|
|
def predict_batch(self, image_batch, **forward_kwargs):
|
|
""" predict using batched data
|
|
|
|
Args:
|
|
image_batch(torch.Tensor): tensor with shape [N, 3, H, W]
|
|
forward_kwargs: kwargs for additional parameters
|
|
|
|
Return:
|
|
output: the output of model.forward, list or tuple
|
|
"""
|
|
with torch.no_grad():
|
|
output = self.model.forward(
|
|
image_batch.to(self.device), **forward_kwargs)
|
|
return output
|