diff --git a/models/export.py b/models/export.py
index 90855d258..6a9f1df57 100644
--- a/models/export.py
+++ b/models/export.py
@@ -26,9 +26,9 @@ if __name__ == '__main__':
     parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
     parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size')  # height, width
     parser.add_argument('--batch-size', type=int, default=1, help='batch size')
-    parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
     parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
     parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
+    parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True')
     parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes')  # ONNX-only
     parser.add_argument('--simplify', action='store_true', help='simplify ONNX model')  # ONNX-only
     opt = parser.parse_args()
@@ -60,9 +60,11 @@ if __name__ == '__main__':
                 m.act = Hardswish()
             elif isinstance(m.act, nn.SiLU):
                 m.act = SiLU()
-        # elif isinstance(m, models.yolo.Detect):
-        #     m.forward = m.forward_export  # assign forward (optional)
-    model.model[-1].export = not opt.grid  # set Detect() layer grid export
+        elif isinstance(m, models.yolo.Detect):
+            m.inplace = opt.inplace
+            m.onnx_dynamic = opt.dynamic
+            # m.forward = m.forward_export  # assign forward (optional)
+
     for _ in range(2):
         y = model(img)  # dry runs
     print(f"\n{colorstr('PyTorch:')} starting from {opt.weights} ({file_size(opt.weights):.1f} MB)")
diff --git a/models/yolo.py b/models/yolo.py
index 520047ff3..314fd806f 100644
--- a/models/yolo.py
+++ b/models/yolo.py
@@ -24,7 +24,7 @@ except ImportError:
 
 class Detect(nn.Module):
     stride = None  # strides computed during build
-    export = False  # onnx export
+    onnx_dynamic = False  # ONNX export parameter
 
     def __init__(self, nc=80, anchors=(), ch=(), inplace=True):  # detection layer
         super(Detect, self).__init__()
@@ -42,14 +42,13 @@ class Detect(nn.Module):
     def forward(self, x):
         # x = x.copy()  # for profiling
         z = []  # inference output
-        self.training |= self.export
         for i in range(self.nl):
             x[i] = self.m[i](x[i])  # conv
             bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
             x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
 
             if not self.training:  # inference
-                if self.grid[i].shape[2:4] != x[i].shape[2:4]:
+                if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
                     self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
 
                 y = x[i].sigmoid()
@@ -58,7 +57,7 @@ class Detect(nn.Module):
                     y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                 else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
                     xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
-                    wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
+                    wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2)  # wh
                     y = torch.cat((xy, wh, y[..., 4:]), -1)
                 z.append(y.view(bs, -1, self.no))