[Fix] Fix dependence and key bug (#611)

* fix dependence bug

* fix
pull/616/head
Yixiao Fang 2022-12-06 18:58:54 +08:00 committed by GitHub
parent 7a7b048f23
commit 8b56a78f93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 11 deletions

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional
import mmcv
import mmengine
import numpy as np
import torch
from torch.utils.data import DataLoader
@ -26,7 +26,7 @@ def nondist_forward_collect(func: object, data_loader: DataLoader,
results_all (Dict(np.ndarray)): The concatenated outputs.
"""
results = []
prog_bar = mmcv.ProgressBar(len(data_loader))
prog_bar = mmengine.ProgressBar(len(data_loader))
for _, data in enumerate(data_loader):
with torch.no_grad():
result = func(data) # output: feat_dict
@ -65,7 +65,7 @@ def dist_forward_collect(func: object,
"""
results = []
if rank == 0:
prog_bar = mmcv.ProgressBar(len(data_loader))
prog_bar = mmengine.ProgressBar(len(data_loader))
for _, data in enumerate(data_loader):
with torch.no_grad():
result = func(data) # dict{key: tensor}

View File

@ -1,4 +1,5 @@
attrs
einops
future
matplotlib
mmcls>=1.0.0rc0

View File

@ -2,9 +2,9 @@
import argparse
import os.path as osp
import mmcv
import mmengine
import numpy as np
from mmcv import Config, DictAction
from mmengine import Config, DictAction
from mmselfsup.datasets.builder import build_dataset
from mmselfsup.registry import VISUALIZERS
@ -53,25 +53,25 @@ def main():
visualizer = VISUALIZERS.build(cfg.visualizer)
visualizer.dataset_meta = dataset.METAINFO
progress_bar = mmcv.ProgressBar(len(dataset))
progress_bar = mmengine.ProgressBar(len(dataset))
for item in dataset:
if 'pseudo_label' in item['data_sample']:
if 'pseudo_label' in item['data_samples']:
# for rotation_pred
if 'rot_label' in item['data_sample'].pseudo_label:
if 'rot_label' in item['data_samples'].pseudo_label:
img = np.concatenate(item['inputs'], axis=-1)
img = np.transpose(img, (1, 2, 0))
# for relative_loc
else:
img = item['inputs'][0].permute(1, 2, 0).numpy()
# for contrastive learning
elif len(item['inputs']) == 2 and 'mask' not in item['data_sample']:
elif len(item['inputs']) == 2 and 'mask' not in item['data_samples']:
img = np.concatenate(item['inputs'], axis=-1)
img = np.transpose(img, (1, 2, 0))
# for mask image modeling
else:
img = item['inputs'][0].permute(1, 2, 0).numpy()
data_sample = item['data_sample']
img_path = osp.basename(item['data_sample'].img_path)
data_sample = item['data_samples']
img_path = osp.basename(item['data_samples'].img_path)
out_file = osp.join(
args.output_dir,