fix tf conversion in new v6 models (#5153)

* fix `tf` conversion in new v6 (#5147)

* sort imports

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Yoni Chechik 2021-10-12 18:38:54 +03:00 committed by GitHub
parent 956be8e642
commit 34da872ab6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -28,7 +28,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from tensorflow import keras from tensorflow import keras
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, autopad, C3 from models.common import Bottleneck, BottleneckCSP, Concat, Conv, C3, DWConv, Focus, SPP, SPPF, autopad
from models.experimental import CrossConv, MixConv2d, attempt_load from models.experimental import CrossConv, MixConv2d, attempt_load
from models.yolo import Detect from models.yolo import Detect
from utils.general import make_divisible, print_args, set_logging from utils.general import make_divisible, print_args, set_logging
@ -183,6 +183,22 @@ class TFSPP(keras.layers.Layer):
return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3)) return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
class TFSPPF(keras.layers.Layer):
# Spatial pyramid pooling-Fast layer
def __init__(self, c1, c2, k=5, w=None):
super(TFSPPF, self).__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')
def call(self, inputs):
x = self.cv1(inputs)
y1 = self.m(x)
y2 = self.m(y1)
return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
class TFDetect(keras.layers.Layer): class TFDetect(keras.layers.Layer):
def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
super(TFDetect, self).__init__() super(TFDetect, self).__init__()
@ -272,7 +288,7 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
pass pass
n = max(round(n * gd), 1) if n > 1 else n # depth gain n = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]: if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
c1, c2 = ch[f], args[0] c1, c2 = ch[f], args[0]
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2 c2 = make_divisible(c2 * gw, 8) if c2 != no else c2