fix torch 1.10 amp error (#330)

pull/332/head
Mashiro 2022-06-22 23:12:20 +08:00 committed by GitHub
parent 2b8a32eca0
commit a4f5533db6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 1 deletions

View File

@ -77,7 +77,20 @@ def autocast(enabled: bool = True, **kwargs):
'If pytorch versions is between 1.5.0 and 1.10, '
'`autocast` is only available in gpu mode')
elif digit_version(TORCH_VERSION) >= digit_version('1.10.0'):
elif (digit_version('1.11.0') > digit_version(TORCH_VERSION) >=
digit_version('1.10.0')):
if torch.cuda.is_available():
kwargs.setdefault('device_type', 'cuda')
else:
kwargs.setdefault('device_type', 'cpu')
# torch.autocast only support `dtype=torch.bfloat16` in
# pytorch 1.10
kwargs.setdefault('dtype', torch.bfloat16)
with torch.autocast(enabled=enabled, **kwargs):
yield
elif digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
if torch.cuda.is_available():
kwargs.setdefault('device_type', 'cuda')
else: