diff --git a/projects/easydeploy/tools/image-demo.py b/projects/easydeploy/tools/image-demo.py index 197ad070..c85f31a0 100644 --- a/projects/easydeploy/tools/image-demo.py +++ b/projects/easydeploy/tools/image-demo.py @@ -74,6 +74,7 @@ def main(): model.to(args.device) cfg = Config.fromfile(args.config) + class_names = cfg.get('class_name') test_pipeline = get_test_pipeline_cfg(cfg) test_pipeline[0] = ConfigDict({'type': 'mmdet.LoadImageFromNDArray'}) @@ -125,8 +126,12 @@ def main(): for (bbox, score, label) in zip(bboxes, scores, labels): bbox = bbox.tolist() color = colors[label] - label_name = cfg.get('class_name', {})[label] - name = f'cls:{label_name}_score:{score:0.4f}' + + if class_names is not None: + label_name = class_names[label] + name = f'cls:{label_name}_score:{score:0.4f}' + else: + name = f'cls:{label}_score:{score:0.4f}' cv2.rectangle(bgr, bbox[:2], bbox[2:], color, 2) cv2.putText(