import os
import sys
import pathlib

__dir__ = pathlib.Path(os.path.abspath(__file__))
sys.path.append(str(__dir__))
sys.path.append(str(__dir__.parent.parent))

import paddle
import paddle.distributed as dist
from utils import Config, ArgsParser


def init_args():
    parser = ArgsParser()
    args = parser.parse_args()
    return args


def main(config, profiler_options):
    from models import build_model, build_loss
    from data_loader import get_dataloader
    from trainer import Trainer
    from post_processing import get_post_processing
    from utils import get_metric

    if paddle.device.cuda.device_count() > 1:
        dist.init_parallel_env()
        config["distributed"] = True
    else:
        config["distributed"] = False
    train_loader = get_dataloader(config["dataset"]["train"], config["distributed"])
    assert train_loader is not None
    if "validate" in config["dataset"]:
        validate_loader = get_dataloader(config["dataset"]["validate"], False)
    else:
        validate_loader = None
    criterion = build_loss(config["loss"])
    config["arch"]["backbone"]["in_channels"] = (
        3 if config["dataset"]["train"]["dataset"]["args"]["img_mode"] != "GRAY" else 1
    )
    model = build_model(config["arch"])
    # set @to_static for benchmark, skip this by default.
    post_p = get_post_processing(config["post_processing"])
    metric = get_metric(config["metric"])
    trainer = Trainer(
        config=config,
        model=model,
        criterion=criterion,
        train_loader=train_loader,
        post_process=post_p,
        metric_cls=metric,
        validate_loader=validate_loader,
        profiler_options=profiler_options,
    )
    trainer.train()


if __name__ == "__main__":
    args = init_args()
    assert os.path.exists(args.config_file)
    config = Config(args.config_file)
    config.merge_dict(args.opt)
    main(config.cfg, args.profiler_options)