parent
7a7b048f23
commit
8b56a78f93
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -26,7 +26,7 @@ def nondist_forward_collect(func: object, data_loader: DataLoader,
|
|||
results_all (Dict(np.ndarray)): The concatenated outputs.
|
||||
"""
|
||||
results = []
|
||||
prog_bar = mmcv.ProgressBar(len(data_loader))
|
||||
prog_bar = mmengine.ProgressBar(len(data_loader))
|
||||
for _, data in enumerate(data_loader):
|
||||
with torch.no_grad():
|
||||
result = func(data) # output: feat_dict
|
||||
|
@ -65,7 +65,7 @@ def dist_forward_collect(func: object,
|
|||
"""
|
||||
results = []
|
||||
if rank == 0:
|
||||
prog_bar = mmcv.ProgressBar(len(data_loader))
|
||||
prog_bar = mmengine.ProgressBar(len(data_loader))
|
||||
for _, data in enumerate(data_loader):
|
||||
with torch.no_grad():
|
||||
result = func(data) # dict{key: tensor}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
attrs
|
||||
einops
|
||||
future
|
||||
matplotlib
|
||||
mmcls>=1.0.0rc0
|
||||
|
|
|
@ -2,9 +2,9 @@
|
|||
import argparse
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
from mmcv import Config, DictAction
|
||||
from mmengine import Config, DictAction
|
||||
|
||||
from mmselfsup.datasets.builder import build_dataset
|
||||
from mmselfsup.registry import VISUALIZERS
|
||||
|
@ -53,25 +53,25 @@ def main():
|
|||
visualizer = VISUALIZERS.build(cfg.visualizer)
|
||||
visualizer.dataset_meta = dataset.METAINFO
|
||||
|
||||
progress_bar = mmcv.ProgressBar(len(dataset))
|
||||
progress_bar = mmengine.ProgressBar(len(dataset))
|
||||
for item in dataset:
|
||||
if 'pseudo_label' in item['data_sample']:
|
||||
if 'pseudo_label' in item['data_samples']:
|
||||
# for rotation_pred
|
||||
if 'rot_label' in item['data_sample'].pseudo_label:
|
||||
if 'rot_label' in item['data_samples'].pseudo_label:
|
||||
img = np.concatenate(item['inputs'], axis=-1)
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
# for relative_loc
|
||||
else:
|
||||
img = item['inputs'][0].permute(1, 2, 0).numpy()
|
||||
# for contrastive learning
|
||||
elif len(item['inputs']) == 2 and 'mask' not in item['data_sample']:
|
||||
elif len(item['inputs']) == 2 and 'mask' not in item['data_samples']:
|
||||
img = np.concatenate(item['inputs'], axis=-1)
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
# for mask image modeling
|
||||
else:
|
||||
img = item['inputs'][0].permute(1, 2, 0).numpy()
|
||||
data_sample = item['data_sample']
|
||||
img_path = osp.basename(item['data_sample'].img_path)
|
||||
data_sample = item['data_samples']
|
||||
img_path = osp.basename(item['data_samples'].img_path)
|
||||
|
||||
out_file = osp.join(
|
||||
args.output_dir,
|
||||
|
|
Loading…
Reference in New Issue