Implement safe intersection training

pull/462/head
Joshua Newman 2021-02-08 21:16:18 -08:00
parent a54e520ee2
commit fa2526a894
5 changed files with 287 additions and 17 deletions

View File

@ -0,0 +1,65 @@
import json
import logging
import numpy as np
from argparse import ArgumentParser
from tqdm import tqdm
from torchreid.utils import FeatureExtractor
from torchreid.utils.tools import read_jsonl
logging.basicConfig(level=logging.INFO)
LOGGER = logging.getLogger(__name__)
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return super(NumpyEncoder, self).default(obj)
def main(in_fp: str, model_fp: str, out_fp: str, device: str, max_objects: int):
extractor = FeatureExtractor(
model_name="resnet50",
model_path=model_fp,
device=device
)
manifest_entries = read_jsonl(in_fp)
with open(out_fp, "w") as f:
for ix, manifest_entry in tqdm(enumerate(manifest_entries), desc="objects"):
if max_objects and ix > max_objects:
continue
path = manifest_entry["path"]
features = extractor([path])
features = features[0].cpu().detach().numpy()
manifest_entry["features"] = features
f.write(json.dumps(manifest_entry, cls=NumpyEncoder) + "\n")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--manifest_fp", required=True, type=str)
parser.add_argument("--model_fp", required=True, type=str)
parser.add_argument("--out_fp", required=True, type=str)
parser.add_argument("--max_objects", required=False, type=int, default=None)
parser.add_argument("--device", required=False, type=str, default="cpu")
args = parser.parse_args()
main(
in_fp=args.manifest_fp,
model_fp=args.model_fp,
out_fp=args.out_fp,
device=args.device,
max_objects=args.max_objects
)

99
scripts/test.py 100644
View File

@ -0,0 +1,99 @@
from argparse import ArgumentParser
import torchreid
import json
import logging
def read_jsonl(in_file_path):
with open(in_file_path) as f:
data = map(json.loads, f)
for datum in data:
yield datum
logging.basicConfig(level=logging.DEBUG)
LOGGER = logging.getLogger(__name__)
def main(data_dir, save_dir):
"""
datamanager = torchreid.data.ImageDataManager(
root='reid-data',
sources='market1501',
targets='market1501',
height=256,
width=128,
batch_size_train=32,
batch_size_test=100,
transforms=['random_flip', 'random_crop']
)
"""
datamanager = torchreid.data.ImageDataManager(
root=data_dir,
sources=['safex_carla_simulation'],
height=256,
width=128,
batch_size_train=32,
batch_size_test=100,
transforms=['random_flip'],
split_id=1
)
model = torchreid.models.build_model(
name='resnet50',
num_classes=datamanager.num_train_pids,
loss='triplet',
pretrained=True
)
model = model.cpu()
optimizer = torchreid.optim.build_optimizer(
model,
optim='adam',
lr=0.0003
)
scheduler = torchreid.optim.build_lr_scheduler(
optimizer,
lr_scheduler='single_step',
stepsize=20
)
engine = torchreid.engine.ImageTripletEngine(
datamanager,
model,
optimizer=optimizer,
scheduler=scheduler,
label_smooth=True
)
start_epoch = torchreid.utils.resume_from_checkpoint(
'/home/ubuntu/deep-person-reid/reid_out/resnet/model/model.pth.tar-15',
model,
optimizer
)
engine.run(
start_epoch=start_epoch,
save_dir=save_dir,
max_epoch=60,
eval_freq=1,
print_freq=2,
test_only=True,
visrank=True
)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--data_dir", default="./reid_out/", required=False)
parser.add_argument("--save_dir", default="./reid_out/", required=False)
args = parser.parse_args()
main(
data_dir=args.data_dir,
save_dir=args.save_dir
)

View File

