# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from functools import partial

import mmcv
import numpy as np
from PIL import Image
from scipy.io import loadmat

AUG_LEN = 10582


def convert_mat(mat_file, in_dir, out_dir):
    data = loadmat(osp.join(in_dir, mat_file))
    mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8)
    seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png'))
    Image.fromarray(mask).save(seg_filename, 'PNG')


def generate_aug_list(merged_list, excluded_list):
    return list(set(merged_list) - set(excluded_list))


def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert PASCAL VOC annotations to mmsegmentation format')
    parser.add_argument('devkit_path', help='pascal voc devkit path')
    parser.add_argument('aug_path', help='pascal voc aug path')
    parser.add_argument('-o', '--out_dir', help='output path')
    parser.add_argument(
        '--nproc', default=1, type=int, help='number of process')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    devkit_path = args.devkit_path
    aug_path = args.aug_path
    nproc = args.nproc
    if args.out_dir is None:
        out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug')
    else:
        out_dir = args.out_dir
    mmcv.mkdir_or_exist(out_dir)
    in_dir = osp.join(aug_path, 'dataset', 'cls')

    mmcv.track_parallel_progress(
        partial(convert_mat, in_dir=in_dir, out_dir=out_dir),
        list(mmcv.scandir(in_dir, suffix='.mat')),
        nproc=nproc)

    full_aug_list = []
    with open(osp.join(aug_path, 'dataset', 'train.txt')) as f:
        full_aug_list += [line.strip() for line in f]
    with open(osp.join(aug_path, 'dataset', 'val.txt')) as f:
        full_aug_list += [line.strip() for line in f]

    with open(
            osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
                     'train.txt')) as f:
        ori_train_list = [line.strip() for line in f]
    with open(
            osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
                     'val.txt')) as f:
        val_list = [line.strip() for line in f]

    aug_train_list = generate_aug_list(ori_train_list + full_aug_list,
                                       val_list)
    assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format(
        AUG_LEN)

    with open(
            osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
                     'trainaug.txt'), 'w') as f:
        f.writelines(line + '\n' for line in aug_train_list)

    aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list)
    assert len(aug_list) == AUG_LEN - len(
        ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN -
                                                      len(ori_train_list))
    with open(
            osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'),
            'w') as f:
        f.writelines(line + '\n' for line in aug_list)

    print('Done!')


if __name__ == '__main__':
    main()