mirror of
https://github.com/open-mmlab/mmrazor.git
synced 2025-06-03 15:02:54 +08:00
* tmp * add new mmdet models * add docstring * pass test and pre-commit * rm razor tracer * update fx tracer, now it can automatically wrap methods and functions. * update tracer passed models * add warning for torch <1.12.0 fix bug for python3.6 update placeholder to support placeholder.XXX * fix bug * update docs * fix lint * fix parse_cfg in configs * restore mutablechannel * test ite prune algorithm when using dist * add get_model_from_path to MMModelLibrrary * add mm models to DefaultModelLibrary * add uts * fix bug * fix bug * add uts * add uts * add uts * add uts * fix bug * restore ite_prune_algorithm * update doc * PruneTracer -> ChannelAnalyzer * prune_tracer -> channel_analyzer * add test for fxtracer * fix bug * fix bug * PruneTracer -> ChannelAnalyzer refine * CustomFxTracer -> MMFxTracer * fix bug when test with torch<1.12 * update print log * fix lint * rm unuseful code Co-authored-by: liukai <liukai@pjlab.org.cn> Co-authored-by: jacky <jacky@xx.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: liukai <your_email@abc.example>
18 lines
489 B
Python
18 lines
489 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
|
|
|
|
class SetTorchThread:
|
|
|
|
def __init__(self, num_thread: int = -1) -> None:
|
|
self.prev_num_threads = torch.get_num_threads()
|
|
self.num_threads = num_thread
|
|
|
|
def __enter__(self):
|
|
if self.num_threads != -1:
|
|
torch.set_num_threads(self.num_threads)
|
|
|
|
def __exit__(self, exc_type, exc_value, tb):
|
|
if self.num_threads != -1:
|
|
torch.set_num_threads(self.prev_num_threads)
|