add load_ext warning (#1089)

pull/1094/head
pc 2021-06-11 16:30:39 +08:00 committed by GitHub
parent a88d1d28c1
commit 69146fe3d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 5 deletions

View File

@ -1,6 +1,7 @@
import importlib import importlib
import os import os
import pkgutil import pkgutil
import warnings
from collections import namedtuple from collections import namedtuple
import torch import torch
@ -14,6 +15,7 @@ if torch.__version__ != 'parrots':
return ext return ext
else: else:
from parrots import extension from parrots import extension
from parrots.base import ParrotsException
has_return_value_ops = [ has_return_value_ops = [
'nms', 'nms',
@ -33,11 +35,11 @@ else:
'ms_deform_attn_forward', 'ms_deform_attn_forward',
] ]
def get_fake_func(name): def get_fake_func(name, e):
def fake_func(*args, **kwargs): def fake_func(*args, **kwargs):
raise RuntimeError( warnings.warn(f'{name} is not supported in parrots now')
'{} is not supported in parrots now'.format(name)) raise e
return fake_func return fake_func
@ -48,8 +50,10 @@ else:
for fun in funcs: for fun in funcs:
try: try:
ext_fun = extension.load(fun, name, lib_dir=lib_root) ext_fun = extension.load(fun, name, lib_dir=lib_root)
except Exception: except ParrotsException as e:
ext_fun = get_fake_func(fun) if 'No element registered' not in e.message:
warnings.warn(e.message)
ext_fun = get_fake_func(fun, e)
ext_list.append(ext_fun) ext_list.append(ext_fun)
else: else:
if fun in has_return_value_ops: if fun in has_return_value_ops: