Thresholding based auto-labelling

Nov3
Joseph 2020-10-04 09:59:13 +05:30
parent 73d7a85f88
commit 2f553726bf
9 changed files with 95 additions and 7 deletions

View File

@ -55,3 +55,9 @@ If you use Detectron2 in your research or wish to refer to the baseline results
year = {2019} year = {2019}
} }
``` ```
## Command
```python
python tools/train_net.py --num-gpus 4 --config-file ./configs/PascalVOC-Detection/faster_rcnn_R_50_C4.yamlSOLVER.IMS_PER_BATCH 4 SOLVER.BASE_LR 0.005
```

View File

@ -5,14 +5,22 @@ MODEL:
RESNETS: RESNETS:
DEPTH: 50 DEPTH: 50
ROI_HEADS: ROI_HEADS:
NUM_CLASSES: 20 NUM_CLASSES: 21 # 0-19 Known class; 20 -> Unknown; 21 -> Background.
# NUM_CLASSES: 20
INPUT: INPUT:
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
MIN_SIZE_TEST: 800 MIN_SIZE_TEST: 800
DATASETS: DATASETS:
TRAIN: ('voc_2007_trainval', 'voc_2012_trainval') TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')
TEST: ('voc_2007_test',) TEST: ('voc_2007_test',)
SOLVER: SOLVER:
STEPS: (12000, 16000) STEPS: (12000, 16000)
MAX_ITER: 18000 # 17.4 epochs MAX_ITER: 18000 # 17.4 epochs
WARMUP_ITERS: 100 WARMUP_ITERS: 100
OUTPUT_DIR: "./output/al_threshold"
OWOD:
ENABLE_THRESHOLD_AUTOLABEL_UNK: True
NUM_UNK_PER_IMAGE: 1
ENABLE_UNCERTAINITY_AUTOLABEL_UNK: False
ENABLE_CLUSTERING: False

View File

@ -0,0 +1,19 @@
_BASE_: "../Base-RCNN-C4.yaml"
MODEL:
WEIGHTS: "/home/joseph/workspace/OWOD/output/baseline_pascal/model_final.pth"
MASK_ON: False
RESNETS:
DEPTH: 50
ROI_HEADS:
NUM_CLASSES: 20
INPUT:
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
MIN_SIZE_TEST: 800
DATASETS:
TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')
TEST: ('voc_2007_test',)
SOLVER:
STEPS: (12000, 16000)
MAX_ITER: 18000 # 17.4 epochs
WARMUP_ITERS: 100
OUTPUT_DIR: "./output/baseline_pascal"

View File

@ -16,3 +16,4 @@ SOLVER:
STEPS: (12000, 16000) STEPS: (12000, 16000)
MAX_ITER: 18000 # 17.4 epochs MAX_ITER: 18000 # 17.4 epochs
WARMUP_ITERS: 100 WARMUP_ITERS: 100
OUTPUT_DIR: "./output/baseline_pascal_fpn"

View File

@ -595,6 +595,15 @@ _C.TEST.AUG.FLIP = True
_C.TEST.PRECISE_BN = CN({"ENABLED": False}) _C.TEST.PRECISE_BN = CN({"ENABLED": False})
_C.TEST.PRECISE_BN.NUM_ITER = 200 _C.TEST.PRECISE_BN.NUM_ITER = 200
# ---------------------------------------------------------------------------- #
# OpenWorld Object Detection
# ---------------------------------------------------------------------------- #
_C.OWOD = CN()
_C.OWOD.ENABLE_THRESHOLD_AUTOLABEL_UNK = False
_C.OWOD.NUM_UNK_PER_IMAGE = 1
_C.OWOD.ENABLE_UNCERTAINITY_AUTOLABEL_UNK = False
_C.OWOD.ENABLE_CLUSTERING = False
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# Misc options # Misc options
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #

View File

@ -246,5 +246,6 @@ if __name__.endswith(".builtin"):
register_all_lvis(_root) register_all_lvis(_root)
register_all_cityscapes(_root) register_all_cityscapes(_root)
register_all_cityscapes_panoptic(_root) register_all_cityscapes_panoptic(_root)
register_all_pascal_voc(_root) # register_all_pascal_voc(_root)
register_all_pascal_voc('/home/joseph/workspace/OWOD/datasets')
register_all_ade20k(_root) register_all_ade20k(_root)

View File

@ -17,7 +17,7 @@ __all__ = ["load_voc_instances", "register_pascal_voc"]
CLASS_NAMES = ( CLASS_NAMES = (
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
"pottedplant", "sheep", "sofa", "train", "tvmonitor" "pottedplant", "sheep", "sofa", "train", "tvmonitor", "unknown"
) )
# fmt: on # fmt: on

View File

@ -87,6 +87,9 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
res_file_template = os.path.join(dirname, "{}.txt") res_file_template = os.path.join(dirname, "{}.txt")
aps = defaultdict(list) # iou -> ap per class aps = defaultdict(list) # iou -> ap per class
# recs = defaultdict(list)
# precs = defaultdict(list)
for cls_id, cls_name in enumerate(self._class_names): for cls_id, cls_name in enumerate(self._class_names):
lines = predictions.get(cls_id, [""]) lines = predictions.get(cls_id, [""])
@ -103,10 +106,27 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
use_07_metric=self._is_2007, use_07_metric=self._is_2007,
) )
aps[thresh].append(ap * 100) aps[thresh].append(ap * 100)
# recs[thresh].append(rec * 100)
# precs[thresh].append(prec * 100)
ret = OrderedDict() ret = OrderedDict()
mAP = {iou: np.mean(x) for iou, x in aps.items()} mAP = {iou: np.mean(x) for iou, x in aps.items()}
ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]} ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]}
# Extra logging of class-wise APs
self._logger.info(self._class_names)
self._logger.info("AP__: " + str(['%.1f' % x for x in list(np.mean([x for _, x in aps.items()], axis=0))]))
self._logger.info("AP50: " + str(['%.1f' % x for x in aps[50]]))
self._logger.info("AP75: " + str(['%.1f' % x for x in aps[75]]))
# self._logger.info("R__: " + str(['%.1f' % x for x in list(np.mean([x for _, x in recs.items()], axis=0))]))
# self._logger.info("R50: " + str(['%.1f' % x for x in recs[50]]))
# self._logger.info("R75: " + str(['%.1f' % x for x in recs[75]]))
#
# self._logger.info("P__: " + str(['%.1f' % x for x in list(np.mean([x for _, x in precs.items()], axis=0))]))
# self._logger.info("P50: " + str(['%.1f' % x for x in precs[50]]))
# self._logger.info("P75: " + str(['%.1f' % x for x in precs[75]]))
return ret return ret

View File

@ -2,6 +2,8 @@
import inspect import inspect
import logging import logging
import numpy as np import numpy as np
import heapq
import operator
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
@ -143,7 +145,9 @@ class ROIHeads(torch.nn.Module):
batch_size_per_image, batch_size_per_image,
positive_fraction, positive_fraction,
proposal_matcher, proposal_matcher,
proposal_append_gt=True enable_thresold_autolabelling,
unk_k,
proposal_append_gt=True,
): ):
""" """
NOTE: this interface is experimental. NOTE: this interface is experimental.
@ -162,6 +166,8 @@ class ROIHeads(torch.nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.proposal_matcher = proposal_matcher self.proposal_matcher = proposal_matcher
self.proposal_append_gt = proposal_append_gt self.proposal_append_gt = proposal_append_gt
self.enable_thresold_autolabelling = enable_thresold_autolabelling
self.unk_k = unk_k
@classmethod @classmethod
def from_config(cls, cfg): def from_config(cls, cfg):
@ -176,10 +182,13 @@ class ROIHeads(torch.nn.Module):
cfg.MODEL.ROI_HEADS.IOU_LABELS, cfg.MODEL.ROI_HEADS.IOU_LABELS,
allow_low_quality_matches=False, allow_low_quality_matches=False,
), ),
"enable_thresold_autolabelling": cfg.OWOD.ENABLE_THRESHOLD_AUTOLABEL_UNK,
"unk_k": cfg.OWOD.NUM_UNK_PER_IMAGE,
} }
def _sample_proposals( def _sample_proposals(
self, matched_idxs: torch.Tensor, matched_labels: torch.Tensor, gt_classes: torch.Tensor self, matched_idxs: torch.Tensor, matched_labels: torch.Tensor, gt_classes: torch.Tensor,
objectness_logits: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Based on the matching between N proposals and M groundtruth, Based on the matching between N proposals and M groundtruth,
@ -214,7 +223,22 @@ class ROIHeads(torch.nn.Module):
) )
sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0) sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0)
return sampled_idxs, gt_classes[sampled_idxs] gt_classes_ss = gt_classes[sampled_idxs]
if self.enable_thresold_autolabelling:
matched_labels_ss = matched_labels[sampled_idxs]
pred_objectness_score_ss = objectness_logits[sampled_idxs]
# 1) Remove FG objectness score. 2) Sort and select top k. 3) Build and apply mask.
mask = torch.zeros((pred_objectness_score_ss.shape), dtype=torch.bool)
pred_objectness_score_ss[matched_labels_ss != 0] = -1
sorted_indices = list(zip(
*heapq.nlargest(self.unk_k, enumerate(pred_objectness_score_ss), key=operator.itemgetter(1))))[0]
for index in sorted_indices:
mask[index] = True
gt_classes_ss[mask] = self.num_classes - 1
return sampled_idxs, gt_classes_ss
@torch.no_grad() @torch.no_grad()
def label_and_sample_proposals( def label_and_sample_proposals(
@ -269,7 +293,7 @@ class ROIHeads(torch.nn.Module):
) )
matched_idxs, matched_labels = self.proposal_matcher(match_quality_matrix) matched_idxs, matched_labels = self.proposal_matcher(match_quality_matrix)
sampled_idxs, gt_classes = self._sample_proposals( sampled_idxs, gt_classes = self._sample_proposals(
matched_idxs, matched_labels, targets_per_image.gt_classes matched_idxs, matched_labels, targets_per_image.gt_classes, proposals_per_image.objectness_logits
) )
# Set target attributes of the sampled proposals: # Set target attributes of the sampled proposals: