polish code
parent
c9e1077daa
commit
1b2ca6e641
|
@ -326,6 +326,4 @@ class STN_ON(nn.Layer):
|
|||
image, self.tps_inputsize, mode="bilinear", align_corners=True)
|
||||
stn_img_feat, ctrl_points = self.stn_head(stn_input)
|
||||
x, _ = self.tps(image, ctrl_points)
|
||||
#print("x:", np.sum(x.numpy()))
|
||||
# print(x.shape)
|
||||
return x
|
||||
|
|
|
@ -215,9 +215,6 @@ def train(config,
|
|||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
state_dict = model.state_dict()
|
||||
# for key in state_dict:
|
||||
# print(key)
|
||||
loss = loss_class(preds, batch)
|
||||
avg_loss = loss['loss']
|
||||
avg_loss.backward()
|
||||
|
@ -414,7 +411,6 @@ def preprocess(is_train=False):
|
|||
yaml.dump(
|
||||
dict(config), f, default_flow_style=False, sort_keys=False)
|
||||
log_file = '{}/train.log'.format(save_model_dir)
|
||||
print("log has save in {}/train.log".format(save_model_dir))
|
||||
else:
|
||||
log_file = None
|
||||
logger = get_logger(name='root', log_file=log_file)
|
||||
|
|
|
@ -72,8 +72,6 @@ def main(config, device, logger, vdl_writer):
|
|||
# for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
character = getattr(post_process_class, 'character')
|
||||
print("getattr character:", character)
|
||||
if config['Architecture']["algorithm"] in ["Distillation",
|
||||
]: # distillation model
|
||||
for key in config['Architecture']["Models"]:
|
||||
|
|
Loading…
Reference in New Issue