remove parallel_test (#238)

pull/246/head
Kai Chen 2020-04-22 23:47:44 +08:00 committed by GitHub
parent af02ac9f01
commit 010b1a0ffc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 79 deletions

View File

@ -7,7 +7,6 @@ from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook,
OptimizerHook, PaviLoggerHook, TensorboardLoggerHook,
TextLoggerHook, WandbLoggerHook)
from .log_buffer import LogBuffer
from .parallel_test import parallel_test
from .priority import Priority, get_priority
from .runner import Runner
from .utils import get_host_info, get_time_str, obj_from_dict
@ -17,7 +16,7 @@ __all__ = [
'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
'LoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
'WandbLoggerHook', '_load_checkpoint', 'load_state_dict',
'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'parallel_test',
'Priority', 'get_priority', 'get_host_info', 'get_time_str',
'obj_from_dict', 'init_dist', 'get_dist_info', 'master_only'
'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
'init_dist', 'get_dist_info', 'master_only'
]

View File

@ -1,75 +0,0 @@
# Copyright (c) Open-MMLab. All rights reserved.
import multiprocessing
import torch
import mmcv
from .checkpoint import load_checkpoint
def worker_func(model_cls, model_kwargs, checkpoint, dataset, data_func,
gpu_id, idx_queue, result_queue):
model = model_cls(**model_kwargs)
load_checkpoint(model, checkpoint, map_location='cpu')
torch.cuda.set_device(gpu_id)
model.cuda()
model.eval()
with torch.no_grad():
while True:
idx = idx_queue.get()
data = dataset[idx]
result = model(**data_func(data, gpu_id))
result_queue.put((idx, result))
def parallel_test(model_cls,
model_kwargs,
checkpoint,
dataset,
data_func,
gpus,
workers_per_gpu=1):
"""Parallel testing on multiple GPUs.
Args:
model_cls (type): Model class type.
model_kwargs (dict): Arguments to init the model.
checkpoint (str): Checkpoint filepath.
dataset (:obj:`Dataset`): The dataset to be tested.
data_func (callable): The function that generates model inputs.
gpus (list[int]): GPU ids to be used.
workers_per_gpu (int): Number of processes on each GPU. It is possible
to run multiple workers on each GPU.
Returns:
list: Test results.
"""
ctx = multiprocessing.get_context('spawn')
idx_queue = ctx.Queue()
result_queue = ctx.Queue()
num_workers = len(gpus) * workers_per_gpu
workers = [
ctx.Process(
target=worker_func,
args=(model_cls, model_kwargs, checkpoint, dataset, data_func,
gpus[i % len(gpus)], idx_queue, result_queue))
for i in range(num_workers)
]
for w in workers:
w.daemon = True
w.start()
for i in range(len(dataset)):
idx_queue.put(i)
results = [None for _ in range(len(dataset))]
prog_bar = mmcv.ProgressBar(task_num=len(dataset))
for _ in range(len(dataset)):
idx, res = result_queue.get()
results[idx] = res
prog_bar.update()
print('\n')
for worker in workers:
worker.terminate()
return results