mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Improve] Add exception for PointRend for support CPU-only usage (#1271)
* [Improve] Add exception for PointRend for support CPU-only usage * fixed linting
This commit is contained in:
parent
66b778c064
commit
82a80880d2
@ -4,7 +4,11 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.cnn import ConvModule
|
from mmcv.cnn import ConvModule
|
||||||
|
|
||||||
|
try:
|
||||||
from mmcv.ops import point_sample
|
from mmcv.ops import point_sample
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
point_sample = None
|
||||||
|
|
||||||
from mmseg.models.builder import HEADS
|
from mmseg.models.builder import HEADS
|
||||||
from mmseg.ops import resize
|
from mmseg.ops import resize
|
||||||
@ -75,6 +79,9 @@ class PointHead(BaseCascadeDecodeHead):
|
|||||||
init_cfg=dict(
|
init_cfg=dict(
|
||||||
type='Normal', std=0.01, override=dict(name='fc_seg')),
|
type='Normal', std=0.01, override=dict(name='fc_seg')),
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
if point_sample is None:
|
||||||
|
raise RuntimeError('Please install mmcv-full for '
|
||||||
|
'point_sample ops')
|
||||||
|
|
||||||
self.num_fcs = num_fcs
|
self.num_fcs = num_fcs
|
||||||
self.coarse_pred_each_layer = coarse_pred_each_layer
|
self.coarse_pred_each_layer = coarse_pred_each_layer
|
||||||
|
Loading…
x
Reference in New Issue
Block a user