polish code
parent
c9e1077daa
commit
1b2ca6e641
|
@ -326,6 +326,4 @@ class STN_ON(nn.Layer):
|
||||||
image, self.tps_inputsize, mode="bilinear", align_corners=True)
|
image, self.tps_inputsize, mode="bilinear", align_corners=True)
|
||||||
stn_img_feat, ctrl_points = self.stn_head(stn_input)
|
stn_img_feat, ctrl_points = self.stn_head(stn_input)
|
||||||
x, _ = self.tps(image, ctrl_points)
|
x, _ = self.tps(image, ctrl_points)
|
||||||
#print("x:", np.sum(x.numpy()))
|
|
||||||
# print(x.shape)
|
|
||||||
return x
|
return x
|
||||||
|
|
|
@ -215,9 +215,6 @@ def train(config,
|
||||||
preds = model(images, data=batch[1:])
|
preds = model(images, data=batch[1:])
|
||||||
else:
|
else:
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
state_dict = model.state_dict()
|
|
||||||
# for key in state_dict:
|
|
||||||
# print(key)
|
|
||||||
loss = loss_class(preds, batch)
|
loss = loss_class(preds, batch)
|
||||||
avg_loss = loss['loss']
|
avg_loss = loss['loss']
|
||||||
avg_loss.backward()
|
avg_loss.backward()
|
||||||
|
@ -414,7 +411,6 @@ def preprocess(is_train=False):
|
||||||
yaml.dump(
|
yaml.dump(
|
||||||
dict(config), f, default_flow_style=False, sort_keys=False)
|
dict(config), f, default_flow_style=False, sort_keys=False)
|
||||||
log_file = '{}/train.log'.format(save_model_dir)
|
log_file = '{}/train.log'.format(save_model_dir)
|
||||||
print("log has save in {}/train.log".format(save_model_dir))
|
|
||||||
else:
|
else:
|
||||||
log_file = None
|
log_file = None
|
||||||
logger = get_logger(name='root', log_file=log_file)
|
logger = get_logger(name='root', log_file=log_file)
|
||||||
|
|
|
@ -72,8 +72,6 @@ def main(config, device, logger, vdl_writer):
|
||||||
# for rec algorithm
|
# for rec algorithm
|
||||||
if hasattr(post_process_class, 'character'):
|
if hasattr(post_process_class, 'character'):
|
||||||
char_num = len(getattr(post_process_class, 'character'))
|
char_num = len(getattr(post_process_class, 'character'))
|
||||||
character = getattr(post_process_class, 'character')
|
|
||||||
print("getattr character:", character)
|
|
||||||
if config['Architecture']["algorithm"] in ["Distillation",
|
if config['Architecture']["algorithm"] in ["Distillation",
|
||||||
]: # distillation model
|
]: # distillation model
|
||||||
for key in config['Architecture']["Models"]:
|
for key in config['Architecture']["Models"]:
|
||||||
|
|
Loading…
Reference in New Issue