added testing using pretrains

pull/33/head
gudovskiy 2021-08-02 10:58:02 -07:00
parent fb604ccb07
commit 4f6bc9c254
4 changed files with 23 additions and 21 deletions

View File

@ -25,7 +25,7 @@ def get_args():
parser.add_argument('-inp', '--input-size', default=256, type=int, metavar='C',
help='image resize dimensions (default: 256)')
parser.add_argument("--action-type", default='norm-train', type=str, metavar='T',
help='norm-train (default: norm-train)')
help='norm-train/norm-test (default: norm-train)')
parser.add_argument('-bs', '--batch-size', default=32, type=int, metavar='B',
help='train batch size (default: 32)')
parser.add_argument('--lr', type=float, default=2e-4, metavar='LR',

View File

@ -18,9 +18,7 @@ def init_seeds(seed=0):
def main(c):
# model
if c.action_type == 'video-train':
c.model = "{}_{}_{}".format(c.enc_arch, c.dec_arch, c.video_path)
elif c.action_type == 'norm-train' or c.action_type == 'norm-test':
if c.action_type in ['norm-train', 'norm-test']:
c.model = "{}_{}_{}_pl{}_cb{}_inp{}_run{}_{}".format(
c.dataset, c.enc_arch, c.dec_arch, c.pool_layers, c.coupling_blocks, c.input_size, c.run_name, c.class_name)
else:
@ -81,7 +79,7 @@ def main(c):
init_seeds(seed=int(time.time()))
c.device = torch.device("cuda" if c.use_cuda else "cpu")
# selected function:
if c.action_type == 'norm-train':
if c.action_type in ['norm-train', 'norm-test']:
train(c)
else:
raise NotImplementedError('{} is not supported action-type!'.format(c.action_type))

View File

@ -279,8 +279,6 @@ def train(c):
elif c.dataset == 'stc':
train_dataset = StcDataset(c, is_train=True)
test_dataset = StcDataset(c, is_train=False)
#elif c.dataset == 'video':
# c.data_path = c.video_path
else:
raise NotImplementedError('{} is not supported dataset!'.format(c.dataset))
#
@ -293,13 +291,16 @@ def train(c):
det_roc_obs = Score_Observer('DET_AUROC')
seg_roc_obs = Score_Observer('SEG_AUROC')
seg_pro_obs = Score_Observer('SEG_AUPRO')
if c.action_type == 'norm-test':
c.meta_epochs = 1
for epoch in range(c.meta_epochs):
if c.viz:
if c.checkpoint:
load_weights(encoder, decoders, c.checkpoint)
else:
if c.action_type == 'norm-test' and c.checkpoint:
load_weights(encoder, decoders, c.checkpoint)
elif c.action_type == 'norm-train':
print('Train meta epoch: {}'.format(epoch))
train_meta_epoch(c, epoch, train_loader, encoder, decoders, optimizer, pool_layers, N)
else:
raise NotImplementedError('{} is not supported action type!'.format(c.action_type))
#height, width, test_image_list, test_dist, gt_label_list, gt_mask_list = test_meta_fps(
# c, epoch, test_loader, encoder, decoders, pool_layers, N)
@ -315,7 +316,6 @@ def train(c):
test_norm-= torch.max(test_norm) # normalize likelihoods to (-Inf:0] by subtracting a constant
test_prob = torch.exp(test_norm) # convert to probs in range [0:1]
test_mask = test_prob.reshape(-1, height[l], width[l])
#print('Prob shape:', test_prob.shape, test_prob.min(), test_prob.max())
test_mask = test_prob.reshape(-1, height[l], width[l])
# upsample
test_map[l] = F.interpolate(test_mask.unsqueeze(1),
@ -325,17 +325,19 @@ def train(c):
for l, p in enumerate(pool_layers):
score_map += test_map[l]
score_mask = score_map
# superpixels
super_mask = score_mask.max() - score_mask # /score_mask.max() # normality score to anomaly score
# invert probs to anomaly scores
super_mask = score_mask.max() - score_mask
# calculate detection AUROC
score_label = np.max(super_mask, axis=(1, 2))
gt_label = np.asarray(gt_label_list, dtype=np.bool)
det_roc_auc = roc_auc_score(gt_label, score_label)
det_roc_obs.update(100.0*det_roc_auc, epoch)
_ = det_roc_obs.update(100.0*det_roc_auc, epoch)
# calculate segmentation AUROC
gt_mask = np.squeeze(np.asarray(gt_mask_list, dtype=np.bool), axis=1)
seg_roc_auc = roc_auc_score(gt_mask.flatten(), super_mask.flatten())
seg_roc_obs.update(100.0*seg_roc_auc, epoch)
save_best_seg_weights = seg_roc_obs.update(100.0*seg_roc_auc, epoch)
if save_best_seg_weights and c.action_type != 'norm-test':
save_weights(encoder, decoders, c.model, run_date) # avoid unnecessary saves
# calculate segmentation AUPRO
# from https://github.com/YoungGod/DFR:
if c.pro: # and (epoch % 4 == 0): # AUPRO is expensive to compute
@ -403,7 +405,9 @@ def train(c):
fprs_selected = rescale(fprs_selected) # rescale fpr [0,0.3] -> [0, 1]
pros_mean_selected = pros_mean[idx]
seg_pro_auc = auc(fprs_selected, pros_mean_selected)
seg_pro_obs.update(100.0*seg_pro_auc, epoch)
_ = seg_pro_obs.update(100.0*seg_pro_auc, epoch)
#
save_results(det_roc_obs, seg_roc_obs, seg_pro_obs, c.model, c.class_name, run_date)
# export visualuzations
if c.viz:
precision, recall, thresholds = precision_recall_curve(gt_label, score_label)
@ -422,7 +426,3 @@ def train(c):
export_scores(c, test_image_list, super_mask, seg_threshold)
export_test_images(c, test_image_list, gt_mask, super_mask, seg_threshold)
export_hist(c, gt_mask, super_mask, seg_threshold)
#save_weights(encoder, decoders, c.model, run_date) # avoid unnecessary saves
elif c.save_results:
save_results(det_roc_obs, seg_roc_obs, seg_pro_obs, c.model, c.class_name, run_date)
save_weights(encoder, decoders, c.model, run_date) # avoid unnecessary saves

View File

@ -15,11 +15,15 @@ class Score_Observer:
def update(self, score, epoch, print_score=True):
self.last = score
save_weights = False
if epoch == 0 or score > self.max_score:
self.max_score = score
self.max_epoch = epoch
save_weights = True
if print_score:
self.print_score()
return save_weights
def print_score(self):
print('{:s}: \t last: {:.2f} \t max: {:.2f} \t epoch_max: {:d}'.format(