116 lines
3.9 KiB
Python
116 lines
3.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import asyncio
|
|
import os
|
|
import shutil
|
|
|
|
import torch
|
|
from pyppeteer import launch
|
|
from torchvision.models import resnet18
|
|
|
|
from mmdeploy.core import FUNCTION_REWRITER, RewriterContext, patch_model
|
|
from mmdeploy.utils import get_root_logger
|
|
|
|
|
|
@FUNCTION_REWRITER.register_rewriter(
|
|
func_name='torchvision.models.ResNet._forward_impl')
|
|
def forward_of_resnet(self, x):
|
|
"""Rewrite the forward implementation of resnet.
|
|
|
|
Early return the feature map after two down-sampling steps.
|
|
"""
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
x = self.maxpool(x)
|
|
|
|
x = self.layer1(x)
|
|
return x
|
|
|
|
|
|
def rewrite_resnet18(original_path: str, rewritten_path: str):
|
|
# prepare inputs and original model
|
|
inputs = torch.rand(1, 3, 224, 224)
|
|
original_model = resnet18(pretrained=False)
|
|
|
|
# export original model
|
|
torch.onnx.export(original_model, inputs, original_path)
|
|
|
|
# patch model
|
|
patched_model = patch_model(original_model, cfg={}, backend='default')
|
|
|
|
# export rewritten onnx under a rewriter context manager
|
|
with RewriterContext(cfg={}, backend='default'), torch.no_grad():
|
|
torch.onnx.export(patched_model, inputs, rewritten_path)
|
|
|
|
|
|
def screen_size():
|
|
"""Get windows size through tkinter."""
|
|
import tkinter
|
|
tk = tkinter.Tk()
|
|
width = tk.winfo_screenwidth()
|
|
height = tk.winfo_screenheight()
|
|
tk.quit()
|
|
return width, height
|
|
|
|
|
|
async def visualize(original_path: str, rewritten_path: str):
|
|
# launch a web browser
|
|
browser = await launch(headless=False, args=['--start-maximized'])
|
|
# create two new pages
|
|
page2 = await browser.newPage()
|
|
page1 = await browser.newPage()
|
|
# go to netron.app
|
|
width, height = screen_size()
|
|
await page1.setViewport({'width': width, 'height': height})
|
|
await page2.setViewport({'width': width, 'height': height})
|
|
await page1.goto('https://netron.app/')
|
|
await page2.goto('https://netron.app/')
|
|
await asyncio.sleep(2)
|
|
|
|
# open local two onnx files
|
|
mupinput1 = await page1.querySelector("input[type='file']")
|
|
mupinput2 = await page2.querySelector("input[type='file']")
|
|
await mupinput1.uploadFile(original_file_path)
|
|
await mupinput2.uploadFile(rewritten_file_path)
|
|
await asyncio.sleep(4)
|
|
for _ in range(6):
|
|
await page1.click('#zoom-out-button')
|
|
await asyncio.sleep(0.3)
|
|
await asyncio.sleep(1)
|
|
await page1.screenshot({'path': original_path.replace('.onnx', '.png')},
|
|
clip={
|
|
'x': width / 4,
|
|
'y': 0,
|
|
'width': width / 2,
|
|
'height': height
|
|
})
|
|
await page2.screenshot({'path': rewritten_path.replace('.onnx', '.png')},
|
|
clip={
|
|
'x': width / 4,
|
|
'y': 0,
|
|
'width': width / 2,
|
|
'height': height
|
|
})
|
|
await browser.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
tmp_dir = os.getcwd() + '/tmp'
|
|
if not os.path.exists(tmp_dir):
|
|
os.mkdir(tmp_dir)
|
|
original_file_path = os.path.join(tmp_dir, 'original.onnx')
|
|
rewritten_file_path = os.path.join(tmp_dir, 'rewritten.onnx')
|
|
logger = get_root_logger()
|
|
logger.info('Generating resnet18 and its rewritten model...')
|
|
rewrite_resnet18(original_file_path, rewritten_file_path)
|
|
|
|
logger.info('Visualizing models through netron...')
|
|
asyncio.get_event_loop().run_until_complete(
|
|
visualize(original_file_path, rewritten_file_path))
|
|
import mmcv
|
|
image1 = mmcv.imread(original_file_path.replace('.onnx', '.png'))
|
|
image2 = mmcv.imread(rewritten_file_path.replace('.onnx', '.png'))
|
|
mmcv.imshow(image1, win_name='original')
|
|
mmcv.imshow(image2, win_name='rewritten')
|
|
shutil.rmtree(tmp_dir)
|