Add missing TF layers (#792)

Add following layers to tf.py:
- TFMP (MP)
- TFSPPCSPC (SPPCSPC)
- TFRepConv (RepConv)
u5
Ibai Gorordo 2022-09-20 06:31:41 +09:00 committed by GitHub
parent a6215c0dbb
commit f2439f894c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 62 additions and 5 deletions

View File

@ -27,8 +27,8 @@ import torch
import torch.nn as nn
from tensorflow import keras
from models.common import (C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv,
DWConvTranspose2d, Focus, autopad)
from models.common import (C3, MP, SPP, SPPF, SPPCSPC, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv,
RepConv, DWConvTranspose2d, Focus, autopad)
from models.experimental import MixConv2d, attempt_load
from models.yolo import Detect
from utils.activations import SiLU
@ -86,6 +86,36 @@ class TFConv(keras.layers.Layer):
def call(self, inputs):
return self.act(self.bn(self.conv(inputs)))
class TFRepConv(keras.layers.Layer):
def __init__(self, c1, c2, k=3, s=1, p=None, g=1, act=True, w=None):
super().__init__()
self.groups = g
self.in_channels = c1
self.out_channels = c2
assert k == 3
assert autopad(k, p) == 1
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
padding_11 = autopad(k, p) - k // 2
self.act = activations(w.act) if act else tf.identity
rbr_reparam = keras.layers.Conv2D(
filters=c2,
kernel_size=k,
strides=s,
padding='SAME' if s == 1 else 'VALID',
use_bias=True,
kernel_initializer=keras.initializers.Constant(w.rbr_reparam.weight.permute(2, 3, 1, 0).numpy()),
bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.rbr_reparam.bias.numpy()))
self.rbr_reparam = rbr_reparam if s == 1 else keras.Sequential([TFPad(autopad(k, p)), rbr_reparam])
self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
def call(self, inputs):
return self.act(self.bn(self.rbr_reparam(inputs)))
class TFDWConv(keras.layers.Layer):
# Depthwise convolution
@ -239,6 +269,14 @@ class TFC3x(keras.layers.Layer):
def call(self, inputs):
return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
class TFMP(keras.layers.Layer):
# Spatial pyramid pooling layer used in YOLOv3-SPP
def __init__(self, k=2, w=None):
super().__init__()
self.m = keras.layers.MaxPooling2D(pool_size=2, strides=2, padding='VALID')
def call(self, inputs):
return self.m(inputs)
class TFSPP(keras.layers.Layer):
# Spatial pyramid pooling layer used in YOLOv3-SPP
@ -269,6 +307,24 @@ class TFSPPF(keras.layers.Layer):
y2 = self.m(y1)
return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
class TFSPPCSPC(keras.layers.Layer):
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13), w=None):
super().__init__()
c_ = int(2 * c2 * e) # hidden channels
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
self.cv3 = TFConv(c_, c_, 3, 1, w=w.cv3)
self.cv4 = TFConv(c_, c_, 1, 1, w=w.cv4)
self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
self.cv5 = TFConv(4 * c_, c_, 1, 1, w=w.cv5)
self.cv6 = TFConv(c_, c_, 3, 1, w=w.cv6)
self.cv7 = TFConv(2 * c_, c2, 1, 1, w=w.cv7)
def call(self, inputs):
x1 = self.cv4(self.cv3(self.cv1(inputs)))
y1 = self.cv6(self.cv5(tf.concat([x1] + [m(x1) for m in self.m], 3)))
y2 = self.cv2(inputs)
return self.cv7(tf.concat((y1, y2), 3))
class TFDetect(keras.layers.Layer):
# TF YOLOv5 Detect layer
@ -355,6 +411,7 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
m_str = m
m = eval(m) if isinstance(m, str) else m # eval strings
for j, a in enumerate(args):
try:
@ -364,8 +421,8 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
n = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [
nn.Conv2d, Conv, DWConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, MixConv2d, Focus, CrossConv,
BottleneckCSP, C3, C3x]:
nn.Conv2d, Conv, DWConv, RepConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, SPPCSPC, MixConv2d,
Focus, CrossConv, BottleneckCSP, C3, C3x]:
c1, c2 = ch[f], args[0]
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
@ -373,7 +430,7 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
if m in [BottleneckCSP, C3, C3x]:
args.insert(2, n)
n = 1
elif m is nn.BatchNorm2d:
elif m in [nn.BatchNorm2d, MP]:
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)