fix pre-train model bugs

fix bugs locks when downloading pre-train model
pull/224/head
liaoxingyu 2020-08-04 15:56:36 +08:00
parent e1aaeb358b
commit d1c20cbe50
12 changed files with 22 additions and 20 deletions

View File

@ -15,7 +15,7 @@ We write a [chinese blog](https://l1aoxingyu.github.io/blogpages/reid/2020/05/29
## Installation
See [INSTALL.md](https://github.com/JDAI-CV/fast-reid/blob/master/INSTALL.md).
See [INSTALL.md](https://github.com/JDAI-CV/fast-reid/docs/blob/master/INSTALL.md).
## Quick Start

View File

@ -8,3 +8,7 @@
- [yacs](https://github.com/rbgirshick/yacs)
- Cython (optional to compile evaluation code)
- tensorboard (needed for visualization): `pip install tensorboard`
- gdown (for automatically downloading pre-train model)
- sklearn
- termcolor
- tabulate

View File

@ -60,7 +60,7 @@ Bag of Specials(BoS):
| :---: | :---: | :---: |:---: | :---: |:---: |
| [AGW(R50)](https://github.com/JDAI-CV/fast-reid/blob/master/configs/Market1501/AGW_R50.yml) | ImageNet | 95.3% | 88.2% | 66.3% | [model](https://github.com/JDAI-CV/fast-reid/releases/download/v0.1.1/market_agw_R50.pth) |
| [AGW(R50-ibn)](https://github.com/JDAI-CV/fast-reid/blob/master/configs/Market1501/AGW_R50-ibn.yml) | ImageNet | 95.1% | 88.7% | 67.1% | [model](https://github.com/JDAI-CV/fast-reid/releases/download/v0.1.1/market_agw_R50-ibn.pth) |
| [AGW(S50)](https://github.com/JDAI-CV/fast-reid/blob/master/configs/Market1501/AGW_S50.yml) | ImageNet | 94.7% | 87.1% | 62.2% | - |
| [AGW(S50)](https://github.com/JDAI-CV/fast-reid/blob/master/configs/Market1501/AGW_S50.yml) | ImageNet | 95.3% | 89.3% | 68.5% | [model](https://github.com/JDAI-CV/fast-reid/releases/download/v0.1.1/market_agw_S50.pth) |
| [AGW(R101-ibn)](https://github.com/JDAI-CV/fast-reid/blob/master/configs/Market1501/AGW_R101-ibn.yml) | ImageNet | 95.5% | 89.5% | 69.5% | [model](https://github.com/JDAI-CV/fast-reid/releases/download/v0.1.1/market_agw_R101-ibn.pth) |
**SBS**:
@ -76,7 +76,7 @@ Bag of Specials(BoS):
| Method | Pretrained | Rank@1 | mAP | mINP | download |
| :---: | :---: | :---: |:---: | :---: | :---:|
| [SBS(R50-ibn)](https://github.com/JDAI-CV/fast-reid/blob/master/configs/Market1501/mgn_R50-ibn.yml) | ImageNet | 95.8% | 89.7% | 67.0% | - |
| [SBS(R50-ibn)](https://github.com/JDAI-CV/fast-reid/blob/master/configs/Market1501/mgn_R50-ibn.yml) | ImageNet | 95.8% | 89.8% | 67.7% | [model](https://github.com/JDAI-CV/fast-reid/releases/download/v0.1.1/market_mgn_R50-ibn.pth) |
### DukeMTMC Baseline

View File

@ -186,7 +186,7 @@ _C.SOLVER = CN()
_C.SOLVER.OPT = "Adam"
_C.SOLVER.MAX_ITER = 40000
_C.SOLVER.MAX_ITER = 120
_C.SOLVER.BASE_LR = 3e-4
_C.SOLVER.BIAS_LR_FACTOR = 1.
@ -222,7 +222,7 @@ _C.SOLVER.SWA.LR_FACTOR = 10.
_C.SOLVER.SWA.ETA_MIN_LR = 3.5e-6
_C.SOLVER.SWA.LR_SCHED = False
_C.SOLVER.CHECKPOINT_PERIOD = 5000
_C.SOLVER.CHECKPOINT_PERIOD = 20
# Number of images per batch across all machines.
# This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
@ -233,7 +233,7 @@ _C.SOLVER.IMS_PER_BATCH = 64
# see 2 images per batch
_C.TEST = CN()
_C.TEST.EVAL_PERIOD = 50
_C.TEST.EVAL_PERIOD = 20
# Number of images per batch in one process.
_C.TEST.IMS_PER_BATCH = 64

View File

@ -70,7 +70,7 @@ def build_reid_test_loader(cfg, dataset_name):
test_loader = DataLoader(
test_set,
batch_sampler=batch_sampler,
num_workers=4, # save some memory
num_workers=0, # save some memory
collate_fn=fast_batch_collator)
return test_loader, len(dataset.query)

View File

@ -31,8 +31,7 @@ class CUHK03(ImageDataset):
dataset_url = None
dataset_name = "cuhk03"
def __init__(self, root='datasets', split_id=0, cuhk03_labeled=False, cuhk03_classic_split=False, **kwargs):
# self.root = osp.abspath(osp.expanduser(root))
def __init__(self, root='datasets', split_id=0, cuhk03_labeled=True, cuhk03_classic_split=False, **kwargs):
self.root = root
self.dataset_dir = osp.join(self.root, self.dataset_dir)
@ -82,7 +81,6 @@ class CUHK03(ImageDataset):
del tmp_train
query = split['query']
gallery = split['gallery']
from ipdb import set_trace; set_trace()
super(CUHK03, self).__init__(train, query, gallery, **kwargs)
@ -271,5 +269,5 @@ class CUHK03(ImageDataset):
'num_gallery_pids': gallery_info[1],
'num_gallery_imgs': gallery_info[2]
}]
with PathManager.open(self.split_new_lab_json_pat, 'w') as f:
with PathManager.open(self.split_new_lab_json_path, 'w') as f:
json.dump(split, f, indent=4, separators=(',', ': '))

View File

@ -448,8 +448,8 @@ def init_pretrained_weights(model, key=''):
if not os.path.exists(cached_file):
if comm.is_main_process():
gdown.download(model_urls[key], cached_file, quiet=False)
else:
comm.synchronize()
comm.synchronize()
state_dict = torch.load(cached_file, map_location=torch.device('cpu'))
model_dict = model.state_dict()

View File

@ -518,8 +518,8 @@ def init_pretrained_weights(key):
if not os.path.exists(cached_file):
if comm.is_main_process():
gdown.download(model_urls[key], cached_file, quiet=False)
else:
comm.synchronize()
comm.synchronize()
logger.info(f"Loading pretrained model from {cached_file}")
state_dict = torch.load(cached_file, map_location=torch.device('cpu'))['model_state']

View File

@ -17,7 +17,6 @@ from fastreid.layers import (
)
from fastreid.utils.checkpoint import get_unexpected_parameters_message, get_missing_parameters_message
from fastreid.utils import comm
from .build import BACKBONE_REGISTRY

View File

@ -275,8 +275,8 @@ def init_pretrained_weights(key):
if not os.path.exists(cached_file):
if comm.is_main_process():
gdown.download(model_urls[key], cached_file, quiet=False)
else:
comm.synchronize()
comm.synchronize()
logger.info(f"Loading pretrained model from {cached_file}")
state_dict = torch.load(cached_file, map_location=torch.device('cpu'))

View File

@ -209,14 +209,15 @@ def init_pretrained_weights(key):
if not os.path.exists(cached_file):
if comm.is_main_process():
gdown.download(model_urls[key], cached_file, quiet=False)
else:
comm.synchronize()
comm.synchronize()
logger.info(f"Loading pretrained model from {cached_file}")
state_dict = torch.load(cached_file, map_location=torch.device('cpu'))
return state_dict
@BACKBONE_REGISTRY.register()
def build_resnext_backbone(cfg):
"""