From dd97fdf22609cda1fa3fd87fb72c627a38c850f5 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Tue, 7 Sep 2021 21:49:25 +0800 Subject: [PATCH] [Enhancement] Support full match (#79) --- mim/commands/search.py | 42 ++++++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/mim/commands/search.py b/mim/commands/search.py index f2fb7be..c887ba0 100644 --- a/mim/commands/search.py +++ b/mim/commands/search.py @@ -564,18 +564,21 @@ def sort_by(dataframe: DataFrame, matched_fields = [] invalid_fields = set() for input_field in input_fields: - contain_index = valid_fields.str.contains(input_field) - contain_fields = valid_fields[contain_index] - if len(contain_fields) == 1: - matched_fields.extend(contain_fields) - elif len(contain_fields) > 2: - raise ValueError( - highlighted_error( - f'{input_field} matchs {contain_fields}. However, ' - 'the number of matched fields should be 1, but got' - f' {len(contain_fields)}.')) + if any(valid_fields.isin([input_field])): + matched_fields.append(input_field) else: - invalid_fields.add(input_field) + contain_index = valid_fields.str.contains(input_field) + contain_fields = valid_fields[contain_index] + if len(contain_fields) == 1: + matched_fields.extend(contain_fields) + elif len(contain_fields) > 2: + raise ValueError( + highlighted_error( + f'{input_field} matchs {contain_fields}. However, ' + 'the number of matched fields should be 1, but got' + f' {len(contain_fields)}.')) + else: + invalid_fields.add(input_field) return matched_fields, invalid_fields if sorted_fields is None: @@ -620,14 +623,17 @@ def select_by(dataframe: DataFrame, # not consistent with the input_fields seen_fields = set() for input_field in input_fields: - contain_index = valid_fields.str.contains(input_field) - contain_fields = valid_fields[contain_index] - if len(contain_fields) > 0: - matched_fields.extend( - field for field in (set(contain_fields) - seen_fields)) - seen_fields.update(set(contain_fields)) + if any(valid_fields.isin([input_field])): + matched_fields.append(input_field) else: - invalid_fields.add(input_field) + contain_index = valid_fields.str.contains(input_field) + contain_fields = valid_fields[contain_index] + if len(contain_fields) > 0: + matched_fields.extend( + field for field in (set(contain_fields) - seen_fields)) + seen_fields.update(set(contain_fields)) + else: + invalid_fields.add(input_field) return matched_fields, invalid_fields if shown_fields is None and unshown_fields is None: