NAFNet_arch simplify
parent
0f79b98eb3
commit
35d0eb9b37
|
@ -175,51 +175,22 @@ class NAFNetLocal(Local_Base, NAFNet):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import resource
|
||||
def using(point=""):
|
||||
# print(f'using .. {point}')
|
||||
usage = resource.getrusage(resource.RUSAGE_SELF)
|
||||
global Total, LastMem
|
||||
|
||||
# if usage[2]/1024.0 - LastMem > 0.01:
|
||||
# print(point, usage[2]/1024.0)
|
||||
print(point, usage[2] / 1024.0)
|
||||
|
||||
LastMem = usage[2] / 1024.0
|
||||
return usage[2] / 1024.0
|
||||
|
||||
img_channel = 3
|
||||
width = 32
|
||||
|
||||
enc_blks = [2, 2, 2, 20]
|
||||
middle_blk_num = 2
|
||||
dec_blks = [2, 2, 2, 2]
|
||||
# enc_blks = [2, 2, 4, 8]
|
||||
# middle_blk_num = 12
|
||||
# dec_blks = [2, 2, 2, 2]
|
||||
|
||||
print('enc blks', enc_blks, 'middle blk num', middle_blk_num, 'dec blks', dec_blks, 'width' , width)
|
||||
enc_blks = [1, 1, 1, 28]
|
||||
middle_blk_num = 1
|
||||
dec_blks = [1, 1, 1, 1]
|
||||
|
||||
using('start . ')
|
||||
net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
|
||||
enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
|
||||
|
||||
using('network .. ')
|
||||
|
||||
# for n, p in net.named_parameters()
|
||||
# print(n, p.shape)
|
||||
|
||||
|
||||
inp = torch.randn((4, 3, 256, 256))
|
||||
|
||||
out = net(inp)
|
||||
final_mem = using('end .. ')
|
||||
# out.sum().backward()
|
||||
|
||||
# out.sum().backward()
|
||||
|
||||
# using('backward .. ')
|
||||
|
||||
# exit(0)
|
||||
|
||||
inp_shape = (3, 512, 512)
|
||||
inp_shape = (3, 256, 256)
|
||||
|
||||
from ptflops import get_model_complexity_info
|
||||
|
||||
|
@ -229,8 +200,3 @@ if __name__ == '__main__':
|
|||
macs = float(macs[:-4])
|
||||
|
||||
print(macs, params)
|
||||
|
||||
print('total .. ', params * 8 + final_mem)
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue