- Adding voc style coco

- yaml for each tasks
- Modifying the loss functions accordingly
Nov3
Joseph 2020-10-26 09:27:36 +05:30
parent e662289ad4
commit 7c927204af
26 changed files with 136263 additions and 40 deletions

View File

@ -0,0 +1,36 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
RPN:
PRE_NMS_TOPK_TEST: 6000
POST_NMS_TOPK_TEST: 1000
ROI_HEADS:
# NUM_CLASSES: 81 # 0-79 Known class; 80 -> Unknown; 81 -> Background.
NUM_CLASSES: 81
NAME: "Res5ROIHeads"
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
MASK_ON: False
RESNETS:
DEPTH: 50
INPUT:
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
MIN_SIZE_TEST: 800
DATASETS:
TRAIN: ("coco_2017_train",)
TEST: ("coco_2017_val",)
SOLVER:
IMS_PER_BATCH: 16
BASE_LR: 0.02
STEPS: (60000, 80000)
MAX_ITER: 90000
VERSION: 2
OWOD:
ENABLE_THRESHOLD_AUTOLABEL_UNK: True
NUM_UNK_PER_IMAGE: 1
ENABLE_UNCERTAINITY_AUTOLABEL_UNK: False
ENABLE_CLUSTERING: True
CLUSTERING:
ITEMS_PER_CLASS: 20
START_ITER: 20
UPDATE_MU_ITER: 3000
MOMENTUM: 0.99
Z_DIMENSION: 128

View File

@ -0,0 +1,14 @@
_BASE_: "../Base-RCNN-C4-OWOD.yaml"
MODEL:
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
DATASETS:
TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')
TEST: ('voc_2007_test','t2_test_unk', 't3_test_unk', 't4_test_unk')
SOLVER:
STEPS: (12000, 16000)
MAX_ITER: 20
WARMUP_ITERS: 100
OUTPUT_DIR: "./output/t1"
OWOD:
PREV_INTRODUCED_CLS: 0
CUR_INTRODUCED_CLS: 20

View File

@ -0,0 +1,14 @@
_BASE_: "../Base-RCNN-C4-OWOD.yaml"
MODEL:
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
DATASETS:
TRAIN: ('t2_train',)
TEST: ('voc_2007_test', 't2_test', 't3_test_unk', 't4_test_unk')
SOLVER:
STEPS: (22000, 26000)
MAX_ITER: 28000
WARMUP_ITERS: 0
OUTPUT_DIR: "./output/t2"
OWOD:
PREV_INTRODUCED_CLS: 20
CUR_INTRODUCED_CLS: 20

View File

@ -0,0 +1,14 @@
_BASE_: "../Base-RCNN-C4-OWOD.yaml"
MODEL:
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
DATASETS:
TRAIN: ('t3_train',)
TEST: ('voc_2007_test','t2_test', 't3_test', 't4_test_unk')
SOLVER:
STEPS: (32000, 36000)
MAX_ITER: 38000
WARMUP_ITERS: 0
OUTPUT_DIR: "./output/t3"
OWOD:
PREV_INTRODUCED_CLS: 40
CUR_INTRODUCED_CLS: 20

View File

@ -0,0 +1,14 @@
_BASE_: "../Base-RCNN-C4-OWOD.yaml"
MODEL:
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl"
DATASETS:
TRAIN: ('t4_train',)
TEST: ('voc_2007_test','t2_test', 't3_test', 't4_test')
SOLVER:
STEPS: (42000, 46000)
MAX_ITER: 48000
WARMUP_ITERS: 0
OUTPUT_DIR: "./output/t4"
OWOD:
PREV_INTRODUCED_CLS: 60
CUR_INTRODUCED_CLS: 20

View File

