mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge branch 'npu_support' of github.com:MengqingCao/pytorch-image-models into MengqingCao-npu_support
This commit is contained in:
commit
81b59faf77
@ -113,13 +113,17 @@ class PrefetchLoader:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.random_erasing = None
|
self.random_erasing = None
|
||||||
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
|
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
|
||||||
|
self.is_npu = device.type == 'npu' and torch.npu.is_available()
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
first = True
|
first = True
|
||||||
if self.is_cuda:
|
if self.is_cuda:
|
||||||
stream = torch.cuda.Stream()
|
stream = torch.cuda.Stream()
|
||||||
stream_context = partial(torch.cuda.stream, stream=stream)
|
stream_context = partial(torch.cuda.stream, stream=stream)
|
||||||
|
elif self.is_npu:
|
||||||
|
stream = torch.npu.Stream()
|
||||||
|
stream_context = partial(torch.npu.stream, stream=stream)
|
||||||
else:
|
else:
|
||||||
stream = None
|
stream = None
|
||||||
stream_context = suppress
|
stream_context = suppress
|
||||||
@ -139,7 +143,10 @@ class PrefetchLoader:
|
|||||||
first = False
|
first = False
|
||||||
|
|
||||||
if stream is not None:
|
if stream is not None:
|
||||||
|
if self.is_cuda:
|
||||||
torch.cuda.current_stream().wait_stream(stream)
|
torch.cuda.current_stream().wait_stream(stream)
|
||||||
|
elif self.is_npu:
|
||||||
|
torch.npu.current_stream().wait_stream(stream)
|
||||||
|
|
||||||
input = next_input
|
input = next_input
|
||||||
target = next_target
|
target = next_target
|
||||||
|
@ -116,6 +116,7 @@ def init_distributed_device_so(
|
|||||||
"xpu": "ccl",
|
"xpu": "ccl",
|
||||||
"hpu": "hccl",
|
"hpu": "hccl",
|
||||||
"cuda": "nccl",
|
"cuda": "nccl",
|
||||||
|
"npu": "hccl",
|
||||||
}
|
}
|
||||||
dist_backend = dist_backends.get(device_type, 'gloo')
|
dist_backend = dist_backends.get(device_type, 'gloo')
|
||||||
dist_url = dist_url or 'env://'
|
dist_url = dist_url or 'env://'
|
||||||
@ -159,6 +160,8 @@ def init_distributed_device_so(
|
|||||||
|
|
||||||
if device_type == 'cuda':
|
if device_type == 'cuda':
|
||||||
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
|
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
|
||||||
|
if device_type == 'npu':
|
||||||
|
assert torch.npu.is_available(), f'Ascend NPU is not available but {device} was specified.'
|
||||||
|
|
||||||
if distributed and device != 'cpu':
|
if distributed and device != 'cpu':
|
||||||
# Ignore manually specified device index in distributed mode and
|
# Ignore manually specified device index in distributed mode and
|
||||||
|
7
train.py
7
train.py
@ -1054,8 +1054,11 @@ def train_one_epoch(
|
|||||||
if model_ema is not None:
|
if model_ema is not None:
|
||||||
model_ema.update(model, step=num_updates)
|
model_ema.update(model, step=num_updates)
|
||||||
|
|
||||||
if args.synchronize_step and device.type == 'cuda':
|
if args.synchronize_step:
|
||||||
|
if device.type == 'cuda':
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
elif device.type == 'npu':
|
||||||
|
torch.npu.synchronize()
|
||||||
time_now = time.time()
|
time_now = time.time()
|
||||||
update_time_m.update(time.time() - update_start_time)
|
update_time_m.update(time.time() - update_start_time)
|
||||||
update_start_time = time_now
|
update_start_time = time_now
|
||||||
@ -1155,6 +1158,8 @@ def validate(
|
|||||||
|
|
||||||
if device.type == 'cuda':
|
if device.type == 'cuda':
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
elif device.type == "npu":
|
||||||
|
torch.npu.synchronize()
|
||||||
|
|
||||||
losses_m.update(reduced_loss.item(), input.size(0))
|
losses_m.update(reduced_loss.item(), input.size(0))
|
||||||
top1_m.update(acc1.item(), output.size(0))
|
top1_m.update(acc1.item(), output.size(0))
|
||||||
|
@ -395,8 +395,10 @@ def _try_run(args, initial_batch_size):
|
|||||||
while batch_size:
|
while batch_size:
|
||||||
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
|
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
|
||||||
try:
|
try:
|
||||||
if torch.cuda.is_available() and 'cuda' in args.device:
|
if 'cuda' in args.device and torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
elif "npu" in args.device and torch.npu.is_available():
|
||||||
|
torch.npu.empty_cache()
|
||||||
results = validate(args)
|
results = validate(args)
|
||||||
return results
|
return results
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user