from argparse import ArgumentParser import torchreid import json import logging def read_jsonl(in_file_path): with open(in_file_path) as f: data = map(json.loads, f) for datum in data: yield datum logging.basicConfig(level=logging.DEBUG) LOGGER = logging.getLogger(__name__) def main(data_dir, save_dir, model_fp): datamanager = torchreid.data.ImageDataManager( root=data_dir, sources="market1501", targets="market1501", height=256, width=128, batch_size_train=32, batch_size_test=100, transforms=["random_flip", "random_crop"] ) model = torchreid.models.build_model( name="resnet18", num_classes=datamanager.num_train_pids, loss="triplet", pretrained=True ) model = model.cpu() optimizer = torchreid.optim.build_optimizer( model, optim="adam", lr=0.0003 ) scheduler = torchreid.optim.build_lr_scheduler( optimizer, lr_scheduler="single_step", stepsize=20 ) engine = torchreid.engine.ImageTripletEngine( datamanager, model, optimizer=optimizer, scheduler=scheduler, label_smooth=True ) start_epoch = torchreid.utils.resume_from_checkpoint( model_fp, model, optimizer ) engine.run( start_epoch=start_epoch, save_dir=save_dir, max_epoch=60, eval_freq=1, print_freq=2, test_only=True, visrank=True ) if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--model_fp", required=True) parser.add_argument("--data_dir", default="./reid_out/", required=False) parser.add_argument("--save_dir", default="./reid_out/", required=False) args = parser.parse_args() main( model_fp=args.model_path, data_dir=args.data_dir, save_dir=args.save_dir )