65 lines
1.9 KiB
Python
65 lines
1.9 KiB
Python
import json
|
|
import logging
|
|
import numpy as np
|
|
|
|
from argparse import ArgumentParser
|
|
from tqdm import tqdm
|
|
|
|
from torchreid.utils import FeatureExtractor
|
|
from torchreid.utils.tools import read_jsonl
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
class NumpyEncoder(json.JSONEncoder):
|
|
def default(self, obj):
|
|
if isinstance(obj, np.integer):
|
|
return int(obj)
|
|
elif isinstance(obj, np.floating):
|
|
return float(obj)
|
|
elif isinstance(obj, np.ndarray):
|
|
return obj.tolist()
|
|
else:
|
|
return super(NumpyEncoder, self).default(obj)
|
|
|
|
|
|
def main(in_fp: str, model_fp: str, out_fp: str, device: str, max_objects: int):
|
|
extractor = FeatureExtractor(
|
|
model_name="resnet18",
|
|
model_path=model_fp,
|
|
device=device
|
|
)
|
|
|
|
manifest_entries = read_jsonl(in_fp)
|
|
# Todo (Josh) speed this up my batching
|
|
with open(out_fp, "w") as f:
|
|
for ix, manifest_entry in tqdm(enumerate(manifest_entries), desc="objects"):
|
|
|
|
if max_objects and ix > max_objects:
|
|
continue
|
|
|
|
path = manifest_entry["path"]
|
|
features = extractor([path])
|
|
features = features[0].cpu().detach().numpy()
|
|
manifest_entry["features"] = features
|
|
f.write(json.dumps(manifest_entry, cls=NumpyEncoder) + "\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = ArgumentParser()
|
|
parser.add_argument("--manifest_fp", required=True, type=str)
|
|
parser.add_argument("--model_fp", required=True, type=str)
|
|
parser.add_argument("--out_fp", required=True, type=str)
|
|
parser.add_argument("--max_objects", required=False, type=int, default=None)
|
|
parser.add_argument("--device", required=False, type=str, default="cpu")
|
|
|
|
args = parser.parse_args()
|
|
main(
|
|
in_fp=args.manifest_fp,
|
|
model_fp=args.model_fp,
|
|
out_fp=args.out_fp,
|
|
device=args.device,
|
|
max_objects=args.max_objects
|
|
)
|