From fa2526a894efd0a551db20cee1e10c546ea9e9b3 Mon Sep 17 00:00:00 2001 From: Joshua Newman Date: Mon, 8 Feb 2021 21:16:18 -0800 Subject: [PATCH] Implement safe intersection training --- scripts/extract_features.py | 65 ++++++++++++ scripts/test.py | 99 +++++++++++++++++++ scripts/train.py | 27 ++--- scripts/train_resnet.py | 92 +++++++++++++++++ .../data/datasets/image/safex_simulation.py | 21 +++- 5 files changed, 287 insertions(+), 17 deletions(-) create mode 100644 scripts/extract_features.py create mode 100644 scripts/test.py create mode 100644 scripts/train_resnet.py diff --git a/scripts/extract_features.py b/scripts/extract_features.py new file mode 100644 index 0000000..80c78bf --- /dev/null +++ b/scripts/extract_features.py @@ -0,0 +1,65 @@ +import json +import logging +import numpy as np + +from argparse import ArgumentParser + +from tqdm import tqdm + +from torchreid.utils import FeatureExtractor + +from torchreid.utils.tools import read_jsonl + +logging.basicConfig(level=logging.INFO) +LOGGER = logging.getLogger(__name__) + + +class NumpyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + else: + return super(NumpyEncoder, self).default(obj) + + +def main(in_fp: str, model_fp: str, out_fp: str, device: str, max_objects: int): + extractor = FeatureExtractor( + model_name="resnet50", + model_path=model_fp, + device=device + ) + + manifest_entries = read_jsonl(in_fp) + with open(out_fp, "w") as f: + for ix, manifest_entry in tqdm(enumerate(manifest_entries), desc="objects"): + + if max_objects and ix > max_objects: + continue + + path = manifest_entry["path"] + features = extractor([path]) + features = features[0].cpu().detach().numpy() + manifest_entry["features"] = features + f.write(json.dumps(manifest_entry, cls=NumpyEncoder) + "\n") + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--manifest_fp", required=True, type=str) + parser.add_argument("--model_fp", required=True, type=str) + parser.add_argument("--out_fp", required=True, type=str) + parser.add_argument("--max_objects", required=False, type=int, default=None) + parser.add_argument("--device", required=False, type=str, default="cpu") + + args = parser.parse_args() + main( + in_fp=args.manifest_fp, + model_fp=args.model_fp, + out_fp=args.out_fp, + device=args.device, + max_objects=args.max_objects + ) diff --git a/scripts/test.py b/scripts/test.py new file mode 100644 index 0000000..74550d6 --- /dev/null +++ b/scripts/test.py @@ -0,0 +1,99 @@ +from argparse import ArgumentParser + +import torchreid + +import json +import logging + + +def read_jsonl(in_file_path): + with open(in_file_path) as f: + data = map(json.loads, f) + for datum in data: + yield datum + + +logging.basicConfig(level=logging.DEBUG) +LOGGER = logging.getLogger(__name__) + + +def main(data_dir, save_dir): + """ + datamanager = torchreid.data.ImageDataManager( + root='reid-data', + sources='market1501', + targets='market1501', + height=256, + width=128, + batch_size_train=32, + batch_size_test=100, + transforms=['random_flip', 'random_crop'] + ) + """ + datamanager = torchreid.data.ImageDataManager( + root=data_dir, + sources=['safex_carla_simulation'], + height=256, + width=128, + batch_size_train=32, + batch_size_test=100, + transforms=['random_flip'], + split_id=1 + ) + + model = torchreid.models.build_model( + name='resnet50', + num_classes=datamanager.num_train_pids, + loss='triplet', + pretrained=True + ) + + model = model.cpu() + + optimizer = torchreid.optim.build_optimizer( + model, + optim='adam', + lr=0.0003 + ) + + scheduler = torchreid.optim.build_lr_scheduler( + optimizer, + lr_scheduler='single_step', + stepsize=20 + ) + + engine = torchreid.engine.ImageTripletEngine( + datamanager, + model, + optimizer=optimizer, + scheduler=scheduler, + label_smooth=True + ) + + start_epoch = torchreid.utils.resume_from_checkpoint( + '/home/ubuntu/deep-person-reid/reid_out/resnet/model/model.pth.tar-15', + model, + optimizer + ) + + engine.run( + start_epoch=start_epoch, + save_dir=save_dir, + max_epoch=60, + eval_freq=1, + print_freq=2, + test_only=True, + visrank=True + ) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--data_dir", default="./reid_out/", required=False) + parser.add_argument("--save_dir", default="./reid_out/", required=False) + + args = parser.parse_args() + main( + data_dir=args.data_dir, + save_dir=args.save_dir + ) diff --git a/scripts/train.py b/scripts/train.py index 201ea70..72d3ad8 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -33,31 +33,33 @@ def main(data_dir, save_dir): datamanager = torchreid.data.ImageDataManager( root=data_dir, sources=['safex_carla_simulation'], - height=50, - width=50, + height=256, + width=128, + split_id=0, batch_size_train=32, batch_size_test=100, - transforms=['random_flip', 'color_jitter'] + transforms=['random_flip'], + ) model = torchreid.models.build_model( name='osnet_ain_x1_0', num_classes=datamanager.num_train_pids, loss='triplet', - pretrained=False + pretrained=True ) model = model.cpu() optimizer = torchreid.optim.build_optimizer( model, - optim='adam', - lr=0.0003 + optim='amsgrad', + lr=0.0015 ) scheduler = torchreid.optim.build_lr_scheduler( optimizer, - lr_scheduler='single_step', + lr_scheduler='cosine', stepsize=20 ) @@ -68,12 +70,13 @@ def main(data_dir, save_dir): scheduler=scheduler, label_smooth=True ) - engine.run( save_dir=save_dir, - max_epoch=60, - eval_freq=10, - print_freq=10, + max_epoch=10, + fixbase_epoch=2, + dist_metric="cosine", + eval_freq=1, + print_freq=2, test_only=False ) @@ -81,7 +84,7 @@ def main(data_dir, save_dir): if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--data_dir", default="./reid_out/", required=False) - parser.add_argument("--save_dir", default="./reid_out/", required=False) + parser.add_argument("--save_dir", default="./reid_out/osnet", required=False) args = parser.parse_args() main( diff --git a/scripts/train_resnet.py b/scripts/train_resnet.py new file mode 100644 index 0000000..575239d --- /dev/null +++ b/scripts/train_resnet.py @@ -0,0 +1,92 @@ +from argparse import ArgumentParser + +import torchreid + +import json +import logging + + +def read_jsonl(in_file_path): + with open(in_file_path) as f: + data = map(json.loads, f) + for datum in data: + yield datum + + +logging.basicConfig(level=logging.DEBUG) +LOGGER = logging.getLogger(__name__) + + +def main(data_dir, save_dir): + """ + datamanager = torchreid.data.ImageDataManager( + root='reid-data', + sources='market1501', + targets='market1501', + height=256, + width=128, + batch_size_train=32, + batch_size_test=100, + transforms=['random_flip', 'random_crop'] + ) + """ + datamanager = torchreid.data.ImageDataManager( + root=data_dir, + sources=['safex_carla_simulation'], + height=256, + width=128, + batch_size_train=32, + batch_size_test=100, + transforms=['random_flip'], + + ) + + model = torchreid.models.build_model( + name='resnet50', + num_classes=datamanager.num_train_pids, + loss='triplet', + pretrained=True + ) + + model = model.cpu() + + optimizer = torchreid.optim.build_optimizer( + model, + optim='amsgrad', + lr=0.0015 + ) + + scheduler = torchreid.optim.build_lr_scheduler( + optimizer, + lr_scheduler='cosine', + stepsize=20 + ) + + engine = torchreid.engine.ImageTripletEngine( + datamanager, + model, + optimizer=optimizer, + scheduler=scheduler, + label_smooth=True + ) + engine.run( + save_dir=save_dir, + max_epoch=10, + fixbase_epoch=2, + dist_metric="cosine", + eval_freq=1, + print_freq=2, + test_only=False + ) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--data_dir", default="./reid_out/", required=False) + parser.add_argument("--save_dir", default="./reid_out/resnet/", required=False) + + args = parser.parse_args() + main( + data_dir=args.data_dir, + save_dir=args.save_dir + ) diff --git a/torchreid/data/datasets/image/safex_simulation.py b/torchreid/data/datasets/image/safex_simulation.py index 73573ae..d49c8ac 100644 --- a/torchreid/data/datasets/image/safex_simulation.py +++ b/torchreid/data/datasets/image/safex_simulation.py @@ -2,7 +2,7 @@ import copy import os import random from collections import defaultdict - +import shutil from ..dataset import ImageDataset import logging @@ -93,7 +93,6 @@ class SafeXCARLASimulation(ImageDataset): splits = [] for split_ix in range(10): - manifest_entries = read_jsonl(manifest_fp) # randomly choose num_train_pids train IDs and the rest for test IDs pids_copy = copy.deepcopy(pids) random.shuffle(pids_copy) @@ -106,25 +105,37 @@ class SafeXCARLASimulation(ImageDataset): gallery = [] # for train IDs, all images are used in the train set. + seen_objects_counter = defaultdict(int) for pid in train_pids: entries = object_manifest_entries_mapping[pid] random.shuffle(entries) for entry in entries: path, object_guid, camera_guid = entry[0], entry[1], entry[2] + query_key = "_".join([str(object_guid), str(camera_guid)]) + seen_objects_counter[query_key] += 1 + if seen_objects_counter[query_key] > 100: + continue + guid_label = train_pid2label[object_guid] entry = (path, guid_label, camera_guid) train.append(entry) + #shutil.copyfile(path, os.path.join(self.train_dir, str(object_guid) + "_" + str(camera_guid) + ".jpg")) # for each test ID, randomly choose two images, one for # query and the other one for gallery. for pid in test_pids: entries = object_manifest_entries_mapping[pid] - samples = random.sample(entries, 2) + samples = random.sample(entries, min(20, len(entries))) + query_sample = samples[0] - gallery_sample = samples[1] + gallery_samples = samples[1:] query.append(query_sample) - gallery.append(gallery_sample) + gallery.extend(gallery_samples) + #shutil.copyfile(query_sample[0], os.path.join(self.query_dir, str(query_sample[1]) + "_" + str(query_sample[2]) + ".jpg")) + for gallery_sample in gallery_samples: + pass + #shutil.copyfile(gallery_sample[0], os.path.join(self.gallery_dir, str(gallery_sample[1]) + "_" + str(gallery_sample[2]) + ".jpg")) split = {'train': train, 'query': query, 'gallery': gallery} splits.append(split)