mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge branch 'npu_support' of github.com:MengqingCao/pytorch-image-models into MengqingCao-npu_support
This commit is contained in:
commit
81b59faf77
@ -113,13 +113,17 @@ class PrefetchLoader:
|
||||
)
|
||||
else:
|
||||
self.random_erasing = None
|
||||
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
|
||||
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
|
||||
self.is_npu = device.type == 'npu' and torch.npu.is_available()
|
||||
|
||||
def __iter__(self):
|
||||
first = True
|
||||
if self.is_cuda:
|
||||
stream = torch.cuda.Stream()
|
||||
stream_context = partial(torch.cuda.stream, stream=stream)
|
||||
elif self.is_npu:
|
||||
stream = torch.npu.Stream()
|
||||
stream_context = partial(torch.npu.stream, stream=stream)
|
||||
else:
|
||||
stream = None
|
||||
stream_context = suppress
|
||||
@ -139,7 +143,10 @@ class PrefetchLoader:
|
||||
first = False
|
||||
|
||||
if stream is not None:
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
if self.is_cuda:
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
elif self.is_npu:
|
||||
torch.npu.current_stream().wait_stream(stream)
|
||||
|
||||
input = next_input
|
||||
target = next_target
|
||||
|
@ -116,6 +116,7 @@ def init_distributed_device_so(
|
||||
"xpu": "ccl",
|
||||
"hpu": "hccl",
|
||||
"cuda": "nccl",
|
||||
"npu": "hccl",
|
||||
}
|
||||
dist_backend = dist_backends.get(device_type, 'gloo')
|
||||
dist_url = dist_url or 'env://'
|
||||
@ -159,6 +160,8 @@ def init_distributed_device_so(
|
||||
|
||||
if device_type == 'cuda':
|
||||
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.'
|
||||
if device_type == 'npu':
|
||||
assert torch.npu.is_available(), f'Ascend NPU is not available but {device} was specified.'
|
||||
|
||||
if distributed and device != 'cpu':
|
||||
# Ignore manually specified device index in distributed mode and
|
||||
|
9
train.py
9
train.py
@ -1054,8 +1054,11 @@ def train_one_epoch(
|
||||
if model_ema is not None:
|
||||
model_ema.update(model, step=num_updates)
|
||||
|
||||
if args.synchronize_step and device.type == 'cuda':
|
||||
torch.cuda.synchronize()
|
||||
if args.synchronize_step:
|
||||
if device.type == 'cuda':
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == 'npu':
|
||||
torch.npu.synchronize()
|
||||
time_now = time.time()
|
||||
update_time_m.update(time.time() - update_start_time)
|
||||
update_start_time = time_now
|
||||
@ -1155,6 +1158,8 @@ def validate(
|
||||
|
||||
if device.type == 'cuda':
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "npu":
|
||||
torch.npu.synchronize()
|
||||
|
||||
losses_m.update(reduced_loss.item(), input.size(0))
|
||||
top1_m.update(acc1.item(), output.size(0))
|
||||
|
@ -395,8 +395,10 @@ def _try_run(args, initial_batch_size):
|
||||
while batch_size:
|
||||
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
|
||||
try:
|
||||
if torch.cuda.is_available() and 'cuda' in args.device:
|
||||
if 'cuda' in args.device and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif "npu" in args.device and torch.npu.is_available():
|
||||
torch.npu.empty_cache()
|
||||
results = validate(args)
|
||||
return results
|
||||
except RuntimeError as e:
|
||||
|
Loading…
x
Reference in New Issue
Block a user