@ -15,12 +15,12 @@ DATASETS:
TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')
TEST: ('voc_2007_test',)
SOLVER:
STEPS: (37000, 40000)
STEPS: (27000, 30000)
MAX_ITER: 46000 # 17.4 epochs
# STEPS: (12000, 16000)
# MAX_ITER: 18000 # 17.4 epochs
WARMUP_ITERS: 100
OUTPUT_DIR: "./output/baseline_run_46000"
OUTPUT_DIR: "./output/baseline_run"
OWOD:
ENABLE_THRESHOLD_AUTOLABEL_UNK: False
NUM_UNK_PER_IMAGE: 1

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,40 @@
import xml.etree.cElementTree as ET
import os
from pycocotools.coco import COCO
def coco_to_voc_detection(coco_annotation_file, target_folder):
os.makedirs(os.path.join(target_folder, 'Annotations'), exist_ok=True)
coco_instance = COCO(coco_annotation_file)
for index, image_id in enumerate(coco_instance.imgToAnns):
image_details = coco_instance.imgs[image_id]
annotation_el = ET.Element('annotation')
ET.SubElement(annotation_el, 'filename').text = image_details['file_name']
size_el = ET.SubElement(annotation_el, 'size')
ET.SubElement(size_el, 'width').text = str(image_details['width'])
ET.SubElement(size_el, 'height').text = str(image_details['height'])
ET.SubElement(size_el, 'depth').text = str(3)
for annotation in coco_instance.imgToAnns[image_id]:
object_el = ET.SubElement(annotation_el, 'object')
ET.SubElement(object_el,'name').text = coco_instance.cats[annotation['category_id']]['name']
# ET.SubElement(object_el, 'name').text = 'unknown'
ET.SubElement(object_el, 'difficult').text = '0'
bb_el = ET.SubElement(object_el, 'bndbox')
ET.SubElement(bb_el, 'xmin').text = str(int(annotation['bbox'][0] + 1.0))
ET.SubElement(bb_el, 'ymin').text = str(int(annotation['bbox'][1] + 1.0))
ET.SubElement(bb_el, 'xmax').text = str(int(annotation['bbox'][0] + annotation['bbox'][2] + 1.0))
ET.SubElement(bb_el, 'ymax').text = str(int(annotation['bbox'][1] + annotation['bbox'][3] + 1.0))
ET.ElementTree(annotation_el).write(os.path.join(target_folder, 'Annotations', image_details['file_name'].split('.')[0] + '.xml'))
if index % 10000 == 0:
print('Processed ' + str(index) + ' images.')
if __name__ == '__main__':
coco_annotation_file = '/home/fk1/workspace/datasets/annotations/instances_val2017.json'
target_folder = '/home/fk1/workspace/OWOD/datasets/coco17_voc_style'
coco_to_voc_detection(coco_annotation_file, target_folder)

View File

