mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix device check
This commit is contained in:
parent
234f975787
commit
37c731ca37
@ -113,8 +113,8 @@ class PrefetchLoader:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.random_erasing = None
|
self.random_erasing = None
|
||||||
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
|
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
|
||||||
self.is_npu = torch.npu.is_available() and device.type == 'npu'
|
self.is_npu = device.type == 'npu' and torch.npu.is_available()
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
first = True
|
first = True
|
||||||
|
@ -395,9 +395,9 @@ def _try_run(args, initial_batch_size):
|
|||||||
while batch_size:
|
while batch_size:
|
||||||
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
|
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
|
||||||
try:
|
try:
|
||||||
if torch.cuda.is_available() and 'cuda' in args.device:
|
if 'cuda' in args.device and torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
elif torch.npu.is_available() and "npu" in args.device:
|
elif "npu" in args.device and torch.npu.is_available():
|
||||||
torch.npu.empty_cache()
|
torch.npu.empty_cache()
|
||||||
results = validate(args)
|
results = validate(args)
|
||||||
return results
|
return results
|
||||||
|
Loading…
x
Reference in New Issue
Block a user