mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
fix torch 1.10 amp error (#330)
This commit is contained in:
parent
2b8a32eca0
commit
a4f5533db6
@ -77,7 +77,20 @@ def autocast(enabled: bool = True, **kwargs):
|
|||||||
'If pytorch versions is between 1.5.0 and 1.10, '
|
'If pytorch versions is between 1.5.0 and 1.10, '
|
||||||
'`autocast` is only available in gpu mode')
|
'`autocast` is only available in gpu mode')
|
||||||
|
|
||||||
elif digit_version(TORCH_VERSION) >= digit_version('1.10.0'):
|
elif (digit_version('1.11.0') > digit_version(TORCH_VERSION) >=
|
||||||
|
digit_version('1.10.0')):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
kwargs.setdefault('device_type', 'cuda')
|
||||||
|
else:
|
||||||
|
kwargs.setdefault('device_type', 'cpu')
|
||||||
|
# torch.autocast only support `dtype=torch.bfloat16` in
|
||||||
|
# pytorch 1.10
|
||||||
|
kwargs.setdefault('dtype', torch.bfloat16)
|
||||||
|
|
||||||
|
with torch.autocast(enabled=enabled, **kwargs):
|
||||||
|
yield
|
||||||
|
|
||||||
|
elif digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
kwargs.setdefault('device_type', 'cuda')
|
kwargs.setdefault('device_type', 'cuda')
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user