mirror of https://github.com/alibaba/EasyCV.git
101 lines
3.2 KiB
Python
101 lines
3.2 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from easycv.models.backbones import ResNet
|
|
from easycv.models.backbones.resnet_jit import ResNetJIT
|
|
from easycv.utils.profiling import benchmark_torch_function
|
|
|
|
|
|
class ResnetTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
|
|
def test_resnet_jit(self):
|
|
with torch.no_grad():
|
|
# input data
|
|
batch_size = 1
|
|
a = torch.rand(batch_size, 3, 224, 224).to('cuda')
|
|
|
|
r50 = ResNet(50, out_indices=[4]).to('cuda')
|
|
r50.init_weights()
|
|
r50.eval()
|
|
|
|
r50_trace = torch.jit.trace(r50, a).to('cuda')
|
|
r50_trace.eval()
|
|
|
|
r50_jittable = ResNetJIT(50, out_indices=[4]).to('cuda')
|
|
r50_jittable.init_weights()
|
|
r50_jittable.eval()
|
|
r50_script = torch.jit.script(r50_jittable, a).to('cuda')
|
|
r50_script.eval()
|
|
|
|
# for np1, np2 in zip(r50.named_parameters(), r502.named_parameters()):
|
|
# n1, p1 = np1
|
|
# n2, p2 = np2
|
|
# if p1.size() != p2.size():
|
|
# print(n1, n2, 'shape not the same')
|
|
# elif not np.allclose(p1.cpu().data.numpy(), p2.cpu().data.numpy()):
|
|
# print(n1, n2, 'value not the same')
|
|
# exit(0)
|
|
|
|
self.assertTrue(
|
|
np.allclose(
|
|
r50(a)[-1].cpu().data.numpy(),
|
|
r50_trace(a)[-1].cpu().data.numpy(),
|
|
atol=1e-2))
|
|
self.assertTrue(
|
|
np.allclose(
|
|
r50_jittable(a)[-1].cpu().data.numpy(),
|
|
r50_script(a)[-1].cpu().data.numpy(),
|
|
atol=1e-2))
|
|
|
|
r50(a)
|
|
iter = 100
|
|
t = benchmark_torch_function(iter, r50, a)
|
|
print(f'origin: {t/batch_size} s/per image')
|
|
|
|
t = benchmark_torch_function(iter, r50_trace, a)
|
|
print(f'trace r50: {t/batch_size} s/per image')
|
|
|
|
t = benchmark_torch_function(iter, r50_script, a)
|
|
print(f'script r50: {t/batch_size} s/per image')
|
|
# for name, param in r50_trace.named_parameters():
|
|
# print(name)
|
|
|
|
# result
|
|
# no jit: 0.001548142358660698 s/per image
|
|
# jit trace: 0.0016211424767971039 s/per image
|
|
# jit script: 0.0016227740794420241 s/per image
|
|
|
|
@torch.no_grad()
|
|
def test_vision_resnet(self):
|
|
from torchvision import models
|
|
batch_size = 1
|
|
a = torch.rand(batch_size, 3, 224, 224).to('cuda')
|
|
|
|
r50 = models.resnet50().to('cuda')
|
|
r50_trace = torch.jit.trace(r50, (a))
|
|
r50_script = torch.jit.script(r50)
|
|
|
|
iter = 100
|
|
|
|
r50_trace(a)
|
|
t = benchmark_torch_function(iter, r50_trace, a)
|
|
print(f'jit trace: {t/batch_size} s/per image')
|
|
|
|
r50_script(a)
|
|
t = benchmark_torch_function(iter, r50_script, a)
|
|
print(f'jit script: {t/batch_size} s/per image')
|
|
|
|
r50(a)
|
|
t = benchmark_torch_function(iter, r50, a)
|
|
print(f'origin: {t/batch_size} s/per image')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|