diff --git a/models/common.py b/models/common.py index d735e10fe..4841c09fc 100644 --- a/models/common.py +++ b/models/common.py @@ -870,7 +870,13 @@ class AutoShape(nn.Module): x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32 - with amp.autocast(autocast): + amp_autocast = None + if check_version(torch.__version__, "2.4.0"): + amp_autocast = torch.amp.autocast("cuda", enabled=autocast) + else: + amp_autocast = torch.cuda.amp.autocast(enabled=autocast) + + with amp_autocast: # Inference with dt[1]: y = self.model(x, augment=augment) # forward