mirror of https://github.com/open-mmlab/mmcv.git
add load_ext warning (#1089)
parent
a88d1d28c1
commit
69146fe3d7
|
@ -1,6 +1,7 @@
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
import pkgutil
|
import pkgutil
|
||||||
|
import warnings
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -14,6 +15,7 @@ if torch.__version__ != 'parrots':
|
||||||
return ext
|
return ext
|
||||||
else:
|
else:
|
||||||
from parrots import extension
|
from parrots import extension
|
||||||
|
from parrots.base import ParrotsException
|
||||||
|
|
||||||
has_return_value_ops = [
|
has_return_value_ops = [
|
||||||
'nms',
|
'nms',
|
||||||
|
@ -33,11 +35,11 @@ else:
|
||||||
'ms_deform_attn_forward',
|
'ms_deform_attn_forward',
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_fake_func(name):
|
def get_fake_func(name, e):
|
||||||
|
|
||||||
def fake_func(*args, **kwargs):
|
def fake_func(*args, **kwargs):
|
||||||
raise RuntimeError(
|
warnings.warn(f'{name} is not supported in parrots now')
|
||||||
'{} is not supported in parrots now'.format(name))
|
raise e
|
||||||
|
|
||||||
return fake_func
|
return fake_func
|
||||||
|
|
||||||
|
@ -48,8 +50,10 @@ else:
|
||||||
for fun in funcs:
|
for fun in funcs:
|
||||||
try:
|
try:
|
||||||
ext_fun = extension.load(fun, name, lib_dir=lib_root)
|
ext_fun = extension.load(fun, name, lib_dir=lib_root)
|
||||||
except Exception:
|
except ParrotsException as e:
|
||||||
ext_fun = get_fake_func(fun)
|
if 'No element registered' not in e.message:
|
||||||
|
warnings.warn(e.message)
|
||||||
|
ext_fun = get_fake_func(fun, e)
|
||||||
ext_list.append(ext_fun)
|
ext_list.append(ext_fun)
|
||||||
else:
|
else:
|
||||||
if fun in has_return_value_ops:
|
if fun in has_return_value_ops:
|
||||||
|
|
Loading…
Reference in New Issue