mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Fix cpu inference (#152)
* Add missing map_location * Add docstring * Update mmseg/apis/inference.py Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com> * Update inference.py * Update inference.py Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>
This commit is contained in:
parent
746c8d3785
commit
7f8bc7935c
@ -16,7 +16,8 @@ def init_segmentor(config, checkpoint=None, device='cuda:0'):
|
||||
object.
|
||||
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
||||
will not load any weights.
|
||||
|
||||
device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
|
||||
Use 'cpu' for loading model on CPU.
|
||||
Returns:
|
||||
nn.Module: The constructed segmentor.
|
||||
"""
|
||||
@ -28,7 +29,7 @@ def init_segmentor(config, checkpoint=None, device='cuda:0'):
|
||||
config.model.pretrained = None
|
||||
model = build_segmentor(config.model, test_cfg=config.test_cfg)
|
||||
if checkpoint is not None:
|
||||
checkpoint = load_checkpoint(model, checkpoint)
|
||||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||
model.CLASSES = checkpoint['meta']['CLASSES']
|
||||
model.PALETTE = checkpoint['meta']['PALETTE']
|
||||
model.cfg = config # save the config in the model for convenience
|
||||
|
Loading…
x
Reference in New Issue
Block a user