diff --git a/groundingdino/util/inference.py b/groundingdino/util/inference.py index 718bc7b..d6e81d8 100644 --- a/groundingdino/util/inference.py +++ b/groundingdino/util/inference.py @@ -250,8 +250,10 @@ class Model: def phrases2classes(phrases: List[str], classes: List[str]) -> np.ndarray: class_ids = [] for phrase in phrases: - try: - class_ids.append(classes.index(phrase)) - except ValueError: + for class_ in classes: + if class_ in phrase: + class_ids.append(classes.index(class_)) + break + else: class_ids.append(None) return np.array(class_ids)