NAFNet_arch simplify

pull/96/head
mayorx 2022-08-03 02:36:28 +08:00
parent 0f79b98eb3
commit 35d0eb9b37
1 changed files with 10 additions and 44 deletions

View File

@ -175,51 +175,22 @@ class NAFNetLocal(Local_Base, NAFNet):
if __name__ == '__main__':
import resource
def using(point=""):
# print(f'using .. {point}')
usage = resource.getrusage(resource.RUSAGE_SELF)
global Total, LastMem
# if usage[2]/1024.0 - LastMem > 0.01:
# print(point, usage[2]/1024.0)
print(point, usage[2] / 1024.0)
LastMem = usage[2] / 1024.0
return usage[2] / 1024.0
img_channel = 3
width = 32
enc_blks = [2, 2, 2, 20]
middle_blk_num = 2
dec_blks = [2, 2, 2, 2]
# enc_blks = [2, 2, 4, 8]
# middle_blk_num = 12
# dec_blks = [2, 2, 2, 2]
print('enc blks', enc_blks, 'middle blk num', middle_blk_num, 'dec blks', dec_blks, 'width' , width)
enc_blks = [1, 1, 1, 28]
middle_blk_num = 1
dec_blks = [1, 1, 1, 1]
using('start . ')
net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
using('network .. ')
# for n, p in net.named_parameters()
# print(n, p.shape)
inp = torch.randn((4, 3, 256, 256))
out = net(inp)
final_mem = using('end .. ')
# out.sum().backward()
# out.sum().backward()
# using('backward .. ')
# exit(0)
inp_shape = (3, 512, 512)
inp_shape = (3, 256, 256)
from ptflops import get_model_complexity_info
@ -229,8 +200,3 @@ if __name__ == '__main__':
macs = float(macs[:-4])
print(macs, params)
print('total .. ', params * 8 + final_mem)