From f2439f894c32a5ad09dd099aae0586ff01080301 Mon Sep 17 00:00:00 2001
From: Ibai Gorordo <43162939+ibaiGorordo@users.noreply.github.com>
Date: Tue, 20 Sep 2022 06:31:41 +0900
Subject: [PATCH] Add missing TF layers (#792)

Add following layers to tf.py:
- TFMP (MP)
- TFSPPCSPC (SPPCSPC)
- TFRepConv (RepConv)
---
 models/tf.py | 67 ++++++++++++++++++++++++++++++++++++++++++++++++----
 1 file changed, 62 insertions(+), 5 deletions(-)

diff --git a/models/tf.py b/models/tf.py
index b0d98cc..2d85066 100644
--- a/models/tf.py
+++ b/models/tf.py
@@ -27,8 +27,8 @@ import torch
 import torch.nn as nn
 from tensorflow import keras
 
-from models.common import (C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv,
-                           DWConvTranspose2d, Focus, autopad)
+from models.common import (C3, MP, SPP, SPPF, SPPCSPC, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv,
+                           RepConv, DWConvTranspose2d, Focus, autopad)
 from models.experimental import MixConv2d, attempt_load
 from models.yolo import Detect
 from utils.activations import SiLU
@@ -86,6 +86,36 @@ class TFConv(keras.layers.Layer):
     def call(self, inputs):
         return self.act(self.bn(self.conv(inputs)))
 
+class TFRepConv(keras.layers.Layer):
+
+    def __init__(self, c1, c2, k=3, s=1, p=None, g=1, act=True, w=None):
+        super().__init__()
+
+
+        self.groups = g
+        self.in_channels = c1
+        self.out_channels = c2
+
+        assert k == 3
+        assert autopad(k, p) == 1
+        assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
+
+        padding_11 = autopad(k, p) - k // 2
+
+        self.act = activations(w.act) if act else tf.identity
+        rbr_reparam = keras.layers.Conv2D(
+            filters=c2,
+            kernel_size=k,
+            strides=s,
+            padding='SAME' if s == 1 else 'VALID',
+            use_bias=True,
+            kernel_initializer=keras.initializers.Constant(w.rbr_reparam.weight.permute(2, 3, 1, 0).numpy()),
+            bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.rbr_reparam.bias.numpy()))
+        self.rbr_reparam = rbr_reparam if s == 1 else keras.Sequential([TFPad(autopad(k, p)), rbr_reparam])
+        self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
+
+    def call(self, inputs):
+        return self.act(self.bn(self.rbr_reparam(inputs)))
 
 class TFDWConv(keras.layers.Layer):
     # Depthwise convolution
@@ -239,6 +269,14 @@ class TFC3x(keras.layers.Layer):
     def call(self, inputs):
         return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
 
+class TFMP(keras.layers.Layer):
+    # Spatial pyramid pooling layer used in YOLOv3-SPP
+    def __init__(self, k=2, w=None):
+        super().__init__()
+        self.m = keras.layers.MaxPooling2D(pool_size=2, strides=2, padding='VALID')
+
+    def call(self, inputs):
+        return self.m(inputs)
 
 class TFSPP(keras.layers.Layer):
     # Spatial pyramid pooling layer used in YOLOv3-SPP
@@ -269,6 +307,24 @@ class TFSPPF(keras.layers.Layer):
         y2 = self.m(y1)
         return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
 
+class TFSPPCSPC(keras.layers.Layer):
+    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13), w=None):
+        super().__init__()
+        c_ = int(2 * c2 * e)  # hidden channels
+        self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
+        self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
+        self.cv3 = TFConv(c_, c_, 3, 1, w=w.cv3)
+        self.cv4 = TFConv(c_, c_, 1, 1, w=w.cv4)
+        self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
+        self.cv5 = TFConv(4 * c_, c_, 1, 1, w=w.cv5)
+        self.cv6 = TFConv(c_, c_, 3, 1, w=w.cv6)
+        self.cv7 = TFConv(2 * c_, c2, 1, 1, w=w.cv7)
+
+    def call(self, inputs):
+        x1 = self.cv4(self.cv3(self.cv1(inputs)))
+        y1 = self.cv6(self.cv5(tf.concat([x1] + [m(x1) for m in self.m], 3)))
+        y2 = self.cv2(inputs)
+        return self.cv7(tf.concat((y1, y2), 3))
 
 class TFDetect(keras.layers.Layer):
     # TF YOLOv5 Detect layer
@@ -355,6 +411,7 @@ def parse_model(d, ch, model, imgsz):  # model_dict, input_channels(3)
     layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
     for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args
         m_str = m
+
         m = eval(m) if isinstance(m, str) else m  # eval strings
         for j, a in enumerate(args):
             try:
@@ -364,8 +421,8 @@ def parse_model(d, ch, model, imgsz):  # model_dict, input_channels(3)
 
         n = max(round(n * gd), 1) if n > 1 else n  # depth gain
         if m in [
-                nn.Conv2d, Conv, DWConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, MixConv2d, Focus, CrossConv,
-                BottleneckCSP, C3, C3x]:
+                nn.Conv2d, Conv, DWConv, RepConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, SPPCSPC, MixConv2d,
+                Focus, CrossConv, BottleneckCSP, C3, C3x]:
             c1, c2 = ch[f], args[0]
             c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
 
@@ -373,7 +430,7 @@ def parse_model(d, ch, model, imgsz):  # model_dict, input_channels(3)
             if m in [BottleneckCSP, C3, C3x]:
                 args.insert(2, n)
                 n = 1
-        elif m is nn.BatchNorm2d:
+        elif m in [nn.BatchNorm2d, MP]:
             args = [ch[f]]
         elif m is Concat:
             c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)