From 35d0eb9b3752511ca258cd942734a3083b3ef74d Mon Sep 17 00:00:00 2001 From: mayorx Date: Wed, 3 Aug 2022 02:36:28 +0800 Subject: [PATCH] NAFNet_arch simplify --- basicsr/models/archs/NAFNet_arch.py | 54 ++++++----------------------- 1 file changed, 10 insertions(+), 44 deletions(-) diff --git a/basicsr/models/archs/NAFNet_arch.py b/basicsr/models/archs/NAFNet_arch.py index 119786b..5735e09 100644 --- a/basicsr/models/archs/NAFNet_arch.py +++ b/basicsr/models/archs/NAFNet_arch.py @@ -175,51 +175,22 @@ class NAFNetLocal(Local_Base, NAFNet): if __name__ == '__main__': - import resource - def using(point=""): - # print(f'using .. {point}') - usage = resource.getrusage(resource.RUSAGE_SELF) - global Total, LastMem - - # if usage[2]/1024.0 - LastMem > 0.01: - # print(point, usage[2]/1024.0) - print(point, usage[2] / 1024.0) - - LastMem = usage[2] / 1024.0 - return usage[2] / 1024.0 - img_channel = 3 width = 32 + + # enc_blks = [2, 2, 4, 8] + # middle_blk_num = 12 + # dec_blks = [2, 2, 2, 2] + + enc_blks = [1, 1, 1, 28] + middle_blk_num = 1 + dec_blks = [1, 1, 1, 1] - enc_blks = [2, 2, 2, 20] - middle_blk_num = 2 - dec_blks = [2, 2, 2, 2] - - print('enc blks', enc_blks, 'middle blk num', middle_blk_num, 'dec blks', dec_blks, 'width' , width) - - using('start . ') - net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num, + net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num, enc_blk_nums=enc_blks, dec_blk_nums=dec_blks) - using('network .. ') - # for n, p in net.named_parameters() - # print(n, p.shape) - - - inp = torch.randn((4, 3, 256, 256)) - - out = net(inp) - final_mem = using('end .. ') - # out.sum().backward() - - # out.sum().backward() - - # using('backward .. ') - - # exit(0) - - inp_shape = (3, 512, 512) + inp_shape = (3, 256, 256) from ptflops import get_model_complexity_info @@ -229,8 +200,3 @@ if __name__ == '__main__': macs = float(macs[:-4]) print(macs, params) - - print('total .. ', params * 8 + final_mem) - - -