mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix Mobilenet V3 model name for sotabench. Minor res2net cleanup.
This commit is contained in:
parent
b5a8bb52fd
commit
d3ba34ee7e
@ -1,3 +1,4 @@
|
|||||||
|
import torch
|
||||||
from torchbench.image_classification import ImageNet
|
from torchbench.image_classification import ImageNet
|
||||||
from timm import create_model
|
from timm import create_model
|
||||||
from timm.data import resolve_data_config, create_transform
|
from timm.data import resolve_data_config, create_transform
|
||||||
@ -77,7 +78,7 @@ model_list = [
|
|||||||
_entry('mixnet_m', 'MixNet-M', '1907.09595'),
|
_entry('mixnet_m', 'MixNet-M', '1907.09595'),
|
||||||
_entry('mixnet_s', 'MixNet-S', '1907.09595'),
|
_entry('mixnet_s', 'MixNet-S', '1907.09595'),
|
||||||
_entry('mnasnet_100', 'MnasNet-B1', '1807.11626'),
|
_entry('mnasnet_100', 'MnasNet-B1', '1807.11626'),
|
||||||
_entry('mobilenetv3_100', 'MobileNet V3(1.0)', '1905.02244',
|
_entry('mobilenetv3_100', 'MobileNet V3-Large 1.0', '1905.02244',
|
||||||
model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching '
|
model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching '
|
||||||
'paper as closely as possible.'),
|
'paper as closely as possible.'),
|
||||||
_entry('resnet18', 'ResNet-18', '1812.01187'),
|
_entry('resnet18', 'ResNet-18', '1812.01187'),
|
||||||
@ -216,4 +217,6 @@ for m in model_list:
|
|||||||
data_root=os.environ.get('IMAGENET_DIR', './imagenet')
|
data_root=os.environ.get('IMAGENET_DIR', './imagenet')
|
||||||
)
|
)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ class Bottle2neck(nn.Module):
|
|||||||
super(Bottle2neck, self).__init__()
|
super(Bottle2neck, self).__init__()
|
||||||
assert dilation == 1 and previous_dilation == 1 # FIXME support dilation
|
assert dilation == 1 and previous_dilation == 1 # FIXME support dilation
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.is_first = True if stride > 1 or downsample is not None else False
|
self.is_first = stride > 1 or downsample is not None
|
||||||
self.num_scales = max(1, scale - 1)
|
self.num_scales = max(1, scale - 1)
|
||||||
width = int(math.floor(planes * (base_width / 64.0))) * cardinality
|
width = int(math.floor(planes * (base_width / 64.0))) * cardinality
|
||||||
outplanes = planes * self.expansion
|
outplanes = planes * self.expansion
|
||||||
|
Loading…
x
Reference in New Issue
Block a user