From 2a731fc5d2f90389c9ac07e3b51acd4b22b71826 Mon Sep 17 00:00:00 2001 From: Mohammad Khoshbin <khoshbin.mohammad.mk@gmail.com> Date: Thu, 28 Jul 2022 20:19:06 +0430 Subject: [PATCH] fix reparametrization for any nc (#342) --- tools/reparameterization.ipynb | 51 ++++++++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/tools/reparameterization.ipynb b/tools/reparameterization.ipynb index 4e9a810..84e4326 100644 --- a/tools/reparameterization.ipynb +++ b/tools/reparameterization.ipynb @@ -28,6 +28,7 @@ "from models.yolo import Model\n", "import torch\n", "from utils.torch_utils import select_device, is_parallel\n", + "import yaml\n", "\n", "device = select_device('0', batch_size=1)\n", "# model trained by cfg/training/*.yaml\n", @@ -35,6 +36,10 @@ "# reparameterized model in cfg/deploy/*.yaml\n", "model = Model('cfg/deploy/yolov7.yaml', ch=3, nc=80).to(device)\n", "\n", + "with open('cfg/deploy/yolov7.yaml') as f:\n", + " yml = yaml.load(f, Loader=yaml.SafeLoader)\n", + "anchors = len(yml['anchors'])\n", + "\n", "# copy intersect weights\n", "state_dict = ckpt['model'].float().state_dict()\n", "exclude = []\n", @@ -44,7 +49,7 @@ "model.nc = ckpt['model'].nc\n", "\n", "# reparametrized YOLOR\n", - "for i in range(255):\n", + "for i in range((model.nc+5)*anchors):\n", " model.state_dict()['model.105.m.0.weight'].data[i, :, :, :] *= state_dict['model.105.im.0.implicit'].data[:, i, : :].squeeze()\n", " model.state_dict()['model.105.m.1.weight'].data[i, :, :, :] *= state_dict['model.105.im.1.implicit'].data[:, i, : :].squeeze()\n", " model.state_dict()['model.105.m.2.weight'].data[i, :, :, :] *= state_dict['model.105.im.2.implicit'].data[:, i, : :].squeeze()\n", @@ -85,6 +90,7 @@ "from models.yolo import Model\n", "import torch\n", "from utils.torch_utils import select_device, is_parallel\n", + "import yaml\n", "\n", "device = select_device('0', batch_size=1)\n", "# model trained by cfg/training/*.yaml\n", @@ -92,6 +98,10 @@ "# reparameterized model in cfg/deploy/*.yaml\n", "model = Model('cfg/deploy/yolov7x.yaml', ch=3, nc=80).to(device)\n", "\n", + "with open('cfg/deploy/yolov7x.yaml') as f:\n", + " yml = yaml.load(f, Loader=yaml.SafeLoader)\n", + "anchors = len(yml['anchors'])\n", + "\n", "# copy intersect weights\n", "state_dict = ckpt['model'].float().state_dict()\n", "exclude = []\n", @@ -101,7 +111,7 @@ "model.nc = ckpt['model'].nc\n", "\n", "# reparametrized YOLOR\n", - "for i in range(255):\n", + "for i in range((model.nc+5)*anchors):\n", " model.state_dict()['model.121.m.0.weight'].data[i, :, :, :] *= state_dict['model.121.im.0.implicit'].data[:, i, : :].squeeze()\n", " model.state_dict()['model.121.m.1.weight'].data[i, :, :, :] *= state_dict['model.121.im.1.implicit'].data[:, i, : :].squeeze()\n", " model.state_dict()['model.121.m.2.weight'].data[i, :, :, :] *= state_dict['model.121.im.2.implicit'].data[:, i, : :].squeeze()\n", @@ -142,6 +152,7 @@ "from models.yolo import Model\n", "import torch\n", "from utils.torch_utils import select_device, is_parallel\n", + "import yaml\n", "\n", "device = select_device('0', batch_size=1)\n", "# model trained by cfg/training/*.yaml\n", @@ -149,6 +160,10 @@ "# reparameterized model in cfg/deploy/*.yaml\n", "model = Model('cfg/deploy/yolov7-w6.yaml', ch=3, nc=80).to(device)\n", "\n", + "with open('cfg/deploy/yolov7-w6.yaml') as f:\n", + " yml = yaml.load(f, Loader=yaml.SafeLoader)\n", + "anchors = len(yml['anchors'])\n", + "\n", "# copy intersect weights\n", "state_dict = ckpt['model'].float().state_dict()\n", "exclude = []\n", @@ -179,7 +194,7 @@ "model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.bias'.format(idx2)].data\n", "\n", "# reparametrized YOLOR\n", - "for i in range(255):\n", + "for i in range((model.nc+5)*anchors):\n", " model.state_dict()['model.{}.m.0.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.0.implicit'.format(idx2)].data[:, i, : :].squeeze()\n", " model.state_dict()['model.{}.m.1.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.1.implicit'.format(idx2)].data[:, i, : :].squeeze()\n", " model.state_dict()['model.{}.m.2.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.2.implicit'.format(idx2)].data[:, i, : :].squeeze()\n", @@ -223,6 +238,7 @@ "from models.yolo import Model\n", "import torch\n", "from utils.torch_utils import select_device, is_parallel\n", + "import yaml\n", "\n", "device = select_device('0', batch_size=1)\n", "# model trained by cfg/training/*.yaml\n", @@ -230,6 +246,10 @@ "# reparameterized model in cfg/deploy/*.yaml\n", "model = Model('cfg/deploy/yolov7-e6.yaml', ch=3, nc=80).to(device)\n", "\n", + "with open('cfg/deploy/yolov7-e6.yaml') as f:\n", + " yml = yaml.load(f, Loader=yaml.SafeLoader)\n", + "anchors = len(yml['anchors'])\n", + "\n", "# copy intersect weights\n", "state_dict = ckpt['model'].float().state_dict()\n", "exclude = []\n", @@ -260,7 +280,7 @@ "model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.bias'.format(idx2)].data\n", "\n", "# reparametrized YOLOR\n", - "for i in range(255):\n", + "for i in range((model.nc+5)*anchors):\n", " model.state_dict()['model.{}.m.0.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.0.implicit'.format(idx2)].data[:, i, : :].squeeze()\n", " model.state_dict()['model.{}.m.1.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.1.implicit'.format(idx2)].data[:, i, : :].squeeze()\n", " model.state_dict()['model.{}.m.2.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.2.implicit'.format(idx2)].data[:, i, : :].squeeze()\n", @@ -304,6 +324,7 @@ "from models.yolo import Model\n", "import torch\n", "from utils.torch_utils import select_device, is_parallel\n", + "import yaml\n", "\n", "device = select_device('0', batch_size=1)\n", "# model trained by cfg/training/*.yaml\n", @@ -311,6 +332,10 @@ "# reparameterized model in cfg/deploy/*.yaml\n", "model = Model('cfg/deploy/yolov7-d6.yaml', ch=3, nc=80).to(device)\n", "\n", + "with open('cfg/deploy/yolov7-d6.yaml') as f:\n", + " yml = yaml.load(f, Loader=yaml.SafeLoader)\n", + "anchors = len(yml['anchors'])\n", + "\n", "# copy intersect weights\n", "state_dict = ckpt['model'].float().state_dict()\n", "exclude = []\n", @@ -341,7 +366,7 @@ "model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.bias'.format(idx2)].data\n", "\n", "# reparametrized YOLOR\n", - "for i in range(255):\n", + "for i in range((model.nc+5)*anchors):\n", " model.state_dict()['model.{}.m.0.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.0.implicit'.format(idx2)].data[:, i, : :].squeeze()\n", " model.state_dict()['model.{}.m.1.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.1.implicit'.format(idx2)].data[:, i, : :].squeeze()\n", " model.state_dict()['model.{}.m.2.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.2.implicit'.format(idx2)].data[:, i, : :].squeeze()\n", @@ -385,6 +410,7 @@ "from models.yolo import Model\n", "import torch\n", "from utils.torch_utils import select_device, is_parallel\n", + "import yaml\n", "\n", "device = select_device('0', batch_size=1)\n", "# model trained by cfg/training/*.yaml\n", @@ -392,6 +418,10 @@ "# reparameterized model in cfg/deploy/*.yaml\n", "model = Model('cfg/deploy/yolov7-e6e.yaml', ch=3, nc=80).to(device)\n", "\n", + "with open('cfg/deploy/yolov7-e6e.yaml') as f:\n", + " yml = yaml.load(f, Loader=yaml.SafeLoader)\n", + "anchors = len(yml['anchors'])\n", + "\n", "# copy intersect weights\n", "state_dict = ckpt['model'].float().state_dict()\n", "exclude = []\n", @@ -422,7 +452,7 @@ "model.state_dict()['model.{}.m.3.bias'.format(idx)].data += state_dict['model.{}.m.3.bias'.format(idx2)].data\n", "\n", "# reparametrized YOLOR\n", - "for i in range(255):\n", + "for i in range((model.nc+5)*anchors):\n", " model.state_dict()['model.{}.m.0.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.0.implicit'.format(idx2)].data[:, i, : :].squeeze()\n", " model.state_dict()['model.{}.m.1.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.1.implicit'.format(idx2)].data[:, i, : :].squeeze()\n", " model.state_dict()['model.{}.m.2.weight'.format(idx)].data[i, :, :, :] *= state_dict['model.{}.im.2.implicit'.format(idx2)].data[:, i, : :].squeeze()\n", @@ -457,7 +487,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3.7.0 ('py37')", "language": "python", "name": "python3" }, @@ -471,7 +501,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.7.0" + }, + "vscode": { + "interpreter": { + "hash": "73080970ff6fd25f9fcdf9c6f9e85b950a97864bb936ee53fb633f473cbfae4b" + } } }, "nbformat": 4,