[Fix]: fix args/cfg bug in extract.py (#357)
parent
b9647eb72c
commit
9c28733ddd
|
@ -149,13 +149,13 @@ def main():
|
||||||
# run
|
# run
|
||||||
outputs = extractor.extract(model, data_loader, distributed=distributed)
|
outputs = extractor.extract(model, data_loader, distributed=distributed)
|
||||||
rank, _ = get_dist_info()
|
rank, _ = get_dist_info()
|
||||||
mmcv.mkdir_or_exist(f'{args.work_dir}/features/')
|
mmcv.mkdir_or_exist(f'{cfg.work_dir}/features/')
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
for key, val in outputs.items():
|
for key, val in outputs.items():
|
||||||
split_num = len(dataset_cfg.split_name)
|
split_num = len(dataset_cfg.split_name)
|
||||||
split_at = dataset_cfg.split_at
|
split_at = dataset_cfg.split_at
|
||||||
for ss in range(split_num):
|
for ss in range(split_num):
|
||||||
output_file = f'{args.work_dir}/features/' \
|
output_file = f'{cfg.work_dir}/features/' \
|
||||||
f'{dataset_cfg.split_name[ss]}_{key}.npy'
|
f'{dataset_cfg.split_name[ss]}_{key}.npy'
|
||||||
if ss == 0:
|
if ss == 0:
|
||||||
np.save(output_file, val[:split_at[0]])
|
np.save(output_file, val[:split_at[0]])
|
||||||
|
|
Loading…
Reference in New Issue