fix tf conversion in new v6 models (#5153)

* fix `tf` conversion in new v6 (#5147)

* sort imports

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Yoni Chechik 2021-10-12 18:38:54 +03:00 committed by GitHub
parent 956be8e642
commit 34da872ab6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -28,7 +28,7 @@ import torch
import torch.nn as nn
from tensorflow import keras
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, autopad, C3
from models.common import Bottleneck, BottleneckCSP, Concat, Conv, C3, DWConv, Focus, SPP, SPPF, autopad
from models.experimental import CrossConv, MixConv2d, attempt_load
from models.yolo import Detect
from utils.general import make_divisible, print_args, set_logging
@ -183,6 +183,22 @@ class TFSPP(keras.layers.Layer):
return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
class TFSPPF(keras.layers.Layer):
# Spatial pyramid pooling-Fast layer
def __init__(self, c1, c2, k=5, w=None):
super(TFSPPF, self).__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')
def call(self, inputs):
x = self.cv1(inputs)
y1 = self.m(x)
y2 = self.m(y1)
return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
class TFDetect(keras.layers.Layer):
def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
super(TFDetect, self).__init__()
@ -272,7 +288,7 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
pass
n = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
c1, c2 = ch[f], args[0]
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2