fast-reid/projects/Shoe/visual_augmenter.py
2021-11-22 17:06:34 +08:00

51 lines
1.2 KiB
Python

# coding: utf-8
import os
from collections import defaultdict
import sys
import shutil
sys.path.append('')
from fastreid.utils.env import seed_all_rng
from fastreid.data.datasets import DATASET_REGISTRY
import projects.Shoe.shoe.data
seed_all_rng(0)
save_root = 'debug/neg_aug'
if os.path.exists(save_root):
shutil.rmtree(save_root)
os.mkdir(save_root)
root = '/data97/bijia/shoe/'
img_root=os.path.join(root, 'shoe_crop_all_images')
anno_path=os.path.join(root, 'labels/1102/train_1102.json')
dataset = DATASET_REGISTRY.get('PairDataset')(img_root=img_root, anno_path=anno_path, transform=None, mode='train')
pos_imgs = []
neg_imgs = []
for i in range(100):
data = dataset[100]
img1 = data['img1']
img2 = data['img2']
target = data['target']
if target == 0:
pos_imgs.append(img1)
neg_imgs.append(img2)
else:
pos_imgs.append(img1)
pos_imgs.append(img2)
pos_dict = defaultdict(list)
for img in pos_imgs:
pos_dict[img.size].append(img)
for i, k in enumerate(pos_dict.keys()):
img = pos_dict[k][0]
img.save(os.path.join(save_root, 'p-' + str(i) + '.jpg'))
for i, img in enumerate(neg_imgs):
img.save(os.path.join(save_root, 'n-' + str(i) + '.jpg'))