mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Refactor] Handle case where device is neither CPU nor CUDA in HamHead (#2868)
This commit is contained in:
parent
969f50459d
commit
ced29fcaf8
@ -5,6 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.device import get_device
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
@ -52,7 +53,7 @@ class Matrix_Decomposition_2D_Base(nn.Module):
|
||||
|
||||
self.rand_init = rand_init
|
||||
|
||||
def _build_bases(self, B, S, D, R, cuda=False):
|
||||
def _build_bases(self, B, S, D, R, device=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def local_step(self, x, bases, coef):
|
||||
@ -80,14 +81,13 @@ class Matrix_Decomposition_2D_Base(nn.Module):
|
||||
D = C // self.S
|
||||
N = H * W
|
||||
x = x.view(B * self.S, D, N)
|
||||
cuda = 'cuda' in str(x.device)
|
||||
if not self.rand_init and not hasattr(self, 'bases'):
|
||||
bases = self._build_bases(1, self.S, D, self.R, cuda=cuda)
|
||||
bases = self._build_bases(1, self.S, D, self.R, device=x.device)
|
||||
self.register_buffer('bases', bases)
|
||||
|
||||
# (S, D, R) -> (B * S, D, R)
|
||||
if self.rand_init:
|
||||
bases = self._build_bases(B, self.S, D, self.R, cuda=cuda)
|
||||
bases = self._build_bases(B, self.S, D, self.R, device=x.device)
|
||||
else:
|
||||
bases = self.bases.repeat(B, 1, 1)
|
||||
|
||||
@ -116,13 +116,11 @@ class NMF2D(Matrix_Decomposition_2D_Base):
|
||||
|
||||
self.inv_t = 1
|
||||
|
||||
def _build_bases(self, B, S, D, R, cuda=False):
|
||||
def _build_bases(self, B, S, D, R, device=None):
|
||||
"""Build bases in initialization."""
|
||||
if cuda:
|
||||
bases = torch.rand((B * S, D, R)).cuda()
|
||||
else:
|
||||
bases = torch.rand((B * S, D, R))
|
||||
|
||||
if device is None:
|
||||
device = get_device()
|
||||
bases = torch.rand((B * S, D, R)).to(device)
|
||||
bases = F.normalize(bases, dim=1)
|
||||
|
||||
return bases
|
||||
|
Loading…
x
Reference in New Issue
Block a user