mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Convert samples/targets in ParserImageInTar to numpy arrays, slightly less mem usage for massive datasets. Add a few more se/eca model defs to resnet.py
This commit is contained in:
parent
5d4c3d0af3
commit
22748f1a2d
@ -155,9 +155,11 @@ def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extension
|
|||||||
samples_and_targets = [(s, class_name_to_idx[l]) for s, l in zip(samples, labels) if l in class_name_to_idx]
|
samples_and_targets = [(s, class_name_to_idx[l]) for s, l in zip(samples, labels) if l in class_name_to_idx]
|
||||||
if sort:
|
if sort:
|
||||||
samples_and_targets = sorted(samples_and_targets, key=lambda k: natural_key(k[0][0].path))
|
samples_and_targets = sorted(samples_and_targets, key=lambda k: natural_key(k[0][0].path))
|
||||||
|
samples, targets = zip(*samples_and_targets)
|
||||||
_logger.info(f'Finished processing {len(samples_and_targets)} samples across {len(tarfiles)} tar files.')
|
samples = np.array(samples)
|
||||||
return samples_and_targets, class_name_to_idx, tarfiles
|
targets = np.array(targets)
|
||||||
|
_logger.info(f'Finished processing {len(samples)} samples across {len(tarfiles)} tar files.')
|
||||||
|
return samples, targets, class_name_to_idx, tarfiles
|
||||||
|
|
||||||
|
|
||||||
class ParserImageInTar(Parser):
|
class ParserImageInTar(Parser):
|
||||||
@ -171,7 +173,7 @@ class ParserImageInTar(Parser):
|
|||||||
if class_map:
|
if class_map:
|
||||||
class_name_to_idx = load_class_map(class_map, root)
|
class_name_to_idx = load_class_map(class_map, root)
|
||||||
self.root = root
|
self.root = root
|
||||||
self.samples_and_targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
|
self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
|
||||||
self.root,
|
self.root,
|
||||||
class_name_to_idx=class_name_to_idx,
|
class_name_to_idx=class_name_to_idx,
|
||||||
cache_tarinfo=cache_tarinfo,
|
cache_tarinfo=cache_tarinfo,
|
||||||
@ -186,10 +188,11 @@ class ParserImageInTar(Parser):
|
|||||||
self.cache_tarfiles = cache_tarfiles
|
self.cache_tarfiles = cache_tarfiles
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.samples_and_targets)
|
return len(self.samples)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
sample, target = self.samples_and_targets[index]
|
sample = self.samples[index]
|
||||||
|
target = self.targets[index]
|
||||||
sample_ti, parent_fn, child_ti = sample
|
sample_ti, parent_fn, child_ti = sample
|
||||||
parent_abs = os.path.join(self.root, parent_fn) if parent_fn else self.root
|
parent_abs = os.path.join(self.root, parent_fn) if parent_fn else self.root
|
||||||
|
|
||||||
@ -213,7 +216,7 @@ class ParserImageInTar(Parser):
|
|||||||
return tf.extractfile(sample_ti), target
|
return tf.extractfile(sample_ti), target
|
||||||
|
|
||||||
def _filename(self, index, basename=False, absolute=False):
|
def _filename(self, index, basename=False, absolute=False):
|
||||||
filename = self.samples_and_targets[index][0][0].name
|
filename = self.samples[index][0].name
|
||||||
if basename:
|
if basename:
|
||||||
filename = os.path.basename(filename)
|
filename = os.path.basename(filename)
|
||||||
return filename
|
return filename
|
||||||
|
@ -162,6 +162,12 @@ default_cfgs = {
|
|||||||
'seresnet152d_320': _cfg(
|
'seresnet152d_320': _cfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet152d_ra2-04464dd2.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet152d_ra2-04464dd2.pth',
|
||||||
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), crop_pct=1.0, pool_size=(10, 10)),
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), crop_pct=1.0, pool_size=(10, 10)),
|
||||||
|
'seresnet200d': _cfg(
|
||||||
|
url='',
|
||||||
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
|
||||||
|
'seresnet269d': _cfg(
|
||||||
|
url='',
|
||||||
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
|
||||||
|
|
||||||
|
|
||||||
# Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py
|
# Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py
|
||||||
@ -216,6 +222,12 @@ default_cfgs = {
|
|||||||
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
|
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
|
||||||
interpolation='bicubic',
|
interpolation='bicubic',
|
||||||
first_conv='conv1.0'),
|
first_conv='conv1.0'),
|
||||||
|
'ecaresnet200d': _cfg(
|
||||||
|
url='',
|
||||||
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
|
||||||
|
'ecaresnet269d': _cfg(
|
||||||
|
url='',
|
||||||
|
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
|
||||||
|
|
||||||
# Efficient Channel Attention ResNeXts
|
# Efficient Channel Attention ResNeXts
|
||||||
'ecaresnext26tn_32x4d': _cfg(
|
'ecaresnext26tn_32x4d': _cfg(
|
||||||
@ -1123,6 +1135,26 @@ def ecaresnet101d_pruned(pretrained=False, **kwargs):
|
|||||||
return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args)
|
return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def ecaresnet200d(pretrained=False, **kwargs):
|
||||||
|
"""Constructs a ResNet-200-D model with ECA.
|
||||||
|
"""
|
||||||
|
model_args = dict(
|
||||||
|
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||||
|
block_args=dict(attn_layer='eca'), **kwargs)
|
||||||
|
return _create_resnet('ecaresnet200d', pretrained, **model_args)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def ecaresnet269d(pretrained=False, **kwargs):
|
||||||
|
"""Constructs a ResNet-269-D model with ECA.
|
||||||
|
"""
|
||||||
|
model_args = dict(
|
||||||
|
block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True,
|
||||||
|
block_args=dict(attn_layer='eca'), **kwargs)
|
||||||
|
return _create_resnet('ecaresnet269d', pretrained, **model_args)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def ecaresnext26tn_32x4d(pretrained=False, **kwargs):
|
def ecaresnext26tn_32x4d(pretrained=False, **kwargs):
|
||||||
"""Constructs an ECA-ResNeXt-26-TN model.
|
"""Constructs an ECA-ResNeXt-26-TN model.
|
||||||
@ -1198,6 +1230,26 @@ def seresnet152d(pretrained=False, **kwargs):
|
|||||||
return _create_resnet('seresnet152d', pretrained, **model_args)
|
return _create_resnet('seresnet152d', pretrained, **model_args)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def seresnet200d(pretrained=False, **kwargs):
|
||||||
|
"""Constructs a ResNet-200-D model with SE attn.
|
||||||
|
"""
|
||||||
|
model_args = dict(
|
||||||
|
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||||
|
block_args=dict(attn_layer='se'), **kwargs)
|
||||||
|
return _create_resnet('seresnet200d', pretrained, **model_args)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def seresnet269d(pretrained=False, **kwargs):
|
||||||
|
"""Constructs a ResNet-269-D model with SE attn.
|
||||||
|
"""
|
||||||
|
model_args = dict(
|
||||||
|
block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True,
|
||||||
|
block_args=dict(attn_layer='se'), **kwargs)
|
||||||
|
return _create_resnet('seresnet269d', pretrained, **model_args)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def seresnet152d_320(pretrained=False, **kwargs):
|
def seresnet152d_320(pretrained=False, **kwargs):
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user