mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
More models in sotabench, more control over sotabench run, dataset filename extraction consistency
This commit is contained in:
parent
9c406532bd
commit
e8ca45854c
@ -73,7 +73,7 @@ def main():
|
||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||
|
||||
config = resolve_data_config(vars(args), model=model)
|
||||
model, test_time_pool = apply_test_time_pool(model, config, args)
|
||||
model, test_time_pool = model, False if args.no_test_pool else apply_test_time_pool(model, config)
|
||||
|
||||
if args.num_gpu > 1:
|
||||
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
||||
@ -115,9 +115,8 @@ def main():
|
||||
topk_ids = np.concatenate(topk_ids, axis=0).squeeze()
|
||||
|
||||
with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file:
|
||||
filenames = loader.dataset.filenames()
|
||||
filenames = loader.dataset.filenames(basename=True)
|
||||
for filename, label in zip(filenames, topk_ids):
|
||||
filename = os.path.basename(filename)
|
||||
out_file.write('{0},{1},{2},{3},{4},{5}\n'.format(
|
||||
filename, label[0], label[1], label[2], label[3], label[4]))
|
||||
|
||||
|
81
sotabench.py
81
sotabench.py
@ -1,8 +1,10 @@
|
||||
import torch
|
||||
from torchbench.image_classification import ImageNet
|
||||
from sotabencheval.image_classification import ImageNetEvaluator
|
||||
from sotabencheval.utils import is_server
|
||||
from timm import create_model
|
||||
from timm.data import resolve_data_config, create_transform
|
||||
from timm.models import TestTimePoolHead
|
||||
from timm.data import resolve_data_config, create_loader, DatasetTar
|
||||
from timm.models import apply_test_time_pool
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
NUM_GPU = 1
|
||||
@ -148,6 +150,10 @@ model_list = [
|
||||
_entry('ese_vovnet19b_dw', 'VoVNet-19-DW-V2', '1911.06667'),
|
||||
_entry('ese_vovnet39b', 'VoVNet-39-V2', '1911.06667'),
|
||||
|
||||
_entry('cspresnet50', 'CSPResNet-50', '1911.11929'),
|
||||
_entry('cspresnext50', 'CSPResNeXt-50', '1911.11929'),
|
||||
_entry('cspdarknet53', 'CSPDarkNet-53', '1911.11929'),
|
||||
|
||||
_entry('tf_efficientnet_b0', 'EfficientNet-B0 (AutoAugment)', '1905.11946',
|
||||
model_desc='Ported from official Google AI Tensorflow weights'),
|
||||
_entry('tf_efficientnet_b1', 'EfficientNet-B1 (AutoAugment)', '1905.11946',
|
||||
@ -448,8 +454,20 @@ model_list = [
|
||||
_entry('regnety_160', 'RegNetY-16GF', '2003.13678'),
|
||||
_entry('regnety_320', 'RegNetY-32GF', '2003.13678', batch_size=BATCH_SIZE // 2),
|
||||
|
||||
_entry('rexnet_100', 'ReXNet-1.0x', '2007.00992'),
|
||||
_entry('rexnet_130', 'ReXNet-1.3x', '2007.00992'),
|
||||
_entry('rexnet_150', 'ReXNet-1.5x', '2007.00992'),
|
||||
_entry('rexnet_200', 'ReXNet-2.0x', '2007.00992'),
|
||||
]
|
||||
|
||||
if is_server():
|
||||
DATA_ROOT = './.data/vision/imagenet'
|
||||
else:
|
||||
# local settings
|
||||
DATA_ROOT = './'
|
||||
DATA_FILENAME = 'ILSVRC2012_img_val.tar'
|
||||
TAR_PATH = os.path.join(DATA_ROOT, DATA_FILENAME)
|
||||
|
||||
for m in model_list:
|
||||
model_name = m['model']
|
||||
# create model from name
|
||||
@ -457,25 +475,60 @@ for m in model_list:
|
||||
param_count = sum([m.numel() for m in model.parameters()])
|
||||
print('Model %s, %s created. Param count: %d' % (model_name, m['paper_model_name'], param_count))
|
||||
|
||||
dataset = DatasetTar(TAR_PATH)
|
||||
filenames = [os.path.splitext(f)[0] for f in dataset.filenames()]
|
||||
|
||||
# get appropriate transform for model's default pretrained config
|
||||
data_config = resolve_data_config(m['args'], model=model, verbose=True)
|
||||
test_time_pool = False
|
||||
if m['ttp']:
|
||||
model = TestTimePoolHead(model, model.default_cfg['pool_size'])
|
||||
model, test_time_pool = apply_test_time_pool(model, data_config)
|
||||
data_config['crop_pct'] = 1.0
|
||||
input_transform = create_transform(**data_config)
|
||||
|
||||
# Run the benchmark
|
||||
ImageNet.benchmark(
|
||||
model=model,
|
||||
model_description=m.get('model_description', None),
|
||||
paper_model_name=m['paper_model_name'],
|
||||
batch_size = m['batch_size']
|
||||
loader = create_loader(
|
||||
dataset,
|
||||
input_size=data_config['input_size'],
|
||||
batch_size=batch_size,
|
||||
use_prefetcher=True,
|
||||
interpolation=data_config['interpolation'],
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
num_workers=6,
|
||||
crop_pct=data_config['crop_pct'],
|
||||
pin_memory=True)
|
||||
|
||||
evaluator = ImageNetEvaluator(
|
||||
root=DATA_ROOT,
|
||||
model_name=m['paper_model_name'],
|
||||
paper_arxiv_id=m['paper_arxiv_id'],
|
||||
input_transform=input_transform,
|
||||
batch_size=m['batch_size'],
|
||||
num_gpu=NUM_GPU,
|
||||
data_root=os.environ.get('IMAGENET_DIR', './.data/vision/imagenet')
|
||||
model_description=m.get('model_description', None),
|
||||
)
|
||||
model.cuda()
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# warmup
|
||||
input = torch.randn((batch_size,) + data_config['input_size']).cuda()
|
||||
model(input)
|
||||
|
||||
bar = tqdm(desc="Evaluation", mininterval=5, total=50000)
|
||||
evaluator.reset_time()
|
||||
sample_count = 0
|
||||
for input, target in loader:
|
||||
output = model(input)
|
||||
num_samples = len(output)
|
||||
image_ids = [filenames[i] for i in range(sample_count, sample_count + num_samples)]
|
||||
output = output.cpu().numpy()
|
||||
evaluator.add(dict(zip(image_ids, list(output))))
|
||||
sample_count += num_samples
|
||||
bar.update(num_samples)
|
||||
bar.close()
|
||||
|
||||
evaluator.save()
|
||||
for k, v in evaluator.results.items():
|
||||
print(k, v)
|
||||
for k, v in evaluator.speed_mem_metrics.items():
|
||||
print(k, v)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
@ -3,10 +3,11 @@ source /workspace/venv/bin/activate
|
||||
|
||||
pip install -r requirements-sotabench.txt
|
||||
|
||||
apt-get install -y libjpeg-dev zlib1g-dev libpng-dev libwebp-dev
|
||||
pip uninstall -y pillow
|
||||
CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
|
||||
|
||||
# FIXME this shouldn't be needed but sb dataset upload functionality doesn't seem to work
|
||||
apt-get install wget
|
||||
wget https://onedrive.hyper.ai/down/ImageNet/data/ImageNet2012/ILSVRC2012_devkit_t12.tar.gz -P ./.data/vision/imagenet
|
||||
wget https://onedrive.hyper.ai/down/ImageNet/data/ImageNet2012/ILSVRC2012_img_val.tar -P ./.data/vision/imagenet
|
||||
#wget -q https://onedrive.hyper.ai/down/ImageNet/data/ImageNet2012/ILSVRC2012_devkit_t12.tar.gz -P ./.data/vision/imagenet
|
||||
wget -q https://onedrive.hyper.ai/down/ImageNet/data/ImageNet2012/ILSVRC2012_img_val.tar -P ./.data/vision/imagenet
|
||||
|
@ -94,17 +94,21 @@ class Dataset(data.Dataset):
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def filenames(self, indices=[], basename=False):
|
||||
if indices:
|
||||
if basename:
|
||||
return [os.path.basename(self.samples[i][0]) for i in indices]
|
||||
else:
|
||||
return [self.samples[i][0] for i in indices]
|
||||
else:
|
||||
if basename:
|
||||
return [os.path.basename(x[0]) for x in self.samples]
|
||||
else:
|
||||
return [x[0] for x in self.samples]
|
||||
def filename(self, index, basename=False, absolute=False):
|
||||
filename = self.samples[index][0]
|
||||
if basename:
|
||||
filename = os.path.basename(filename)
|
||||
elif not absolute:
|
||||
filename = os.path.relpath(filename, self.root)
|
||||
return filename
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
fn = lambda x: x
|
||||
if basename:
|
||||
fn = os.path.basename
|
||||
elif not absolute:
|
||||
fn = lambda x: os.path.relpath(x, self.root)
|
||||
return [fn(x[0]) for x in self.samples]
|
||||
|
||||
|
||||
def _extract_tar_info(tarfile, class_to_idx=None, sort=True):
|
||||
@ -160,6 +164,16 @@ class DatasetTar(data.Dataset):
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def filename(self, index, basename=False):
|
||||
filename = self.samples[index][0].name
|
||||
if basename:
|
||||
filename = os.path.basename(filename)
|
||||
return filename
|
||||
|
||||
def filenames(self, basename=False):
|
||||
fn = os.path.basename if basename else lambda x: x
|
||||
return [fn(x[0].name) for x in self.samples]
|
||||
|
||||
|
||||
class AugMixDataset(torch.utils.data.Dataset):
|
||||
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
|
||||
|
@ -36,13 +36,12 @@ class TestTimePoolHead(nn.Module):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
def apply_test_time_pool(model, config, args):
|
||||
def apply_test_time_pool(model, config):
|
||||
test_time_pool = False
|
||||
if not hasattr(model, 'default_cfg') or not model.default_cfg:
|
||||
return model, False
|
||||
if not args.no_test_pool and \
|
||||
config['input_size'][-1] > model.default_cfg['input_size'][-1] and \
|
||||
config['input_size'][-2] > model.default_cfg['input_size'][-2]:
|
||||
if (config['input_size'][-1] > model.default_cfg['input_size'][-1] and
|
||||
config['input_size'][-2] > model.default_cfg['input_size'][-2]):
|
||||
_logger.info('Target input size %s > pretrained default %s, using test time pooling' %
|
||||
(str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:])))
|
||||
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
|
||||
|
@ -166,6 +166,7 @@ class ReXNetV1(nn.Module):
|
||||
se_rd=12, ch_div=1, drop_rate=0.2, feature_location='bottleneck'):
|
||||
super(ReXNetV1, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.num_classes = num_classes
|
||||
|
||||
assert output_stride == 32 # FIXME support dilation
|
||||
stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32
|
||||
|
@ -139,7 +139,7 @@ def validate(args):
|
||||
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
|
||||
|
||||
data_config = resolve_data_config(vars(args), model=model)
|
||||
model, test_time_pool = apply_test_time_pool(model, data_config, args)
|
||||
model, test_time_pool = model, False if args.no_test_pool else apply_test_time_pool(model, data_config)
|
||||
|
||||
if args.torchscript:
|
||||
torch.jit.optimized_execution(True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user