fix device check

This commit is contained in:
MengqingCao 2024-10-17 12:38:02 +00:00
parent 234f975787
commit 37c731ca37
2 changed files with 4 additions and 4 deletions

View File

@ -113,8 +113,8 @@ class PrefetchLoader:
)
else:
self.random_erasing = None
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
self.is_npu = torch.npu.is_available() and device.type == 'npu'
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
self.is_npu = device.type == 'npu' and torch.npu.is_available()
def __iter__(self):
first = True

View File

@ -395,9 +395,9 @@ def _try_run(args, initial_batch_size):
while batch_size:
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
try:
if torch.cuda.is_available() and 'cuda' in args.device:
if 'cuda' in args.device and torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.npu.is_available() and "npu" in args.device:
elif "npu" in args.device and torch.npu.is_available():
torch.npu.empty_cache()
results = validate(args)
return results