Merge branch 'npu_support' of github.com:MengqingCao/pytorch-image-models into MengqingCao-npu_support

This commit is contained in:
Ross Wightman 2024-10-18 14:50:00 -07:00
commit 81b59faf77
4 changed files with 22 additions and 5 deletions

View File

@ -113,13 +113,17 @@ class PrefetchLoader:
)
else:
self.random_erasing = None
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
self.is_npu = device.type == 'npu' and torch.npu.is_available()
def __iter__(self):
first = True
if self.is_cuda:
stream = torch.cuda.Stream()
stream_context = partial(torch.cuda.stream, stream=stream)
elif self.is_npu:
stream = torch.npu.Stream()
stream_context = partial(torch.npu.stream, stream=stream)
else:
stream = None
stream_context = suppress
@ -139,7 +143,10 @@ class PrefetchLoader:
first = False
if stream is not None:
torch.cuda.current_stream().wait_stream(stream)
if self.is_cuda:
torch.cuda.current_stream().wait_stream(stream)
elif self.is_npu:
torch.npu.current_stream().wait_stream(stream)
input = next_input
target = next_target

View File

@ -116,6 +116,7 @@ def init_distributed_device_so(
"xpu": "ccl",
"hpu": "hccl",
"cuda": "nccl",
"npu": "hccl",
}
dist_backend = dist_backends.get(device_type, 'gloo')
dist_url = dist_url or 'env://'
@ -159,6 +160,8 @@ def init_distributed_device_so(
if device_type == 'cuda':
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
if device_type == 'npu':
assert torch.npu.is_available(), f'Ascend NPU is not available but {device} was specified.'
if distributed and device != 'cpu':
# Ignore manually specified device index in distributed mode and

View File

@ -1054,8 +1054,11 @@ def train_one_epoch(
if model_ema is not None:
model_ema.update(model, step=num_updates)
if args.synchronize_step and device.type == 'cuda':
torch.cuda.synchronize()
if args.synchronize_step:
if device.type == 'cuda':
torch.cuda.synchronize()
elif device.type == 'npu':
torch.npu.synchronize()
time_now = time.time()
update_time_m.update(time.time() - update_start_time)
update_start_time = time_now
@ -1155,6 +1158,8 @@ def validate(
if device.type == 'cuda':
torch.cuda.synchronize()
elif device.type == "npu":
torch.npu.synchronize()
losses_m.update(reduced_loss.item(), input.size(0))
top1_m.update(acc1.item(), output.size(0))

View File

@ -395,8 +395,10 @@ def _try_run(args, initial_batch_size):
while batch_size:
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
try:
if torch.cuda.is_available() and 'cuda' in args.device:
if 'cuda' in args.device and torch.cuda.is_available():
torch.cuda.empty_cache()
elif "npu" in args.device and torch.npu.is_available():
torch.npu.empty_cache()
results = validate(args)
return results
except RuntimeError as e: