# Copyright (c) OpenMMLab. All rights reserved. import asyncio import os import shutil import torch from pyppeteer import launch from torchvision.models import resnet18 from mmdeploy.core import FUNCTION_REWRITER, RewriterContext, patch_model from mmdeploy.utils import get_root_logger @FUNCTION_REWRITER.register_rewriter( func_name='torchvision.models.ResNet._forward_impl') def forward_of_resnet(ctx, self, x): """Rewrite the forward implementation of resnet. Early return the feature map after two down-sampling steps. """ x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) return x def rewrite_resnet18(original_path: str, rewritten_path: str): # prepare inputs and original model inputs = torch.rand(1, 3, 224, 224) original_model = resnet18(pretrained=False) # export original model torch.onnx.export(original_model, inputs, original_path) # patch model patched_model = patch_model(original_model, cfg={}, backend='default') # export rewritten onnx under a rewriter context manager with RewriterContext(cfg={}, backend='default'), torch.no_grad(): torch.onnx.export(patched_model, inputs, rewritten_path) def screen_size(): """Get windows size through tkinter.""" import tkinter tk = tkinter.Tk() width = tk.winfo_screenwidth() height = tk.winfo_screenheight() tk.quit() return width, height async def visualize(original_path: str, rewritten_path: str): # launch a web browser browser = await launch(headless=False, args=['--start-maximized']) # create two new pages page2 = await browser.newPage() page1 = await browser.newPage() # go to netron.app width, height = screen_size() await page1.setViewport({'width': width, 'height': height}) await page2.setViewport({'width': width, 'height': height}) await page1.goto('https://netron.app/') await page2.goto('https://netron.app/') await asyncio.sleep(2) # open local two onnx files mupinput1 = await page1.querySelector("input[type='file']") mupinput2 = await page2.querySelector("input[type='file']") await mupinput1.uploadFile(original_file_path) await mupinput2.uploadFile(rewritten_file_path) await asyncio.sleep(4) for _ in range(6): await page1.click('#zoom-out-button') await asyncio.sleep(0.3) await asyncio.sleep(1) await page1.screenshot({'path': original_path.replace('.onnx', '.png')}, clip={ 'x': width / 4, 'y': 0, 'width': width / 2, 'height': height }) await page2.screenshot({'path': rewritten_path.replace('.onnx', '.png')}, clip={ 'x': width / 4, 'y': 0, 'width': width / 2, 'height': height }) await browser.close() if __name__ == '__main__': tmp_dir = os.getcwd() + '/tmp' if not os.path.exists(tmp_dir): os.mkdir(tmp_dir) original_file_path = os.path.join(tmp_dir, 'original.onnx') rewritten_file_path = os.path.join(tmp_dir, 'rewritten.onnx') logger = get_root_logger() logger.info('Generating resnet18 and its rewritten model...') rewrite_resnet18(original_file_path, rewritten_file_path) logger.info('Visualizing models through netron...') asyncio.get_event_loop().run_until_complete( visualize(original_file_path, rewritten_file_path)) import mmcv image1 = mmcv.imread(original_file_path.replace('.onnx', '.png')) image2 = mmcv.imread(rewritten_file_path.replace('.onnx', '.png')) mmcv.imshow(image1, win_name='original') mmcv.imshow(image2, win_name='rewritten') shutil.rmtree(tmp_dir)