diff --git a/README.md b/README.md
index 38113bf..bdd74ab 100644
--- a/README.md
+++ b/README.md
@@ -153,20 +153,26 @@ python detect.py --weights yolov7.pt --conf 0.25 --img-size 640 --source inferen
 
 
 ## Export
-Tested with: Python 3.7.13 and Pytorch 1.12.0+cu113 
-Pytorch to ONNX, use `--include-nms` flag for the end-to-end ONNX model with `EfficientNMS`.
+
+Pytorch -> ONNX -> TensorRT -> Detection on TensorRT in Python <a href="https://colab.research.google.com/gist/AlexeyAB/fcb47ae544cf284eb24d8ad8e880d45c/yolov7trtlinaom.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
+
+
+**Pytorch to ONNX**, use `--include-nms` flag for the end-to-end ONNX model with `EfficientNMS`
 ```shell
 wget https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7-tiny.pt
 python export.py --weights yolov7-tiny.pt --grid --include-nms
 ```
 
-ONNX to TensorRT
+**ONNX to TensorRT**
 ```shell
 git clone https://github.com/Linaom1214/tensorrt-python.git
 cd tensorrt-python
 python export.py -o yolov7-tiny.onnx -e yolov7-tiny-nms.trt -p fp16
 ```
 
+Tested with: Python 3.7.13, Pytorch 1.12.0+cu113
+
+
 ## Citation
 
 ```
diff --git a/models/yolo.py b/models/yolo.py
index 5d2845f..b988de7 100644
--- a/models/yolo.py
+++ b/models/yolo.py
@@ -84,6 +84,7 @@ class Detect(nn.Module):
 class IDetect(nn.Module):
     stride = None  # strides computed during build
     export = False  # onnx export
+    include_nms = False
 
     def __init__(self, nc=80, anchors=(), ch=()):  # detection layer
         super(IDetect, self).__init__()
@@ -139,7 +140,10 @@ class IDetect(nn.Module):
                 y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                 z.append(y.view(bs, -1, self.no))
 
-        return x if self.training else (torch.cat(z, 1), x)    
+        if self.include_nms:
+            z = self.convert(z)
+
+        return x if self.training else (z, ) if self.include_nms else (torch.cat(z, 1), x)
     
     def fuse(self):
         print("IDetect.fuse")
@@ -160,6 +164,18 @@ class IDetect(nn.Module):
         yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
         return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
 
+    def convert(self, z):
+        z = torch.cat(z, 1)
+        box = z[:, :, :4]
+        conf = z[:, :, 4:5]
+        score = z[:, :, 5:]
+        score *= conf
+        convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
+                                           dtype=torch.float32,
+                                           device=z.device)
+        box @= convert_matrix                          
+        return (box, score)
+
 
 class IKeypoint(nn.Module):
     stride = None  # strides computed during build