update some predcitors, support batch inference (#195)

update some predcitors, support batch inference
This commit is contained in:
Cathy0908 2022-09-20 10:04:42 +08:00 committed by GitHub
parent bb53e066be
commit 5dfe7b2898
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 671 additions and 667 deletions

3
.gitignore vendored
View File

@ -137,6 +137,3 @@ pai_jobs/easycv/resources/
*.tar.gz *.tar.gz
thirdparty/test thirdparty/test
scripts/test scripts/test
# easycv default cache dir
.easycv_cache

View File

@ -86,3 +86,13 @@ checkpoint_config = dict(interval=10)
# runtime settings # runtime settings
total_epochs = 100 total_epochs = 100
predict = dict(
type='ClassificationPredictor',
pipelines=[
dict(type='Resize', size=256),
dict(type='CenterCrop', size=224),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Collect', keys=['img'])
])

View File

@ -27,7 +27,7 @@ def load_image(img_path):
def load_seg_map(seg_path, reduce_zero_label): def load_seg_map(seg_path, reduce_zero_label):
gt_semantic_seg = _load_img(seg_path, mode='RGB') gt_semantic_seg = _load_img(seg_path, mode='P')
# reduce zero_label # reduce zero_label
if reduce_zero_label: if reduce_zero_label:
# avoid using underflow conversion # avoid using underflow conversion

View File

@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import io
import logging import logging
import time import time
@ -6,10 +7,10 @@ import cv2
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from easycv.file import io from easycv import file
from easycv.framework.errors import IOError from easycv.framework.errors import IOError
from easycv.utils.constant import MAX_READ_IMAGE_TRY_TIMES from easycv.utils.constant import MAX_READ_IMAGE_TRY_TIMES
from .utils import is_oss_path from .utils import is_oss_path, is_url_path
def load_image(img_path, mode='BGR', max_try_times=MAX_READ_IMAGE_TRY_TIMES): def load_image(img_path, mode='BGR', max_try_times=MAX_READ_IMAGE_TRY_TIMES):
@ -20,16 +21,31 @@ def load_image(img_path, mode='BGR', max_try_times=MAX_READ_IMAGE_TRY_TIMES):
img = None img = None
while try_cnt < max_try_times: while try_cnt < max_try_times:
try: try:
with io.open(img_path, 'rb') as infile: if is_url_path(img_path):
# cv2.imdecode may corrupt when the img is broken from mmcv.fileio.file_client import HTTPBackend
image = Image.open(infile) # RGB client = HTTPBackend()
img_bytes = client.get(img_path)
buff = io.BytesIO(img_bytes)
image = Image.open(buff)
if mode.upper() != 'BGR' and image.mode.upper() != mode.upper(
):
image = image.convert(mode.upper())
img = np.asarray(image, dtype=np.uint8) img = np.asarray(image, dtype=np.uint8)
if mode.upper() == 'BGR': else:
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) with file.io.open(img_path, 'rb') as infile:
assert mode.upper() in ['RGB', 'BGR' # cv2.imdecode may corrupt when the img is broken
], 'Only support `RGB` and `BGR` mode!' image = Image.open(infile)
assert img is not None if mode.upper() != 'BGR' and image.mode.upper(
break ) != mode.upper():
image = image.convert(mode.upper())
img = np.asarray(image, dtype=np.uint8)
if mode.upper() == 'BGR':
if image.mode.upper() != 'RGB':
image = image.convert('RGB')
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
assert img is not None
break
except Exception as e: except Exception as e:
logging.error(e) logging.error(e)
logging.warning('Read file {} fault, try count : {}'.format( logging.warning('Read file {} fault, try count : {}'.format(

View File

@ -13,7 +13,7 @@ from tqdm import tqdm
from easycv.framework.errors import ValueError from easycv.framework.errors import ValueError
OSS_PREFIX = 'oss://' OSS_PREFIX = 'oss://'
URL_PREFIX = 'https://' URL_PREFIX = ('https://', 'http://')
def create_namedtuple(**kwargs): def create_namedtuple(**kwargs):
@ -33,6 +33,7 @@ def url_path_exists(url):
urllib.request.urlopen(url).code urllib.request.urlopen(url).code
except Exception as err: except Exception as err:
print(err) print(err)
return False
return True return True

View File

@ -9,5 +9,4 @@ from .feature_extractor import (TorchFaceAttrExtractor,
from .hand_keypoints_predictor import HandKeypointsPredictor from .hand_keypoints_predictor import HandKeypointsPredictor
from .pose_predictor import (TorchPoseTopDownPredictor, from .pose_predictor import (TorchPoseTopDownPredictor,
TorchPoseTopDownPredictorWithDetector) TorchPoseTopDownPredictorWithDetector)
from .segmentation import (Mask2formerPredictor, SegFormerPredictor, from .segmentation import Mask2formerPredictor, SegmentationPredictor
SegmentationPredictor)

View File

@ -1,19 +1,23 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import json
import os import os
import pickle import pickle
import cv2
import numpy as np import numpy as np
import torch import torch
from mmcv.parallel import collate, scatter_kwargs from mmcv.parallel import collate, scatter_kwargs
from PIL import Image from PIL import Image
from torch.hub import load_state_dict_from_url
from torchvision.transforms import Compose from torchvision.transforms import Compose
from easycv.datasets.registry import PIPELINES from easycv.datasets.registry import PIPELINES
from easycv.file import io from easycv.file import io
from easycv.file.utils import is_url_path
from easycv.framework.errors import ValueError from easycv.framework.errors import ValueError
from easycv.models.builder import build_model from easycv.models.builder import build_model
from easycv.utils.checkpoint import load_checkpoint from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.config_tools import mmcv_config_fromfile from easycv.utils.config_tools import Config, mmcv_config_fromfile
from easycv.utils.constant import CACHE_DIR from easycv.utils.constant import CACHE_DIR
from easycv.utils.mmlab_utils import (dynamic_adapt_for_mmlab, from easycv.utils.mmlab_utils import (dynamic_adapt_for_mmlab,
remove_adapt_for_mmlab) remove_adapt_for_mmlab)
@ -107,7 +111,9 @@ class PredictorV2(object):
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically. device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
save_results (bool): Whether to save predict results. save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True. save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
""" """
INPUT_IMAGE_MODE = 'BGR' # the image mode into the model
def __init__(self, def __init__(self,
model_path, model_path,
@ -116,30 +122,51 @@ class PredictorV2(object):
device=None, device=None,
save_results=False, save_results=False,
save_path=None, save_path=None,
mode='rgb', pipelines=None,
*args, *args,
**kwargs): **kwargs):
self.model_path = model_path self.model_path = model_path
self.batch_size = batch_size self.batch_size = batch_size
self.save_results = save_results self.save_results = save_results
self.save_path = save_path self.save_path = save_path
self.config_file = config_file
if self.save_results: if self.save_results:
assert self.save_path is not None assert self.save_path is not None
self.device = device self.device = device
if self.device is None: if self.device is None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.cfg = None
if config_file is not None: if config_file is not None:
if isinstance(config_file, str): if isinstance(config_file, str):
self.cfg = mmcv_config_fromfile(config_file) self.cfg = mmcv_config_fromfile(config_file)
else: else:
self.cfg = config_file self.cfg = config_file
else:
self.cfg = self._load_cfg_from_ckpt(self.model_path)
if self.cfg is None:
raise ValueError('Please provide "config_file"!')
self.model = self.prepare_model() self.model = self.prepare_model()
self.pipelines = pipelines
self.processor = self.build_processor() self.processor = self.build_processor()
self._load_op = None self._load_op = None
self.mode = mode
def _load_cfg_from_ckpt(self, model_path):
if is_url_path(model_path):
ckpt = load_state_dict_from_url(model_path)
else:
with io.open(model_path, 'rb') as infile:
ckpt = torch.load(infile, map_location='cpu')
cfg = None
if 'meta' in ckpt and 'config' in ckpt['meta']:
cfg = ckpt['meta']['config']
if isinstance(cfg, dict):
cfg = Config(cfg)
elif isinstance(cfg, str):
cfg = Config(json.loads(cfg))
return cfg
def prepare_model(self): def prepare_model(self):
"""Build model from config file by default. """Build model from config file by default.
@ -152,8 +179,6 @@ class PredictorV2(object):
return model return model
def _build_model(self): def _build_model(self):
if self.cfg is None:
raise ValueError('Please provide "config_file"!')
# Use mmdet model # Use mmdet model
dynamic_adapt_for_mmlab(self.cfg) dynamic_adapt_for_mmlab(self.cfg)
model = build_model(self.cfg.model) model = build_model(self.cfg.model)
@ -165,16 +190,15 @@ class PredictorV2(object):
"""Build processor to process loaded input. """Build processor to process loaded input.
If you need custom preprocessing ops, you need to reimplement it. If you need custom preprocessing ops, you need to reimplement it.
""" """
if self.cfg is None: if self.pipelines is not None:
pipeline = [] pipelines = self.pipelines
else: else:
pipeline = [ pipelines = self.cfg.get('test_pipeline', [])
build_from_cfg(p, PIPELINES)
for p in self.cfg.get('test_pipeline', []) pipelines = [build_from_cfg(p, PIPELINES) for p in pipelines]
]
from easycv.datasets.shared.pipelines.transforms import Compose from easycv.datasets.shared.pipelines.transforms import Compose
processor = Compose(pipeline) processor = Compose(pipelines)
return processor return processor
def _load_input(self, input): def _load_input(self, input):
@ -190,10 +214,13 @@ class PredictorV2(object):
} }
""" """
if self._load_op is None: if self._load_op is None:
load_cfg = dict(type='LoadImage', mode=self.mode) load_cfg = dict(type='LoadImage', mode=self.INPUT_IMAGE_MODE)
self._load_op = build_from_cfg(load_cfg, PIPELINES) self._load_op = build_from_cfg(load_cfg, PIPELINES)
if not isinstance(input, str): if not isinstance(input, str):
if isinstance(input, np.ndarray):
# Only support RGB mode if input is np.ndarray.
input = cv2.cvtColor(input, cv2.COLOR_RGB2BGR)
sample = self._load_op({'img': input}) sample = self._load_op({'img': input})
else: else:
sample = self._load_op({'filename': input}) sample = self._load_op({'filename': input})
@ -229,8 +256,32 @@ class PredictorV2(object):
return outputs return outputs
def postprocess(self, inputs, *args, **kwargs): def postprocess(self, inputs, *args, **kwargs):
"""Process model outputs. """Process model batch outputs.
If you need add some processing ops to process model outputs, you need to reimplement it. """
outputs = []
out_i = {}
batch_size = 1
# get current batch size
for k, batch_v in inputs.items():
if batch_v is not None:
batch_size = len(batch_v)
break
for i in range(batch_size):
for k, batch_v in inputs.items():
if batch_v is not None:
out_i[k] = batch_v[i]
else:
out_i[k] = None
out_i = self.postprocess_single(out_i)
outputs.append(out_i)
return outputs
def postprocess_single(self, inputs):
"""Process outputs of single sample.
If you need add some processing ops, you need to reimplement it.
""" """
return inputs return inputs
@ -260,16 +311,22 @@ class PredictorV2(object):
results_list = [] results_list = []
for i in range(0, len(inputs), self.batch_size): for i in range(0, len(inputs), self.batch_size):
batch = inputs[i:max(len(inputs) - 1, i + self.batch_size)] batch = inputs[i:min(len(inputs), i + self.batch_size)]
batch_outputs = self.preprocess(batch) batch_outputs = self.preprocess(batch)
batch_outputs = self.forward(batch_outputs) batch_outputs = self.forward(batch_outputs)
results = self.postprocess(batch_outputs) results = self.postprocess(batch_outputs)
assert len(results) == len(
batch), f'Mismatch size {len(results)} != {len(batch)}'
if keep_inputs: if keep_inputs:
results = {'inputs': batch, 'results': results} for i in range(len(batch)):
results[i].update({'inputs': batch[i]})
# if dump, the outputs will not added to the return value to prevent taking up too much memory # if dump, the outputs will not added to the return value to prevent taking up too much memory
if self.save_results: if self.save_results:
self.dump([results], self.save_path, mode='ab+') self.dump(results, self.save_path, mode='ab+')
else: else:
results_list.append(results) if isinstance(results, list):
results_list.extend(results)
else:
results_list.append(results)
return results_list return results_list

View File

@ -3,17 +3,130 @@ import math
import numpy as np import numpy as np
import torch import torch
from PIL import Image, ImageFile
from easycv.file import io
from easycv.framework.errors import ValueError from easycv.framework.errors import ValueError
from .base import Predictor from easycv.utils.misc import deprecated
from .base import Predictor, PredictorV2
from .builder import PREDICTORS from .builder import PREDICTORS
@PREDICTORS.register_module()
class ClassificationPredictor(PredictorV2):
"""Predictor for classification.
Args:
model_path (str): Path of model path.
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward.
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
topk (int): Return top-k results. Default: 1.
pil_input (bool): Whether use PIL image. If processor need PIL input, set true, default false.
label_map_path (str): File path of saving labels list.
"""
def __init__(self,
model_path,
config_file=None,
batch_size=1,
device=None,
save_results=False,
save_path=None,
pipelines=[],
topk=1,
pil_input=True,
label_map_path=[],
*args,
**kwargs):
super(ClassificationPredictor, self).__init__(
model_path,
config_file=config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
*args,
**kwargs)
self.topk = topk
self.pil_input = pil_input
# Adapt to torchvision transforms which process PIL inputs.
if self.pil_input:
self.INPUT_IMAGE_MODE = 'RGB'
if label_map_path is None:
class_list = self.cfg.get('CLASSES', [])
else:
with io.open(label_map_path, 'r') as f:
class_list = f.readlines()
self.label_map = [i.strip() for i in class_list]
def _load_input(self, input):
"""Load image from file or numpy or PIL object.
Args:
input: File path or numpy or PIL object.
Returns:
{
'filename': filename,
'img': img,
'img_shape': img_shape,
'img_fields': ['img']
}
"""
if self.pil_input:
results = {}
if isinstance(input, str):
img = Image.open(input)
if img.mode.upper() != self.INPUT_IMAGE_MODE.upper():
img = img.convert(self.INPUT_IMAGE_MODE.upper())
results['filename'] = input
else:
assert isinstance(input, ImageFile.ImageFile)
img = input
results['filename'] = None
results['img'] = img
results['img_shape'] = img.size
results['ori_shape'] = img.size
results['img_fields'] = ['img']
return results
return super()._load_input(input)
def postprocess(self, inputs, *args, **kwargs):
"""Return top-k results."""
output_prob = inputs['prob'].data.cpu()
topk_class = torch.topk(output_prob, self.topk).indices.numpy()
output_prob = output_prob.numpy()
batch_results = []
batch_size = output_prob.shape[0]
for i in range(batch_size):
result = {'class': np.squeeze(topk_class[i]).tolist()}
if isinstance(result['class'], int):
result['class'] = [result['class']]
if len(self.label_map) > 0:
result['class_name'] = [
self.label_map[i] for i in result['class']
]
result['class_probs'] = {}
for l_idx, l_name in enumerate(self.label_map):
result['class_probs'][l_name] = output_prob[i][l_idx]
batch_results.append(result)
return batch_results
try: try:
from easy_vision.python.inference.predictor import PredictorInterface from easy_vision.python.inference.predictor import PredictorInterface
except: except:
from .interface import PredictorInterface from .interface import PredictorInterface
@deprecated(reason='Please use ClassificationPredictor.')
@PREDICTORS.register_module() @PREDICTORS.register_module()
class TorchClassifier(PredictorInterface): class TorchClassifier(PredictorInterface):

View File

@ -5,9 +5,6 @@ from glob import glob
import numpy as np import numpy as np
import torch import torch
from mmcv.ops import RoIPool
from mmcv.parallel import collate, scatter
from torch.hub import load_state_dict_from_url
from torchvision.transforms import Compose from torchvision.transforms import Compose
from easycv.apis.export import reparameterize_models from easycv.apis.export import reparameterize_models
@ -15,16 +12,12 @@ from easycv.core.visualization import imshow_bboxes
from easycv.datasets.registry import PIPELINES from easycv.datasets.registry import PIPELINES
from easycv.datasets.utils import replace_ImageToTensor from easycv.datasets.utils import replace_ImageToTensor
from easycv.file import io from easycv.file import io
from easycv.file.utils import is_url_path, url_path_exists
from easycv.framework.errors import TypeError
from easycv.models import build_model from easycv.models import build_model
from easycv.models.detection.utils import postprocess from easycv.models.detection.utils import postprocess
from easycv.utils.checkpoint import load_checkpoint from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.config_tools import mmcv_config_fromfile from easycv.utils.config_tools import mmcv_config_fromfile
from easycv.utils.constant import CACHE_DIR from easycv.utils.constant import CACHE_DIR
from easycv.utils.logger import get_root_logger from easycv.utils.misc import deprecated
from easycv.utils.mmlab_utils import (dynamic_adapt_for_mmlab,
remove_adapt_for_mmlab)
from easycv.utils.registry import build_from_cfg from easycv.utils.registry import build_from_cfg
from .base import PredictorV2 from .base import PredictorV2
from .builder import PREDICTORS from .builder import PREDICTORS
@ -47,14 +40,16 @@ class DetectionPredictor(PredictorV2):
""" """
def __init__(self, def __init__(self,
model_path=None, model_path,
config_file=None, config_file=None,
batch_size=1, batch_size=1,
device=None, device=None,
save_results=False, save_results=False,
save_path=None, save_path=None,
mode='rgb', pipelines=None,
score_threshold=0.5): score_threshold=0.5,
*arg,
**kwargs):
super(DetectionPredictor, self).__init__( super(DetectionPredictor, self).__init__(
model_path, model_path,
config_file=config_file, config_file=config_file,
@ -62,194 +57,55 @@ class DetectionPredictor(PredictorV2):
device=device, device=device,
save_results=save_results, save_results=save_results,
save_path=save_path, save_path=save_path,
mode=mode, pipelines=pipelines,
) )
self.score_thresh = score_threshold self.score_thresh = score_threshold
self.CLASSES = self.cfg.get('CLASSES', None)
def build_processor(self):
if self.pipelines is not None:
pipelines = self.pipelines
elif self.cfg is None:
pipelines = []
else:
pipelines = self.cfg.get('test_pipeline', [])
# for batch inference
self.pipelines = replace_ImageToTensor(pipelines)
return super().build_processor()
def postprocess_single(self, inputs, *args, **kwargs):
if inputs['detection_scores'] is None or len(
inputs['detection_scores']) < 1:
return inputs
scores = inputs['detection_scores']
if scores is not None and self.score_thresh > 0:
keeped_ids = scores > self.score_thresh
inputs['detection_scores'] = inputs['detection_scores'][keeped_ids]
inputs['detection_boxes'] = inputs['detection_boxes'][keeped_ids]
inputs['detection_classes'] = inputs['detection_classes'][
keeped_ids]
class_names = []
for _, classes_id in enumerate(inputs['detection_classes']):
if classes_id is None:
class_names.append(None)
elif self.CLASSES is not None and len(self.CLASSES) > 0:
class_names.append(self.CLASSES[int(classes_id)])
else:
class_names.append(classes_id)
inputs['detection_class_names'] = class_names
def postprocess(self, inputs, *args, **kwargs):
for batch_index in range(self.batch_size):
this_detection_scores = inputs['detection_scores'][batch_index]
sel_ids = this_detection_scores > self.score_thresh
inputs['detection_scores'][batch_index] = inputs[
'detection_scores'][batch_index][sel_ids]
inputs['detection_boxes'][batch_index] = inputs['detection_boxes'][
batch_index][sel_ids]
inputs['detection_classes'][batch_index] = inputs[
'detection_classes'][batch_index][sel_ids]
# TODO class label remapping
return inputs return inputs
def visualize(self, img, results, show=False, out_file=None):
class DetrPredictor(PredictorInterface): """Only support show one sample now."""
"""Inference image(s) with the detector. bboxes = results['detection_boxes']
Args: labels = results['detection_class_names']
model_path (str): checkpoint model and export model are shared. img = self._load_input(img)['img']
config_path (str): If config_path is specified, both checkpoint model and export model can be used; if config_path=None, the export model is used by default.
"""
def __init__(self, model_path, config_path=None):
self.model_path = model_path
if config_path is not None:
self.cfg = mmcv_config_fromfile(config_path)
else:
logger = get_root_logger()
logger.warning('please use export model!')
if is_url_path(self.model_path) and url_path_exists(
self.model_path):
checkpoint = load_state_dict_from_url(model_path)
else:
assert io.exists(
self.model_path), f'{self.model_path} does not exists'
with io.open(self.model_path, 'rb') as infile:
checkpoint = torch.load(infile, map_location='cpu')
assert 'meta' in checkpoint and 'config' in checkpoint[
'meta'], 'meta.config is missing from checkpoint'
config_str = checkpoint['meta']['config']
if isinstance(config_str, dict):
config_str = json.dumps(config_str)
# get config
basename = os.path.basename(self.model_path)
fname, _ = os.path.splitext(basename)
self.local_config_file = os.path.join(CACHE_DIR,
f'{fname}_config.json')
if not os.path.exists(CACHE_DIR):
os.makedirs(CACHE_DIR)
with open(self.local_config_file, 'w') as ofile:
ofile.write(config_str)
self.cfg = mmcv_config_fromfile(self.local_config_file)
# dynamic adapt mmdet models
dynamic_adapt_for_mmlab(self.cfg)
# build model
self.model = build_model(self.cfg.model)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
map_location = 'cpu' if self.device == 'cpu' else 'cuda'
self.ckpt = load_checkpoint(
self.model, self.model_path, map_location=map_location)
self.model.to(self.device)
self.model.eval()
self.CLASSES = self.cfg.CLASSES
def predict(self, imgs):
"""
Args:
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
Either image files or loaded images.
Returns:
If imgs is a list or tuple, the same length list type results
will be returned, otherwise return the detection results directly.
"""
if isinstance(imgs, (list, tuple)):
is_batch = True
else:
imgs = [imgs]
is_batch = False
cfg = self.cfg
device = next(self.model.parameters()).device # model device
if isinstance(imgs[0], np.ndarray):
cfg = cfg.copy()
# set loading pipeline type
cfg.data.val.pipeline.insert(0, dict(type='LoadImageFromWebcam'))
else:
cfg = cfg.copy()
# set loading pipeline type
cfg.data.val.pipeline.insert(
0,
dict(
type='LoadImageFromFile',
file_client_args=dict(
backend=('http' if imgs[0].startswith('http'
) else 'disk'))))
cfg.data.val.pipeline = replace_ImageToTensor(cfg.data.val.pipeline)
transforms = []
for transform in cfg.data.val.pipeline:
if 'img_scale' in transform:
transform['img_scale'] = tuple(transform['img_scale'])
if isinstance(transform, dict):
transform = build_from_cfg(transform, PIPELINES)
transforms.append(transform)
elif callable(transform):
transforms.append(transform)
else:
raise TypeError('transform must be callable or a dict')
test_pipeline = Compose(transforms)
datas = []
for img in imgs:
# prepare data
if isinstance(img, np.ndarray):
# directly add img
data = dict(img=img)
else:
# add information into dict
data = dict(img_info=dict(filename=img), img_prefix=None)
# build the data pipeline
data = test_pipeline(data)
datas.append(data)
data = collate(datas, samples_per_gpu=len(imgs))
# just get the actual data from DataContainer
data['img_metas'] = [
img_metas.data[0] for img_metas in data['img_metas']
]
data['img'] = [img.data[0] for img in data['img']]
if next(self.model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
for m in self.model.modules():
assert not isinstance(
m, RoIPool
), 'CPU inference with RoIPool is not supported currently.'
# forward the model
with torch.no_grad():
results = self.model(mode='test', **data)
return results
def visualize(self,
img,
results,
score_thr=0.3,
show=False,
out_file=None):
bboxes = results['detection_boxes'][0]
scores = results['detection_scores'][0]
labels = results['detection_classes'][0].tolist()
# If self.CLASSES is not None, class_id will be converted to self.CLASSES for visualization,
# otherwise the class_id will be displayed.
# And don't try to modify the value in results, it may cause some bugs or even precision problems,
# because `self.evaluate` will also use the results, refer to: https://github.com/alibaba/EasyCV/pull/67
if self.CLASSES is not None and len(self.CLASSES) > 0:
for i, classes_id in enumerate(labels):
if classes_id is None:
labels[i] = None
else:
labels[i] = self.CLASSES[int(classes_id)]
if scores is not None and score_thr > 0:
inds = scores > score_thr
bboxes = bboxes[inds]
labels = np.array(labels)[inds]
imshow_bboxes( imshow_bboxes(
img, img,
bboxes, bboxes,
@ -263,6 +119,12 @@ class DetrPredictor(PredictorInterface):
out_file=out_file) out_file=out_file)
@deprecated(reason='Please use DetectionPredictor.')
@PREDICTORS.register_module()
class DetrPredictor(DetectionPredictor):
""""""
@PREDICTORS.register_module() @PREDICTORS.register_module()
class TorchYoloXPredictor(PredictorInterface): class TorchYoloXPredictor(PredictorInterface):

View File

@ -25,6 +25,11 @@ class FaceKeypointsPredictor(PredictorV2):
Args: Args:
model_path (str): Path of model path model_path (str): Path of model path
config_file (str): config file path for model and processor to init. Defaults to None. config_file (str): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward.
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
""" """
def __init__(self, def __init__(self,
@ -34,7 +39,7 @@ class FaceKeypointsPredictor(PredictorV2):
device=None, device=None,
save_results=False, save_results=False,
save_path=None, save_path=None,
mode='bgr'): pipelines=None):
super(FaceKeypointsPredictor, self).__init__( super(FaceKeypointsPredictor, self).__init__(
model_path, model_path,
config_file, config_file,
@ -42,7 +47,7 @@ class FaceKeypointsPredictor(PredictorV2):
device=device, device=device,
save_results=save_results, save_results=save_results,
save_path=save_path, save_path=save_path,
mode=mode) pipelines=pipelines)
self.input_size = self.cfg.IMAGE_SIZE self.input_size = self.cfg.IMAGE_SIZE
self.point_number = self.cfg.POINT_NUMBER self.point_number = self.cfg.POINT_NUMBER

View File

@ -25,9 +25,11 @@ class HandKeypointsPredictor(PredictorV2):
config_file: path or ``Config`` of config file config_file: path or ``Config`` of config file
detection_model_config: dict of hand detection model predictor config, detection_model_config: dict of hand detection model predictor config,
example like ``dict(type="", model_path="", config_file="", ......)`` example like ``dict(type="", model_path="", config_file="", ......)``
batch_size: batch_size to infer batch_size (int): batch size for forward.
save_results: bool device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
save_path: path of result image save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
""" """
def __init__(self, def __init__(self,
@ -38,7 +40,7 @@ class HandKeypointsPredictor(PredictorV2):
device=None, device=None,
save_results=False, save_results=False,
save_path=None, save_path=None,
mode='rgb', pipelines=None,
*args, *args,
**kwargs): **kwargs):
super(HandKeypointsPredictor, self).__init__( super(HandKeypointsPredictor, self).__init__(
@ -48,7 +50,7 @@ class HandKeypointsPredictor(PredictorV2):
device=device, device=device,
save_results=save_results, save_results=save_results,
save_path=save_path, save_path=save_path,
mode=mode, pipelines=pipelines,
*args, *args,
**kwargs) **kwargs)
self.dataset_info = DatasetInfo(COCO_WHOLEBODY_HAND_DATASET_INFO) self.dataset_info = DatasetInfo(COCO_WHOLEBODY_HAND_DATASET_INFO)
@ -70,52 +72,48 @@ class HandKeypointsPredictor(PredictorV2):
} }
} }
""" """
image_paths = input['inputs'] image_path = input['inputs']
batch_data = [] data_list = []
box_id = 0 box_id = 0
for batch_index, image_path in enumerate(image_paths): det_bbox_result = input['detection_boxes']
det_bbox_result = input['results']['detection_boxes'][batch_index] det_bbox_scores = input['detection_scores']
det_bbox_scores = input['results']['detection_scores'][batch_index] img = mmcv.imread(image_path, 'color', self.INPUT_IMAGE_MODE)
img = mmcv.imread(image_path, 'color', self.mode) for bbox, score in zip(det_bbox_result, det_bbox_scores):
for bbox, score in zip(det_bbox_result, det_bbox_scores): center, scale = _box2cs(self.cfg.data_cfg['image_size'], bbox)
center, scale = _box2cs(self.cfg.data_cfg['image_size'], bbox) # prepare data
# prepare data data = {
data = { 'image_file':
'image_file': image_path,
image_path, 'img':
'img': img,
img, 'image_id':
'image_id': 0,
batch_index, 'center':
'center': center,
center, 'scale':
'scale': scale,
scale, 'bbox_score':
'bbox_score': score,
score, 'bbox_id':
'bbox_id': box_id, # need to be assigned if batch_size > 1
box_id, # need to be assigned if batch_size > 1 'dataset':
'dataset': 'coco_wholebody_hand',
'coco_wholebody_hand', 'joints_3d':
'joints_3d': np.zeros((self.cfg.data_cfg.num_joints, 3), dtype=np.float32),
np.zeros((self.cfg.data_cfg.num_joints, 3), 'joints_3d_visible':
dtype=np.float32), np.zeros((self.cfg.data_cfg.num_joints, 3), dtype=np.float32),
'joints_3d_visible': 'rotation':
np.zeros((self.cfg.data_cfg.num_joints, 3), 0,
dtype=np.float32), 'flip_pairs':
'rotation': self.dataset_info.flip_pairs,
0, 'ann_info': {
'flip_pairs': 'image_size': np.array(self.cfg.data_cfg['image_size']),
self.dataset_info.flip_pairs, 'num_joints': self.cfg.data_cfg['num_joints']
'ann_info': {
'image_size':
np.array(self.cfg.data_cfg['image_size']),
'num_joints': self.cfg.data_cfg['num_joints']
}
} }
batch_data.append(data) }
box_id += 1 data_list.append(data)
return batch_data box_id += 1
return data_list
def preprocess_single(self, input): def preprocess_single(self, input):
results = [] results = []
@ -128,8 +126,11 @@ class HandKeypointsPredictor(PredictorV2):
"""Process all inputs list. And collate to batch and put to target device. """Process all inputs list. And collate to batch and put to target device.
If you need custom ops to load or process a batch samples, you need to reimplement it. If you need custom ops to load or process a batch samples, you need to reimplement it.
""" """
# hand det and return source image
det_results = self.detection_predictor(inputs, keep_inputs=True)
batch_outputs = [] batch_outputs = []
for i in inputs: for i in det_results:
for res in self.preprocess_single(i, *args, **kwargs): for res in self.preprocess_single(i, *args, **kwargs):
batch_outputs.append(res) batch_outputs.append(res)
batch_outputs = self._collate_fn(batch_outputs) batch_outputs = self._collate_fn(batch_outputs)
@ -137,37 +138,25 @@ class HandKeypointsPredictor(PredictorV2):
return batch_outputs return batch_outputs
def postprocess(self, inputs, *args, **kwargs): def postprocess(self, inputs, *args, **kwargs):
output = {} keypoints = inputs['preds']
output['keypoints'] = inputs['preds'] boxes = inputs['boxes']
output['boxes'] = inputs['boxes'] for i, bbox in enumerate(boxes):
for i, bbox in enumerate(output['boxes']):
center, scale = bbox[:2], bbox[2:4] center, scale = bbox[:2], bbox[2:4]
output['boxes'][i][:4] = bbox_cs2xyxy(center, scale) boxes[i][:4] = bbox_cs2xyxy(center, scale)
output['boxes'] = output['boxes'][:, :4] boxes = boxes[:, :4]
return output # TODO: support multi bboxes for a single sample
assert len(keypoints.shape) == 3
def __call__(self, inputs, keep_inputs=False): assert len(boxes.shape) == 2
if isinstance(inputs, str): batch_outputs = []
inputs = [inputs] batch_size = keypoints.shape[0]
keypoints = np.split(keypoints, batch_size)
results_list = [] boxes = np.split(boxes, batch_size)
for i in range(0, len(inputs), self.batch_size): for i in range(batch_size):
batch = inputs[i:max(len(inputs) - 1, i + self.batch_size)] batch_outputs.append({
# hand det and return source image 'keypoints': keypoints[i],
det_results = self.detection_predictor(batch, keep_inputs=True) 'boxes': boxes[i]
# hand keypoints })
batch_outputs = self.preprocess(det_results) return batch_outputs
batch_outputs = self.forward(batch_outputs)
results = self.postprocess(batch_outputs)
if keep_inputs:
results = {'inputs': batch, 'results': results}
# if dump, the outputs will not added to the return value to prevent taking up too much memory
if self.save_results:
self.dump([results], self.save_path, mode='ab+')
else:
results_list.append(results)
return results_list
def show_result(self, def show_result(self,
image_path, image_path,

View File

@ -5,22 +5,25 @@ import numpy as np
import torch import torch
from matplotlib.collections import PatchCollection from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon from matplotlib.patches import Polygon
from torchvision.transforms import Compose
from easycv.core.visualization.image import imshow_bboxes from easycv.core.visualization.image import imshow_bboxes
from easycv.datasets.registry import PIPELINES
from easycv.file import io
from easycv.models import build_model
from easycv.predictors.builder import PREDICTORS from easycv.predictors.builder import PREDICTORS
from easycv.predictors.interface import PredictorInterface
from easycv.utils.checkpoint import load_checkpoint
from easycv.utils.config_tools import mmcv_config_fromfile
from easycv.utils.registry import build_from_cfg
from .base import PredictorV2 from .base import PredictorV2
@PREDICTORS.register_module() @PREDICTORS.register_module()
class SegmentationPredictor(PredictorV2): class SegmentationPredictor(PredictorV2):
"""Predictor for Segmentation.
Args:
model_path (str): Path of model path.
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward.
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
"""
def __init__(self, def __init__(self,
model_path, model_path,
@ -28,20 +31,21 @@ class SegmentationPredictor(PredictorV2):
batch_size=1, batch_size=1,
device=None, device=None,
save_results=False, save_results=False,
save_path=None): save_path=None,
"""Predict pipeline for Segmentation pipelines=None,
*args,
**kwargs):
Args:
model_path (str): Path of model path
config_file (str): config file path for model and processor to init. Defaults to None.
"""
super(SegmentationPredictor, self).__init__( super(SegmentationPredictor, self).__init__(
model_path, model_path,
config_file, config_file,
batch_size=batch_size, batch_size=batch_size,
device=device, device=device,
save_results=save_results, save_results=save_results,
save_path=save_path) save_path=save_path,
pipelines=pipelines,
*args,
**kwargs)
self.CLASSES = self.cfg.CLASSES self.CLASSES = self.cfg.CLASSES
self.PALETTE = self.cfg.PALETTE self.PALETTE = self.cfg.PALETTE
@ -123,71 +127,61 @@ class SegmentationPredictor(PredictorV2):
@PREDICTORS.register_module() @PREDICTORS.register_module()
class Mask2formerPredictor(PredictorInterface): class Mask2formerPredictor(SegmentationPredictor):
"""Predictor for Mask2former.
def __init__(self, model_path, model_config=None): Args:
"""init model model_path (str): Path of model path.
config_file (Optinal[str]): config file path for model and processor to init. Defaults to None.
batch_size (int): batch size for forward.
device (str): Support 'cuda' or 'cpu', if is None, detect device automatically.
save_results (bool): Whether to save predict results.
save_path (str): File path for saving results, only valid when `save_results` is True.
pipelines (list[dict]): Data pipeline configs.
"""
Args: def __init__(self,
model_path (str): Path of model path model_path,
model_config (config, optional): config string for model to init. Defaults to None. config_file=None,
batch_size=1,
device=None,
save_results=False,
save_path=None,
pipelines=None,
task_mode='panoptic',
*args,
**kwargs):
super(Mask2formerPredictor, self).__init__(
model_path,
config_file,
batch_size=batch_size,
device=device,
save_results=save_results,
save_path=save_path,
pipelines=pipelines,
*args,
**kwargs)
self.task_mode = task_mode
def forward(self, inputs):
"""Model forward.
""" """
self.model_path = model_path with torch.no_grad():
outputs = self.model(**inputs, mode='test', encode=False)
return outputs
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' def postprocess(self, inputs):
self.model = None output = {}
with io.open(self.model_path, 'rb') as infile: if self.task_mode == 'panoptic':
checkpoint = torch.load(infile, map_location='cpu') output['pan'] = inputs['pan_results'][0]
elif self.task_mode == 'instance':
assert 'meta' in checkpoint and 'config' in checkpoint[ output['segms'] = inputs['detection_masks'][0]
'meta'], 'meta.config is missing from checkpoint' output['bboxes'] = inputs['detection_boxes'][0]
output['scores'] = inputs['detection_scores'][0]
self.cfg = checkpoint['meta']['config'] output['labels'] = inputs['detection_classes'][0]
self.classes = len(self.cfg.PALETTE) else:
self.class_name = self.cfg.CLASSES raise ValueError(f'Not support model {self.task_mode}')
# build model return output
self.model = build_model(self.cfg.model)
self.ckpt = load_checkpoint(
self.model, self.model_path, map_location=self.device)
self.model.to(self.device)
self.model.eval()
# build pipeline
test_pipeline = self.cfg.test_pipeline
pipeline = [build_from_cfg(p, PIPELINES) for p in test_pipeline]
self.pipeline = Compose(pipeline)
def predict(self, input_data_list, mode='panoptic'):
"""
Args:
input_data_list: a list of numpy array(in rgb order), each array is a sample
to be predicted
"""
output_list = []
for idx, img in enumerate(input_data_list):
output = {}
if not isinstance(img, np.ndarray):
img = np.asarray(img)
data_dict = {'img': img}
ori_shape = img.shape
data_dict = self.pipeline(data_dict)
img = data_dict['img']
img[0] = torch.unsqueeze(img[0], 0).to(self.device)
img_metas = [[
img_meta._data for img_meta in data_dict['img_metas']
]]
img_metas[0][0]['ori_shape'] = ori_shape
res = self.model.forward_test(img, img_metas, encode=False)
if mode == 'panoptic':
output['pan'] = res['pan_results'][0]
elif mode == 'instance':
output['segms'] = res['detection_masks'][0]
output['bboxes'] = res['detection_boxes'][0]
output['scores'] = res['detection_scores'][0]
output['labels'] = res['detection_classes'][0]
output_list.append(output)
return output_list
def show_panoptic(self, img, pan_mask): def show_panoptic(self, img, pan_mask):
pan_label = np.unique(pan_mask) pan_label = np.unique(pan_mask)
@ -214,147 +208,6 @@ class Mask2formerPredictor(PredictorInterface):
return instance_result return instance_result
@PREDICTORS.register_module()
class SegFormerPredictor(PredictorInterface):
def __init__(self, model_path, model_config):
"""init model
Args:
model_path (str): Path of model path
model_config (config): config string for model to init. Defaults to None.
"""
self.model_path = model_path
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = None
with io.open(self.model_path, 'rb') as infile:
checkpoint = torch.load(infile, map_location='cpu')
self.cfg = mmcv_config_fromfile(model_config)
self.CLASSES = self.cfg.CLASSES
self.PALETTE = self.cfg.PALETTE
# build model
self.model = build_model(self.cfg.model)
self.ckpt = load_checkpoint(
self.model, self.model_path, map_location=self.device)
self.model.to(self.device)
self.model.eval()
# build pipeline
test_pipeline = self.cfg.test_pipeline
pipeline = [build_from_cfg(p, PIPELINES) for p in test_pipeline]
self.pipeline = Compose(pipeline)
def predict(self, input_data_list):
"""
using session run predict a number of samples using batch_size
Args:
input_data_list: a list of numpy array(in rgb order), each array is a sample
to be predicted
use a fixed number if you do not want to adjust batch_size in runtime
"""
output_list = []
for idx, img in enumerate(input_data_list):
if type(img) is not np.ndarray:
img = np.asarray(img)
ori_img_shape = img.shape[:2]
data_dict = {'img': img}
data_dict['ori_shape'] = ori_img_shape
data_dict = self.pipeline(data_dict)
img = data_dict['img']
img = torch.unsqueeze(img[0], 0).to(self.device)
data_dict.pop('img')
with torch.no_grad():
out = self.model([img],
mode='test',
img_metas=[[data_dict['img_metas'][0]._data]])
output_list.append(out)
return output_list
def show_result(self,
img,
result,
palette=None,
win_name='',
show=False,
wait_time=0,
out_file=None,
opacity=0.5):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
result (Tensor): The semantic segmentation results to draw over
`img`.
palette (list[list[int]]] | np.ndarray | None): The palette of
segmentation map. If None is given, random palette will be
generated. Default: None
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The filename to write the image.
Default: None.
opacity(float): Opacity of painted segmentation map.
Default 0.5.
Must be in (0, 1] range.
Returns:
img (Tensor): Only if not `show` or `out_file`
"""
img = mmcv.imread(img)
img = img.copy()
seg = result[0]
if palette is None:
if self.PALETTE is None:
# Get random state before set seed,
# and restore random state later.
# It will prevent loss of randomness, as the palette
# may be different in each iteration if not specified.
# See: https://github.com/open-mmlab/mmdetection/issues/5844
state = np.random.get_state()
np.random.seed(42)
# random palette
palette = np.random.randint(
0, 255, size=(len(self.CLASSES), 3))
np.random.set_state(state)
else:
palette = self.PALETTE
palette = np.array(palette)
assert palette.shape[0] == len(self.CLASSES)
assert palette.shape[1] == 3
assert len(palette.shape) == 2
assert 0 < opacity <= 1.0
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# convert to BGR
color_seg = color_seg[..., ::-1]
img = img * (1 - opacity) + color_seg * opacity
img = img.astype(np.uint8)
# if out_file specified, do not show image in window
if out_file is not None:
show = False
if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)
if not (show or out_file):
return img
def _get_bias_color(base, max_dist=30): def _get_bias_color(base, max_dist=30):
"""Get different colors for each masks. """Get different colors for each masks.

View File

@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import os import os
import torch import torch
@ -8,6 +9,7 @@ from mmcv.runner.checkpoint import get_state_dict, weights_to_cpu
from torch.optim import Optimizer from torch.optim import Optimizer
from easycv.file import io from easycv.file import io
from easycv.file.utils import is_url_path
from easycv.framework.errors import TypeError from easycv.framework.errors import TypeError
from easycv.utils.constant import CACHE_DIR from easycv.utils.constant import CACHE_DIR
@ -32,28 +34,40 @@ def load_checkpoint(model,
Returns: Returns:
dict or OrderedDict: The loaded checkpoint. dict or OrderedDict: The loaded checkpoint.
""" """
if not filename.startswith('oss://'): if filename.startswith('oss://'):
return mmcv_load_checkpoint(
model,
filename,
map_location=map_location,
strict=strict,
logger=logger)
else:
_, fname = os.path.split(filename) _, fname = os.path.split(filename)
cache_file = os.path.join(CACHE_DIR, fname) cache_file = os.path.join(CACHE_DIR, fname)
if not os.path.exists(CACHE_DIR):
os.makedirs(CACHE_DIR)
if not os.path.exists(cache_file): if not os.path.exists(cache_file):
print(f'download checkpoint from {filename} to {cache_file}') logging.info(
f'download checkpoint from {filename} to {cache_file}')
io.copy(filename, cache_file) io.copy(filename, cache_file)
if torch.distributed.is_available( if torch.distributed.is_available(
) and torch.distributed.is_initialized(): ) and torch.distributed.is_initialized():
torch.distributed.barrier() torch.distributed.barrier()
return mmcv_load_checkpoint( filename = cache_file
model, elif is_url_path(filename):
cache_file, from torch.hub import urlparse, download_url_to_file
map_location=map_location, parts = urlparse(filename)
strict=strict, base_name = os.path.basename(parts.path)
logger=logger) cache_file = os.path.join(CACHE_DIR, base_name)
if not os.path.exists(CACHE_DIR):
os.makedirs(CACHE_DIR)
if not os.path.exists(cache_file):
logging.info(
f'download checkpoint from {filename} to {cache_file}')
download_url_to_file(filename, cache_file)
if torch.distributed.is_available(
) and torch.distributed.is_initialized():
torch.distributed.barrier()
filename = cache_file
return mmcv_load_checkpoint(
model,
filename,
map_location=map_location,
strict=strict,
logger=logger)
def save_checkpoint(model, filename, optimizer=None, meta=None): def save_checkpoint(model, filename, optimizer=None, meta=None):

View File

@ -1,4 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
CACHE_DIR = '.easycv_cache' import os
CACHE_DIR = os.path.expanduser('~/.cache/easycv/')
MAX_READ_IMAGE_TRY_TIMES = 10 MAX_READ_IMAGE_TRY_TIMES = 10

View File

@ -1,12 +1,12 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import functools
import inspect
import logging import logging
import warnings
from functools import partial from functools import partial
import mmcv import mmcv
import numpy as np import numpy as np
from six.moves import map, zip
from easycv.models.backbones.repvgg_yolox_backbone import RepVGGBlock
def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True): def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
@ -79,6 +79,8 @@ def reparameterize_models(model):
Args: Args:
model: nn.Module model: nn.Module
""" """
from easycv.models.backbones.repvgg_yolox_backbone import RepVGGBlock
reparameterize_count = 0 reparameterize_count = 0
for layer in model.modules(): for layer in model.modules():
if isinstance(layer, RepVGGBlock): if isinstance(layer, RepVGGBlock):
@ -89,3 +91,31 @@ def reparameterize_models(model):
.format(reparameterize_count)) .format(reparameterize_count))
print('reparam:', reparameterize_count) print('reparam:', reparameterize_count)
return model return model
def deprecated(reason):
"""
This is a decorator which can be used to mark functions
as deprecated. It will result in a warning being emitted
when the function is used.
"""
def decorator(func1):
if inspect.isclass(func1):
fmt1 = 'Call to deprecated class {name} ({reason}).'
else:
fmt1 = 'Call to deprecated function {name} ({reason}).'
@functools.wraps(func1)
def new_func1(*args, **kwargs):
warnings.simplefilter('always', DeprecationWarning)
warnings.warn(
fmt1.format(name=func1.__name__, reason=reason),
category=DeprecationWarning,
stacklevel=2)
warnings.simplefilter('default', DeprecationWarning)
return func1(*args, **kwargs)
return new_func1
return decorator

View File

@ -7,9 +7,7 @@ from tests.ut_config import (IMG_NORM_CFG_255, SEG_DATA_SMALL_RAW_LOCAL,
from easycv.core.evaluation.builder import build_evaluator from easycv.core.evaluation.builder import build_evaluator
from easycv.datasets.builder import build_datasource from easycv.datasets.builder import build_datasource
from easycv.datasets.segmentation.data_sources.raw import SegSourceRaw
from easycv.datasets.segmentation.raw import SegDataset from easycv.datasets.segmentation.raw import SegDataset
from easycv.file import io
class SegDatasetTest(unittest.TestCase): class SegDatasetTest(unittest.TestCase):

View File

@ -8,14 +8,57 @@ import unittest
import cv2 import cv2
import torch import torch
from easycv.predictors.builder import build_predictor
from easycv.predictors.classifier import TorchClassifier
from easycv.utils.test_util import clean_up, get_tmp_dir from easycv.utils.test_util import clean_up, get_tmp_dir
from easycv.utils.config_tools import mmcv_config_fromfile
from tests.ut_config import (PRETRAINED_MODEL_RESNET50_WITHOUTHEAD, from tests.ut_config import (PRETRAINED_MODEL_RESNET50_WITHOUTHEAD,
IMAGENET_LABEL_TXT, TEST_IMAGES_DIR) IMAGENET_LABEL_TXT, TEST_IMAGES_DIR)
class ClassificationPredictorTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_single(self):
checkpoint = PRETRAINED_MODEL_RESNET50_WITHOUTHEAD
config_file = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py'
cfg = mmcv_config_fromfile(config_file)
predict_op = build_predictor(
dict(
**cfg.predict,
model_path=checkpoint,
config_file=config_file,
label_map_path=IMAGENET_LABEL_TXT))
img_path = os.path.join(TEST_IMAGES_DIR, 'catb.jpg')
results = predict_op([img_path])[0]
self.assertListEqual(results['class'], [283])
self.assertListEqual(results['class_name'], ['"Persian cat",'])
self.assertEqual(len(results['class_probs']), 1000)
def test_batch(self):
checkpoint = PRETRAINED_MODEL_RESNET50_WITHOUTHEAD
config_file = 'configs/classification/imagenet/resnet/imagenet_resnet50_jpg.py'
cfg = mmcv_config_fromfile(config_file)
predict_op = build_predictor(
dict(
**cfg.predict,
model_path=checkpoint,
config_file=config_file,
label_map_path=IMAGENET_LABEL_TXT,
batch_size=3))
img_path = os.path.join(TEST_IMAGES_DIR, 'catb.jpg')
num_imgs = 4
results = predict_op([img_path] * num_imgs)
self.assertEqual(len(results), num_imgs)
for res in results:
self.assertListEqual(res['class'], [283])
self.assertListEqual(res['class_name'], ['"Persian cat",'])
self.assertEqual(len(res['class_probs']), 1000)
class TorchClassifierTest(unittest.TestCase): class TorchClassifierTest(unittest.TestCase):
def setUp(self): def setUp(self):
@ -62,6 +105,8 @@ class TorchClassifierTest(unittest.TestCase):
output_ckpt = f'{self.tmp_dir}/export.pth' output_ckpt = f'{self.tmp_dir}/export.pth'
torch.save(output_dict, output_ckpt) torch.save(output_dict, output_ckpt)
from easycv.predictors.classifier import TorchClassifier
fe = TorchClassifier( fe = TorchClassifier(
output_ckpt, topk=topk, label_map_path=IMAGENET_LABEL_TXT) output_ckpt, topk=topk, label_map_path=IMAGENET_LABEL_TXT)

View File

@ -4,11 +4,11 @@ isort:skip_file
""" """
import os import os
import unittest import unittest
import tempfile
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from easycv.predictors.detector import TorchYoloXPredictor, DetrPredictor from easycv.predictors.detector import TorchYoloXPredictor, DetectionPredictor
from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_EXPORT, from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_EXPORT,
PRETRAINED_MODEL_YOLOXS_EXPORT_OLD, PRETRAINED_MODEL_YOLOXS_EXPORT_OLD,
PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT, PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT,
@ -154,25 +154,18 @@ class DetectorTest(unittest.TestCase):
[510.37033, 268.4982, 527.67017, 273.04935]]), [510.37033, 268.4982, 527.67017, 273.04935]]),
decimal=1) decimal=1)
def test_vitdet_detector(self): def _detection_detector_assert(self, output):
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/epoch_100_export.pth'
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
out_file = './result.jpg'
vitdet = DetrPredictor(model_path)
output = vitdet.predict(img)
vitdet.visualize(img, output, out_file=out_file)
self.assertIn('detection_boxes', output) self.assertIn('detection_boxes', output)
self.assertIn('detection_scores', output) self.assertIn('detection_scores', output)
self.assertIn('detection_classes', output) self.assertIn('detection_classes', output)
self.assertIn('detection_masks', output) self.assertIn('detection_masks', output)
self.assertIn('img_metas', output) self.assertIn('img_metas', output)
self.assertEqual(len(output['detection_boxes'][0]), 33) self.assertEqual(len(output['detection_boxes']), 33)
self.assertEqual(len(output['detection_scores'][0]), 33) self.assertEqual(len(output['detection_scores']), 33)
self.assertEqual(len(output['detection_classes'][0]), 33) self.assertEqual(len(output['detection_classes']), 33)
self.assertListEqual( self.assertListEqual(
output['detection_classes'][0].tolist(), output['detection_classes'].tolist(),
np.array([ np.array([
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 7, 7, 13, 13, 13, 56 2, 2, 2, 2, 2, 2, 7, 7, 13, 13, 13, 56
@ -180,7 +173,7 @@ class DetectorTest(unittest.TestCase):
dtype=np.int32).tolist()) dtype=np.int32).tolist())
assert_array_almost_equal( assert_array_almost_equal(
output['detection_scores'][0], output['detection_scores'],
np.array([ np.array([
0.9975854158401489, 0.9965696334838867, 0.9922919869422913, 0.9975854158401489, 0.9965696334838867, 0.9922919869422913,
0.9833580851554871, 0.983080267906189, 0.970454752445221, 0.9833580851554871, 0.983080267906189, 0.970454752445221,
@ -198,7 +191,7 @@ class DetectorTest(unittest.TestCase):
decimal=2) decimal=2)
assert_array_almost_equal( assert_array_almost_equal(
output['detection_boxes'][0], output['detection_boxes'],
np.array([[ np.array([[
294.22674560546875, 116.6078109741211, 379.4328918457031, 294.22674560546875, 116.6078109741211, 379.4328918457031,
150.14097595214844 150.14097595214844
@ -333,6 +326,32 @@ class DetectorTest(unittest.TestCase):
]]), ]]),
decimal=1) decimal=1)
def test_detection_detector_single(self):
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/epoch_100_export.pth'
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
vitdet = DetectionPredictor(model_path, score_threshold=0.0)
output = vitdet(img)
output = output[0]
with tempfile.NamedTemporaryFile(suffix='.jpg') as tmp_file:
tmp_save_path = tmp_file.name
vitdet.visualize(img, output, out_file=tmp_save_path)
self._detection_detector_assert(output)
def test_detection_detector_batch(self):
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/vitdet/vit_base/epoch_100_export.pth'
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
vitdet = DetectionPredictor(
model_path, score_threshold=0.0, batch_size=2)
num_samples = 3
images = [img] * num_samples
outputs = vitdet(images)
self.assertEqual(len(outputs), num_samples)
for output in outputs:
with tempfile.NamedTemporaryFile(suffix='.jpg') as tmp_file:
tmp_save_path = tmp_file.name
vitdet.visualize(img, output, out_file=tmp_save_path)
self._detection_detector_assert(output)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -3,22 +3,14 @@
isort:skip_file isort:skip_file
""" """
import os import os
import tempfile
import unittest import unittest
import cv2
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from easycv.predictors.detector import TorchYoloXPredictor from easycv.predictors.detector import TorchYoloXPredictor
from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_EXPORT, from tests.ut_config import (PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE,
PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_JIT,
PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_JIT,
PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE,
PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_BLADE, PRETRAINED_MODEL_YOLOXS_PRE_NOTRT_BLADE,
DET_DATA_SMALL_COCO_LOCAL) DET_DATA_SMALL_COCO_LOCAL)
from easycv.utils.test_util import benchmark
import logging
import pandas as pd
import torch import torch
from numpy.testing import assert_array_almost_equal from numpy.testing import assert_array_almost_equal
@ -37,7 +29,6 @@ class DetectorTest(unittest.TestCase):
input_data_list = [np.asarray(Image.open(img))] input_data_list = [np.asarray(Image.open(img))]
blade_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE blade_path = PRETRAINED_MODEL_YOLOXS_NOPRE_NOTRT_BLADE
# blade_path = '/home/zouxinyi.zxy/easycv_nfs/pretrained_models/detection/infer_yolox/debug_blade.pt.blade'
predictor_blade = TorchYoloXPredictor( predictor_blade = TorchYoloXPredictor(
model_path=blade_path, score_thresh=0.5) model_path=blade_path, score_thresh=0.5)

View File

@ -19,7 +19,7 @@ class FaceKeypointsPredictorWithoutDetectorTest(unittest.TestCase):
def test_single(self): def test_single(self):
predict_pipeline = FaceKeypointsPredictor( predict_pipeline = FaceKeypointsPredictor(
model_path=self.model_path, config_file=self.model_config_path) model_path=self.model_path, config_file=self.model_config_path)
output = predict_pipeline(self.image_path)[0][0] output = predict_pipeline(self.image_path)[0]
output_keypoints = output['point'] output_keypoints = output['point']
output_pose = output['pose'] output_pose = output['pose']
img = cv2.imread(self.image_path) img = cv2.imread(self.image_path)
@ -38,18 +38,10 @@ class FaceKeypointsPredictorWithoutDetectorTest(unittest.TestCase):
total_samples = 3 total_samples = 3
output = predict_pipeline([self.image_path] * total_samples) output = predict_pipeline([self.image_path] * total_samples)
self.assertEqual(len(output), 2) self.assertEqual(len(output), total_samples)
self.assertEqual(len(output[0]), 2) for out in output:
self.assertEqual(len(output[1]), 1) self.assertEqual(out['point'].shape, (106, 2))
self.assertEqual(output[0][0]['point'].shape[0], 106) self.assertEqual(out['pose'].shape, (3, ))
self.assertEqual(output[0][0]['point'].shape[1], 2)
self.assertEqual(output[0][0]['pose'].shape[0], 3)
self.assertEqual(output[0][1]['point'].shape[0], 106)
self.assertEqual(output[0][1]['point'].shape[1], 2)
self.assertEqual(output[0][1]['pose'].shape[0], 3)
self.assertEqual(output[1][0]['point'].shape[0], 106)
self.assertEqual(output[1][0]['point'].shape[1], 2)
self.assertEqual(output[1][0]['pose'].shape[0], 3)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -39,6 +39,37 @@ class HandKeypointsPredictorTest(unittest.TestCase):
self.assertEqual(keypoints.shape[1], 21) self.assertEqual(keypoints.shape[1], 21)
self.assertEqual(keypoints.shape[2], 3) self.assertEqual(keypoints.shape[2], 3)
def test_batch(self):
config = mmcv_config_fromfile(self.model_config_path)
predict_pipeline = HandKeypointsPredictor(
model_path=self.model_path,
config_file=config,
batch_size=2,
detection_predictor_config=dict(
type='DetectionPredictor',
model_path=MM_DEFAULT_HAND_DETECTION_SSDLITE_MODEL_PATH,
config_file=MM_DEFAULT_HAND_DETECTION_SSDLITE_CONFIG_FILE,
score_threshold=0.5))
num_samples = 4
outputs = predict_pipeline(
[self.image_path] * num_samples, keep_inputs=True)
base_keypoints = outputs[0]['keypoints']
base_boxes = outputs[0]['boxes']
for output in outputs:
keypoints = output['keypoints']
boxes = output['boxes']
image_show = predict_pipeline.show_result(
self.image_path,
keypoints,
boxes,
save_path=self.save_image_path)
self.assertEqual(keypoints.shape, (1, 21, 3))
self.assertEqual(boxes.shape, (1, 4))
self.assertListEqual(keypoints.tolist(), base_keypoints.tolist())
self.assertListEqual(boxes.tolist(), base_boxes.tolist())
self.assertEqual(output['inputs'], self.image_path)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -8,6 +8,7 @@ import unittest
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tests.ut_config import (MODEL_CONFIG_SEGFORMER, from tests.ut_config import (MODEL_CONFIG_SEGFORMER,
PRETRAINED_MODEL_MASK2FORMER_DIR,
PRETRAINED_MODEL_SEGFORMER, TEST_IMAGES_DIR) PRETRAINED_MODEL_SEGFORMER, TEST_IMAGES_DIR)
from easycv.predictors.segmentation import SegmentationPredictor from easycv.predictors.segmentation import SegmentationPredictor
@ -31,14 +32,14 @@ class SegmentationPredictorTest(unittest.TestCase):
outputs = predict_pipeline(img_path, keep_inputs=True) outputs = predict_pipeline(img_path, keep_inputs=True)
self.assertEqual(len(outputs), 1) self.assertEqual(len(outputs), 1)
self.assertEqual(outputs[0]['inputs'], [img_path]) results = outputs[0]
self.assertEqual(results['inputs'], img_path)
results = outputs[0]['results']
self.assertListEqual( self.assertListEqual(
list(img.shape)[:2], list(results['seg_pred'][0].shape)) list(img.shape)[:2], list(results['seg_pred'].shape))
self.assertListEqual(results['seg_pred'][0][1, :10].tolist(), self.assertListEqual(results['seg_pred'][1, :10].tolist(),
[161 for i in range(10)]) [161 for i in range(10)])
self.assertListEqual(results['seg_pred'][0][-1, -10:].tolist(), self.assertListEqual(results['seg_pred'][-1, -10:].tolist(),
[133 for i in range(10)]) [133 for i in range(10)])
def test_batch(self): def test_batch(self):
@ -56,19 +57,15 @@ class SegmentationPredictorTest(unittest.TestCase):
total_samples = 3 total_samples = 3
outputs = predict_pipeline( outputs = predict_pipeline(
[img_path] * total_samples, keep_inputs=True) [img_path] * total_samples, keep_inputs=True)
self.assertEqual(len(outputs), 2) self.assertEqual(len(outputs), 3)
self.assertEqual(outputs[0]['inputs'], [img_path] * 2) for i in range(len(outputs)):
self.assertEqual(outputs[1]['inputs'], [img_path] * 1) self.assertEqual(outputs[i]['inputs'], img_path)
self.assertEqual(len(outputs[0]['results']['seg_pred']), 2)
self.assertEqual(len(outputs[1]['results']['seg_pred']), 1)
for result in [outputs[0]['results'], outputs[1]['results']]:
self.assertListEqual( self.assertListEqual(
list(img.shape)[:2], list(result['seg_pred'][0].shape)) list(img.shape)[:2], list(outputs[i]['seg_pred'].shape))
self.assertListEqual(result['seg_pred'][0][1, :10].tolist(), self.assertListEqual(outputs[i]['seg_pred'][1, :10].tolist(),
[161 for i in range(10)]) [161 for i in range(10)])
self.assertListEqual(result['seg_pred'][0][-1, -10:].tolist(), self.assertListEqual(outputs[i]['seg_pred'][-1, -10:].tolist(),
[133 for i in range(10)]) [133 for i in range(10)])
def test_dump(self): def test_dump(self):
@ -91,17 +88,47 @@ class SegmentationPredictorTest(unittest.TestCase):
total_samples = 3 total_samples = 3
outputs = predict_pipeline( outputs = predict_pipeline(
[img_path] * total_samples, keep_inputs=True) [img_path] * total_samples, keep_inputs=False)
self.assertEqual(outputs, []) self.assertEqual(outputs, [])
with open(tmp_path, 'rb') as f: with open(tmp_path, 'rb') as f:
results = pickle.loads(f.read()) results = pickle.loads(f.read())
self.assertIn('inputs', results[0]) for res in results:
self.assertIn('results', results[0]) self.assertNotIn('inputs', res)
self.assertIn('seg_pred', res)
shutil.rmtree(temp_dir, ignore_errors=True) shutil.rmtree(temp_dir, ignore_errors=True)
@unittest.skipIf(True, 'WIP')
class Mask2formerPredictorTest(unittest.TestCase):
def test_single(self):
import cv2
from easycv.predictors.segmentation import Mask2formerPredictor
pan_ckpt = os.path.join(PRETRAINED_MODEL_MASK2FORMER_DIR,
'mask2former_pan_export.pth')
instance_ckpt = os.path.join(PRETRAINED_MODEL_MASK2FORMER_DIR,
'mask2former_r50_instance.pth')
img_path = os.path.join(TEST_IMAGES_DIR, 'mask2former.jpg')
# panop
predictor = Mask2formerPredictor(
model_path=pan_ckpt, output_mode='panoptic')
img = cv2.imread(img_path)
predict_out = predictor([img])
pan_img = predictor.show_panoptic(img, predict_out[0]['pan'])
cv2.imwrite('pan_out.jpg', pan_img)
# instance
predictor = Mask2formerPredictor(
model_path=instance_ckpt, output_mode='instance')
img = cv2.imread(img_path)
predict_out = predictor.predict([img], mode='instance')
instance_img = predictor.show_instance(img, **predict_out[0])
cv2.imwrite('instance_out.jpg', instance_img)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -1,48 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
"""
isort:skip_file
"""
import os
import unittest
import numpy as np
from PIL import Image
from tests.ut_config import TEST_IMAGES_DIR
from tests.ut_config import (PRETRAINED_MODEL_SEGFORMER,
MODEL_CONFIG_SEGFORMER)
from easycv.predictors.segmentation import SegFormerPredictor
class SegmentorTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_segformer_detector(self):
segmentation_model_path = PRETRAINED_MODEL_SEGFORMER
segmentation_model_config = MODEL_CONFIG_SEGFORMER
img = os.path.join(TEST_IMAGES_DIR, '000000289059.jpg')
if not os.path.exists(img):
img = './data/test/segmentation/coco_stuff_164k/val2017/000000289059.jpg'
input_data_list = [np.asarray(Image.open(img))]
predictor = SegFormerPredictor(
model_path=segmentation_model_path,
model_config=segmentation_model_config)
output = predictor.predict(input_data_list)[0]
self.assertIn('seg_pred', output)
self.assertListEqual(
list(input_data_list[0].shape)[:2],
list(output['seg_pred'][0].shape))
self.assertListEqual(output['seg_pred'][0][1, :10].tolist(),
[161 for i in range(10)])
self.assertListEqual(output['seg_pred'][0][-1, -10:].tolist(),
[133 for i in range(10)])
if __name__ == '__main__':
unittest.main()

View File

@ -120,10 +120,10 @@ PRETRAINED_MODEL_YOLOX_COMPRESSION = os.path.join(
BASE_LOCAL_PATH, 'pretrained_models/compression/yolox_compression.pth') BASE_LOCAL_PATH, 'pretrained_models/compression/yolox_compression.pth')
PRETRAINED_MODEL_MAE = os.path.join( PRETRAINED_MODEL_MAE = os.path.join(
BASE_LOCAL_PATH, 'pretrained_models/classification/vit/mae_vit_b_1600.pth') BASE_LOCAL_PATH, 'pretrained_models/classification/vit/mae_vit_b_1600.pth')
PRETRAINED_MODEL_MASK2FORMER = os.path.join( PRETRAINED_MODEL_MASK2FORMER_DIR = os.path.join(
BASE_LOCAL_PATH, BASE_LOCAL_PATH, 'pretrained_models/segmentation/mask2former/')
'pretrained_models/segmentation/mask2former/mask2former_r50_instance.pth') PRETRAINED_MODEL_MASK2FORMER = os.path.join(PRETRAINED_MODEL_MASK2FORMER_DIR,
'mask2former_r50_instance.pth')
PRETRAINED_MODEL_SEGFORMER = os.path.join( PRETRAINED_MODEL_SEGFORMER = os.path.join(
BASE_LOCAL_PATH, BASE_LOCAL_PATH,
'pretrained_models/segmentation/segformer/segformer_b0/SegmentationEvaluator_mIoU_best.pth' 'pretrained_models/segmentation/segformer/segformer_b0/SegmentationEvaluator_mIoU_best.pth'

View File

@ -21,6 +21,7 @@ except:
from easycv.predictors.builder import build_predictor, PREDICTORS from easycv.predictors.builder import build_predictor, PREDICTORS
from easycv.utils.constant import CACHE_DIR
def normPRED(d): def normPRED(d):
@ -47,8 +48,8 @@ class SODPredictor(object):
""" """
def load_url_weights(name, url_index="http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/evtorch_thirdparty/u2net_sod/", map_location=None): def load_url_weights(name, url_index="http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/release/evtorch_thirdparty/u2net_sod/", map_location=None):
os.makedirs('.easycv_cache', exist_ok=True) os.makedirs(CACHE_DIR, exist_ok=True)
local_model = os.path.join('.easycv_cache', name+'.pth') local_model = os.path.join(CACHE_DIR, name+'.pth')
if os.path.exists(local_model): if os.path.exists(local_model):
weights = torch.load(local_model) weights = torch.load(local_model)
if weights is not None: if weights is not None: