diff --git a/models/tf.py b/models/tf.py index 010fafdfa..1c6da43ad 100644 --- a/models/tf.py +++ b/models/tf.py @@ -28,7 +28,7 @@ import torch import torch.nn as nn from tensorflow import keras -from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, autopad, C3 +from models.common import Bottleneck, BottleneckCSP, Concat, Conv, C3, DWConv, Focus, SPP, SPPF, autopad from models.experimental import CrossConv, MixConv2d, attempt_load from models.yolo import Detect from utils.general import make_divisible, print_args, set_logging @@ -183,6 +183,22 @@ class TFSPP(keras.layers.Layer): return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3)) +class TFSPPF(keras.layers.Layer): + # Spatial pyramid pooling-Fast layer + def __init__(self, c1, c2, k=5, w=None): + super(TFSPPF, self).__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1) + self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2) + self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME') + + def call(self, inputs): + x = self.cv1(inputs) + y1 = self.m(x) + y2 = self.m(y1) + return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3)) + + class TFDetect(keras.layers.Layer): def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer super(TFDetect, self).__init__() @@ -272,7 +288,7 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3) pass n = max(round(n * gd), 1) if n > 1 else n # depth gain - if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]: + if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]: c1, c2 = ch[f], args[0] c2 = make_divisible(c2 * gw, 8) if c2 != no else c2