diff --git a/detect.py b/detect.py
index 53b63eb..841926c 100644
--- a/detect.py
+++ b/detect.py
@@ -73,7 +73,7 @@ def detect(save_img=False):
 
         # Inference
         t1 = time_synchronized()
-        pred = model(img, augment=opt.augment)[0]
+        pred = model(img, augment=opt.augment)
 
         # Apply NMS
         pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
diff --git a/models/export.py b/export.py
similarity index 86%
rename from models/export.py
rename to export.py
index dc12559..06dfc94 100644
--- a/models/export.py
+++ b/export.py
@@ -21,6 +21,7 @@ if __name__ == '__main__':
     parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes')
     parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
     parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
+    parser.add_argument('--simplify', action='store_true', help='simplify onnx model')
     opt = parser.parse_args()
     opt.img_size *= 2 if len(opt.img_size) == 1 else 1  # expand
     print(opt)
@@ -68,6 +69,7 @@ if __name__ == '__main__':
 
         print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
         f = opt.weights.replace('.pt', '.onnx')  # filename
+        model.eval()
         torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
                           output_names=['classes', 'boxes'] if y is None else ['output'],
                           dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'},  # size(1,3,640,640)
@@ -76,6 +78,23 @@ if __name__ == '__main__':
         # Checks
         onnx_model = onnx.load(f)  # load onnx model
         onnx.checker.check_model(onnx_model)  # check onnx model
+
+        # # Metadata
+        # d = {'stride': int(max(model.stride))}
+        # for k, v in d.items():
+        #     meta = onnx_model.metadata_props.add()
+        #     meta.key, meta.value = k, str(v)
+        # onnx.save(onnx_model, f)
+
+        if opt.simplify:
+            try:
+                import onnxsim
+
+                print('\nStarting to simplify ONNX...')
+                onnx_model, check = onnxsim.simplify(onnx_model)
+                assert check, 'assert check failed'
+            except Exception as e:
+                print(f'Simplifier failure: {e}')
         # print(onnx.helper.printable_graph(onnx_model.graph))  # print a human readable model
         print('ONNX export success, saved as %s' % f)
     except Exception as e:
diff --git a/models/yolo.py b/models/yolo.py
index dcb073e..951452d 100644
--- a/models/yolo.py
+++ b/models/yolo.py
@@ -50,11 +50,16 @@ class Detect(nn.Module):
                     self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
 
                 y = x[i].sigmoid()
-                y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
-                y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
+                if not torch.onnx.is_in_onnx_export():
+                    y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
+                    y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
+                else:
+                    xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i]  # xy
+                    wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
+                    y = torch.cat((xy, wh, y[..., 4:]), -1)
                 z.append(y.view(bs, -1, self.no))
 
-        return x if self.training else (torch.cat(z, 1), x)
+        return x if self.training else torch.cat(z, 1)
 
     @staticmethod
     def _make_grid(nx=20, ny=20):