fix torch 1.10 amp error (#330)
parent
2b8a32eca0
commit
a4f5533db6
|
@ -77,7 +77,20 @@ def autocast(enabled: bool = True, **kwargs):
|
|||
'If pytorch versions is between 1.5.0 and 1.10, '
|
||||
'`autocast` is only available in gpu mode')
|
||||
|
||||
elif digit_version(TORCH_VERSION) >= digit_version('1.10.0'):
|
||||
elif (digit_version('1.11.0') > digit_version(TORCH_VERSION) >=
|
||||
digit_version('1.10.0')):
|
||||
if torch.cuda.is_available():
|
||||
kwargs.setdefault('device_type', 'cuda')
|
||||
else:
|
||||
kwargs.setdefault('device_type', 'cpu')
|
||||
# torch.autocast only support `dtype=torch.bfloat16` in
|
||||
# pytorch 1.10
|
||||
kwargs.setdefault('dtype', torch.bfloat16)
|
||||
|
||||
with torch.autocast(enabled=enabled, **kwargs):
|
||||
yield
|
||||
|
||||
elif digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
|
||||
if torch.cuda.is_available():
|
||||
kwargs.setdefault('device_type', 'cuda')
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue