[Refactor] Handle case where device is neither CPU nor CUDA in HamHead (#2868)

This commit is contained in:
Junhwa Song 2023-04-14 11:12:49 +09:00 committed by GitHub
parent 969f50459d
commit ced29fcaf8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,6 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.device import get_device
from mmseg.registry import MODELS
from ..utils import resize
@ -52,7 +53,7 @@ class Matrix_Decomposition_2D_Base(nn.Module):
self.rand_init = rand_init
def _build_bases(self, B, S, D, R, cuda=False):
def _build_bases(self, B, S, D, R, device=None):
raise NotImplementedError
def local_step(self, x, bases, coef):
@ -80,14 +81,13 @@ class Matrix_Decomposition_2D_Base(nn.Module):
D = C // self.S
N = H * W
x = x.view(B * self.S, D, N)
cuda = 'cuda' in str(x.device)
if not self.rand_init and not hasattr(self, 'bases'):
bases = self._build_bases(1, self.S, D, self.R, cuda=cuda)
bases = self._build_bases(1, self.S, D, self.R, device=x.device)
self.register_buffer('bases', bases)
# (S, D, R) -> (B * S, D, R)
if self.rand_init:
bases = self._build_bases(B, self.S, D, self.R, cuda=cuda)
bases = self._build_bases(B, self.S, D, self.R, device=x.device)
else:
bases = self.bases.repeat(B, 1, 1)
@ -116,13 +116,11 @@ class NMF2D(Matrix_Decomposition_2D_Base):
self.inv_t = 1
def _build_bases(self, B, S, D, R, cuda=False):
def _build_bases(self, B, S, D, R, device=None):
"""Build bases in initialization."""
if cuda:
bases = torch.rand((B * S, D, R)).cuda()
else:
bases = torch.rand((B * S, D, R))
if device is None:
device = get_device()
bases = torch.rand((B * S, D, R)).to(device)
bases = F.normalize(bases, dim=1)
return bases