@ -33,31 +33,33 @@ def main(data_dir, save_dir):
datamanager = torchreid.data.ImageDataManager(
root=data_dir,
sources=['safex_carla_simulation'],
height=50,
width=50,
height=256,
width=128,
split_id=0,
batch_size_train=32,
batch_size_test=100,
transforms=['random_flip', 'color_jitter']
transforms=['random_flip'],
)
model = torchreid.models.build_model(
name='osnet_ain_x1_0',
num_classes=datamanager.num_train_pids,
loss='triplet',
pretrained=False
pretrained=True
)
model = model.cpu()
optimizer = torchreid.optim.build_optimizer(
model,
optim='adam',
lr=0.0003
optim='amsgrad',
lr=0.0015
)
scheduler = torchreid.optim.build_lr_scheduler(
optimizer,
lr_scheduler='single_step',
lr_scheduler='cosine',
stepsize=20
)
@ -68,12 +70,13 @@ def main(data_dir, save_dir):
scheduler=scheduler,
label_smooth=True
)
engine.run(
save_dir=save_dir,
max_epoch=60,
eval_freq=10,
print_freq=10,
max_epoch=10,
fixbase_epoch=2,
dist_metric="cosine",
eval_freq=1,
print_freq=2,
test_only=False
)
@ -81,7 +84,7 @@ def main(data_dir, save_dir):
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--data_dir", default="./reid_out/", required=False)
parser.add_argument("--save_dir", default="./reid_out/", required=False)
parser.add_argument("--save_dir", default="./reid_out/osnet", required=False)
args = parser.parse_args()
main(

View File

@ -0,0 +1,92 @@
from argparse import ArgumentParser
import torchreid
import json
import logging
def read_jsonl(in_file_path):
with open(in_file_path) as f:
data = map(json.loads, f)
for datum in data:
yield datum
logging.basicConfig(level=logging.DEBUG)
LOGGER = logging.getLogger(__name__)
def main(data_dir, save_dir):
"""
datamanager = torchreid.data.ImageDataManager(
root='reid-data',
sources='market1501',
targets='market1501',
height=256,
width=128,
batch_size_train=32,
batch_size_test=100,
transforms=['random_flip', 'random_crop']
)
"""
datamanager = torchreid.data.ImageDataManager(
root=data_dir,
sources=['safex_carla_simulation'],
height=256,
width=128,
batch_size_train=32,
batch_size_test=100,
transforms=['random_flip'],
)
model = torchreid.models.build_model(
name='resnet50',
num_classes=datamanager.num_train_pids,
loss='triplet',
pretrained=True
)
model = model.cpu()
optimizer = torchreid.optim.build_optimizer(
model,
optim='amsgrad',
lr=0.0015
)
scheduler = torchreid.optim.build_lr_scheduler(
optimizer,
lr_scheduler='cosine',
stepsize=20
)
engine = torchreid.engine.ImageTripletEngine(
datamanager,
model,
optimizer=optimizer,
scheduler=scheduler,
label_smooth=True
)
engine.run(
save_dir=save_dir,
max_epoch=10,
fixbase_epoch=2,
dist_metric="cosine",
eval_freq=1,
print_freq=2,
test_only=False
)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--data_dir", default="./reid_out/", required=False)
parser.add_argument("--save_dir", default="./reid_out/resnet/", required=False)
args = parser.parse_args()
main(
data_dir=args.data_dir,
save_dir=args.save_dir
)

View File

@ -2,7 +2,7 @@ import copy
import os
import random
from collections import defaultdict
import shutil
from ..dataset import ImageDataset
import logging
@ -93,7 +93,6 @@ class SafeXCARLASimulation(ImageDataset):
splits = []
for split_ix in range(10):
manifest_entries = read_jsonl(manifest_fp)
# randomly choose num_train_pids train IDs and the rest for test IDs
pids_copy = copy.deepcopy(pids)
random.shuffle(pids_copy)
@ -106,25 +105,37 @@ class SafeXCARLASimulation(ImageDataset):
gallery = []
# for train IDs, all images are used in the train set.
seen_objects_counter = defaultdict(int)
for pid in train_pids:
entries = object_manifest_entries_mapping[pid]
random.shuffle(entries)
for entry in entries:
path, object_guid, camera_guid = entry[0], entry[1], entry[2]
query_key = "_".join([str(object_guid), str(camera_guid)])
seen_objects_counter[query_key] += 1
if seen_objects_counter[query_key] > 100:
continue
guid_label = train_pid2label[object_guid]
entry = (path, guid_label, camera_guid)
train.append(entry)
#shutil.copyfile(path, os.path.join(self.train_dir, str(object_guid) + "_" + str(camera_guid) + ".jpg"))
# for each test ID, randomly choose two images, one for
# query and the other one for gallery.
for pid in test_pids:
entries = object_manifest_entries_mapping[pid]
samples = random.sample(entries, 2)
samples = random.sample(entries, min(20, len(entries)))
query_sample = samples[0]
gallery_sample = samples[1]
gallery_samples = samples[1:]
query.append(query_sample)
gallery.append(gallery_sample)
gallery.extend(gallery_samples)
#shutil.copyfile(query_sample[0], os.path.join(self.query_dir, str(query_sample[1]) + "_" + str(query_sample[2]) + ".jpg"))
for gallery_sample in gallery_samples:
pass
#shutil.copyfile(gallery_sample[0], os.path.join(self.gallery_dir, str(gallery_sample[1]) + "_" + str(gallery_sample[2]) + ".jpg"))
split = {'train': train, 'query': query, 'gallery': gallery}
splits.append(split)