[Refactor] Refactor some of the functionality of `dataset_analysis` (#294)

* Refactor some of the functionality of data_analysis

* Remove redundant code in dataset_analysis.py

* Update tools/analysis_tools/dataset_analysis.py

Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com>

* simplified code

* add docstring and simplify code

Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com>
pull/367/head
kitecats 2022-11-22 11:21:35 +08:00 committed by GitHub
parent 47bb3ce408
commit 748f151886
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 25 additions and 11 deletions

View File

@ -7,7 +7,6 @@ import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
from mmengine.config import Config
from mmengine.dataset.dataset_wrapper import ConcatDataset
from mmengine.utils import ProgressBar
from prettytable import PrettyTable
@ -378,21 +377,36 @@ def main():
# register all modules in mmdet into the registries
register_all_modules()
def replace_pipeline_to_none(cfg):
"""Recursively iterate over all dataset(or datasets) and set their
pipelines to none.Datasets are mean ConcatDataset.
Recursively terminates only when all dataset(or datasets) have been
traversed
"""
if cfg.get('dataset', None) is None and cfg.get('datasets',
None) is None:
return
dataset = cfg.dataset if cfg.get('dataset', None) else cfg.datasets
if isinstance(dataset, list):
for item in dataset:
item.pipeline = None
elif dataset.get('pipeline', None):
dataset.pipeline = None
else:
replace_pipeline_to_none(dataset)
# 1.Build Dataset
if args.val_dataset is False:
replace_pipeline_to_none(cfg.train_dataloader)
dataset = DATASETS.build(cfg.train_dataloader.dataset)
elif args.val_dataset is True:
else:
replace_pipeline_to_none(cfg.val_dataloader)
dataset = DATASETS.build(cfg.val_dataloader.dataset)
# Determine whether the dataset is ConcatDataset
if isinstance(dataset, ConcatDataset):
datasets = dataset.datasets
data_list = []
for idx in range(len(datasets)):
datasets_list = datasets[idx].load_data_list()
data_list += datasets_list
else:
data_list = dataset.load_data_list()
# Build lists to store data for all raw data
data_list = dataset
# 2.Prepare data
# Drawing settings