from argparse import ArgumentParser

import torchreid

import json
import logging


def read_jsonl(in_file_path):
    with open(in_file_path) as f:
        data = map(json.loads, f)
        for datum in data:
            yield datum


logging.basicConfig(level=logging.DEBUG)
LOGGER = logging.getLogger(__name__)


def main(data_dir, save_dir):
    """
    datamanager = torchreid.data.ImageDataManager(
        root='reid-data',
        sources='market1501',
        targets='market1501',
        height=256,
        width=128,
        batch_size_train=32,
        batch_size_test=100,
        transforms=['random_flip', 'random_crop']
    )
    """
    datamanager = torchreid.data.ImageDataManager(
        root=data_dir,
        sources=['safex_carla_simulation'],
        height=256,
        width=128,
        batch_size_train=32,
        batch_size_test=100,
        transforms=['random_flip'],
        split_id=1
    )

    model = torchreid.models.build_model(
        name='resnet50',
        num_classes=datamanager.num_train_pids,
        loss='triplet',
        pretrained=True
    )

    model = model.cpu()

    optimizer = torchreid.optim.build_optimizer(
        model,
        optim='adam',
        lr=0.0003
    )

    scheduler = torchreid.optim.build_lr_scheduler(
        optimizer,
        lr_scheduler='single_step',
        stepsize=20
    )

    engine = torchreid.engine.ImageTripletEngine(
        datamanager,
        model,
        optimizer=optimizer,
        scheduler=scheduler,
        label_smooth=True
    )

    start_epoch = torchreid.utils.resume_from_checkpoint(
        '/home/ubuntu/deep-person-reid/reid_out/resnet/model/model.pth.tar-15',
        model,
        optimizer
    )

    engine.run(
        start_epoch=start_epoch,
        save_dir=save_dir,
        max_epoch=60,
        eval_freq=1,
        print_freq=2,
        test_only=True,
        visrank=True
    )


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--data_dir", default="./reid_out/", required=False)
    parser.add_argument("--save_dir", default="./reid_out/", required=False)

    args = parser.parse_args()
    main(
        data_dir=args.data_dir,
        save_dir=args.save_dir
    )