diff --git a/mmpretrain/engine/hooks/swav_hook.py b/mmpretrain/engine/hooks/swav_hook.py index 27cf73d1..be5f3a36 100644 --- a/mmpretrain/engine/hooks/swav_hook.py +++ b/mmpretrain/engine/hooks/swav_hook.py @@ -3,6 +3,7 @@ import os.path as osp from typing import Dict, List, Optional, Sequence import torch +from mmengine.device import get_device from mmengine.dist import get_rank, get_world_size, is_distributed from mmengine.hooks import Hook from mmengine.logging import MMLogger @@ -97,25 +98,13 @@ class SwAVHook(Hook): if self.queue_length > 0 \ and runner.epoch >= self.epoch_queue_starts \ and self.queue is None: - if torch.cuda.is_available(): - self.queue = torch.zeros( - len(self.crops_for_assign), - self.queue_length // runner.world_size, - self.feat_dim, - ).cuda() - elif hasattr(torch.backends, - 'mps') and torch.backends.mps.is_available(): - self.queue = torch.zeros( - len(self.crops_for_assign), - self.queue_length // runner.world_size, - self.feat_dim, - ).to(torch.device('mps')) - else: - self.queue = torch.zeros( - len(self.crops_for_assign), - self.queue_length // runner.world_size, - self.feat_dim, - ) + + self.queue = torch.zeros( + len(self.crops_for_assign), + self.queue_length // runner.world_size, + self.feat_dim, + device=get_device(), + ) # set the boolean type of use_the_queue get_ori_model(runner.model).head.loss_module.queue = self.queue diff --git a/mmpretrain/models/heads/itpn_clip_head.py b/mmpretrain/models/heads/itpn_clip_head.py index 5fd19131..52c49b8c 100644 --- a/mmpretrain/models/heads/itpn_clip_head.py +++ b/mmpretrain/models/heads/itpn_clip_head.py @@ -3,6 +3,7 @@ from typing import List, Optional, Union import torch import torch.nn as nn +from mmengine.device import get_device from mmengine.model import BaseModule from mmpretrain.registry import MODELS @@ -43,11 +44,7 @@ class iTPNClipHead(BaseModule): target (torch.Tensor): Target generated by target_generator. mask (torch.Tensor): Generated mask for pretraing. """ - if torch.cuda.is_available(): - mask = mask.to(torch.device('cuda'), non_blocking=True) - elif hasattr(torch.backends, - 'mps') and torch.backends.mps.is_available(): - mask = mask.to(torch.device('mps'), non_blocking=True) + mask = mask.to(get_device(), non_blocking=True) mask = mask.flatten(1).to(torch.bool) target = target[mask] diff --git a/projects/gradio_demo/launch.py b/projects/gradio_demo/launch.py index 03d31f1a..48b9e1df 100644 --- a/projects/gradio_demo/launch.py +++ b/projects/gradio_demo/launch.py @@ -20,27 +20,26 @@ mmpretrain.utils.progress.disable_progress_bar = True logger = MMLogger('mmpretrain', logger_name='mmpre') if torch.cuda.is_available(): - gpus = [ + devices = [ torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count()) ] - logger.info(f'Available GPUs: {len(gpus)}') + logger.info(f'Available GPUs: {len(devices)}') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + devices = [torch.device('mps')] logger.info('Available MPS.') else: - gpus = None - logger.info('No available GPU.') + devices = [torch.device('cpu')] + logger.info('Available CPU.') def get_free_device(): - if gpus is None: - return torch.device('cpu') if hasattr(torch.cuda, 'mem_get_info'): - free = [torch.cuda.mem_get_info(gpu)[0] for gpu in gpus] + free = [torch.cuda.mem_get_info(gpu)[0] for gpu in devices] select = max(zip(free, range(len(free))))[1] else: import random - select = random.randint(0, len(gpus) - 1) - return gpus[select] + select = random.randint(0, len(devices) - 1) + return devices[select] class InferencerCache: