diff --git a/projects/FastRT/tools/gen_wts.py b/projects/FastRT/tools/gen_wts.py index 1fd7b86..dbcaf86 100644 --- a/projects/FastRT/tools/gen_wts.py +++ b/projects/FastRT/tools/gen_wts.py @@ -95,13 +95,13 @@ if __name__ == '__main__': model.eval() if args.verify: - input = torch.ones(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).to(cfg.MODEL.DEVICE) + input = torch.ones(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).to(cfg.MODEL.DEVICE) * 255. out = model(input).view(-1).cpu().detach().numpy() print('[Model output]: \n', out) if args.benchmark: start_time = time.time() - input = torch.ones(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).to(cfg.MODEL.DEVICE) + input = torch.ones(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).to(cfg.MODEL.DEVICE) * 255. for i in range(100): out = model(input).view(-1).cpu().detach() print("--- %s seconds ---" % ((time.time() - start_time)/100.) )