delete print
parent
fb8e883fa6
commit
b79fee116a
|
@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
|
|
||||||
flops = paddle.flops(model, [1, 3, 640, 640])
|
flops = paddle.flops(model, [1, 3, 640, 640])
|
||||||
print(f"FLOPs before pruning: {flops}")
|
logger.info(f"FLOPs before pruning: {flops}")
|
||||||
|
|
||||||
from paddleslim.dygraph import FPGMFilterPruner
|
from paddleslim.dygraph import FPGMFilterPruner
|
||||||
model.train()
|
model.train()
|
||||||
|
@ -144,10 +144,10 @@ def main(config, device, logger, vdl_writer):
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
if ("weights" in param.name and "conv" in param.name) or (
|
if ("weights" in param.name and "conv" in param.name) or (
|
||||||
"w_0" in param.name and "conv2d" in param.name):
|
"w_0" in param.name and "conv2d" in param.name):
|
||||||
print(f"{param.name}: {param.shape}")
|
logger.info(f"{param.name}: {param.shape}")
|
||||||
|
|
||||||
flops = paddle.flops(model, [1, 3, 640, 640])
|
flops = paddle.flops(model, [1, 3, 640, 640])
|
||||||
print(f"FLOPs after pruning: {flops}")
|
logger.info(f"FLOPs after pruning: {flops}")
|
||||||
|
|
||||||
# start train
|
# start train
|
||||||
program.train(config, train_dataloader, valid_dataloader, device, model,
|
program.train(config, train_dataloader, valid_dataloader, device, model,
|
||||||
|
|
Loading…
Reference in New Issue