@ -0,0 +1,63 @@
from pycocotools.coco import COCO
import numpy as np
T2_CLASS_NAMES = [
"truck", "traffic light", "fire hydrant", "stop sign", "parking meter",
"bench", "elephant", "bear", "zebra", "giraffe",
"backpack", "umbrella", "handbag", "tie", "suitcase",
"microwave", "oven", "toaster", "sink", "refrigerator"
]
# Train
coco_annotation_file = '/home/joseph/workspace/datasets/mscoco/annotations/instances_train2017.json'
dest_file = '/home/joseph/workspace/OWOD/datasets/coco17_voc_style/ImageSets/t2_train.txt'
coco_instance = COCO(coco_annotation_file)
image_ids = []
cls = []
for index, image_id in enumerate(coco_instance.imgToAnns):
image_details = coco_instance.imgs[image_id]
classes = [coco_instance.cats[annotation['category_id']]['name'] for annotation in coco_instance.imgToAnns[image_id]]
if not set(classes).isdisjoint(T2_CLASS_NAMES):
image_ids.append(image_details['file_name'].split('.')[0])
cls.extend(classes)
(unique, counts) = np.unique(cls, return_counts=True)
print({x:y for x,y in zip(unique, counts)})
with open(dest_file, 'w') as file:
for image_id in image_ids:
file.write(str(image_id)+'\n')
print('Created train file')
# Test
coco_annotation_file = '/home/joseph/workspace/datasets/mscoco/annotations/instances_val2017.json'
dest_file = '/home/joseph/workspace/OWOD/datasets/coco17_voc_style/ImageSets/t2_test.txt'
coco_instance = COCO(coco_annotation_file)
image_ids = []
cls = []
for index, image_id in enumerate(coco_instance.imgToAnns):
image_details = coco_instance.imgs[image_id]
classes = [coco_instance.cats[annotation['category_id']]['name'] for annotation in coco_instance.imgToAnns[image_id]]
if not set(classes).isdisjoint(T2_CLASS_NAMES):
image_ids.append(image_details['file_name'].split('.')[0])
cls.extend(classes)
(unique, counts) = np.unique(cls, return_counts=True)
print({x:y for x,y in zip(unique, counts)})
with open(dest_file, 'w') as file:
for image_id in image_ids:
file.write(str(image_id)+'\n')
print('Created test file')
dest_file = '/home/joseph/workspace/OWOD/datasets/coco17_voc_style/ImageSets/t2_test_unk.txt'
with open(dest_file, 'w') as file:
for image_id in image_ids:
file.write(str(image_id)+'\n')
print('Created test_unk file')

View File

@ -0,0 +1,63 @@
from pycocotools.coco import COCO
import numpy as np
T3_CLASS_NAMES = [
"frisbee", "skis", "snowboard", "sports ball", "kite",
"baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
"banana", "apple", "sandwich", "orange", "broccoli",
"carrot", "hot dog", "pizza", "donut", "cake"
]
# Train
coco_annotation_file = '/home/joseph/workspace/datasets/mscoco/annotations/instances_train2017.json'
dest_file = '/home/joseph/workspace/OWOD/datasets/coco17_voc_style/ImageSets/t3_train.txt'
coco_instance = COCO(coco_annotation_file)
image_ids = []
cls = []
for index, image_id in enumerate(coco_instance.imgToAnns):
image_details = coco_instance.imgs[image_id]
classes = [coco_instance.cats[annotation['category_id']]['name'] for annotation in coco_instance.imgToAnns[image_id]]
if not set(classes).isdisjoint(T3_CLASS_NAMES):
image_ids.append(image_details['file_name'].split('.')[0])
cls.extend(classes)
(unique, counts) = np.unique(cls, return_counts=True)
print({x:y for x,y in zip(unique, counts)})
with open(dest_file, 'w') as file:
for image_id in image_ids:
file.write(str(image_id)+'\n')
print('Created train file')
# Test
coco_annotation_file = '/home/joseph/workspace/datasets/mscoco/annotations/instances_val2017.json'
dest_file = '/home/joseph/workspace/OWOD/datasets/coco17_voc_style/ImageSets/t3_test.txt'
coco_instance = COCO(coco_annotation_file)
image_ids = []
cls = []
for index, image_id in enumerate(coco_instance.imgToAnns):
image_details = coco_instance.imgs[image_id]
classes = [coco_instance.cats[annotation['category_id']]['name'] for annotation in coco_instance.imgToAnns[image_id]]
if not set(classes).isdisjoint(T3_CLASS_NAMES):
image_ids.append(image_details['file_name'].split('.')[0])
cls.extend(classes)
(unique, counts) = np.unique(cls, return_counts=True)
print({x:y for x,y in zip(unique, counts)})
with open(dest_file, 'w') as file:
for image_id in image_ids:
file.write(str(image_id)+'\n')
print('Created test file')
dest_file = '/home/joseph/workspace/OWOD/datasets/coco17_voc_style/ImageSets/t3_test_unk.txt'
with open(dest_file, 'w') as file:
for image_id in image_ids:
file.write(str(image_id)+'\n')
print('Created test_unk file')

View File

@ -0,0 +1,63 @@
from pycocotools.coco import COCO
import numpy as np
T4_CLASS_NAMES = [
"bed", "toilet", "laptop", "mouse",
"remote", "keyboard", "cell phone", "book", "clock",
"vase", "scissors", "teddy bear", "hair drier", "toothbrush",
"wine glass", "cup", "fork", "knife", "spoon", "bowl"
]
# Train
coco_annotation_file = '/home/joseph/workspace/datasets/mscoco/annotations/instances_train2017.json'
dest_file = '/home/joseph/workspace/OWOD/datasets/coco17_voc_style/ImageSets/t4_train.txt'
coco_instance = COCO(coco_annotation_file)
image_ids = []
cls = []
for index, image_id in enumerate(coco_instance.imgToAnns):
image_details = coco_instance.imgs[image_id]
classes = [coco_instance.cats[annotation['category_id']]['name'] for annotation in coco_instance.imgToAnns[image_id]]
if not set(classes).isdisjoint(T4_CLASS_NAMES):
image_ids.append(image_details['file_name'].split('.')[0])
cls.extend(classes)
(unique, counts) = np.unique(cls, return_counts=True)
print({x:y for x,y in zip(unique, counts)})
with open(dest_file, 'w') as file:
for image_id in image_ids:
file.write(str(image_id)+'\n')
print('Created train file')
# Test
coco_annotation_file = '/home/joseph/workspace/datasets/mscoco/annotations/instances_val2017.json'
dest_file = '/home/joseph/workspace/OWOD/datasets/coco17_voc_style/ImageSets/t4_test.txt'
coco_instance = COCO(coco_annotation_file)
image_ids = []
cls = []
for index, image_id in enumerate(coco_instance.imgToAnns):
image_details = coco_instance.imgs[image_id]
classes = [coco_instance.cats[annotation['category_id']]['name'] for annotation in coco_instance.imgToAnns[image_id]]
if not set(classes).isdisjoint(T4_CLASS_NAMES):
image_ids.append(image_details['file_name'].split('.')[0])
cls.extend(classes)
(unique, counts) = np.unique(cls, return_counts=True)
print({x:y for x,y in zip(unique, counts)})
with open(dest_file, 'w') as file:
for image_id in image_ids:
file.write(str(image_id)+'\n')
print('Created test file')
dest_file = '/home/joseph/workspace/OWOD/datasets/coco17_voc_style/ImageSets/t4_test_unk.txt'
with open(dest_file, 'w') as file:
for image_id in image_ids:
file.write(str(image_id)+'\n')
print('Created test_unk file')

View File

@ -611,6 +611,8 @@ _C.OWOD.CLUSTERING.UPDATE_MU_ITER = 200
_C.OWOD.CLUSTERING.MOMENTUM = 0.9
_C.OWOD.CLUSTERING.Z_DIMENSION = 64
_C.OWOD.PREV_INTRODUCED_CLS = 0
_C.OWOD.CUR_INTRODUCED_CLS = 20
# ---------------------------------------------------------------------------- #
# Misc options

View File

@ -27,6 +27,7 @@ from .cityscapes_panoptic import register_all_cityscapes_panoptic
from .coco import load_sem_seg
from .lvis import get_lvis_instances_meta, register_lvis_instances
from .pascal_voc import register_pascal_voc
from .voc_style_coco import register_voc_style_coco
from .register_coco import register_coco_instances, register_coco_panoptic_separated
# ==== Predefined datasets and splits for COCO ==========
@ -220,6 +221,24 @@ def register_all_pascal_voc(root):
MetadataCatalog.get(name).evaluator_type = "pascal_voc"
def register_all_voc_style_coco(root):
SPLITS = [
("t2_train", "coco17_voc_style"),
("t2_test", "coco17_voc_style"),
("t2_test_unk", "coco17_voc_style"),
("t3_train", "coco17_voc_style"),
("t3_test", "coco17_voc_style"),
("t3_test_unk", "coco17_voc_style"),
("t4_train", "coco17_voc_style"),
("t4_test", "coco17_voc_style"),
("t4_test_unk", "coco17_voc_style"),
]
for name, dirname in SPLITS:
year = 2007
register_voc_style_coco(name, os.path.join(root, dirname), name, year)
MetadataCatalog.get(name).evaluator_type = "pascal_voc"
def register_all_ade20k(root):
root = os.path.join(root, "ADEChallengeData2016")
for name, dirname in [("train", "training"), ("val", "validation")]:
@ -248,5 +267,6 @@ if __name__.endswith(".builtin"):
register_all_cityscapes_panoptic(_root)
# register_all_pascal_voc(_root)
# register_all_pascal_voc('/home/joseph/workspace/OWOD/datasets')
register_all_pascal_voc('/home/fk1/workspace/OWOD/datasets')
register_all_pascal_voc('/home/joseph/workspace/OWOD/datasets')
register_all_voc_style_coco('/home/joseph/workspace/OWOD/datasets')
register_all_ade20k(_root)

View File

