mirror of https://github.com/open-mmlab/mmcv.git
remove parallel_test (#238)
parent
af02ac9f01
commit
010b1a0ffc
|
@ -7,7 +7,6 @@ from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistSamplerSeedHook,
|
||||||
OptimizerHook, PaviLoggerHook, TensorboardLoggerHook,
|
OptimizerHook, PaviLoggerHook, TensorboardLoggerHook,
|
||||||
TextLoggerHook, WandbLoggerHook)
|
TextLoggerHook, WandbLoggerHook)
|
||||||
from .log_buffer import LogBuffer
|
from .log_buffer import LogBuffer
|
||||||
from .parallel_test import parallel_test
|
|
||||||
from .priority import Priority, get_priority
|
from .priority import Priority, get_priority
|
||||||
from .runner import Runner
|
from .runner import Runner
|
||||||
from .utils import get_host_info, get_time_str, obj_from_dict
|
from .utils import get_host_info, get_time_str, obj_from_dict
|
||||||
|
@ -17,7 +16,7 @@ __all__ = [
|
||||||
'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
|
'LrUpdaterHook', 'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook',
|
||||||
'LoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
|
'LoggerHook', 'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
|
||||||
'WandbLoggerHook', '_load_checkpoint', 'load_state_dict',
|
'WandbLoggerHook', '_load_checkpoint', 'load_state_dict',
|
||||||
'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'parallel_test',
|
'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
|
||||||
'Priority', 'get_priority', 'get_host_info', 'get_time_str',
|
'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
|
||||||
'obj_from_dict', 'init_dist', 'get_dist_info', 'master_only'
|
'init_dist', 'get_dist_info', 'master_only'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,75 +0,0 @@
|
||||||
# Copyright (c) Open-MMLab. All rights reserved.
|
|
||||||
import multiprocessing
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import mmcv
|
|
||||||
from .checkpoint import load_checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
def worker_func(model_cls, model_kwargs, checkpoint, dataset, data_func,
|
|
||||||
gpu_id, idx_queue, result_queue):
|
|
||||||
model = model_cls(**model_kwargs)
|
|
||||||
load_checkpoint(model, checkpoint, map_location='cpu')
|
|
||||||
torch.cuda.set_device(gpu_id)
|
|
||||||
model.cuda()
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
while True:
|
|
||||||
idx = idx_queue.get()
|
|
||||||
data = dataset[idx]
|
|
||||||
result = model(**data_func(data, gpu_id))
|
|
||||||
result_queue.put((idx, result))
|
|
||||||
|
|
||||||
|
|
||||||
def parallel_test(model_cls,
|
|
||||||
model_kwargs,
|
|
||||||
checkpoint,
|
|
||||||
dataset,
|
|
||||||
data_func,
|
|
||||||
gpus,
|
|
||||||
workers_per_gpu=1):
|
|
||||||
"""Parallel testing on multiple GPUs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_cls (type): Model class type.
|
|
||||||
model_kwargs (dict): Arguments to init the model.
|
|
||||||
checkpoint (str): Checkpoint filepath.
|
|
||||||
dataset (:obj:`Dataset`): The dataset to be tested.
|
|
||||||
data_func (callable): The function that generates model inputs.
|
|
||||||
gpus (list[int]): GPU ids to be used.
|
|
||||||
workers_per_gpu (int): Number of processes on each GPU. It is possible
|
|
||||||
to run multiple workers on each GPU.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: Test results.
|
|
||||||
"""
|
|
||||||
ctx = multiprocessing.get_context('spawn')
|
|
||||||
idx_queue = ctx.Queue()
|
|
||||||
result_queue = ctx.Queue()
|
|
||||||
num_workers = len(gpus) * workers_per_gpu
|
|
||||||
workers = [
|
|
||||||
ctx.Process(
|
|
||||||
target=worker_func,
|
|
||||||
args=(model_cls, model_kwargs, checkpoint, dataset, data_func,
|
|
||||||
gpus[i % len(gpus)], idx_queue, result_queue))
|
|
||||||
for i in range(num_workers)
|
|
||||||
]
|
|
||||||
for w in workers:
|
|
||||||
w.daemon = True
|
|
||||||
w.start()
|
|
||||||
|
|
||||||
for i in range(len(dataset)):
|
|
||||||
idx_queue.put(i)
|
|
||||||
|
|
||||||
results = [None for _ in range(len(dataset))]
|
|
||||||
prog_bar = mmcv.ProgressBar(task_num=len(dataset))
|
|
||||||
for _ in range(len(dataset)):
|
|
||||||
idx, res = result_queue.get()
|
|
||||||
results[idx] = res
|
|
||||||
prog_bar.update()
|
|
||||||
print('\n')
|
|
||||||
for worker in workers:
|
|
||||||
worker.terminate()
|
|
||||||
|
|
||||||
return results
|
|
Loading…
Reference in New Issue