diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 3ba6b62ce..6fa7e3b34 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -16,7 +16,8 @@ def init_segmentor(config, checkpoint=None, device='cuda:0'): object. checkpoint (str, optional): Checkpoint path. If left as None, the model will not load any weights. - + device (str, optional) CPU/CUDA device option. Default 'cuda:0'. + Use 'cpu' for loading model on CPU. Returns: nn.Module: The constructed segmentor. """ @@ -28,7 +29,7 @@ def init_segmentor(config, checkpoint=None, device='cuda:0'): config.model.pretrained = None model = build_segmentor(config.model, test_cfg=config.test_cfg) if checkpoint is not None: - checkpoint = load_checkpoint(model, checkpoint) + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') model.CLASSES = checkpoint['meta']['CLASSES'] model.PALETTE = checkpoint['meta']['PALETTE'] model.cfg = config # save the config in the model for convenience