mirror of https://github.com/open-mmlab/mmcv.git
remove parallel_test (#238)
parent
af02ac9f01
commit
010b1a0ffc
|
@ -7,7 +7,6 @@ from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook,
|
|||
OptimizerHook, PaviLoggerHook, TensorboardLoggerHook,
|
||||
TextLoggerHook, WandbLoggerHook)
|
||||
from .log_buffer import LogBuffer
|
||||
from .parallel_test import parallel_test
|
||||
from .priority import Priority, get_priority
|
||||
from .runner import Runner
|
||||
from .utils import get_host_info, get_time_str, obj_from_dict
|
||||
|
@ -17,7 +16,7 @@ __all__ = [
|
|||
'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
|
||||
'LoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
|
||||
'WandbLoggerHook', '_load_checkpoint', 'load_state_dict',
|
||||
'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'parallel_test',
|
||||
'Priority', 'get_priority', 'get_host_info', 'get_time_str',
|
||||
'obj_from_dict', 'init_dist', 'get_dist_info', 'master_only'
|
||||
'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
|
||||
'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
|
||||
'init_dist', 'get_dist_info', 'master_only'
|
||||
]
|
||||
|
|
|
@ -1,75 +0,0 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import multiprocessing
|
||||
|
||||
import torch
|
||||
|
||||
import mmcv
|
||||
from .checkpoint import load_checkpoint
|
||||
|
||||
|
||||
def worker_func(model_cls, model_kwargs, checkpoint, dataset, data_func,
|
||||
gpu_id, idx_queue, result_queue):
|
||||
model = model_cls(**model_kwargs)
|
||||
load_checkpoint(model, checkpoint, map_location='cpu')
|
||||
torch.cuda.set_device(gpu_id)
|
||||
model.cuda()
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
while True:
|
||||
idx = idx_queue.get()
|
||||
data = dataset[idx]
|
||||
result = model(**data_func(data, gpu_id))
|
||||
result_queue.put((idx, result))
|
||||
|
||||
|
||||
def parallel_test(model_cls,
|
||||
model_kwargs,
|
||||
checkpoint,
|
||||
dataset,
|
||||
data_func,
|
||||
gpus,
|
||||
workers_per_gpu=1):
|
||||
"""Parallel testing on multiple GPUs.
|
||||
|
||||
Args:
|
||||
model_cls (type): Model class type.
|
||||
model_kwargs (dict): Arguments to init the model.
|
||||
checkpoint (str): Checkpoint filepath.
|
||||
dataset (:obj:`Dataset`): The dataset to be tested.
|
||||
data_func (callable): The function that generates model inputs.
|
||||
gpus (list[int]): GPU ids to be used.
|
||||
workers_per_gpu (int): Number of processes on each GPU. It is possible
|
||||
to run multiple workers on each GPU.
|
||||
|
||||
Returns:
|
||||
list: Test results.
|
||||
"""
|
||||
ctx = multiprocessing.get_context('spawn')
|
||||
idx_queue = ctx.Queue()
|
||||
result_queue = ctx.Queue()
|
||||
num_workers = len(gpus) * workers_per_gpu
|
||||
workers = [
|
||||
ctx.Process(
|
||||
target=worker_func,
|
||||
args=(model_cls, model_kwargs, checkpoint, dataset, data_func,
|
||||
gpus[i % len(gpus)], idx_queue, result_queue))
|
||||
for i in range(num_workers)
|
||||
]
|
||||
for w in workers:
|
||||
w.daemon = True
|
||||
w.start()
|
||||
|
||||
for i in range(len(dataset)):
|
||||
idx_queue.put(i)
|
||||
|
||||
results = [None for _ in range(len(dataset))]
|
||||
prog_bar = mmcv.ProgressBar(task_num=len(dataset))
|
||||
for _ in range(len(dataset)):
|
||||
idx, res = result_queue.get()
|
||||
results[idx] = res
|
||||
prog_bar.update()
|
||||
print('\n')
|
||||
for worker in workers:
|
||||
worker.terminate()
|
||||
|
||||
return results
|
Loading…
Reference in New Issue