mirror of
https://github.com/megvii-research/NAFNet.git
synced 2025-06-03 21:55:00 +08:00
NAFNet_arch simplify
This commit is contained in:
parent
0f79b98eb3
commit
35d0eb9b37
@ -175,51 +175,22 @@ class NAFNetLocal(Local_Base, NAFNet):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import resource
|
|
||||||
def using(point=""):
|
|
||||||
# print(f'using .. {point}')
|
|
||||||
usage = resource.getrusage(resource.RUSAGE_SELF)
|
|
||||||
global Total, LastMem
|
|
||||||
|
|
||||||
# if usage[2]/1024.0 - LastMem > 0.01:
|
|
||||||
# print(point, usage[2]/1024.0)
|
|
||||||
print(point, usage[2] / 1024.0)
|
|
||||||
|
|
||||||
LastMem = usage[2] / 1024.0
|
|
||||||
return usage[2] / 1024.0
|
|
||||||
|
|
||||||
img_channel = 3
|
img_channel = 3
|
||||||
width = 32
|
width = 32
|
||||||
|
|
||||||
enc_blks = [2, 2, 2, 20]
|
# enc_blks = [2, 2, 4, 8]
|
||||||
middle_blk_num = 2
|
# middle_blk_num = 12
|
||||||
dec_blks = [2, 2, 2, 2]
|
# dec_blks = [2, 2, 2, 2]
|
||||||
|
|
||||||
print('enc blks', enc_blks, 'middle blk num', middle_blk_num, 'dec blks', dec_blks, 'width' , width)
|
enc_blks = [1, 1, 1, 28]
|
||||||
|
middle_blk_num = 1
|
||||||
|
dec_blks = [1, 1, 1, 1]
|
||||||
|
|
||||||
using('start . ')
|
|
||||||
net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
|
net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
|
||||||
enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
|
enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
|
||||||
|
|
||||||
using('network .. ')
|
|
||||||
|
|
||||||
# for n, p in net.named_parameters()
|
inp_shape = (3, 256, 256)
|
||||||
# print(n, p.shape)
|
|
||||||
|
|
||||||
|
|
||||||
inp = torch.randn((4, 3, 256, 256))
|
|
||||||
|
|
||||||
out = net(inp)
|
|
||||||
final_mem = using('end .. ')
|
|
||||||
# out.sum().backward()
|
|
||||||
|
|
||||||
# out.sum().backward()
|
|
||||||
|
|
||||||
# using('backward .. ')
|
|
||||||
|
|
||||||
# exit(0)
|
|
||||||
|
|
||||||
inp_shape = (3, 512, 512)
|
|
||||||
|
|
||||||
from ptflops import get_model_complexity_info
|
from ptflops import get_model_complexity_info
|
||||||
|
|
||||||
@ -229,8 +200,3 @@ if __name__ == '__main__':
|
|||||||
macs = float(macs[:-4])
|
macs = float(macs[:-4])
|
||||||
|
|
||||||
print(macs, params)
|
print(macs, params)
|
||||||
|
|
||||||
print('total .. ', params * 8 + final_mem)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user