From 8879db3fbaa62e276198f0c8e41ea4d27fca766a Mon Sep 17 00:00:00 2001 From: liaoxingyu Date: Tue, 16 Jun 2020 19:43:36 +0800 Subject: [PATCH] update training instruction Summary: update dataset configuration and training instruction --- GETTING_STARTED.md | 22 ++++++++++++++++++++++ datasets/README.md | 24 ++++++++++++++++++++++-- fastreid/data/build.py | 7 +++++-- fastreid/data/datasets/__init__.py | 2 ++ fastreid/data/datasets/bases.py | 5 +---- fastreid/utils/checkpoint.py | 2 +- tools/train_net.py | 1 + 7 files changed, 54 insertions(+), 9 deletions(-) diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index b0236ab..608db5f 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -12,3 +12,25 @@ Then you should set the pretrain model path in `configs/Base-bagtricks.yml`. ```bash cd fastreid/evaluation/rank_cylib; make all ``` + +## Training & Evaluation in Command Line + +We provide a script in "tools/train_net.py", that is made to train all the configs provided in fastreid. +You may want to use it as a reference to write your own training script. + +To train a model with "train_net.py", first setup up the corresponding datasets following [datasets/README.md](https://github.com/JDAI-CV/fast-reid/tree/master/datasets), then run: + +```bash +CUDA_VISIBLE_DEVICES=$gpus tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml +``` + +The configs are made for 1-GPU training. + +To evaluate a model's performance, use + +```bash +CUDA_VISIBLE_DEVICES=$gpus tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml \ +--eval-only MODEL.WEIGHTS /path/to/checkpoint_file +``` + +For more options, see `./train_net.py -h`. diff --git a/datasets/README.md b/datasets/README.md index 8607714..f890144 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -2,9 +2,9 @@ Fastreid has buildin support for a few datasets. The datasets are assumed to exist in a directory specified by the environment variable `FASTREID_DATASETS`. Under this directory, fastreid expects to find datasets in the structure described below. -You can set the location for builtin datasets by `export FASTREID_DATASETS=/path/to/datasets/`. If left unset, the default is `./datasets` relative to your current working directory. +You can set the location for builtin datasets by `export FASTREID_DATASETS=/path/to/datasets/`. If left unset, the default is `datasets/` relative to your current working directory. -The model zoo contains configs and models that use these buildin datasets. +The [model zoo](https://github.com/JDAI-CV/fast-reid/blob/master/MODEL_ZOO.md) contains configs and models that use these buildin datasets. ## Expected dataset structure for Market1501 @@ -20,4 +20,24 @@ datasets/ ## Expected dataset structure for DukeMTMC +1. Download datasets to `datasets/` +2. Extract dataset. The dataset structure would like: + +```bash +datasets/ + DukeMTMC-reID/ + bounding_box_train/ + bounding_box_test/ +``` + ## Expected dataset structure for MSMT17 + +1. Download datasets to `datasets/` +2. Extract dataset. The dataset structure would like: + +```bash +datasets/ + MSMT17_V2/ + mask_train_v2/ + mask_test_v2/ +``` diff --git a/fastreid/data/build.py b/fastreid/data/build.py index e7005a9..861e135 100644 --- a/fastreid/data/build.py +++ b/fastreid/data/build.py @@ -4,6 +4,7 @@ @contact: sherlockliao01@gmail.com """ +import os import torch from torch._six import container_abcs, string_classes, int_classes from torch.utils.data import DataLoader @@ -13,13 +14,15 @@ from .common import CommDataset from .datasets import DATASET_REGISTRY from .transforms import build_transforms +_root = os.getenv("FASTREID_DATASETS", "datasets") + def build_reid_train_loader(cfg): train_transforms = build_transforms(cfg, is_train=True) train_items = list() for d in cfg.DATASETS.NAMES: - dataset = DATASET_REGISTRY.get(d)(combineall=cfg.DATASETS.COMBINEALL) + dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL) dataset.show_train() train_items.extend(dataset.train) @@ -50,7 +53,7 @@ def build_reid_train_loader(cfg): def build_reid_test_loader(cfg, dataset_name): test_transforms = build_transforms(cfg, is_train=False) - dataset = DATASET_REGISTRY.get(dataset_name)() + dataset = DATASET_REGISTRY.get(dataset_name)(root=_root) dataset.show_test() test_items = dataset.query + dataset.gallery diff --git a/fastreid/data/datasets/__init__.py b/fastreid/data/datasets/__init__.py index 2f10264..8518cb9 100644 --- a/fastreid/data/datasets/__init__.py +++ b/fastreid/data/datasets/__init__.py @@ -19,3 +19,5 @@ from .msmt17 import MSMT17 from .veri import VeRi from .vehicleid import VehicleID, SmallVehicleID, MediumVehicleID, LargeVehicleID from .veriwild import VeRiWild, SmallVeRiWild, MediumVeRiWild, LargeVeRiWild + +__all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")] diff --git a/fastreid/data/datasets/bases.py b/fastreid/data/datasets/bases.py index 910df7b..28a1780 100644 --- a/fastreid/data/datasets/bases.py +++ b/fastreid/data/datasets/bases.py @@ -5,11 +5,8 @@ """ import copy -import os - -import numpy as np -import torch import logging +import os class Dataset(object): diff --git a/fastreid/utils/checkpoint.py b/fastreid/utils/checkpoint.py index 74baba9..738a9cd 100644 --- a/fastreid/utils/checkpoint.py +++ b/fastreid/utils/checkpoint.py @@ -95,7 +95,7 @@ class Checkpointer(object): if not path: # no checkpoint provided self.logger.info( - "No checkpoint found. Initializing model from scratch" + "No checkpoint found. Training model from scratch" ) return {} self.logger.info("Loading checkpoint from {}".format(path)) diff --git a/tools/train_net.py b/tools/train_net.py index 14e3030..72df9b5 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python # encoding: utf-8 """ @author: sherlock