diff --git a/models/export.py b/models/export.py
index b262df83b..00e3b2a4f 100644
--- a/models/export.py
+++ b/models/export.py
@@ -8,12 +8,14 @@ import argparse
 
 import torch
 
+from models.common import Conv
+from models.experimental import attempt_load
+from utils.activations import Hardswish
 from utils.general import set_logging
-from utils.google_utils import attempt_download
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')
+    parser.add_argument('--weights', type=str, default='./yolov5s.pt', help='weights path')  # from yolov5/models/
     parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size')
     parser.add_argument('--batch-size', type=int, default=1, help='batch size')
     opt = parser.parse_args()
@@ -25,12 +27,15 @@ if __name__ == '__main__':
     img = torch.zeros((opt.batch_size, 3, *opt.img_size))  # image size(1,3,320,192) iDetection
 
     # Load PyTorch model
-    attempt_download(opt.weights)
-    model = torch.load(opt.weights, map_location=torch.device('cpu'))['model'].float()
-    model.eval()
-    model.fuse()
+    model = attempt_load(opt.weights, map_location=torch.device('cpu'))  # load FP32 model
 
     # Update model
+    for k, m in model.named_modules():
+        m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatability
+        if isinstance(m, Conv):
+            m.act = Hardswish()  # assign activation
+        # if isinstance(m, Detect):
+        #    m.forward = m.forward_export  # assign forward (optional)
     model.model[-1].export = True  # set Detect() layer export=True
     y = model(img)  # dry run
 
@@ -56,7 +61,7 @@ if __name__ == '__main__':
         # Checks
         onnx_model = onnx.load(f)  # load onnx model
         onnx.checker.check_model(onnx_model)  # check onnx model
-        print(onnx.helper.printable_graph(onnx_model.graph))  # print a human readable model
+        # print(onnx.helper.printable_graph(onnx_model.graph))  # print a human readable model
         print('ONNX export success, saved as %s' % f)
     except Exception as e:
         print('ONNX export failure: %s' % e)
diff --git a/utils/activations.py b/utils/activations.py
index 58225c6de..162cb9fc3 100644
--- a/utils/activations.py
+++ b/utils/activations.py
@@ -10,11 +10,11 @@ class Swish(nn.Module):  #
         return x * torch.sigmoid(x)
 
 
-class Hardswish(nn.Module):  # alternative to nn.Hardswish() for export
+class Hardswish(nn.Module):  # export-friendly version of nn.Hardswish()
     @staticmethod
     def forward(x):
-        # return x * F.hardsigmoid(x)
-        return x * F.hardtanh(x + 3, 0., 6.) / 6.
+        # return x * F.hardsigmoid(x)  # for torchscript and CoreML
+        return x * F.hardtanh(x + 3, 0., 6.) / 6.  # for torchscript, CoreML and ONNX
 
 
 class MemoryEfficientSwish(nn.Module):