fixed a missed amp.autocast
parent
0c6aaa0665
commit
3e9b43e8d2
|
@ -870,7 +870,13 @@ class AutoShape(nn.Module):
|
|||
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
|
||||
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
||||
|
||||
with amp.autocast(autocast):
|
||||
amp_autocast = None
|
||||
if check_version(torch.__version__, "2.4.0"):
|
||||
amp_autocast = torch.amp.autocast("cuda", enabled=autocast)
|
||||
else:
|
||||
amp_autocast = torch.cuda.amp.autocast(enabled=autocast)
|
||||
|
||||
with amp_autocast:
|
||||
# Inference
|
||||
with dt[1]:
|
||||
y = self.model(x, augment=augment) # forward
|
||||
|
|
Loading…
Reference in New Issue