From 2a731fc5d2f90389c9ac07e3b51acd4b22b71826 Mon Sep 17 00:00:00 2001
From: Mohammad Khoshbin <khoshbin.mohammad.mk@gmail.com>
Date: Thu, 28 Jul 2022 20:19:06 +0430
Subject: [PATCH] fix reparametrization for any nc (#342)

---
 tools/reparameterization.ipynb | 51 ++++++++++++++++++++++++++++------
 1 file changed, 43 insertions(+), 8 deletions(-)

diff --git a/tools/reparameterization.ipynb b/tools/reparameterization.ipynb
index 4e9a810..84e4326 100644
--- a/tools/reparameterization.ipynb
+++ b/tools/reparameterization.ipynb
@@ -28,6 +28,7 @@
     "from models.yolo import Model\n",
     "import torch\n",
     "from utils.torch_utils import select_device, is_parallel\n",
+    "import yaml\n",
     "\n",
     "device = select_device('0', batch_size=1)\n",
     "# model trained by cfg/training/*.yaml\n",
@@ -35,6 +36,10 @@
     "# reparameterized model in cfg/deploy/*.yaml\n",
     "model = Model('cfg/deploy/yolov7.yaml', ch=3, nc=80).to(device)\n",
     "\n",
+    "with open('cfg/deploy/yolov7.yaml') as f:\n",
+    "    yml = yaml.load(f, Loader=yaml.SafeLoader)\n",
+    "anchors = len(yml['anchors'])\n",
+    "\n",
     "# copy intersect weights\n",
     "state_dict = ckpt['model'].float().state_dict()\n",
     "exclude = []\n",
@@ -44,7 +49,7 @@
     "model.nc = ckpt['model'].nc\n",
     "\n",
     "# reparametrized YOLOR\n",
-    "for i in range(255):\n",
+    "for i in range((model.nc+5)*anchors):\n",
     "    model.state_dict()['model.105.m.0.weight'].data[i, :, :, :] *= state_dict['model.105.im.0.implicit'].data[:, i, : :].squeeze()\n",
     "    model.state_dict()['model.105.m.1.weight'].data[i, :, :, :] *= state_dict['model.105.im.1.implicit'].data[:, i, : :].squeeze()\n",
     "    model.state_dict()['model.105.m.2.weight'].data[i, :, :, :] *= state_dict['model.105.im.2.implicit'].data[:, i, : :].squeeze()\n",
@@ -85,6 +90,7 @@
     "from models.yolo import Model\n",
     "import torch\n",
     "from utils.torch_utils import select_device, is_parallel\n",
+    "import yaml\n",
     "\n",
     "device = select_device('0', batch_size=1)\n",
     "# model trained by cfg/training/*.yaml\n",
@@ -92,6 +98,10 @@
     "# reparameterized model in cfg/deploy/*.yaml\n",
     "model = Model('cfg/deploy/yolov7x.yaml', ch=3, nc=80).to(device)\n",
     "\n",
+    "with open('cfg/deploy/yolov7x.yaml') as f:\n",
+    "    yml = yaml.load(f, Loader=yaml.SafeLoader)\n",
+    "anchors = len(yml['anchors'])\n",
+    "\n",
     "# copy intersect weights\n",
     "state_dict = ckpt['model'].float().state_dict()\n",
     "exclude = []\n",
@@ -101,7 +111,7 @@
     "model.nc = ckpt['model'].nc\n",
     "\n",
     "# reparametrized YOLOR\n",
-    "for i in range(255):\n",
+    "for i in range((model.nc+5)*anchors):\n",
     "    model.state_dict()['model.121.m.0.weight'].data[i, :, :, :] *= state_dict['model.121.im.0.implicit'].data[:, i, : :].squeeze()\n",
     "    model.state_dict()['model.121.m.1.weight'].data[i, :, :, :] *= state_dict['model.121.im.1.implicit'].data[:, i, : :].squeeze()\n",
     "    model.state_dict()['model.121.m.2.weight'].data[i, :, :, :] *= state_dict['model.121.im.2.implicit'].data[:, i, : :].squeeze()\n",
@@ -142,6 +152,7 @@
     "from models.yolo import Model\n",
     "import torch\n",
     "from utils.torch_utils import select_device, is_parallel\n",
+    "import yaml\n",
     "\n",
     "device = select_device('0', batch_size=1)\n",
     "# model trained by cfg/training/*.yaml\n",
@@ -149,6 +160,10 @@
     "# reparameterized model in cfg/deploy/*.yaml\n",
     "model = Model('cfg/deploy/yolov7-w6.yaml', ch=3, nc=80).to(device)\n",
     "\n",
+    "with open('cfg/deploy/yolov7-w6.yaml') as f:\n",
+    "    yml = yaml.load(f, Loader=yaml.SafeLoader)\n",
+    "anchors = len(yml['anchors'])\n",
+    "\n",
     "# copy intersect weights\n",
     "state_dict = ckpt['model'].float().state_dict()\n",
     "exclude = []\n",
@@ -179,7 +194,7 @@
     "model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.bias'.format(idx2)].data\n",
     "\n",
     "# reparametrized YOLOR\n",
-    "for i in range(255):\n",
+    "for i in range((model.nc+5)*anchors):\n",
     "    model.state_dict()['model.{}.m.0.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.0.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
     "    model.state_dict()['model.{}.m.1.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.1.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
     "    model.state_dict()['model.{}.m.2.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.2.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
@@ -223,6 +238,7 @@
     "from models.yolo import Model\n",
     "import torch\n",
     "from utils.torch_utils import select_device, is_parallel\n",
+    "import yaml\n",
     "\n",
     "device = select_device('0', batch_size=1)\n",
     "# model trained by cfg/training/*.yaml\n",
@@ -230,6 +246,10 @@
     "# reparameterized model in cfg/deploy/*.yaml\n",
     "model = Model('cfg/deploy/yolov7-e6.yaml', ch=3, nc=80).to(device)\n",
     "\n",
+    "with open('cfg/deploy/yolov7-e6.yaml') as f:\n",
+    "    yml = yaml.load(f, Loader=yaml.SafeLoader)\n",
+    "anchors = len(yml['anchors'])\n",
+    "\n",
     "# copy intersect weights\n",
     "state_dict = ckpt['model'].float().state_dict()\n",
     "exclude = []\n",
@@ -260,7 +280,7 @@
     "model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.bias'.format(idx2)].data\n",
     "\n",
     "# reparametrized YOLOR\n",
-    "for i in range(255):\n",
+    "for i in range((model.nc+5)*anchors):\n",
     "    model.state_dict()['model.{}.m.0.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.0.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
     "    model.state_dict()['model.{}.m.1.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.1.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
     "    model.state_dict()['model.{}.m.2.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.2.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
@@ -304,6 +324,7 @@
     "from models.yolo import Model\n",
     "import torch\n",
     "from utils.torch_utils import select_device, is_parallel\n",
+    "import yaml\n",
     "\n",
     "device = select_device('0', batch_size=1)\n",
     "# model trained by cfg/training/*.yaml\n",
@@ -311,6 +332,10 @@
     "# reparameterized model in cfg/deploy/*.yaml\n",
     "model = Model('cfg/deploy/yolov7-d6.yaml', ch=3, nc=80).to(device)\n",
     "\n",
+    "with open('cfg/deploy/yolov7-d6.yaml') as f:\n",
+    "    yml = yaml.load(f, Loader=yaml.SafeLoader)\n",
+    "anchors = len(yml['anchors'])\n",
+    "\n",
     "# copy intersect weights\n",
     "state_dict = ckpt['model'].float().state_dict()\n",
     "exclude = []\n",
@@ -341,7 +366,7 @@
     "model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.bias'.format(idx2)].data\n",
     "\n",
     "# reparametrized YOLOR\n",
-    "for i in range(255):\n",
+    "for i in range((model.nc+5)*anchors):\n",
     "    model.state_dict()['model.{}.m.0.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.0.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
     "    model.state_dict()['model.{}.m.1.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.1.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
     "    model.state_dict()['model.{}.m.2.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.2.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
@@ -385,6 +410,7 @@
     "from models.yolo import Model\n",
     "import torch\n",
     "from utils.torch_utils import select_device, is_parallel\n",
+    "import yaml\n",
     "\n",
     "device = select_device('0', batch_size=1)\n",
     "# model trained by cfg/training/*.yaml\n",
@@ -392,6 +418,10 @@
     "# reparameterized model in cfg/deploy/*.yaml\n",
     "model = Model('cfg/deploy/yolov7-e6e.yaml', ch=3, nc=80).to(device)\n",
     "\n",
+    "with open('cfg/deploy/yolov7-e6e.yaml') as f:\n",
+    "    yml = yaml.load(f, Loader=yaml.SafeLoader)\n",
+    "anchors = len(yml['anchors'])\n",
+    "\n",
     "# copy intersect weights\n",
     "state_dict = ckpt['model'].float().state_dict()\n",
     "exclude = []\n",
@@ -422,7 +452,7 @@
     "model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.bias'.format(idx2)].data\n",
     "\n",
     "# reparametrized YOLOR\n",
-    "for i in range(255):\n",
+    "for i in range((model.nc+5)*anchors):\n",
     "    model.state_dict()['model.{}.m.0.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.0.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
     "    model.state_dict()['model.{}.m.1.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.1.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
     "    model.state_dict()['model.{}.m.2.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.2.implicit'.format(idx2)].data[:, i, : :].squeeze()\n",
@@ -457,7 +487,7 @@
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 3 (ipykernel)",
+   "display_name": "Python 3.7.0 ('py37')",
    "language": "python",
    "name": "python3"
   },
@@ -471,7 +501,12 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.8.10"
+   "version": "3.7.0"
+  },
+  "vscode": {
+   "interpreter": {
+    "hash": "73080970ff6fd25f9fcdf9c6f9e85b950a97864bb936ee53fb633f473cbfae4b"
+   }
   }
  },
  "nbformat": 4,