Merge pull request #4 from AICradle/feature/safex-reid-training

CLean up scripts
pull/462/head
Joshua Newman 2021-07-30 10:21:41 -06:00 committed by GitHub
commit c087521d24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 18 additions and 54 deletions

1
.gitignore vendored
View File

@ -137,6 +137,7 @@ dmypy.json
# ReID # ReID
reid-data/ reid-data/
log/ log/
saved-models/ saved-models/
model-zoo/ model-zoo/

View File

View File

@ -3,11 +3,9 @@ import logging
import numpy as np import numpy as np
from argparse import ArgumentParser from argparse import ArgumentParser
from tqdm import tqdm from tqdm import tqdm
from torchreid.utils import FeatureExtractor from torchreid.utils import FeatureExtractor
from torchreid.utils.tools import read_jsonl from torchreid.utils.tools import read_jsonl
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -28,12 +26,13 @@ class NumpyEncoder(json.JSONEncoder):
def main(in_fp: str, model_fp: str, out_fp: str, device: str, max_objects: int): def main(in_fp: str, model_fp: str, out_fp: str, device: str, max_objects: int):
extractor = FeatureExtractor( extractor = FeatureExtractor(
model_name="resnet50", model_name="resnet18",
model_path=model_fp, model_path=model_fp,
device=device device=device
) )
manifest_entries = read_jsonl(in_fp) manifest_entries = read_jsonl(in_fp)
# Todo (Josh) speed this up my batching
with open(out_fp, "w") as f: with open(out_fp, "w") as f:
for ix, manifest_entry in tqdm(enumerate(manifest_entries), desc="objects"): for ix, manifest_entry in tqdm(enumerate(manifest_entries), desc="objects"):

View File

@ -17,34 +17,22 @@ logging.basicConfig(level=logging.DEBUG)
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
def main(data_dir, save_dir): def main(data_dir, save_dir, model_fp):
"""
datamanager = torchreid.data.ImageDataManager(
root='reid-data',
sources='market1501',
targets='market1501',
height=256,
width=128,
batch_size_train=32,
batch_size_test=100,
transforms=['random_flip', 'random_crop']
)
"""
datamanager = torchreid.data.ImageDataManager( datamanager = torchreid.data.ImageDataManager(
root=data_dir, root=data_dir,
sources=['safex_carla_simulation'], sources="market1501",
targets="market1501",
height=256, height=256,
width=128, width=128,
batch_size_train=32, batch_size_train=32,
batch_size_test=100, batch_size_test=100,
transforms=['random_flip'], transforms=["random_flip", "random_crop"]
split_id=1
) )
model = torchreid.models.build_model( model = torchreid.models.build_model(
name='resnet50', name="resnet18",
num_classes=datamanager.num_train_pids, num_classes=datamanager.num_train_pids,
loss='triplet', loss="triplet",
pretrained=True pretrained=True
) )
@ -52,13 +40,13 @@ def main(data_dir, save_dir):
optimizer = torchreid.optim.build_optimizer( optimizer = torchreid.optim.build_optimizer(
model, model,
optim='adam', optim="adam",
lr=0.0003 lr=0.0003
) )
scheduler = torchreid.optim.build_lr_scheduler( scheduler = torchreid.optim.build_lr_scheduler(
optimizer, optimizer,
lr_scheduler='single_step', lr_scheduler="single_step",
stepsize=20 stepsize=20
) )
@ -71,7 +59,7 @@ def main(data_dir, save_dir):
) )
start_epoch = torchreid.utils.resume_from_checkpoint( start_epoch = torchreid.utils.resume_from_checkpoint(
'/home/ubuntu/deep-person-reid/reid_out/resnet/model/model.pth.tar-15', model_fp,
model, model,
optimizer optimizer
) )
@ -89,11 +77,13 @@ def main(data_dir, save_dir):
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--model_fp", required=True)
parser.add_argument("--data_dir", default="./reid_out/", required=False) parser.add_argument("--data_dir", default="./reid_out/", required=False)
parser.add_argument("--save_dir", default="./reid_out/", required=False) parser.add_argument("--save_dir", default="./reid_out/", required=False)
args = parser.parse_args() args = parser.parse_args()
main( main(
model_fp=args.model_path,
data_dir=args.data_dir, data_dir=args.data_dir,
save_dir=args.save_dir save_dir=args.save_dir
) )

View File

@ -18,9 +18,8 @@ LOGGER = logging.getLogger(__name__)
def main(data_dir, save_dir): def main(data_dir, save_dir):
"""
datamanager = torchreid.data.ImageDataManager( datamanager = torchreid.data.ImageDataManager(
root='reid-data', root=data_dir,
sources='market1501', sources='market1501',
targets='market1501', targets='market1501',
height=256, height=256,
@ -29,19 +28,6 @@ def main(data_dir, save_dir):
batch_size_test=100, batch_size_test=100,
transforms=['random_flip', 'random_crop'] transforms=['random_flip', 'random_crop']
) )
"""
datamanager = torchreid.data.ImageDataManager(
root=data_dir,
sources=['safex_carla_simulation'],
height=256,
width=128,
split_id=0,
batch_size_train=32,
batch_size_test=100,
transforms=['random_flip'],
)
model = torchreid.models.build_model( model = torchreid.models.build_model(
name='osnet_ain_x1_0', name='osnet_ain_x1_0',
num_classes=datamanager.num_train_pids, num_classes=datamanager.num_train_pids,

View File

@ -18,9 +18,8 @@ LOGGER = logging.getLogger(__name__)
def main(data_dir, save_dir): def main(data_dir, save_dir):
"""
datamanager = torchreid.data.ImageDataManager( datamanager = torchreid.data.ImageDataManager(
root='reid-data', root=data_dir,
sources='market1501', sources='market1501',
targets='market1501', targets='market1501',
height=256, height=256,
@ -29,20 +28,9 @@ def main(data_dir, save_dir):
batch_size_test=100, batch_size_test=100,
transforms=['random_flip', 'random_crop'] transforms=['random_flip', 'random_crop']
) )
"""
datamanager = torchreid.data.ImageDataManager(
root=data_dir,
sources=['safex_carla_simulation'],
height=256,
width=128,
batch_size_train=32,
batch_size_test=100,
transforms=['random_flip'],
)
model = torchreid.models.build_model( model = torchreid.models.build_model(
name='resnet50', name='resnet18',
num_classes=datamanager.num_train_pids, num_classes=datamanager.num_train_pids,
loss='triplet', loss='triplet',
pretrained=True pretrained=True

View File

@ -40,7 +40,7 @@ class FeatureExtractor(object):
extractor = FeatureExtractor( extractor = FeatureExtractor(
model_name='osnet_x1_0', model_name='osnet_x1_0',
model_path='a/b/c/model.pth.tar', model_fp='a/b/c/model.pth.tar',
device='cuda' device='cuda'
) )