deep-person-reid/scripts/test.py

90 lines
1.9 KiB
Python
Raw Normal View History

2021-02-09 13:16:18 +08:00
from argparse import ArgumentParser
import torchreid
import json
import logging
def read_jsonl(in_file_path):
with open(in_file_path) as f:
data = map(json.loads, f)
for datum in data:
yield datum
logging.basicConfig(level=logging.DEBUG)
LOGGER = logging.getLogger(__name__)
2021-07-31 00:18:59 +08:00
def main(data_dir, save_dir, model_fp):
2021-02-09 13:16:18 +08:00
datamanager = torchreid.data.ImageDataManager(
root=data_dir,
2021-07-31 00:18:59 +08:00
sources="market1501",
targets="market1501",
2021-02-09 13:16:18 +08:00
height=256,
width=128,
batch_size_train=32,
batch_size_test=100,
2021-07-31 00:18:59 +08:00
transforms=["random_flip", "random_crop"]
2021-02-09 13:16:18 +08:00
)
model = torchreid.models.build_model(
2021-07-31 00:18:59 +08:00
name="resnet18",
2021-02-09 13:16:18 +08:00
num_classes=datamanager.num_train_pids,
2021-07-31 00:18:59 +08:00
loss="triplet",
2021-02-09 13:16:18 +08:00
pretrained=True
)
model = model.cpu()
optimizer = torchreid.optim.build_optimizer(
model,
2021-07-31 00:18:59 +08:00
optim="adam",
2021-02-09 13:16:18 +08:00
lr=0.0003
)
scheduler = torchreid.optim.build_lr_scheduler(
optimizer,
2021-07-31 00:18:59 +08:00
lr_scheduler="single_step",
2021-02-09 13:16:18 +08:00
stepsize=20
)
engine = torchreid.engine.ImageTripletEngine(
datamanager,
model,
optimizer=optimizer,
scheduler=scheduler,
label_smooth=True
)
start_epoch = torchreid.utils.resume_from_checkpoint(
2021-07-31 00:18:59 +08:00
model_fp,
2021-02-09 13:16:18 +08:00
model,
optimizer
)
engine.run(
start_epoch=start_epoch,
save_dir=save_dir,
max_epoch=60,
eval_freq=1,
print_freq=2,
test_only=True,
visrank=True
)
if __name__ == "__main__":
parser = ArgumentParser()
2021-07-31 00:18:59 +08:00
parser.add_argument("--model_fp", required=True)
2021-02-09 13:16:18 +08:00
parser.add_argument("--data_dir", default="./reid_out/", required=False)
parser.add_argument("--save_dir", default="./reid_out/", required=False)
args = parser.parse_args()
main(
2021-07-31 00:18:59 +08:00
model_fp=args.model_path,
2021-02-09 13:16:18 +08:00
data_dir=args.data_dir,
save_dir=args.save_dir
)