diff --git a/mmseg/models/decode_heads/ham_head.py b/mmseg/models/decode_heads/ham_head.py index d80025f77..073d8011b 100644 --- a/mmseg/models/decode_heads/ham_head.py +++ b/mmseg/models/decode_heads/ham_head.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule +from mmengine.device import get_device from mmseg.registry import MODELS from ..utils import resize @@ -52,7 +53,7 @@ class Matrix_Decomposition_2D_Base(nn.Module): self.rand_init = rand_init - def _build_bases(self, B, S, D, R, cuda=False): + def _build_bases(self, B, S, D, R, device=None): raise NotImplementedError def local_step(self, x, bases, coef): @@ -80,14 +81,13 @@ class Matrix_Decomposition_2D_Base(nn.Module): D = C // self.S N = H * W x = x.view(B * self.S, D, N) - cuda = 'cuda' in str(x.device) if not self.rand_init and not hasattr(self, 'bases'): - bases = self._build_bases(1, self.S, D, self.R, cuda=cuda) + bases = self._build_bases(1, self.S, D, self.R, device=x.device) self.register_buffer('bases', bases) # (S, D, R) -> (B * S, D, R) if self.rand_init: - bases = self._build_bases(B, self.S, D, self.R, cuda=cuda) + bases = self._build_bases(B, self.S, D, self.R, device=x.device) else: bases = self.bases.repeat(B, 1, 1) @@ -116,13 +116,11 @@ class NMF2D(Matrix_Decomposition_2D_Base): self.inv_t = 1 - def _build_bases(self, B, S, D, R, cuda=False): + def _build_bases(self, B, S, D, R, device=None): """Build bases in initialization.""" - if cuda: - bases = torch.rand((B * S, D, R)).cuda() - else: - bases = torch.rand((B * S, D, R)) - + if device is None: + device = get_device() + bases = torch.rand((B * S, D, R)).to(device) bases = F.normalize(bases, dim=1) return bases