mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Support training with data without metainfo
. (#417)
* support training with data without metainfo * clean the code * clean the code
This commit is contained in:
parent
c287e1fb92
commit
ee56f151f6
@ -34,8 +34,9 @@ class RuntimeInfoHook(Hook):
|
||||
runner.message_hub.update_info('iter', runner.iter)
|
||||
runner.message_hub.update_info('max_epochs', runner.max_epochs)
|
||||
runner.message_hub.update_info('max_iters', runner.max_iters)
|
||||
runner.message_hub.update_info(
|
||||
'dataset_meta', runner.train_dataloader.dataset.metainfo)
|
||||
if hasattr(runner.train_dataloader.dataset, 'dataset_meta'):
|
||||
runner.message_hub.update_info(
|
||||
'dataset_meta', runner.train_dataloader.dataset.metainfo)
|
||||
|
||||
def before_train_epoch(self, runner) -> None:
|
||||
"""Update current epoch information before every epoch."""
|
||||
|
@ -1915,9 +1915,9 @@ class Runner:
|
||||
self._randomness_cfg.update(seed=resumed_seed)
|
||||
self.set_randomness(**self._randomness_cfg)
|
||||
|
||||
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
|
||||
if (dataset_meta is not None
|
||||
and dataset_meta != self.train_dataloader.dataset.metainfo):
|
||||
resumed_dataset_meta = checkpoint['meta'].get('dataset_meta', None)
|
||||
dataset_meta = getattr(self.train_dataloader.dataset, 'metainfo', None)
|
||||
if resumed_dataset_meta != dataset_meta:
|
||||
warnings.warn(
|
||||
'The dataset metainfo from the resumed checkpoint is '
|
||||
'different from the current training dataset, please '
|
||||
@ -2045,12 +2045,14 @@ class Runner:
|
||||
|
||||
meta.update(
|
||||
cfg=self.cfg.pretty_text,
|
||||
dataset_meta=self.train_dataloader.dataset.metainfo,
|
||||
seed=self.seed,
|
||||
experiment_name=self.experiment_name,
|
||||
time=time.strftime('%Y%m%d_%H%M%S', time.localtime()),
|
||||
mmengine_version=mmengine.__version__ + get_git_hash())
|
||||
|
||||
if hasattr(self.train_dataloader.dataset, 'metainfo'):
|
||||
meta.update(dataset_meta=self.train_dataloader.dataset.metainfo)
|
||||
|
||||
if is_model_wrapper(self.model):
|
||||
model = self.model.module
|
||||
else:
|
||||
|
@ -173,6 +173,18 @@ class ToyDataset(Dataset):
|
||||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class ToyDatasetNoMeta(Dataset):
|
||||
data = torch.randn(12, 2)
|
||||
label = torch.ones(12)
|
||||
|
||||
def __len__(self):
|
||||
return self.data.size(0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return dict(inputs=self.data[index], data_sample=self.label[index])
|
||||
|
||||
|
||||
@METRICS.register_module()
|
||||
class ToyMetric1(BaseMetric):
|
||||
|
||||
@ -1526,6 +1538,12 @@ class TestRunner(TestCase):
|
||||
runner = runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
|
||||
# 9 Test training with a dataset without metainfo
|
||||
cfg = copy.deepcopy(cfg)
|
||||
cfg.train_dataloader.dataset = dict(type='ToyDatasetNoMeta')
|
||||
runner = runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
|
||||
def test_val(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_val1'
|
||||
|
Loading…
x
Reference in New Issue
Block a user