Use to refactor the device selection.

pull/1699/head
mzr1996 2023-07-25 09:50:33 +08:00
parent 2354e052a5
commit a0a220d84b
3 changed files with 18 additions and 33 deletions

View File

@ -3,6 +3,7 @@ import os.path as osp
from typing import Dict, List, Optional, Sequence
import torch
from mmengine.device import get_device
from mmengine.dist import get_rank, get_world_size, is_distributed
from mmengine.hooks import Hook
from mmengine.logging import MMLogger
@ -97,25 +98,13 @@ class SwAVHook(Hook):
if self.queue_length > 0 \
and runner.epoch >= self.epoch_queue_starts \
and self.queue is None:
if torch.cuda.is_available():
self.queue = torch.zeros(
len(self.crops_for_assign),
self.queue_length // runner.world_size,
self.feat_dim,
).cuda()
elif hasattr(torch.backends,
'mps') and torch.backends.mps.is_available():
self.queue = torch.zeros(
len(self.crops_for_assign),
self.queue_length // runner.world_size,
self.feat_dim,
).to(torch.device('mps'))
else:
self.queue = torch.zeros(
len(self.crops_for_assign),
self.queue_length // runner.world_size,
self.feat_dim,
)
self.queue = torch.zeros(
len(self.crops_for_assign),
self.queue_length // runner.world_size,
self.feat_dim,
device=get_device(),
)
# set the boolean type of use_the_queue
get_ori_model(runner.model).head.loss_module.queue = self.queue

View File

@ -3,6 +3,7 @@ from typing import List, Optional, Union
import torch
import torch.nn as nn
from mmengine.device import get_device
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@ -43,11 +44,7 @@ class iTPNClipHead(BaseModule):
target (torch.Tensor): Target generated by target_generator.
mask (torch.Tensor): Generated mask for pretraing.
"""
if torch.cuda.is_available():
mask = mask.to(torch.device('cuda'), non_blocking=True)
elif hasattr(torch.backends,
'mps') and torch.backends.mps.is_available():
mask = mask.to(torch.device('mps'), non_blocking=True)
mask = mask.to(get_device(), non_blocking=True)
mask = mask.flatten(1).to(torch.bool)
target = target[mask]

View File

@ -20,27 +20,26 @@ mmpretrain.utils.progress.disable_progress_bar = True
logger = MMLogger('mmpretrain', logger_name='mmpre')
if torch.cuda.is_available():
gpus = [
devices = [
torch.device(f'cuda:{i}') for i in range(torch.cuda.device_count())
]
logger.info(f'Available GPUs: {len(gpus)}')
logger.info(f'Available GPUs: {len(devices)}')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
devices = [torch.device('mps')]
logger.info('Available MPS.')
else:
gpus = None
logger.info('No available GPU.')
devices = [torch.device('cpu')]
logger.info('Available CPU.')
def get_free_device():
if gpus is None:
return torch.device('cpu')
if hasattr(torch.cuda, 'mem_get_info'):
free = [torch.cuda.mem_get_info(gpu)[0] for gpu in gpus]
free = [torch.cuda.mem_get_info(gpu)[0] for gpu in devices]
select = max(zip(free, range(len(free))))[1]
else:
import random
select = random.randint(0, len(gpus) - 1)
return gpus[select]
select = random.randint(0, len(devices) - 1)
return devices[select]
class InferencerCache: