fixed a missed amp.autocast

pull/13244/head
Jacob Brown 2024-08-05 16:19:18 -06:00
parent 0c6aaa0665
commit 3e9b43e8d2
1 changed files with 7 additions and 1 deletions

View File

@ -870,7 +870,13 @@ class AutoShape(nn.Module):
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
with amp.autocast(autocast):
amp_autocast = None
if check_version(torch.__version__, "2.4.0"):
amp_autocast = torch.amp.autocast("cuda", enabled=autocast)
else:
amp_autocast = torch.cuda.amp.autocast(enabled=autocast)
with amp_autocast:
# Inference
with dt[1]:
y = self.model(x, augment=augment) # forward