support cpu deploy_test (#769)

pull/780/head
q.yao 2021-08-10 20:47:08 +08:00 committed by GitHub
parent 9155d9e9ed
commit 58f5dbce7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 0 deletions

View File

@ -53,6 +53,7 @@ class ONNXRuntimeSegmentor(BaseSegmentor):
self.io_binding.bind_output(name)
self.cfg = cfg
self.test_mode = cfg.model.test_cfg.mode
self.is_cuda_available = is_cuda_available
def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')
@ -65,6 +66,10 @@ class ONNXRuntimeSegmentor(BaseSegmentor):
def simple_test(self, img: torch.Tensor, img_meta: Iterable,
**kwargs) -> list:
if not self.is_cuda_available:
img = img.detach().cpu()
elif self.device_id >= 0:
img = img.cuda(self.device_id)
device_type = img.device.type
self.io_binding.bind_input(
name='input',