[Fix] Support training with data without metainfo. (#417)

* support training with data without metainfo

* clean the code

* clean the code
This commit is contained in:
Mashiro 2022-08-11 14:51:11 +08:00 committed by GitHub
parent c287e1fb92
commit ee56f151f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 6 deletions

View File

@ -34,8 +34,9 @@ class RuntimeInfoHook(Hook):
runner.message_hub.update_info('iter', runner.iter)
runner.message_hub.update_info('max_epochs', runner.max_epochs)
runner.message_hub.update_info('max_iters', runner.max_iters)
runner.message_hub.update_info(
'dataset_meta', runner.train_dataloader.dataset.metainfo)
if hasattr(runner.train_dataloader.dataset, 'dataset_meta'):
runner.message_hub.update_info(
'dataset_meta', runner.train_dataloader.dataset.metainfo)
def before_train_epoch(self, runner) -> None:
"""Update current epoch information before every epoch."""

View File

@ -1915,9 +1915,9 @@ class Runner:
self._randomness_cfg.update(seed=resumed_seed)
self.set_randomness(**self._randomness_cfg)
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
if (dataset_meta is not None
and dataset_meta != self.train_dataloader.dataset.metainfo):
resumed_dataset_meta = checkpoint['meta'].get('dataset_meta', None)
dataset_meta = getattr(self.train_dataloader.dataset, 'metainfo', None)
if resumed_dataset_meta != dataset_meta:
warnings.warn(
'The dataset metainfo from the resumed checkpoint is '
'different from the current training dataset, please '
@ -2045,12 +2045,14 @@ class Runner:
meta.update(
cfg=self.cfg.pretty_text,
dataset_meta=self.train_dataloader.dataset.metainfo,
seed=self.seed,
experiment_name=self.experiment_name,
time=time.strftime('%Y%m%d_%H%M%S', time.localtime()),
mmengine_version=mmengine.__version__ + get_git_hash())
if hasattr(self.train_dataloader.dataset, 'metainfo'):
meta.update(dataset_meta=self.train_dataloader.dataset.metainfo)
if is_model_wrapper(self.model):
model = self.model.module
else:

View File

@ -173,6 +173,18 @@ class ToyDataset(Dataset):
return dict(inputs=self.data[index], data_sample=self.label[index])
@DATASETS.register_module()
class ToyDatasetNoMeta(Dataset):
data = torch.randn(12, 2)
label = torch.ones(12)
def __len__(self):
return self.data.size(0)
def __getitem__(self, index):
return dict(inputs=self.data[index], data_sample=self.label[index])
@METRICS.register_module()
class ToyMetric1(BaseMetric):
@ -1526,6 +1538,12 @@ class TestRunner(TestCase):
runner = runner.from_cfg(cfg)
runner.train()
# 9 Test training with a dataset without metainfo
cfg = copy.deepcopy(cfg)
cfg.train_dataloader.dataset = dict(type='ToyDatasetNoMeta')
runner = runner.from_cfg(cfg)
runner.train()
def test_val(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_val1'