@ -26,37 +26,6 @@ CLASS_NAMES = (
# )
# fmt: on
VOC_CLASS_NAMES = (
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
)
T2_CLASS_NAMES = (
"truck", "trafficlight", "firehydrant", "stopsigh", "parkingmeter",
"bench", "elephant", "bear", "zebra", "giraffe",
"backpack", "umbrella", "handbag", "tie", "suitcase",
"microwave", "oven", "toaster", "sink", "refrigerator"
)
T3_CLASS_NAMES = (
"frisbee", "skis", "snowboard", "sportsball", "kite",
"baseballbat", "baseballglove", "skateboard", "surfboard", "tennisracket",
"banana", "apple", "sandwich", "orange", "broccoli",
"carrot", "hotdog", "pizza", "donut", "cake"
)
T4_CLASS_NAMES = (
"bed", "toilet", "tv", "laptop", "mouse",
"remote", "keyboard", "cellphone", "book", "clock",
"vase", "scissors", "teddybear", "hairdrier", "toothbrush",
"bottle", "wineglass", "cup", "fork", "knife", "spoon", "bowl"
)
UNK_CLASS = ("unknown")
def load_voc_instances(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]):
"""
Load Pascal VOC detection annotations to Detectron2 format.

View File

@ -0,0 +1,122 @@
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import itertools
import numpy as np
import os
import xml.etree.ElementTree as ET
from typing import List, Tuple, Union
from fvcore.common.file_io import PathManager
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.structures import BoxMode
__all__ = ["load_voc_coco_instances", "register_voc_style_coco"]
VOC_CLASS_NAMES = [
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
]
T2_CLASS_NAMES = [
"truck", "traffic light", "fire hydrant", "stop sign", "parking meter",
"bench", "elephant", "bear", "zebra", "giraffe",
"backpack", "umbrella", "handbag", "tie", "suitcase",
"microwave", "oven", "toaster", "sink", "refrigerator"
]
T3_CLASS_NAMES = [
"frisbee", "skis", "snowboard", "sports ball", "kite",
"baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
"banana", "apple", "sandwich", "orange", "broccoli",
"carrot", "hot dog", "pizza", "donut", "cake"
]
T4_CLASS_NAMES = [
"bed", "toilet", "laptop", "mouse",
"remote", "keyboard", "cell phone", "book", "clock",
"vase", "scissors", "teddy bear", "hair drier", "toothbrush",
"wine glass", "cup", "fork", "knife", "spoon", "bowl"
]
UNK_CLASS = ["unknown"]
INCR_CLASS_NAMES = itertools.chain(VOC_CLASS_NAMES, T2_CLASS_NAMES, T3_CLASS_NAMES, T4_CLASS_NAMES, UNK_CLASS)
INCR_CLASS_NAMES = tuple(INCR_CLASS_NAMES)
def load_voc_coco_instances(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]):
"""
Load Pascal VOC detection annotations to Detectron2 format.
Args:
dirname: Contain "Annotations", "ImageSets", "JPEGImages"
split (str): one of "train", "test", "val", "trainval": t1_train, t1_test
class_names: list or tuple of class names
"""
with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f:
fileids = np.loadtxt(f, dtype=np.str)
known_class_list = None
if 't2' in split:
known_class_list = T2_CLASS_NAMES
elif 't3' in split:
known_class_list = T3_CLASS_NAMES
elif 't4' in split:
known_class_list = T4_CLASS_NAMES
# Needs to read many small annotation files. Makes sense at local
annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/"))
dicts = []
for fileid in fileids:
anno_file = os.path.join(annotation_dirname, fileid + ".xml")
jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg")
with PathManager.open(anno_file) as f:
tree = ET.parse(f)
r = {
"file_name": jpeg_file,
"image_id": fileid,
"height": int(tree.findall("./size/height")[0].text),
"width": int(tree.findall("./size/width")[0].text),
}
instances = []
for obj in tree.findall("object"):
cls_name = obj.find("name").text
if cls_name not in known_class_list:
continue
if 'unk' in split:
cls = "unknown"
else:
cls = cls_name
# We include "difficult" samples in training.
# Based on limited experiments, they don't hurt accuracy.
# difficult = int(obj.find("difficult").text)
# if difficult == 1:
# continue
bbox = obj.find("bndbox")
bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]]
# Original annotations are integers in the range [1, W or H]
# Assuming they mean 1-based pixel indices (inclusive),
# a box with annotation (xmin=1, xmax=W) covers the whole image.
# In coordinate space this is represented by (xmin=0, xmax=W)
bbox[0] -= 1.0
bbox[1] -= 1.0
instances.append(
{"category_id": class_names.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS}
)
r["annotations"] = instances
dicts.append(r)
return dicts
def register_voc_style_coco(name, dirname, split, year, class_names=INCR_CLASS_NAMES):
DatasetCatalog.register(name, lambda: load_voc_coco_instances(dirname, split, class_names))
MetadataCatalog.get(name).set(
thing_classes=list(class_names), dirname=dirname, year=year, split=split
)

View File

@ -28,7 +28,7 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
official API.
"""
def __init__(self, dataset_name):
def __init__(self, dataset_name, cfg=None):
"""
Args:
dataset_name (str): name of the dataset, e.g., "voc_2007_test"
@ -42,6 +42,10 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
self._is_2007 = meta.year == 2007
self._cpu_device = torch.device("cpu")
self._logger = logging.getLogger(__name__)
if cfg is not None:
self.prev_intro_cls = cfg.OWOD.PREV_INTRODUCED_CLS
self.curr_intro_cls = cfg.OWOD.CUR_INTRODUCED_CLS
self.total_num_class = cfg.MODEL.ROI_HEADS.NUM_CLASSES
def reset(self):
self._predictions = defaultdict(list) # class name -> list of prediction strings
@ -114,8 +118,9 @@ class PascalVOCDetectionEvaluator(DatasetEvaluator):
ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]}
# Extra logging of class-wise APs
avg_precs = list(np.mean([x for _, x in aps.items()], axis=0))
self._logger.info(self._class_names)
self._logger.info("AP__: " + str(['%.1f' % x for x in list(np.mean([x for _, x in aps.items()], axis=0))]))
self._logger.info("AP__: " + str(['%.1f' % x for x in avg_precs]))
self._logger.info("AP50: " + str(['%.1f' % x for x in aps[50]]))
self._logger.info("AP75: " + str(['%.1f' % x for x in aps[75]]))
@ -152,8 +157,8 @@ def parse_rec(filename):
for obj in tree.findall("object"):
obj_struct = {}
obj_struct["name"] = obj.find("name").text
obj_struct["pose"] = obj.find("pose").text
obj_struct["truncated"] = int(obj.find("truncated").text)
# obj_struct["pose"] = obj.find("pose").text
# obj_struct["truncated"] = int(obj.find("truncated").text)
obj_struct["difficult"] = int(obj.find("difficult").text)
bbox = obj.find("bndbox")
obj_struct["bbox"] = [

View File

@ -144,6 +144,7 @@ class FastRCNNOutputs:
pred_class_logits,
pred_proposal_deltas,
proposals,
invalid_class_range,
smooth_l1_beta=0.0,
box_reg_loss_type="smooth_l1",
):
@ -178,6 +179,7 @@ class FastRCNNOutputs:
self.box_reg_loss_type = box_reg_loss_type
self.image_shapes = [x.image_size for x in proposals]
self.invalid_class_range = invalid_class_range
if len(proposals):
box_type = type(proposals[0].proposal_boxes)
@ -231,6 +233,7 @@ class FastRCNNOutputs:
return 0.0 * self.pred_class_logits.sum()
else:
self._log_accuracy()
self.pred_class_logits[:, self.invalid_class_range] = -10e10
return F.cross_entropy(self.pred_class_logits, self.gt_classes, reduction="mean")
def box_reg_loss(self):
@ -396,6 +399,8 @@ class FastRCNNOutputLayers(nn.Module):
clustering_momentum,
clustering_z_dimension,
enable_clustering,
prev_intro_cls,
curr_intro_cls,
num_classes: int,
test_score_thresh: float = 0.0,
test_nms_thresh: float = 0.5,
@ -460,6 +465,12 @@ class FastRCNNOutputLayers(nn.Module):
self.hingeloss = nn.HingeEmbeddingLoss(2)
self.enable_clustering = enable_clustering
self.prev_intro_cls = prev_intro_cls
self.curr_intro_cls = curr_intro_cls
self.seen_classes = self.prev_intro_cls + self.curr_intro_cls
self.invalid_class_range = list(range(self.seen_classes, self.num_classes-1))
logging.getLogger(__name__).info("Invalid class range: " + str(self.invalid_class_range))
# self.ae_model = AE(input_size, clustering_z_dimension)
# self.ae_model.apply(Xavier)
@ -482,7 +493,9 @@ class FastRCNNOutputLayers(nn.Module):
"clustering_update_mu_iter" : cfg.OWOD.CLUSTERING.UPDATE_MU_ITER,
"clustering_momentum" : cfg.OWOD.CLUSTERING.MOMENTUM,
"clustering_z_dimension": cfg.OWOD.CLUSTERING.Z_DIMENSION,
"enable_clustering" : cfg.OWOD.ENABLE_CLUSTERING
"enable_clustering" : cfg.OWOD.ENABLE_CLUSTERING,
"prev_intro_cls" : cfg.OWOD.PREV_INTRODUCED_CLS,
"curr_intro_cls" : cfg.OWOD.CUR_INTRODUCED_CLS
# fmt: on
}
@ -687,6 +700,7 @@ class FastRCNNOutputLayers(nn.Module):
scores,
proposal_deltas,
proposals,
self.invalid_class_range,
self.smooth_l1_beta,
self.box_reg_loss_type,
).losses()

View File

@ -85,7 +85,7 @@ class Trainer(DefaultTrainer):
), "CityscapesEvaluator currently do not work with multiple machines."
return CityscapesSemSegEvaluator(dataset_name)
elif evaluator_type == "pascal_voc":
return PascalVOCDetectionEvaluator(dataset_name)
return PascalVOCDetectionEvaluator(dataset_name, cfg)
elif evaluator_type == "lvis":
return LVISEvaluator(dataset_name, cfg, True, output_folder)
if len(evaluator_list) == 0: