From 3024cea3a3fed1475bb500215aacef0d0c0b655c Mon Sep 17 00:00:00 2001
From: liaoxingyu <sherlockliao01@gmail.com>
Date: Mon, 28 Sep 2020 17:05:40 +0800
Subject: [PATCH] add more datasets

Summary: add wildtracker.py datasets and cuhk_sysu.py datasets.
---
 fastreid/data/datasets/__init__.py    | 78 +++++++++++++--------------
 fastreid/data/datasets/cuhk_sysu.py   | 58 ++++++++++++++++++++
 fastreid/data/datasets/wildtracker.py | 59 ++++++++++++++++++++
 3 files changed, 156 insertions(+), 39 deletions(-)
 create mode 100644 fastreid/data/datasets/cuhk_sysu.py
 create mode 100644 fastreid/data/datasets/wildtracker.py

diff --git a/fastreid/data/datasets/__init__.py b/fastreid/data/datasets/__init__.py
index ef3b498..f7b8ccd 100644
--- a/fastreid/data/datasets/__init__.py
+++ b/fastreid/data/datasets/__init__.py
@@ -1,39 +1,39 @@
-# encoding: utf-8
-"""
-@author:  liaoxingyu
-@contact: sherlockliao01@gmail.com
-"""
-
-from ...utils.registry import Registry
-
-DATASET_REGISTRY = Registry("DATASET")
-DATASET_REGISTRY.__doc__ = """
-Registry for datasets
-It must returns an instance of :class:`Backbone`.
-"""
-
-# Person re-id datasets
-from .cuhk03 import CUHK03
-from .dukemtmcreid import DukeMTMC
-from .market1501 import Market1501
-from .msmt17 import MSMT17
-from .AirportALERT import AirportALERT
-from .iLIDS import iLIDS
-from .pku import PKU
-from .prai import PRAI
-from .sensereid import SenseReID
-from .sysu_mm import SYSU_mm
-from .thermalworld import Thermalworld
-from .pes3d import PeS3D
-from .caviara import CAVIARa
-from .viper import VIPeR
-from .lpw import LPW
-from .shinpuhkan import Shinpuhkan
-
-# Vehicle re-id datasets
-from .veri import VeRi
-from .vehicleid import VehicleID, SmallVehicleID, MediumVehicleID, LargeVehicleID
-from .veriwild import VeRiWild, SmallVeRiWild, MediumVeRiWild, LargeVeRiWild
-
-
-__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from ...utils.registry import Registry
+
+DATASET_REGISTRY = Registry("DATASET")
+DATASET_REGISTRY.__doc__ = """
+Registry for datasets
+It must returns an instance of :class:`Backbone`.
+"""
+
+# Person re-id datasets
+from .cuhk03 import CUHK03
+from .dukemtmcreid import DukeMTMC
+from .market1501 import Market1501
+from .msmt17 import MSMT17
+from .AirportALERT import AirportALERT
+from .iLIDS import iLIDS
+from .pku import PKU
+from .prai import PRAI
+from .sensereid import SenseReID
+from .sysu_mm import SYSU_mm
+from .thermalworld import Thermalworld
+from .pes3d import PeS3D
+from .caviara import CAVIARa
+from .viper import VIPeR
+from .lpw import LPW
+from .shinpuhkan import Shinpuhkan
+
+# Vehicle re-id datasets
+from .veri import VeRi
+from .vehicleid import VehicleID, SmallVehicleID, MediumVehicleID, LargeVehicleID
+from .veriwild import VeRiWild, SmallVeRiWild, MediumVeRiWild, LargeVeRiWild
+
+
+__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
diff --git a/fastreid/data/datasets/cuhk_sysu.py b/fastreid/data/datasets/cuhk_sysu.py
new file mode 100644
index 0000000..75a1488
--- /dev/null
+++ b/fastreid/data/datasets/cuhk_sysu.py
@@ -0,0 +1,58 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import glob
+import os.path as osp
+import re
+import warnings
+
+from .bases import ImageDataset
+from ..datasets import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class cuhkSYSU(ImageDataset):
+    r"""CUHK SYSU datasets.
+
+    The dataset is collected from two sources: street snap and movie.
+    In street snap, 12,490 images and 6,057 query persons were collected
+    with movable cameras across hundreds of scenes while 5,694 images and
+    2,375 query persons were selected from movies and TV dramas.
+
+    Dataset statistics:
+        - identities: xxx.
+        - images: 12936 (train).
+    """
+    dataset_dir = 'cuhk_sysu'
+    dataset_name = "cuhksysu"
+
+    def __init__(self, root='datasets', **kwargs):
+        self.root = root
+        self.dataset_dir = osp.join(self.root, self.dataset_dir)
+
+        self.data_dir = osp.join(self.dataset_dir, "cropped_images")
+
+        required_files = [self.data_dir]
+        self.check_before_run(required_files)
+
+        train = self.process_dir(self.data_dir)
+        query = []
+        gallery = []
+
+        super(cuhkSYSU, self).__init__(train, query, gallery, **kwargs)
+
+    def process_dir(self, dir_path):
+        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
+        pattern = re.compile(r'p([-\d]+)_s(\d)')
+
+        data = []
+        for img_path in img_paths:
+            pid, _ = map(int, pattern.search(img_path).groups())
+            pid = self.dataset_name + "_" + str(pid)
+            camid = self.dataset_name + "_0"
+            data.append((img_path, pid, camid))
+
+        return data
diff --git a/fastreid/data/datasets/wildtracker.py b/fastreid/data/datasets/wildtracker.py
new file mode 100644
index 0000000..d163d5d
--- /dev/null
+++ b/fastreid/data/datasets/wildtracker.py
@@ -0,0 +1,59 @@
+# encoding: utf-8
+"""
+@author:  wangguanan
+@contact: guan.wang0706@gmail.com
+"""
+
+import glob
+import os
+
+from .bases import ImageDataset
+from ..datasets import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class WildTrackCrop(ImageDataset):
+    """WildTrack.
+    Reference:
+        WILDTRACK: A Multi-camera HD Dataset for Dense Unscripted Pedestrian Detection
+            T. Chavdarova; P. BaquƩ; A. Maksai; S. Bouquet; C. Jose et al.
+    URL: `<https://www.epfl.ch/labs/cvlab/data/data-wildtrack/>`_
+    Dataset statistics:
+        - identities: 313
+        - images: 33979 (train only)
+        - cameras: 7
+    Args:
+        data_path(str): path to WildTrackCrop dataset
+        combineall(bool): combine train and test sets as train set if True
+    """
+    dataset_url = None
+    dataset_dir = 'Wildtrack_crop_dataset'
+    dataset_name = 'wildtrack'
+
+    def __init__(self, root='datasets', **kwargs):
+        self.root = root
+        self.dataset_dir = os.path.join(self.root, self.dataset_dir)
+
+        self.train_dir = os.path.join(self.dataset_dir, "crop")
+
+        train = self.process_dir(self.train_dir)
+        query = []
+        gallery = []
+
+        super(WildTrackCrop, self).__init__(train, query, gallery, **kwargs)
+
+    def process_dir(self, dir_path):
+        r"""
+        :param dir_path: directory path saving images
+        Returns
+            data(list) = [img_path, pid, camid]
+        """
+        data = []
+        for dir_name in os.listdir(dir_path):
+            img_lists = glob.glob(os.path.join(dir_path, dir_name, "*.png"))
+            for img_path in img_lists:
+                pid = self.dataset_name + "_" + dir_name
+                camid = img_path.split('/')[-1].split('_')[0]
+                camid = self.dataset_name + "_" + camid
+                data.append([img_path, pid, camid])
+        return data