mirror of https://github.com/open-mmlab/mim.git
[Fix] Fix filename of model zoo and output of search command (#53)
* fix filename of model zoo * add comment * fix install * update search cmd * update search command * fix model zoo * format output * update readem * fix typo * reset display_width * fix typo * support substring match * simply _filter_fieldpull/61/head
parent
1bb17cd47d
commit
0a1ffd3c1c
20
README.md
20
README.md
|
@ -154,11 +154,11 @@ Please refer to [installation.md](docs/installation.md) for installation.
|
|||
> mim search mmcls --model resnet
|
||||
> mim search mmcls --dataset cifar-10
|
||||
> mim search mmcls --valid-field
|
||||
> mim search mmcls --condition 'bs>45,epoch>100'
|
||||
> mim search mmcls --condition 'bs>45 epoch>100'
|
||||
> mim search mmcls --condition '128<bs<=256'
|
||||
> mim search mmcls --sort bs epoch
|
||||
> mim search mmcls --field epoch bs weight
|
||||
> mim search mmcls --condition 'batch_size>45,epochs>100'
|
||||
> mim search mmcls --condition 'batch_size>45 epochs>100'
|
||||
> mim search mmcls --condition '128<batch_size<=256'
|
||||
> mim search mmcls --sort batch_size epochs
|
||||
> mim search mmcls --field epochs batch_size weight
|
||||
> mim search mmcls --exclude-field weight paper
|
||||
```
|
||||
|
||||
|
@ -171,11 +171,11 @@ Please refer to [installation.md](docs/installation.md) for installation.
|
|||
get_model_info('mmcls==0.11.0', local=False)
|
||||
get_model_info('mmcls', models=['resnet'])
|
||||
get_model_info('mmcls', training_datasets=['cifar-10'])
|
||||
get_model_info('mmcls', filter_conditions='bs>45,epoch>100')
|
||||
get_model_info('mmcls', filter_conditions='bs>45 epoch>100')
|
||||
get_model_info('mmcls', filter_conditions='128<bs<=256')
|
||||
get_model_info('mmcls', sorted_fields=['bs', 'epoch'])
|
||||
get_model_info('mmcls', shown_fields=['epoch', 'bs', 'weight'])
|
||||
get_model_info('mmcls', filter_conditions='batch_size>45,epochs>100')
|
||||
get_model_info('mmcls', filter_conditions='batch_size>45 epochs>100')
|
||||
get_model_info('mmcls', filter_conditions='128<batch_size<=256')
|
||||
get_model_info('mmcls', sorted_fields=['batch_size', 'epochs'])
|
||||
get_model_info('mmcls', shown_fields=['epochs', 'batch_size', 'weight'])
|
||||
```
|
||||
|
||||
</details>
|
||||
|
|
|
@ -418,7 +418,9 @@ def install_from_repo(repo_root: str,
|
|||
"""
|
||||
|
||||
def copy_file_to_package():
|
||||
items = ['tools', 'configs', 'model_zoo.yml']
|
||||
# rename the model_zoo.yml to model-index.yml but support both of them
|
||||
# for backward compatibility
|
||||
items = ['tools', 'configs', 'model_zoo.yml', 'model-index.yml']
|
||||
module_name = PKG2MODULE.get(package, package)
|
||||
pkg_root = osp.join(repo_root, module_name)
|
||||
|
||||
|
@ -439,7 +441,9 @@ def install_from_repo(repo_root: str,
|
|||
def link_file_to_package():
|
||||
# When user installs package with editable mode, we should create
|
||||
# symlinks to package, which will synchronize the modified files.
|
||||
items = ['tools', 'configs', 'model_zoo.yml']
|
||||
# Besides, rename the model_zoo.yml to model-index.yml but support both
|
||||
# of them for backward compatibility
|
||||
items = ['tools', 'configs', 'model_zoo.yml', 'model-index.yml']
|
||||
module_name = PKG2MODULE.get(package, package)
|
||||
pkg_root = osp.join(repo_root, module_name)
|
||||
|
||||
|
|
|
@ -3,14 +3,14 @@ import pickle
|
|||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
import typing
|
||||
from pkg_resources import resource_filename
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import click
|
||||
import pandas as pd
|
||||
from modelindex.load_model_index import load
|
||||
from modelindex.models.ModelIndex import ModelIndex
|
||||
from pandas import DataFrame
|
||||
from pandas import DataFrame, Series
|
||||
|
||||
from mim.click import OptionEatAll, get_downstream_package, param2lowercase
|
||||
from mim.utils import (
|
||||
|
@ -19,19 +19,11 @@ from mim.utils import (
|
|||
cast2lowercase,
|
||||
echo_success,
|
||||
get_github_url,
|
||||
get_installed_version,
|
||||
highlighted_error,
|
||||
is_installed,
|
||||
split_package_version,
|
||||
)
|
||||
|
||||
abbrieviation = {
|
||||
'batch_size': 'bs',
|
||||
'epochs': 'epoch',
|
||||
'inference_time': 'fps',
|
||||
'inference_time_(fps)': 'fps',
|
||||
}
|
||||
|
||||
|
||||
@click.command('search')
|
||||
@click.argument(
|
||||
|
@ -98,11 +90,11 @@ def cli(packages: List[str],
|
|||
> mim search mmcls --model resnet
|
||||
> mim search mmcls --dataset cifar-10
|
||||
> mim search mmcls --valid-filed
|
||||
> mim search mmcls --condition 'bs>45,epoch>100'
|
||||
> mim search mmcls --condition 'bs>45 epoch>100'
|
||||
> mim search mmcls --condition '128<bs<=256'
|
||||
> mim search mmcls --sort bs epoch
|
||||
> mim search mmcls --field epoch bs weight
|
||||
> mim search mmcls --condition 'batch_size>45,epochs>100'
|
||||
> mim search mmcls --condition 'batch_size>45 epochs>100'
|
||||
> mim search mmcls --condition '128<batch_size<=256'
|
||||
> mim search mmcls --sort batch_size epochs
|
||||
> mim search mmcls --field epochs batch_size weight
|
||||
> mim search mmcls --exclude-field weight paper
|
||||
"""
|
||||
packages_info = {}
|
||||
|
@ -215,10 +207,17 @@ def load_metadata_from_local(package: str):
|
|||
>>> metadata = load_metadata_from_local('mmcls')
|
||||
"""
|
||||
if is_installed(package):
|
||||
version = get_installed_version(package)
|
||||
click.echo(f'local verison: {version}')
|
||||
# rename the model_zoo.yml to model-index.yml but support both of them
|
||||
# for backward compatibility
|
||||
metadata_path = resource_filename(package, 'model-index.yml')
|
||||
if not osp.exists(metadata_path):
|
||||
metadata_path = resource_filename(package, 'model_zoo.yml')
|
||||
if not osp.exists(metadata_path):
|
||||
raise FileNotFoundError(
|
||||
highlighted_error(
|
||||
'model-index.yml or model_zoo.yml is not found, please'
|
||||
f' upgrade your {package} to support search command'))
|
||||
|
||||
metadata_path = resource_filename(package, 'model_zoo.yml')
|
||||
metadata = load(metadata_path)
|
||||
|
||||
return metadata
|
||||
|
@ -260,12 +259,18 @@ def load_metadata_from_remote(package: str) -> Optional[ModelIndex]:
|
|||
clone_cmd.append(repo_root)
|
||||
subprocess.check_call(
|
||||
clone_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)
|
||||
metadata_path = osp.join(repo_root, 'model_zoo.yml')
|
||||
|
||||
# rename the model_zoo.yml to model-index.yml but support both of
|
||||
# them for backward compatibility
|
||||
metadata_path = resource_filename(package, 'model-index.yml')
|
||||
if not osp.exists(metadata_path):
|
||||
raise FileNotFoundError(
|
||||
highlighted_error(
|
||||
'current version can not support "mim search '
|
||||
f'{package}", please upgrade your {package}.'))
|
||||
metadata_path = resource_filename(package, 'model_zoo.yml')
|
||||
if not osp.exists(metadata_path):
|
||||
raise FileNotFoundError(
|
||||
highlighted_error(
|
||||
'model-index.yml or model_zoo.yml is not found, '
|
||||
f'please upgrade your {package} to support search '
|
||||
'command'))
|
||||
|
||||
metadata = load(metadata_path)
|
||||
|
||||
|
@ -278,22 +283,51 @@ def load_metadata_from_remote(package: str) -> Optional[ModelIndex]:
|
|||
|
||||
def convert2df(metadata: ModelIndex) -> DataFrame:
|
||||
"""Convert metadata into DataFrame format."""
|
||||
|
||||
def _parse(data: dict) -> dict:
|
||||
parsed_data = {}
|
||||
for key, value in data.items():
|
||||
unit = ''
|
||||
name = key.split()
|
||||
if '(' in key:
|
||||
# inference time (ms/im) will be splitted into `inference time`
|
||||
# and `(ms/im)`
|
||||
name, unit = name[0:-1], name[-1]
|
||||
name = '_'.join(name)
|
||||
name = cast2lowercase(name)
|
||||
|
||||
if isinstance(value, str):
|
||||
parsed_data[name] = cast2lowercase(value)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
if isinstance(value[0], dict):
|
||||
# inference time is a list of dict like List[dict]
|
||||
# each item of inference time represents the environment
|
||||
# where it is tested
|
||||
for _value in value:
|
||||
envs = [
|
||||
str(_value.get(env)) for env in [
|
||||
'hardware', 'backend', 'batch size', 'mode',
|
||||
'resolution'
|
||||
]
|
||||
]
|
||||
new_name = f'inference_time{unit}[{",".join(envs)}]'
|
||||
parsed_data[new_name] = _value.get('value')
|
||||
else:
|
||||
new_name = f'{name}{unit}'
|
||||
parsed_data[new_name] = ','.join(cast2lowercase(value))
|
||||
else:
|
||||
new_name = f'{name}{unit}'
|
||||
parsed_data[new_name] = value
|
||||
|
||||
return parsed_data
|
||||
|
||||
name2model = {}
|
||||
name2collection = {}
|
||||
for collection in metadata.collections:
|
||||
collection_info = {}
|
||||
data = getattr(collection.metadata, 'data', None)
|
||||
if data:
|
||||
for key, value in data.items():
|
||||
name = '_'.join(key.split())
|
||||
name = cast2lowercase(name)
|
||||
name = abbrieviation.get(name, name)
|
||||
if isinstance(value, str):
|
||||
collection_info[name] = cast2lowercase(value)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
collection_info[name] = ','.join(cast2lowercase(value))
|
||||
else:
|
||||
collection_info[name] = value
|
||||
collection_info.update(_parse(data))
|
||||
|
||||
paper = getattr(collection, 'paper', None)
|
||||
if paper:
|
||||
|
@ -312,16 +346,7 @@ def convert2df(metadata: ModelIndex) -> DataFrame:
|
|||
model_info = {}
|
||||
data = getattr(model.metadata, 'data', None)
|
||||
if data:
|
||||
for key, value in model.metadata.data.items():
|
||||
name = '_'.join(key.split())
|
||||
name = cast2lowercase(name)
|
||||
name = abbrieviation.get(name, name)
|
||||
if isinstance(value, str):
|
||||
model_info[name] = cast2lowercase(value)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
model_info[name] = ','.join(cast2lowercase(value))
|
||||
else:
|
||||
model_info[name] = value
|
||||
model_info.update(_parse(data))
|
||||
|
||||
results = getattr(model, 'results', None)
|
||||
for result in results:
|
||||
|
@ -333,7 +358,6 @@ def convert2df(metadata: ModelIndex) -> DataFrame:
|
|||
for key, value in metrics.items():
|
||||
name = '_'.join(key.split())
|
||||
name = cast2lowercase(name)
|
||||
name = abbrieviation.get(name, name)
|
||||
model_info[f'{dataset}/{name}'] = value
|
||||
|
||||
paper = getattr(model, 'paper', None)
|
||||
|
@ -445,7 +469,8 @@ def filter_by_conditions(
|
|||
and_conditions = []
|
||||
or_conditions = []
|
||||
|
||||
# 'fps>45,epoch>100' or 'fps>45 epoch>100' -> ['fps>40', 'epoch>100']
|
||||
# 'inference_time>45,epoch>100' or 'inference_time>45 epoch>100' will be
|
||||
# parsed into ['inference_time>40', 'epoch>100']
|
||||
filter_conditions = re.split(r'[ ,]+', filter_conditions) # type: ignore
|
||||
|
||||
valid_fields = dataframe.columns
|
||||
|
@ -519,26 +544,48 @@ def sort_by(dataframe: DataFrame,
|
|||
ascending: bool = True) -> DataFrame:
|
||||
"""Sort by the fields.
|
||||
|
||||
When sorting output with some fields, substring is spported. For example,
|
||||
if sorted_fields is ['epo'], the actual sorted fieds will be ['epochs'].
|
||||
|
||||
Args:
|
||||
dataframe (DataFrame): Data to be sorted.
|
||||
sorted_fields (List[str], optional): Sort output by sorted_fields.
|
||||
Default: None.
|
||||
ascending (bool): Sort by ascending or descending. Default: True.
|
||||
"""
|
||||
|
||||
@typing.no_type_check
|
||||
def _filter_field(valid_fields: Series, input_fields: List[str]):
|
||||
matched_fields = []
|
||||
invalid_fields = set()
|
||||
for input_field in input_fields:
|
||||
contain_index = valid_fields.str.contains(input_field)
|
||||
contain_fields = valid_fields[contain_index]
|
||||
if len(contain_fields) == 1:
|
||||
matched_fields.extend(contain_fields)
|
||||
elif len(contain_fields) > 2:
|
||||
raise ValueError(
|
||||
highlighted_error(
|
||||
f'{input_field} matchs {contain_fields}. However, '
|
||||
'the number of matched fields should be 1, but got'
|
||||
f' {len(contain_fields)}.'))
|
||||
else:
|
||||
invalid_fields.add(input_field)
|
||||
return matched_fields, invalid_fields
|
||||
|
||||
if sorted_fields is None:
|
||||
return dataframe
|
||||
|
||||
sorted_fields = cast2lowercase(sorted_fields)
|
||||
|
||||
valid_fields = set(dataframe.columns)
|
||||
invalid_fields = set(sorted_fields) - valid_fields # type: ignore
|
||||
valid_fields = dataframe.columns
|
||||
matched_fields, invalid_fields = _filter_field(valid_fields, sorted_fields)
|
||||
if invalid_fields:
|
||||
raise ValueError(
|
||||
highlighted_error(
|
||||
f'Expected fields: {valid_fields}, but got {invalid_fields}'))
|
||||
|
||||
sorted_fields = list(sorted_fields) # type: ignore
|
||||
return dataframe.sort_values(by=sorted_fields, ascending=ascending)
|
||||
return dataframe.sort_values(by=matched_fields, ascending=ascending)
|
||||
|
||||
|
||||
def select_by(dataframe: DataFrame,
|
||||
|
@ -546,6 +593,10 @@ def select_by(dataframe: DataFrame,
|
|||
unshown_fields: Optional[List[str]] = None) -> DataFrame:
|
||||
"""Select by the fields.
|
||||
|
||||
When selecting some fields to be shown or be hidden, substring is spported.
|
||||
For example, if shown_fields is ['epo'], all field contain 'epo' which will
|
||||
be chosen. So the new shown field will be ['epochs'].
|
||||
|
||||
Args:
|
||||
dataframe (DataFrame): Data to be filtered.
|
||||
shown_fields (List[str], optional): Fields to be outputted.
|
||||
|
@ -553,6 +604,27 @@ def select_by(dataframe: DataFrame,
|
|||
unshown_fields (List[str], optional): Fields to be hidden.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
@typing.no_type_check
|
||||
def _filter_field(valid_fields: Series, input_fields: List[str]):
|
||||
matched_fields = []
|
||||
invalid_fields = set()
|
||||
# record those fields which have been added to matched_fields to avoid
|
||||
# duplicated fields. Although the seen_fields is not necessary if
|
||||
# matched_fields is type of set, the order of matched_fields will be
|
||||
# not consistent with the input_fields
|
||||
seen_fields = set()
|
||||
for input_field in input_fields:
|
||||
contain_index = valid_fields.str.contains(input_field)
|
||||
contain_fields = valid_fields[contain_index]
|
||||
if len(contain_fields) > 0:
|
||||
matched_fields.extend(
|
||||
field for field in (set(contain_fields) - seen_fields))
|
||||
seen_fields.update(set(contain_fields))
|
||||
else:
|
||||
invalid_fields.add(input_field)
|
||||
return matched_fields, invalid_fields
|
||||
|
||||
if shown_fields is None and unshown_fields is None:
|
||||
return dataframe
|
||||
|
||||
|
@ -561,27 +633,27 @@ def select_by(dataframe: DataFrame,
|
|||
highlighted_error(
|
||||
'shown_fields and unshown_fields must be mutually exclusive.'))
|
||||
|
||||
valid_fields = set(dataframe.columns)
|
||||
valid_fields = dataframe.columns
|
||||
if shown_fields:
|
||||
shown_fields = cast2lowercase(shown_fields)
|
||||
invalid_fields = set(shown_fields) - valid_fields # type: ignore
|
||||
matched_fields, invalid_fields = _filter_field(valid_fields,
|
||||
shown_fields)
|
||||
if invalid_fields:
|
||||
raise ValueError(
|
||||
highlighted_error(f'Expected fields: {valid_fields}, but got '
|
||||
f'{invalid_fields}'))
|
||||
|
||||
dataframe = dataframe.filter(items=shown_fields)
|
||||
|
||||
dataframe = dataframe.filter(items=matched_fields)
|
||||
else:
|
||||
unshown_fields = cast2lowercase(unshown_fields) # type: ignore
|
||||
invalid_fields = set(unshown_fields) - valid_fields # type: ignore
|
||||
matched_fields, invalid_fields = _filter_field(valid_fields,
|
||||
unshown_fields)
|
||||
if invalid_fields:
|
||||
raise ValueError(
|
||||
highlighted_error(f'Expected fields: {valid_fields}, but got '
|
||||
f'{invalid_fields}'))
|
||||
|
||||
dataframe = dataframe.drop(
|
||||
columns=list(unshown_fields)) # type: ignore
|
||||
dataframe = dataframe.drop(columns=matched_fields)
|
||||
|
||||
dataframe = dataframe.dropna(axis=0, how='all')
|
||||
|
||||
|
@ -598,15 +670,45 @@ def dump2json(dataframe: DataFrame, json_path: str) -> None:
|
|||
dataframe.to_json(json_path)
|
||||
|
||||
|
||||
def print_df(dataframe: DataFrame) -> None:
|
||||
def print_df(dataframe: DataFrame, display_width: int = 80) -> None:
|
||||
"""Print Dataframe into terminal."""
|
||||
|
||||
def _max_len(dataframe):
|
||||
key_max_len = 0
|
||||
value_max_len = 0
|
||||
for row in dataframe.iterrows():
|
||||
for key, value in row[1].to_dict().items():
|
||||
key_max_len = max(key_max_len, len(key))
|
||||
value_max_len = max(value_max_len, len(str(value)))
|
||||
return key_max_len, value_max_len
|
||||
|
||||
key_max_len, value_max_len = _max_len(dataframe)
|
||||
key_max_len += 2
|
||||
if key_max_len + value_max_len > display_width:
|
||||
value_max_len = display_width - key_max_len
|
||||
|
||||
def _table(row):
|
||||
output = ''
|
||||
output += '-' * (key_max_len + value_max_len)
|
||||
output += '\n'
|
||||
output += click.style(f'config id: {row[0]}\n', fg='green')
|
||||
row_dict = row[1].dropna().to_dict()
|
||||
keys = sorted(row_dict.keys())
|
||||
for key in keys:
|
||||
output += key.ljust(key_max_len)
|
||||
value = str(row_dict[key])
|
||||
if len(value) > value_max_len:
|
||||
if value_max_len > 3:
|
||||
output += f'{value[:value_max_len-3]}...'
|
||||
else:
|
||||
output += '.' * value_max_len
|
||||
else:
|
||||
output += value
|
||||
output += '\n'
|
||||
return output
|
||||
|
||||
def _generate_output():
|
||||
for row in dataframe.iterrows():
|
||||
config_msg = click.style(f'config id: {row[0]}\n', fg='green')
|
||||
yield from [
|
||||
config_msg, '-' * pd.get_option('display.width'),
|
||||
f'\n{row[1].dropna().to_string()}\n'
|
||||
]
|
||||
yield _table(row)
|
||||
|
||||
click.echo_via_pager(_generate_output())
|
||||
|
|
|
@ -59,30 +59,31 @@ def test_search():
|
|||
result = runner.invoke(search, ['mmcls', '--dataset', 'cifar-10'])
|
||||
assert result.exit_code == 0
|
||||
|
||||
# mim search mmcls --condition 'bs>45,epoch>100'
|
||||
result = runner.invoke(search, ['mmcls', '--condition', 'bs>45,epoch>100'])
|
||||
# mim search mmcls --condition 'batch_size>45,epochs>100'
|
||||
result = runner.invoke(
|
||||
search, ['mmcls', '--condition', 'batch_size>45,epochs>100'])
|
||||
assert result.exit_code == 0
|
||||
|
||||
# mim search mmcls --condition 'bs>45 epoch>100'
|
||||
result = runner.invoke(search, ['mmcls', '--condition', 'bs>45 epoch>100'])
|
||||
# mim search mmcls --condition 'batch_size>45 epochs>100'
|
||||
result = runner.invoke(
|
||||
search, ['mmcls', '--condition', 'batch_size>45 epochs>100'])
|
||||
assert result.exit_code == 0
|
||||
|
||||
# mim search mmcls --condition '128<bs<=256'
|
||||
result = runner.invoke(search, ['mmcls', '--condition', '128<bs<=256'])
|
||||
# mim search mmcls --condition '128<batch_size<=256'
|
||||
result = runner.invoke(search,
|
||||
['mmcls', '--condition', '128<batch_size<=256'])
|
||||
assert result.exit_code == 0
|
||||
|
||||
# mim search mmcls --sort epochs
|
||||
# invalid field
|
||||
result = runner.invoke(search, ['mmcls', '--sort', 'epochs'])
|
||||
assert result.exit_code == 1
|
||||
# mim search mmcls --sort epoch
|
||||
result = runner.invoke(search, ['mmcls', '--sort', 'epoch'])
|
||||
assert result.exit_code == 0
|
||||
# mim search mmcls --sort epochs
|
||||
result = runner.invoke(search, ['mmcls', '--sort', 'epochs'])
|
||||
assert result.exit_code == 0
|
||||
|
||||
# mim search mmcls --field epochs
|
||||
# invalid field
|
||||
result = runner.invoke(search, ['mmcls', '--field', 'epochs'])
|
||||
assert result.exit_code == 1
|
||||
# mim search mmcls --field epoch
|
||||
result = runner.invoke(search, ['mmcls', '--field', 'epoch'])
|
||||
assert result.exit_code == 0
|
||||
# mim search mmcls --field epochs
|
||||
result = runner.invoke(search, ['mmcls', '--field', 'epochs'])
|
||||
assert result.exit_code == 0
|
||||
|
|
Loading…
Reference in New Issue