mirror of https://github.com/JosephKJ/OWOD.git
Thresholding based auto-labelling
parent
73d7a85f88
commit
2f553726bf
|
@ -55,3 +55,9 @@ If you use Detectron2 in your research or wish to refer to the baseline results
|
|||
year = {2019}
|
||||
}
|
||||
```
|
||||
|
||||
## Command
|
||||
|
||||
```python
|
||||
python tools/train_net.py --num-gpus 4 --config-file ./configs/PascalVOC-Detection/faster_rcnn_R_50_C4.yamlSOLVER.IMS_PER_BATCH 4 SOLVER.BASE_LR 0.005
|
||||
```
|
|
@ -5,14 +5,22 @@ MODEL:
|
|||
RESNETS:
|
||||
DEPTH: 50
|
||||
ROI_HEADS:
|
||||
NUM_CLASSES: 20
|
||||
NUM_CLASSES: 21 # 0-19 Known class; 20 -> Unknown; 21 -> Background.
|
||||
# NUM_CLASSES: 20
|
||||
INPUT:
|
||||
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
|
||||
MIN_SIZE_TEST: 800
|
||||
DATASETS:
|
||||
|
||||
TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')
|
||||
TEST: ('voc_2007_test',)
|
||||
SOLVER:
|
||||
STEPS: (12000, 16000)
|
||||
MAX_ITER: 18000 # 17.4 epochs
|
||||
WARMUP_ITERS: 100
|
||||
OUTPUT_DIR: "./output/al_threshold"
|
||||
OWOD:
|
||||
ENABLE_THRESHOLD_AUTOLABEL_UNK: True
|
||||
NUM_UNK_PER_IMAGE: 1
|
||||
ENABLE_UNCERTAINITY_AUTOLABEL_UNK: False
|
||||
ENABLE_CLUSTERING: False
|
|
@ -0,0 +1,19 @@
|
|||
_BASE_: "../Base-RCNN-C4.yaml"
|
||||
MODEL:
|
||||
WEIGHTS: "/home/joseph/workspace/OWOD/output/baseline_pascal/model_final.pth"
|
||||
MASK_ON: False
|
||||
RESNETS:
|
||||
DEPTH: 50
|
||||
ROI_HEADS:
|
||||
NUM_CLASSES: 20
|
||||
INPUT:
|
||||
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
|
||||
MIN_SIZE_TEST: 800
|
||||
DATASETS:
|
||||
TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')
|
||||
TEST: ('voc_2007_test',)
|
||||
SOLVER:
|
||||
STEPS: (12000, 16000)
|
||||
MAX_ITER: 18000 # 17.4 epochs
|
||||
WARMUP_ITERS: 100
|
||||
OUTPUT_DIR: "./output/baseline_pascal"
|
|
@ -16,3 +16,4 @@ SOLVER:
|
|||
STEPS: (12000, 16000)
|
||||
MAX_ITER: 18000 # 17.4 epochs
|
||||
WARMUP_ITERS: 100
|
||||
OUTPUT_DIR: "./output/baseline_pascal_fpn"
|
||||
|
|
|
@ -595,6 +595,15 @@ _C.TEST.AUG.FLIP = True
|
|||
_C.TEST.PRECISE_BN = CN({"ENABLED": False})
|
||||
_C.TEST.PRECISE_BN.NUM_ITER = 200
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# OpenWorld Object Detection
|
||||
# ---------------------------------------------------------------------------- #
|
||||
_C.OWOD = CN()
|
||||
_C.OWOD.ENABLE_THRESHOLD_AUTOLABEL_UNK = False
|
||||
_C.OWOD.NUM_UNK_PER_IMAGE = 1
|
||||
_C.OWOD.ENABLE_UNCERTAINITY_AUTOLABEL_UNK = False
|
||||
_C.OWOD.ENABLE_CLUSTERING = False
|
||||
|
||||
# ---------------------------------------------------------------------------- #
|
||||
# Misc options
|
||||
# ---------------------------------------------------------------------------- #
|
||||
|
|
|
@ -246,5 +246,6 @@ if __name__.endswith(".builtin"):
|
|||
register_all_lvis(_root)
|
||||
register_all_cityscapes(_root)
|
||||
register_all_cityscapes_panoptic(_root)
|
||||
register_all_pascal_voc(_root)
|
||||
# register_all_pascal_voc(_root)
|
||||
register_all_pascal_voc('/home/joseph/workspace/OWOD/datasets')
|
||||
register_all_ade20k(_root)
|
||||
|
|
|
@ -17,7 +17,7 @@ __all__ = ["load_voc_instances", "register_pascal_voc"]
|
|||
CLASS_NAMES = (
|
||||
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
|
||||
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
|
||||
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
|
||||
"pottedplant", "sheep", "sofa", "train", "tvmonitor", "unknown"
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
|
|
|
@ -87,6 +87,9 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
|||
res_file_template = os.path.join(dirname, "{}.txt")
|
||||
|
||||
aps = defaultdict(list) # iou -> ap per class
|
||||
# recs = defaultdict(list)
|
||||
# precs = defaultdict(list)
|
||||
|
||||
for cls_id, cls_name in enumerate(self._class_names):
|
||||
lines = predictions.get(cls_id, [""])
|
||||
|
||||
|
@ -103,10 +106,27 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
|
|||
use_07_metric=self._is_2007,
|
||||
)
|
||||
aps[thresh].append(ap * 100)
|
||||
# recs[thresh].append(rec * 100)
|
||||
# precs[thresh].append(prec * 100)
|
||||
|
||||
ret = OrderedDict()
|
||||
mAP = {iou: np.mean(x) for iou, x in aps.items()}
|
||||
ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]}
|
||||
|
||||
# Extra logging of class-wise APs
|
||||
self._logger.info(self._class_names)
|
||||
self._logger.info("AP__: " + str(['%.1f' % x for x in list(np.mean([x for _, x in aps.items()], axis=0))]))
|
||||
self._logger.info("AP50: " + str(['%.1f' % x for x in aps[50]]))
|
||||
self._logger.info("AP75: " + str(['%.1f' % x for x in aps[75]]))
|
||||
|
||||
# self._logger.info("R__: " + str(['%.1f' % x for x in list(np.mean([x for _, x in recs.items()], axis=0))]))
|
||||
# self._logger.info("R50: " + str(['%.1f' % x for x in recs[50]]))
|
||||
# self._logger.info("R75: " + str(['%.1f' % x for x in recs[75]]))
|
||||
#
|
||||
# self._logger.info("P__: " + str(['%.1f' % x for x in list(np.mean([x for _, x in precs.items()], axis=0))]))
|
||||
# self._logger.info("P50: " + str(['%.1f' % x for x in precs[50]]))
|
||||
# self._logger.info("P75: " + str(['%.1f' % x for x in precs[75]]))
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
import inspect
|
||||
import logging
|
||||
import numpy as np
|
||||
import heapq
|
||||
import operator
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -143,7 +145,9 @@ class ROIHeads(torch.nn.Module):
|
|||
batch_size_per_image,
|
||||
positive_fraction,
|
||||
proposal_matcher,
|
||||
proposal_append_gt=True
|
||||
enable_thresold_autolabelling,
|
||||
unk_k,
|
||||
proposal_append_gt=True,
|
||||
):
|
||||
"""
|
||||
NOTE: this interface is experimental.
|
||||
|
@ -162,6 +166,8 @@ class ROIHeads(torch.nn.Module):
|
|||
self.num_classes = num_classes
|
||||
self.proposal_matcher = proposal_matcher
|
||||
self.proposal_append_gt = proposal_append_gt
|
||||
self.enable_thresold_autolabelling = enable_thresold_autolabelling
|
||||
self.unk_k = unk_k
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg):
|
||||
|
@ -176,10 +182,13 @@ class ROIHeads(torch.nn.Module):
|
|||
cfg.MODEL.ROI_HEADS.IOU_LABELS,
|
||||
allow_low_quality_matches=False,
|
||||
),
|
||||
"enable_thresold_autolabelling": cfg.OWOD.ENABLE_THRESHOLD_AUTOLABEL_UNK,
|
||||
"unk_k": cfg.OWOD.NUM_UNK_PER_IMAGE,
|
||||
}
|
||||
|
||||
def _sample_proposals(
|
||||
self, matched_idxs: torch.Tensor, matched_labels: torch.Tensor, gt_classes: torch.Tensor
|
||||
self, matched_idxs: torch.Tensor, matched_labels: torch.Tensor, gt_classes: torch.Tensor,
|
||||
objectness_logits: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Based on the matching between N proposals and M groundtruth,
|
||||
|
@ -214,7 +223,22 @@ class ROIHeads(torch.nn.Module):
|
|||
)
|
||||
|
||||
sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0)
|
||||
return sampled_idxs, gt_classes[sampled_idxs]
|
||||
gt_classes_ss = gt_classes[sampled_idxs]
|
||||
|
||||
if self.enable_thresold_autolabelling:
|
||||
matched_labels_ss = matched_labels[sampled_idxs]
|
||||
pred_objectness_score_ss = objectness_logits[sampled_idxs]
|
||||
|
||||
# 1) Remove FG objectness score. 2) Sort and select top k. 3) Build and apply mask.
|
||||
mask = torch.zeros((pred_objectness_score_ss.shape), dtype=torch.bool)
|
||||
pred_objectness_score_ss[matched_labels_ss != 0] = -1
|
||||
sorted_indices = list(zip(
|
||||
*heapq.nlargest(self.unk_k, enumerate(pred_objectness_score_ss), key=operator.itemgetter(1))))[0]
|
||||
for index in sorted_indices:
|
||||
mask[index] = True
|
||||
gt_classes_ss[mask] = self.num_classes - 1
|
||||
|
||||
return sampled_idxs, gt_classes_ss
|
||||
|
||||
@torch.no_grad()
|
||||
def label_and_sample_proposals(
|
||||
|
@ -269,7 +293,7 @@ class ROIHeads(torch.nn.Module):
|
|||
)
|
||||
matched_idxs, matched_labels = self.proposal_matcher(match_quality_matrix)
|
||||
sampled_idxs, gt_classes = self._sample_proposals(
|
||||
matched_idxs, matched_labels, targets_per_image.gt_classes
|
||||
matched_idxs, matched_labels, targets_per_image.gt_classes, proposals_per_image.objectness_logits
|
||||
)
|
||||
|
||||
# Set target attributes of the sampled proposals:
|
||||
|
|
Loading…
Reference in New Issue