mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Ported Tensorflow pretrained EfficientNet weights and some model cleanup
* B0-B3 weights ported from TF with close to paper accuracy * Renamed gen_mobilenet to gen_efficientnet since scaling params go well beyond 'mobile' specific * Add Tensorflow preprocessing option for closer images to source repo
This commit is contained in:
parent
4efecfdc47
commit
4bb5e9b224
92
README.md
92
README.md
@ -29,8 +29,8 @@ I've included a few of my favourite models, but this is not an exhaustive collec
|
|||||||
* PNasNet (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch))
|
* PNasNet (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch))
|
||||||
* DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene)
|
* DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene)
|
||||||
* DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107
|
* DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107
|
||||||
* Generic MobileNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable and InvertedResidual blocks
|
* Generic EfficientNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable and InvertedResidual blocks
|
||||||
* EfficientNet (B0-B4) (https://arxiv.org/abs/1905.11946) -- work in progress, validating
|
* EfficientNet (B0-B4) (https://arxiv.org/abs/1905.11946) -- validated, compat with TF weights
|
||||||
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626)
|
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626)
|
||||||
* MobileNet-V1 (https://arxiv.org/abs/1704.04861)
|
* MobileNet-V1 (https://arxiv.org/abs/1704.04861)
|
||||||
* MobileNet-V2 (https://arxiv.org/abs/1801.04381)
|
* MobileNet-V2 (https://arxiv.org/abs/1801.04381)
|
||||||
@ -39,60 +39,8 @@ I've included a few of my favourite models, but this is not an exhaustive collec
|
|||||||
* FBNet-C (https://arxiv.org/abs/1812.03443) -- TODO A/B variants
|
* FBNet-C (https://arxiv.org/abs/1812.03443) -- TODO A/B variants
|
||||||
* Single-Path NAS (https://arxiv.org/abs/1904.02877) -- pixel1 variant
|
* Single-Path NAS (https://arxiv.org/abs/1904.02877) -- pixel1 variant
|
||||||
|
|
||||||
The full list of model strings that can be passed to model factory via `--model` arg for train, validation, inference scripts:
|
Use the `--model` arg to specify model for train, validation, inference scripts. Match the all lowercase
|
||||||
```
|
creation fn for the model you'd like.
|
||||||
chamnetv1_100
|
|
||||||
chamnetv2_100
|
|
||||||
densenet121
|
|
||||||
densenet161
|
|
||||||
densenet169
|
|
||||||
densenet201
|
|
||||||
dpn107
|
|
||||||
dpn131
|
|
||||||
dpn68
|
|
||||||
dpn68b
|
|
||||||
dpn92
|
|
||||||
dpn98
|
|
||||||
fbnetc_100
|
|
||||||
inception_resnet_v2
|
|
||||||
inception_v4
|
|
||||||
mnasnet_050
|
|
||||||
mnasnet_075
|
|
||||||
mnasnet_100
|
|
||||||
mnasnet_140
|
|
||||||
mnasnet_small
|
|
||||||
mobilenetv1_100
|
|
||||||
mobilenetv2_100
|
|
||||||
mobilenetv3_050
|
|
||||||
mobilenetv3_075
|
|
||||||
mobilenetv3_100
|
|
||||||
pnasnet5large
|
|
||||||
resnet101
|
|
||||||
resnet152
|
|
||||||
resnet18
|
|
||||||
resnet34
|
|
||||||
resnet50
|
|
||||||
resnext101_32x4d
|
|
||||||
resnext101_64x4d
|
|
||||||
resnext152_32x4d
|
|
||||||
resnext50_32x4d
|
|
||||||
semnasnet_050
|
|
||||||
semnasnet_075
|
|
||||||
semnasnet_100
|
|
||||||
semnasnet_140
|
|
||||||
seresnet101
|
|
||||||
seresnet152
|
|
||||||
seresnet18
|
|
||||||
seresnet34
|
|
||||||
seresnet50
|
|
||||||
seresnext101_32x4d
|
|
||||||
seresnext26_32x4d
|
|
||||||
seresnext50_32x4d
|
|
||||||
spnasnet_100
|
|
||||||
tflite_mnasnet_100
|
|
||||||
tflite_semnasnet_100
|
|
||||||
xception
|
|
||||||
```
|
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
Several (less common) features that I often utilize in my projects are included. Many of their additions are the reason why I maintain my own set of models, instead of using others' via PIP:
|
Several (less common) features that I often utilize in my projects are included. Many of their additions are the reason why I maintain my own set of models, instead of using others' via PIP:
|
||||||
@ -147,20 +95,40 @@ I've leveraged the training scripts in this repository to train a few of the mod
|
|||||||
| gluon_resnet50_v1s | 78.712 (21.288) | 94.242 (5.758) | 25.68 | bicubic | |
|
| gluon_resnet50_v1s | 78.712 (21.288) | 94.242 (5.758) | 25.68 | bicubic | |
|
||||||
| gluon_resnet50_v1c | 78.010 (21.990) | 93.988 (6.012) | 25.58 | bicubic | |
|
| gluon_resnet50_v1c | 78.010 (21.990) | 93.988 (6.012) | 25.58 | bicubic | |
|
||||||
| gluon_resnet50_v1b | 77.578 (22.422) | 93.718 (6.282) | 25.56 | bicubic | |
|
| gluon_resnet50_v1b | 77.578 (22.422) | 93.718 (6.282) | 25.56 | bicubic | |
|
||||||
|
| tf_efficientnet_b0 *tfp | 76.828 (23.172) | 93.226 (6.774) | 5.29 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
|
||||||
|
| tf_efficientnet_b0 | 76.528 (23.472) | 93.010 (6.990) | 5.29 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
|
||||||
| gluon_resnet34_v1b | 74.580 (25.420) | 91.988 (8.012) | 21.80 | bicubic | |
|
| gluon_resnet34_v1b | 74.580 (25.420) | 91.988 (8.012) | 21.80 | bicubic | |
|
||||||
| SE-MNASNet 1.00 (A1) | 73.086 (26.914) | 91.336 (8.664) | 3.87 | bicubic | [Google TFLite](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet) |
|
| tflite_semnasnet_100 | 73.086 (26.914) | 91.336 (8.664) | 3.87 | bicubic | [Google TFLite](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet) |
|
||||||
| MNASNet 1.00 (B1) | 72.398 (27.602) | 90.930 (9.070) | 4.36 | bicubic | [Google TFLite](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet)
|
| tflite_mnasnet_100 | 72.398 (27.602) | 90.930 (9.070) | 4.36 | bicubic | [Google TFLite](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet)
|
||||||
| gluon_resnet18_v1b | 70.830 (29.170) | 89.756 (10.244) | 11.69 | bicubic | |
|
| gluon_resnet18_v1b | 70.830 (29.170) | 89.756 (10.244) | 11.69 | bicubic | |
|
||||||
|
|
||||||
#### @ 299x299
|
#### @ 240x240
|
||||||
| Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Source |
|
| Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Source |
|
||||||
|---|---|---|---|---|---|
|
|---|---|---|---|---|---|
|
||||||
| Gluon Inception-V3 | 78.804 (21.196) | 94.380 (5.620) | 27.16M | bicubic | [MxNet Gluon](https://gluon-cv.mxnet.io/model_zoo/classification.html) |
|
| tf_efficientnet_b1 *tfp | 78.796 (21.204) | 94.232 (5.768) | 7.79 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
|
||||||
| Tensorflow Inception-V3 | 77.856 (22.144) | 93.644 (6.356) | 27.16M | bicubic | [Tensorflow Slim](https://github.com/tensorflow/models/tree/master/research/slim) |
|
| tf_efficientnet_b1 | 78.554 (21.446) | 94.098 (5.902) | 7.79 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
|
||||||
| Adversarially trained Inception-V3 | 77.576 (22.424) | 93.724 (6.276) | 27.16M | bicubic | [Tensorflow Adv models](https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models) |
|
|
||||||
|
#### @ 260x260
|
||||||
|
| Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Source |
|
||||||
|
|---|---|---|---|---|---|
|
||||||
|
| tf_efficientnet_b2 *tfp | 79.782 (20.218) | 94.800 (5.200) | 9.11 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
|
||||||
|
| tf_efficientnet_b2 | 79.606 (20.394) | 94.712 (5.288) | 9.11 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
|
||||||
|
|
||||||
|
#### @ 299x299 and 300x300
|
||||||
|
| Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Source |
|
||||||
|
|---|---|---|---|---|---|
|
||||||
|
| tf_efficientnet_b3 *tfp | 80.982 (19.018) | 95.332 (4.668) | 12.23 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
|
||||||
|
| tf_efficientnet_b3 | 80.874 (19.126) | 95.302 (4.698) | 12.23 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
|
||||||
|
| gluon_inception_v3 | 78.804 (21.196) | 94.380 (5.620) | 27.16M | bicubic | [MxNet Gluon](https://gluon-cv.mxnet.io/model_zoo/classification.html) |
|
||||||
|
| tf_inception_v3 | 77.856 (22.144) | 93.644 (6.356) | 27.16M | bicubic | [Tensorflow Slim](https://github.com/tensorflow/models/tree/master/research/slim) |
|
||||||
|
| adv_inception_v3 | 77.576 (22.424) | 93.724 (6.276) | 27.16M | bicubic | [Tensorflow Adv models](https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models) |
|
||||||
|
|
||||||
NOTE: For some reason I can't hit the stated accuracy with my impl of MNASNet and Google's tflite weights. Using a TF equivalent to 'SAME' padding was important to get > 70%, but something small is still missing. Trying to train my own weights from scratch with these models has so far to leveled off in the same 72-73% range.
|
NOTE: For some reason I can't hit the stated accuracy with my impl of MNASNet and Google's tflite weights. Using a TF equivalent to 'SAME' padding was important to get > 70%, but something small is still missing. Trying to train my own weights from scratch with these models has so far to leveled off in the same 72-73% range.
|
||||||
|
|
||||||
|
Models with `*tfp` next to them were scored with `--tf-preprocessing` flag.
|
||||||
|
|
||||||
|
The `tf_efficientnet` and `tflite_(se)mnasnet` models require an equivalent for 'SAME' padding as their arch results in asymmetric padding. I've added this in the model creation wrapper, but it does come with a performance penalty.
|
||||||
|
|
||||||
## Script Usage
|
## Script Usage
|
||||||
|
|
||||||
### Training
|
### Training
|
||||||
|
@ -7,7 +7,7 @@ import gluoncv
|
|||||||
import torch
|
import torch
|
||||||
from models.model_factory import create_model
|
from models.model_factory import create_model
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Training')
|
parser = argparse.ArgumentParser(description='Convert from MXNet')
|
||||||
parser.add_argument('--model', default='all', type=str, metavar='MODEL',
|
parser.add_argument('--model', default='all', type=str, metavar='MODEL',
|
||||||
help='Name of model to train (default: "all"')
|
help='Name of model to train (default: "all"')
|
||||||
|
|
||||||
|
@ -54,6 +54,7 @@ class Dataset(data.Dataset):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root,
|
root,
|
||||||
|
load_bytes=False,
|
||||||
transform=None):
|
transform=None):
|
||||||
|
|
||||||
imgs, _, _ = find_images_and_targets(root)
|
imgs, _, _ = find_images_and_targets(root)
|
||||||
@ -62,11 +63,12 @@ class Dataset(data.Dataset):
|
|||||||
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
|
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
|
||||||
self.root = root
|
self.root = root
|
||||||
self.imgs = imgs
|
self.imgs = imgs
|
||||||
|
self.load_bytes = load_bytes
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
path, target = self.imgs[index]
|
path, target = self.imgs[index]
|
||||||
img = Image.open(path).convert('RGB')
|
img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB')
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
img = self.transform(img)
|
img = self.transform(img)
|
||||||
if target is None:
|
if target is None:
|
||||||
|
@ -89,12 +89,17 @@ def create_loader(
|
|||||||
distributed=False,
|
distributed=False,
|
||||||
crop_pct=None,
|
crop_pct=None,
|
||||||
collate_fn=None,
|
collate_fn=None,
|
||||||
|
tf_preprocessing=False,
|
||||||
):
|
):
|
||||||
if isinstance(input_size, tuple):
|
if isinstance(input_size, tuple):
|
||||||
img_size = input_size[-2:]
|
img_size = input_size[-2:]
|
||||||
else:
|
else:
|
||||||
img_size = input_size
|
img_size = input_size
|
||||||
|
|
||||||
|
if tf_preprocessing and use_prefetcher:
|
||||||
|
from data.tf_preprocessing import TfPreprocessTransform
|
||||||
|
transform = TfPreprocessTransform(is_training=is_training, size=img_size)
|
||||||
|
else:
|
||||||
if is_training:
|
if is_training:
|
||||||
transform = transforms_imagenet_train(
|
transform = transforms_imagenet_train(
|
||||||
img_size,
|
img_size,
|
||||||
|
220
data/tf_preprocessing.py
Normal file
220
data/tf_preprocessing.py
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""ImageNet preprocessing for MnasNet."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
IMAGE_SIZE = 224
|
||||||
|
CROP_PADDING = 32
|
||||||
|
|
||||||
|
|
||||||
|
def distorted_bounding_box_crop(image_bytes,
|
||||||
|
bbox,
|
||||||
|
min_object_covered=0.1,
|
||||||
|
aspect_ratio_range=(0.75, 1.33),
|
||||||
|
area_range=(0.05, 1.0),
|
||||||
|
max_attempts=100,
|
||||||
|
scope=None):
|
||||||
|
"""Generates cropped_image using one of the bboxes randomly distorted.
|
||||||
|
|
||||||
|
See `tf.image.sample_distorted_bounding_box` for more documentation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_bytes: `Tensor` of binary image data.
|
||||||
|
bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
|
||||||
|
where each coordinate is [0, 1) and the coordinates are arranged
|
||||||
|
as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
|
||||||
|
image.
|
||||||
|
min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
|
||||||
|
area of the image must contain at least this fraction of any bounding
|
||||||
|
box supplied.
|
||||||
|
aspect_ratio_range: An optional list of `float`s. The cropped area of the
|
||||||
|
image must have an aspect ratio = width / height within this range.
|
||||||
|
area_range: An optional list of `float`s. The cropped area of the image
|
||||||
|
must contain a fraction of the supplied image within in this range.
|
||||||
|
max_attempts: An optional `int`. Number of attempts at generating a cropped
|
||||||
|
region of the image of the specified constraints. After `max_attempts`
|
||||||
|
failures, return the entire image.
|
||||||
|
scope: Optional `str` for name scope.
|
||||||
|
Returns:
|
||||||
|
cropped image `Tensor`
|
||||||
|
"""
|
||||||
|
with tf.name_scope(scope, 'distorted_bounding_box_crop', [image_bytes, bbox]):
|
||||||
|
shape = tf.image.extract_jpeg_shape(image_bytes)
|
||||||
|
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
|
||||||
|
shape,
|
||||||
|
bounding_boxes=bbox,
|
||||||
|
min_object_covered=min_object_covered,
|
||||||
|
aspect_ratio_range=aspect_ratio_range,
|
||||||
|
area_range=area_range,
|
||||||
|
max_attempts=max_attempts,
|
||||||
|
use_image_if_no_bounding_boxes=True)
|
||||||
|
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
|
||||||
|
|
||||||
|
# Crop the image to the specified bounding box.
|
||||||
|
offset_y, offset_x, _ = tf.unstack(bbox_begin)
|
||||||
|
target_height, target_width, _ = tf.unstack(bbox_size)
|
||||||
|
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
|
||||||
|
image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def _at_least_x_are_equal(a, b, x):
|
||||||
|
"""At least `x` of `a` and `b` `Tensors` are equal."""
|
||||||
|
match = tf.equal(a, b)
|
||||||
|
match = tf.cast(match, tf.int32)
|
||||||
|
return tf.greater_equal(tf.reduce_sum(match), x)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_and_random_crop(image_bytes, image_size):
|
||||||
|
"""Make a random crop of image_size."""
|
||||||
|
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
|
||||||
|
image = distorted_bounding_box_crop(
|
||||||
|
image_bytes,
|
||||||
|
bbox,
|
||||||
|
min_object_covered=0.1,
|
||||||
|
aspect_ratio_range=(3. / 4, 4. / 3.),
|
||||||
|
area_range=(0.08, 1.0),
|
||||||
|
max_attempts=10,
|
||||||
|
scope=None)
|
||||||
|
original_shape = tf.image.extract_jpeg_shape(image_bytes)
|
||||||
|
bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)
|
||||||
|
|
||||||
|
image = tf.cond(
|
||||||
|
bad,
|
||||||
|
lambda: _decode_and_center_crop(image_bytes, image_size),
|
||||||
|
lambda: tf.image.resize_bicubic([image], # pylint: disable=g-long-lambda
|
||||||
|
[image_size, image_size])[0])
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_and_center_crop(image_bytes, image_size):
|
||||||
|
"""Crops to center of image with padding then scales image_size."""
|
||||||
|
shape = tf.image.extract_jpeg_shape(image_bytes)
|
||||||
|
image_height = shape[0]
|
||||||
|
image_width = shape[1]
|
||||||
|
|
||||||
|
padded_center_crop_size = tf.cast(
|
||||||
|
((image_size / (image_size + CROP_PADDING)) *
|
||||||
|
tf.cast(tf.minimum(image_height, image_width), tf.float32)),
|
||||||
|
tf.int32)
|
||||||
|
|
||||||
|
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
|
||||||
|
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
|
||||||
|
crop_window = tf.stack([offset_height, offset_width,
|
||||||
|
padded_center_crop_size, padded_center_crop_size])
|
||||||
|
image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
|
||||||
|
image = tf.image.resize_bicubic([image], [image_size, image_size])[0]
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def _flip(image):
|
||||||
|
"""Random horizontal image flip."""
|
||||||
|
image = tf.image.random_flip_left_right(image)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE):
|
||||||
|
"""Preprocesses the given image for evaluation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_bytes: `Tensor` representing an image binary of arbitrary size.
|
||||||
|
use_bfloat16: `bool` for whether to use bfloat16.
|
||||||
|
image_size: image size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A preprocessed image `Tensor`.
|
||||||
|
"""
|
||||||
|
image = _decode_and_random_crop(image_bytes, image_size)
|
||||||
|
image = _flip(image)
|
||||||
|
image = tf.reshape(image, [image_size, image_size, 3])
|
||||||
|
image = tf.image.convert_image_dtype(
|
||||||
|
image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE):
|
||||||
|
"""Preprocesses the given image for evaluation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_bytes: `Tensor` representing an image binary of arbitrary size.
|
||||||
|
use_bfloat16: `bool` for whether to use bfloat16.
|
||||||
|
image_size: image size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A preprocessed image `Tensor`.
|
||||||
|
"""
|
||||||
|
image = _decode_and_center_crop(image_bytes, image_size)
|
||||||
|
image = tf.reshape(image, [image_size, image_size, 3])
|
||||||
|
image = tf.image.convert_image_dtype(
|
||||||
|
image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_image(image_bytes,
|
||||||
|
is_training=False,
|
||||||
|
use_bfloat16=False,
|
||||||
|
image_size=IMAGE_SIZE):
|
||||||
|
"""Preprocesses the given image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_bytes: `Tensor` representing an image binary of arbitrary size.
|
||||||
|
is_training: `bool` for whether the preprocessing is for training.
|
||||||
|
use_bfloat16: `bool` for whether to use bfloat16.
|
||||||
|
image_size: image size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A preprocessed image `Tensor` with value range of [0, 255].
|
||||||
|
"""
|
||||||
|
if is_training:
|
||||||
|
return preprocess_for_train(image_bytes, use_bfloat16, image_size)
|
||||||
|
else:
|
||||||
|
return preprocess_for_eval(image_bytes, use_bfloat16, image_size)
|
||||||
|
|
||||||
|
|
||||||
|
class TfPreprocessTransform:
|
||||||
|
|
||||||
|
def __init__(self, is_training=False, size=224):
|
||||||
|
self.is_training = is_training
|
||||||
|
self.size = size[0] if isinstance(size, tuple) else size
|
||||||
|
self._image_bytes = None
|
||||||
|
self.process_image = self._build_tf_graph()
|
||||||
|
self.sess = None
|
||||||
|
|
||||||
|
def _build_tf_graph(self):
|
||||||
|
with tf.device('/cpu:0'):
|
||||||
|
self._image_bytes = tf.placeholder(
|
||||||
|
shape=[],
|
||||||
|
dtype=tf.string,
|
||||||
|
)
|
||||||
|
img = preprocess_image(self._image_bytes, self.is_training, False, self.size)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def __call__(self, image_bytes):
|
||||||
|
if self.sess is None:
|
||||||
|
self.sess = tf.Session()
|
||||||
|
img = self.sess.run(self.process_image, feed_dict={self._image_bytes: image_bytes})
|
||||||
|
img = img.round().clip(0, 255).astype(np.uint8)
|
||||||
|
if img.ndim < 3:
|
||||||
|
img = np.expand_dims(img, axis=-1)
|
||||||
|
img = np.rollaxis(img, 2) # HWC to CHW
|
||||||
|
return img
|
@ -1,6 +1,6 @@
|
|||||||
""" Generic MobileNet
|
""" Generic EfficientNets
|
||||||
|
|
||||||
A generic MobileNet class with building blocks to support a variety of models:
|
A generic class with building blocks to support a variety of models with efficient architectures:
|
||||||
* EfficientNet (B0-B4 in code right now, work in progress, still verifying)
|
* EfficientNet (B0-B4 in code right now, work in progress, still verifying)
|
||||||
* MNasNet B1, A1 (SE), Small
|
* MNasNet B1, A1 (SE), Small
|
||||||
* MobileNet V1, V2, and V3 (work in progress)
|
* MobileNet V1, V2, and V3 (work in progress)
|
||||||
@ -32,8 +32,9 @@ _models = [
|
|||||||
'semnasnet_100', 'semnasnet_140', 'mnasnet_small', 'mobilenetv1_100', 'mobilenetv2_100',
|
'semnasnet_100', 'semnasnet_140', 'mnasnet_small', 'mobilenetv1_100', 'mobilenetv2_100',
|
||||||
'mobilenetv3_050', 'mobilenetv3_075', 'mobilenetv3_100', 'chamnetv1_100', 'chamnetv2_100',
|
'mobilenetv3_050', 'mobilenetv3_075', 'mobilenetv3_100', 'chamnetv1_100', 'chamnetv2_100',
|
||||||
'fbnetc_100', 'spnasnet_100', 'tflite_mnasnet_100', 'tflite_semnasnet_100', 'efficientnet_b0',
|
'fbnetc_100', 'spnasnet_100', 'tflite_mnasnet_100', 'tflite_semnasnet_100', 'efficientnet_b0',
|
||||||
'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4']
|
'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'tf_efficientnet_b0',
|
||||||
__all__ = ['GenMobileNet', 'genmobilenet_model_names'] + _models
|
'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3']
|
||||||
|
__all__ = ['GenEfficientNet', 'gen_efficientnet_model_names'] + _models
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url='', **kwargs):
|
def _cfg(url='', **kwargs):
|
||||||
@ -74,6 +75,18 @@ default_cfgs = {
|
|||||||
'efficientnet_b2': _cfg(url='', input_size=(3, 260, 260)),
|
'efficientnet_b2': _cfg(url='', input_size=(3, 260, 260)),
|
||||||
'efficientnet_b3': _cfg(url='', input_size=(3, 300, 300)),
|
'efficientnet_b3': _cfg(url='', input_size=(3, 300, 300)),
|
||||||
'efficientnet_b4': _cfg(url='', input_size=(3, 380, 380)),
|
'efficientnet_b4': _cfg(url='', input_size=(3, 380, 380)),
|
||||||
|
'tf_efficientnet_b0': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548.pth',
|
||||||
|
input_size=(3, 224, 224), interpolation='bicubic'),
|
||||||
|
'tf_efficientnet_b1': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1-5c1377c4.pth',
|
||||||
|
input_size=(3, 240, 240), interpolation='bicubic', crop_pct=0.882),
|
||||||
|
'tf_efficientnet_b2': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2-e393ef04.pth',
|
||||||
|
input_size=(3, 260, 260), interpolation='bicubic', crop_pct=0.890),
|
||||||
|
'tf_efficientnet_b3': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955.pth',
|
||||||
|
input_size=(3, 300, 300), interpolation='bicubic', crop_pct=0.904),
|
||||||
}
|
}
|
||||||
|
|
||||||
_DEBUG = False
|
_DEBUG = False
|
||||||
@ -648,10 +661,10 @@ class InvertedResidual(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class GenMobileNet(nn.Module):
|
class GenEfficientNet(nn.Module):
|
||||||
""" Generic Mobile Net
|
""" Generic EfficientNet
|
||||||
|
|
||||||
An implementation of mobile optimized networks that covers:
|
An implementation of efficient network architectures, in many cases mobile optimized networks:
|
||||||
* MobileNet-V1
|
* MobileNet-V1
|
||||||
* MobileNet-V2
|
* MobileNet-V2
|
||||||
* MobileNet-V3
|
* MobileNet-V3
|
||||||
@ -659,7 +672,7 @@ class GenMobileNet(nn.Module):
|
|||||||
* FBNet A, B, and C
|
* FBNet A, B, and C
|
||||||
* ChamNet (arch details are murky)
|
* ChamNet (arch details are murky)
|
||||||
* Single-Path NAS Pixel1
|
* Single-Path NAS Pixel1
|
||||||
* EfficientNet
|
* EfficientNetB0-B4 (rest easy to add)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
|
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
|
||||||
@ -669,7 +682,7 @@ class GenMobileNet(nn.Module):
|
|||||||
se_gate_fn=torch.sigmoid, se_reduce_mid=False,
|
se_gate_fn=torch.sigmoid, se_reduce_mid=False,
|
||||||
global_pool='avg', head_conv='default', weight_init='goog',
|
global_pool='avg', head_conv='default', weight_init='goog',
|
||||||
folded_bn=False, padding_same=False,):
|
folded_bn=False, padding_same=False,):
|
||||||
super(GenMobileNet, self).__init__()
|
super(GenEfficientNet, self).__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
self.drop_connect_rate = drop_connect_rate
|
self.drop_connect_rate = drop_connect_rate
|
||||||
@ -783,7 +796,7 @@ def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs):
|
|||||||
['ir_r1_k3_s1_e6_c320'],
|
['ir_r1_k3_s1_e6_c320'],
|
||||||
]
|
]
|
||||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
model = GenMobileNet(
|
model = GenEfficientNet(
|
||||||
_decode_arch_def(arch_def),
|
_decode_arch_def(arch_def),
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
stem_size=32,
|
stem_size=32,
|
||||||
@ -823,7 +836,7 @@ def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs):
|
|||||||
['ir_r1_k3_s1_e6_c320_noskip']
|
['ir_r1_k3_s1_e6_c320_noskip']
|
||||||
]
|
]
|
||||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
model = GenMobileNet(
|
model = GenEfficientNet(
|
||||||
_decode_arch_def(arch_def),
|
_decode_arch_def(arch_def),
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
stem_size=32,
|
stem_size=32,
|
||||||
@ -856,7 +869,7 @@ def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs):
|
|||||||
['ir_r1_k3_s1_e6_c144']
|
['ir_r1_k3_s1_e6_c144']
|
||||||
]
|
]
|
||||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
model = GenMobileNet(
|
model = GenEfficientNet(
|
||||||
_decode_arch_def(arch_def),
|
_decode_arch_def(arch_def),
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
stem_size=8,
|
stem_size=8,
|
||||||
@ -883,7 +896,7 @@ def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs):
|
|||||||
['dsa_r2_k3_s2_c1024'],
|
['dsa_r2_k3_s2_c1024'],
|
||||||
]
|
]
|
||||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
model = GenMobileNet(
|
model = GenEfficientNet(
|
||||||
_decode_arch_def(arch_def),
|
_decode_arch_def(arch_def),
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
stem_size=32,
|
stem_size=32,
|
||||||
@ -915,7 +928,7 @@ def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs):
|
|||||||
['ir_r1_k3_s1_e6_c320'],
|
['ir_r1_k3_s1_e6_c320'],
|
||||||
]
|
]
|
||||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
model = GenMobileNet(
|
model = GenEfficientNet(
|
||||||
_decode_arch_def(arch_def),
|
_decode_arch_def(arch_def),
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
stem_size=32,
|
stem_size=32,
|
||||||
@ -956,7 +969,7 @@ def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs):
|
|||||||
['cn_r1_k1_s1_c960'], # hard-swish
|
['cn_r1_k1_s1_c960'], # hard-swish
|
||||||
]
|
]
|
||||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
model = GenMobileNet(
|
model = GenEfficientNet(
|
||||||
_decode_arch_def(arch_def),
|
_decode_arch_def(arch_def),
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
stem_size=16,
|
stem_size=16,
|
||||||
@ -992,7 +1005,7 @@ def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs):
|
|||||||
['ir_r1_k3_s1_e10_c104'],
|
['ir_r1_k3_s1_e10_c104'],
|
||||||
]
|
]
|
||||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
model = GenMobileNet(
|
model = GenEfficientNet(
|
||||||
_decode_arch_def(arch_def),
|
_decode_arch_def(arch_def),
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
stem_size=32,
|
stem_size=32,
|
||||||
@ -1025,7 +1038,7 @@ def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs):
|
|||||||
['ir_r1_k3_s1_e6_c112'],
|
['ir_r1_k3_s1_e6_c112'],
|
||||||
]
|
]
|
||||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
model = GenMobileNet(
|
model = GenEfficientNet(
|
||||||
_decode_arch_def(arch_def),
|
_decode_arch_def(arch_def),
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
stem_size=32,
|
stem_size=32,
|
||||||
@ -1059,7 +1072,7 @@ def _gen_fbnetc(channel_multiplier, num_classes=1000, **kwargs):
|
|||||||
['ir_r1_k3_s1_e6_c352'],
|
['ir_r1_k3_s1_e6_c352'],
|
||||||
]
|
]
|
||||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
model = GenMobileNet(
|
model = GenEfficientNet(
|
||||||
_decode_arch_def(arch_def),
|
_decode_arch_def(arch_def),
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
stem_size=16,
|
stem_size=16,
|
||||||
@ -1099,7 +1112,7 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs):
|
|||||||
['ir_r1_k3_s1_e6_c320_noskip']
|
['ir_r1_k3_s1_e6_c320_noskip']
|
||||||
]
|
]
|
||||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
model = GenMobileNet(
|
model = GenEfficientNet(
|
||||||
_decode_arch_def(arch_def),
|
_decode_arch_def(arch_def),
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
stem_size=32,
|
stem_size=32,
|
||||||
@ -1119,9 +1132,21 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
|
|||||||
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
||||||
Paper: https://arxiv.org/abs/1905.11946
|
Paper: https://arxiv.org/abs/1905.11946
|
||||||
|
|
||||||
|
EfficientNet params
|
||||||
|
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
|
||||||
|
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
|
||||||
|
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
|
||||||
|
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
|
||||||
|
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
|
||||||
|
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
|
||||||
|
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
|
||||||
|
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
|
||||||
|
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
channel_multiplier: multiplier to number of channels per layer
|
channel_multiplier: multiplier to number of channels per layer
|
||||||
depth_multiplier: multiplier to number of repeats per stage
|
depth_multiplier: multiplier to number of repeats per stage
|
||||||
|
|
||||||
"""
|
"""
|
||||||
arch_def = [
|
arch_def = [
|
||||||
['ds_r1_k3_s1_e1_c16_se0.25'],
|
['ds_r1_k3_s1_e1_c16_se0.25'],
|
||||||
@ -1133,13 +1158,16 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
|
|||||||
['ir_r1_k3_s1_e6_c320_se0.25'],
|
['ir_r1_k3_s1_e6_c320_se0.25'],
|
||||||
]
|
]
|
||||||
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
bn_momentum, bn_eps = _resolve_bn_params(kwargs)
|
||||||
model = GenMobileNet(
|
# NOTE: other models in the family didn't scale the feature count
|
||||||
|
num_features = _round_channels(1280, channel_multiplier, 8, None)
|
||||||
|
model = GenEfficientNet(
|
||||||
_decode_arch_def(arch_def, depth_multiplier),
|
_decode_arch_def(arch_def, depth_multiplier),
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
stem_size=32,
|
stem_size=32,
|
||||||
channel_multiplier=channel_multiplier,
|
channel_multiplier=channel_multiplier,
|
||||||
channel_divisor=8,
|
channel_divisor=8,
|
||||||
channel_min=None,
|
channel_min=None,
|
||||||
|
num_features=num_features,
|
||||||
bn_momentum=bn_momentum,
|
bn_momentum=bn_momentum,
|
||||||
bn_eps=bn_eps,
|
bn_eps=bn_eps,
|
||||||
act_fn=swish,
|
act_fn=swish,
|
||||||
@ -1357,19 +1385,8 @@ def spnasnet_100(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
# EfficientNet params
|
|
||||||
# (width_coefficient, depth_coefficient, resolution, dropout_rate)
|
|
||||||
# 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
|
|
||||||
# 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
|
|
||||||
# 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
|
|
||||||
# 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
|
|
||||||
# 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
|
|
||||||
# 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
|
|
||||||
# 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
|
|
||||||
# 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
|
|
||||||
|
|
||||||
def efficientnet_b0(num_classes, in_chans=3, pretrained=False, **kwargs):
|
def efficientnet_b0(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
""" EfficientNet """
|
""" EfficientNet-B0 """
|
||||||
default_cfg = default_cfgs['efficientnet_b0']
|
default_cfg = default_cfgs['efficientnet_b0']
|
||||||
# NOTE for train, drop_rate should be 0.2
|
# NOTE for train, drop_rate should be 0.2
|
||||||
model = _gen_efficientnet(
|
model = _gen_efficientnet(
|
||||||
@ -1382,7 +1399,7 @@ def efficientnet_b0(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def efficientnet_b1(num_classes, in_chans=3, pretrained=False, **kwargs):
|
def efficientnet_b1(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
""" EfficientNet """
|
""" EfficientNet-B1 """
|
||||||
default_cfg = default_cfgs['efficientnet_b1']
|
default_cfg = default_cfgs['efficientnet_b1']
|
||||||
# NOTE for train, drop_rate should be 0.2
|
# NOTE for train, drop_rate should be 0.2
|
||||||
model = _gen_efficientnet(
|
model = _gen_efficientnet(
|
||||||
@ -1395,7 +1412,7 @@ def efficientnet_b1(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def efficientnet_b2(num_classes, in_chans=3, pretrained=False, **kwargs):
|
def efficientnet_b2(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
""" EfficientNet """
|
""" EfficientNet-B2 """
|
||||||
default_cfg = default_cfgs['efficientnet_b2']
|
default_cfg = default_cfgs['efficientnet_b2']
|
||||||
# NOTE for train, drop_rate should be 0.3
|
# NOTE for train, drop_rate should be 0.3
|
||||||
model = _gen_efficientnet(
|
model = _gen_efficientnet(
|
||||||
@ -1408,7 +1425,7 @@ def efficientnet_b2(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def efficientnet_b3(num_classes, in_chans=3, pretrained=False, **kwargs):
|
def efficientnet_b3(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
""" EfficientNet """
|
""" EfficientNet-B3 """
|
||||||
default_cfg = default_cfgs['efficientnet_b3']
|
default_cfg = default_cfgs['efficientnet_b3']
|
||||||
# NOTE for train, drop_rate should be 0.3
|
# NOTE for train, drop_rate should be 0.3
|
||||||
model = _gen_efficientnet(
|
model = _gen_efficientnet(
|
||||||
@ -1421,7 +1438,7 @@ def efficientnet_b3(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def efficientnet_b4(num_classes, in_chans=3, pretrained=False, **kwargs):
|
def efficientnet_b4(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
""" EfficientNet """
|
""" EfficientNet-B4 """
|
||||||
default_cfg = default_cfgs['efficientnet_b4']
|
default_cfg = default_cfgs['efficientnet_b4']
|
||||||
# NOTE for train, drop_rate should be 0.4
|
# NOTE for train, drop_rate should be 0.4
|
||||||
model = _gen_efficientnet(
|
model = _gen_efficientnet(
|
||||||
@ -1433,5 +1450,61 @@ def efficientnet_b4(num_classes, in_chans=3, pretrained=False, **kwargs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def genmobilenet_model_names():
|
def tf_efficientnet_b0(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" EfficientNet-B0. Tensorflow compatible variant """
|
||||||
|
default_cfg = default_cfgs['tf_efficientnet_b0']
|
||||||
|
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
|
||||||
|
kwargs['padding_same'] = True
|
||||||
|
model = _gen_efficientnet(
|
||||||
|
channel_multiplier=1.0, depth_multiplier=1.0,
|
||||||
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def tf_efficientnet_b1(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" EfficientNet-B1. Tensorflow compatible variant """
|
||||||
|
default_cfg = default_cfgs['tf_efficientnet_b1']
|
||||||
|
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
|
||||||
|
kwargs['padding_same'] = True
|
||||||
|
model = _gen_efficientnet(
|
||||||
|
channel_multiplier=1.0, depth_multiplier=1.1,
|
||||||
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def tf_efficientnet_b2(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" EfficientNet-B2. Tensorflow compatible variant """
|
||||||
|
default_cfg = default_cfgs['tf_efficientnet_b2']
|
||||||
|
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
|
||||||
|
kwargs['padding_same'] = True
|
||||||
|
model = _gen_efficientnet(
|
||||||
|
channel_multiplier=1.1, depth_multiplier=1.2,
|
||||||
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def tf_efficientnet_b3(num_classes, in_chans=3, pretrained=False, **kwargs):
|
||||||
|
""" EfficientNet-B3. Tensorflow compatible variant """
|
||||||
|
default_cfg = default_cfgs['tf_efficientnet_b3']
|
||||||
|
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
|
||||||
|
kwargs['padding_same'] = True
|
||||||
|
model = _gen_efficientnet(
|
||||||
|
channel_multiplier=1.2, depth_multiplier=1.4,
|
||||||
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def gen_efficientnet_model_names():
|
||||||
return set(_models)
|
return set(_models)
|
@ -6,7 +6,7 @@ from models.dpn import *
|
|||||||
from models.senet import *
|
from models.senet import *
|
||||||
from models.xception import *
|
from models.xception import *
|
||||||
from models.pnasnet import *
|
from models.pnasnet import *
|
||||||
from models.genmobilenet import *
|
from models.gen_efficientnet import *
|
||||||
from models.inception_v3 import *
|
from models.inception_v3 import *
|
||||||
from models.gluon_resnet import *
|
from models.gluon_resnet import *
|
||||||
|
|
||||||
@ -23,8 +23,8 @@ def create_model(
|
|||||||
|
|
||||||
margs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained)
|
margs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained)
|
||||||
|
|
||||||
# Not all models have support for batchnorm params passed as args, only genmobilenet variants
|
# Not all models have support for batchnorm params passed as args, only gen_efficientnet variants
|
||||||
supports_bn_params = model_name in genmobilenet_model_names()
|
supports_bn_params = model_name in gen_efficientnet_model_names()
|
||||||
if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]):
|
if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]):
|
||||||
kwargs.pop('bn_tf', None)
|
kwargs.pop('bn_tf', None)
|
||||||
kwargs.pop('bn_momentum', None)
|
kwargs.pop('bn_momentum', None)
|
||||||
|
@ -44,6 +44,8 @@ parser.add_argument('--num-gpu', type=int, default=1,
|
|||||||
help='Number of GPUS to use')
|
help='Number of GPUS to use')
|
||||||
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
|
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
|
||||||
help='disable test time pool')
|
help='disable test time pool')
|
||||||
|
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
|
||||||
|
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -71,7 +73,7 @@ def main():
|
|||||||
criterion = nn.CrossEntropyLoss().cuda()
|
criterion = nn.CrossEntropyLoss().cuda()
|
||||||
|
|
||||||
loader = create_loader(
|
loader = create_loader(
|
||||||
Dataset(args.data),
|
Dataset(args.data, load_bytes=args.tf_preprocessing),
|
||||||
input_size=data_config['input_size'],
|
input_size=data_config['input_size'],
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
use_prefetcher=True,
|
use_prefetcher=True,
|
||||||
@ -79,7 +81,8 @@ def main():
|
|||||||
mean=data_config['mean'],
|
mean=data_config['mean'],
|
||||||
std=data_config['std'],
|
std=data_config['std'],
|
||||||
num_workers=args.workers,
|
num_workers=args.workers,
|
||||||
crop_pct=1.0 if test_time_pool else data_config['crop_pct'])
|
crop_pct=1.0 if test_time_pool else data_config['crop_pct'],
|
||||||
|
tf_preprocessing=args.tf_preprocessing)
|
||||||
|
|
||||||
batch_time = AverageMeter()
|
batch_time = AverageMeter()
|
||||||
losses = AverageMeter()
|
losses = AverageMeter()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user