added testing using pretrains
parent
fb604ccb07
commit
4f6bc9c254
|
@ -25,7 +25,7 @@ def get_args():
|
|||
parser.add_argument('-inp', '--input-size', default=256, type=int, metavar='C',
|
||||
help='image resize dimensions (default: 256)')
|
||||
parser.add_argument("--action-type", default='norm-train', type=str, metavar='T',
|
||||
help='norm-train (default: norm-train)')
|
||||
help='norm-train/norm-test (default: norm-train)')
|
||||
parser.add_argument('-bs', '--batch-size', default=32, type=int, metavar='B',
|
||||
help='train batch size (default: 32)')
|
||||
parser.add_argument('--lr', type=float, default=2e-4, metavar='LR',
|
||||
|
|
6
main.py
6
main.py
|
@ -18,9 +18,7 @@ def init_seeds(seed=0):
|
|||
|
||||
def main(c):
|
||||
# model
|
||||
if c.action_type == 'video-train':
|
||||
c.model = "{}_{}_{}".format(c.enc_arch, c.dec_arch, c.video_path)
|
||||
elif c.action_type == 'norm-train' or c.action_type == 'norm-test':
|
||||
if c.action_type in ['norm-train', 'norm-test']:
|
||||
c.model = "{}_{}_{}_pl{}_cb{}_inp{}_run{}_{}".format(
|
||||
c.dataset, c.enc_arch, c.dec_arch, c.pool_layers, c.coupling_blocks, c.input_size, c.run_name, c.class_name)
|
||||
else:
|
||||
|
@ -81,7 +79,7 @@ def main(c):
|
|||
init_seeds(seed=int(time.time()))
|
||||
c.device = torch.device("cuda" if c.use_cuda else "cpu")
|
||||
# selected function:
|
||||
if c.action_type == 'norm-train':
|
||||
if c.action_type in ['norm-train', 'norm-test']:
|
||||
train(c)
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported action-type!'.format(c.action_type))
|
||||
|
|
32
train.py
32
train.py
|
@ -279,8 +279,6 @@ def train(c):
|
|||
elif c.dataset == 'stc':
|
||||
train_dataset = StcDataset(c, is_train=True)
|
||||
test_dataset = StcDataset(c, is_train=False)
|
||||
#elif c.dataset == 'video':
|
||||
# c.data_path = c.video_path
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported dataset!'.format(c.dataset))
|
||||
#
|
||||
|
@ -293,13 +291,16 @@ def train(c):
|
|||
det_roc_obs = Score_Observer('DET_AUROC')
|
||||
seg_roc_obs = Score_Observer('SEG_AUROC')
|
||||
seg_pro_obs = Score_Observer('SEG_AUPRO')
|
||||
if c.action_type == 'norm-test':
|
||||
c.meta_epochs = 1
|
||||
for epoch in range(c.meta_epochs):
|
||||
if c.viz:
|
||||
if c.checkpoint:
|
||||
load_weights(encoder, decoders, c.checkpoint)
|
||||
else:
|
||||
if c.action_type == 'norm-test' and c.checkpoint:
|
||||
load_weights(encoder, decoders, c.checkpoint)
|
||||
elif c.action_type == 'norm-train':
|
||||
print('Train meta epoch: {}'.format(epoch))
|
||||
train_meta_epoch(c, epoch, train_loader, encoder, decoders, optimizer, pool_layers, N)
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported action type!'.format(c.action_type))
|
||||
|
||||
#height, width, test_image_list, test_dist, gt_label_list, gt_mask_list = test_meta_fps(
|
||||
# c, epoch, test_loader, encoder, decoders, pool_layers, N)
|
||||
|
@ -315,7 +316,6 @@ def train(c):
|
|||
test_norm-= torch.max(test_norm) # normalize likelihoods to (-Inf:0] by subtracting a constant
|
||||
test_prob = torch.exp(test_norm) # convert to probs in range [0:1]
|
||||
test_mask = test_prob.reshape(-1, height[l], width[l])
|
||||
#print('Prob shape:', test_prob.shape, test_prob.min(), test_prob.max())
|
||||
test_mask = test_prob.reshape(-1, height[l], width[l])
|
||||
# upsample
|
||||
test_map[l] = F.interpolate(test_mask.unsqueeze(1),
|
||||
|
@ -325,17 +325,19 @@ def train(c):
|
|||
for l, p in enumerate(pool_layers):
|
||||
score_map += test_map[l]
|
||||
score_mask = score_map
|
||||
# superpixels
|
||||
super_mask = score_mask.max() - score_mask # /score_mask.max() # normality score to anomaly score
|
||||
# invert probs to anomaly scores
|
||||
super_mask = score_mask.max() - score_mask
|
||||
# calculate detection AUROC
|
||||
score_label = np.max(super_mask, axis=(1, 2))
|
||||
gt_label = np.asarray(gt_label_list, dtype=np.bool)
|
||||
det_roc_auc = roc_auc_score(gt_label, score_label)
|
||||
det_roc_obs.update(100.0*det_roc_auc, epoch)
|
||||
_ = det_roc_obs.update(100.0*det_roc_auc, epoch)
|
||||
# calculate segmentation AUROC
|
||||
gt_mask = np.squeeze(np.asarray(gt_mask_list, dtype=np.bool), axis=1)
|
||||
seg_roc_auc = roc_auc_score(gt_mask.flatten(), super_mask.flatten())
|
||||
seg_roc_obs.update(100.0*seg_roc_auc, epoch)
|
||||
save_best_seg_weights = seg_roc_obs.update(100.0*seg_roc_auc, epoch)
|
||||
if save_best_seg_weights and c.action_type != 'norm-test':
|
||||
save_weights(encoder, decoders, c.model, run_date) # avoid unnecessary saves
|
||||
# calculate segmentation AUPRO
|
||||
# from https://github.com/YoungGod/DFR:
|
||||
if c.pro: # and (epoch % 4 == 0): # AUPRO is expensive to compute
|
||||
|
@ -403,7 +405,9 @@ def train(c):
|
|||
fprs_selected = rescale(fprs_selected) # rescale fpr [0,0.3] -> [0, 1]
|
||||
pros_mean_selected = pros_mean[idx]
|
||||
seg_pro_auc = auc(fprs_selected, pros_mean_selected)
|
||||
seg_pro_obs.update(100.0*seg_pro_auc, epoch)
|
||||
_ = seg_pro_obs.update(100.0*seg_pro_auc, epoch)
|
||||
#
|
||||
save_results(det_roc_obs, seg_roc_obs, seg_pro_obs, c.model, c.class_name, run_date)
|
||||
# export visualuzations
|
||||
if c.viz:
|
||||
precision, recall, thresholds = precision_recall_curve(gt_label, score_label)
|
||||
|
@ -422,7 +426,3 @@ def train(c):
|
|||
export_scores(c, test_image_list, super_mask, seg_threshold)
|
||||
export_test_images(c, test_image_list, gt_mask, super_mask, seg_threshold)
|
||||
export_hist(c, gt_mask, super_mask, seg_threshold)
|
||||
#save_weights(encoder, decoders, c.model, run_date) # avoid unnecessary saves
|
||||
elif c.save_results:
|
||||
save_results(det_roc_obs, seg_roc_obs, seg_pro_obs, c.model, c.class_name, run_date)
|
||||
save_weights(encoder, decoders, c.model, run_date) # avoid unnecessary saves
|
||||
|
|
4
utils.py
4
utils.py
|
@ -15,11 +15,15 @@ class Score_Observer:
|
|||
|
||||
def update(self, score, epoch, print_score=True):
|
||||
self.last = score
|
||||
save_weights = False
|
||||
if epoch == 0 or score > self.max_score:
|
||||
self.max_score = score
|
||||
self.max_epoch = epoch
|
||||
save_weights = True
|
||||
if print_score:
|
||||
self.print_score()
|
||||
|
||||
return save_weights
|
||||
|
||||
def print_score(self):
|
||||
print('{:s}: \t last: {:.2f} \t max: {:.2f} \t epoch_max: {:d}'.format(
|
||||
|
|
Loading…
Reference in New Issue