fix `tf` conversion in new v6 models (#5153)
* fix `tf` conversion in new v6 (#5147) * sort imports Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/5154/head
parent
956be8e642
commit
34da872ab6
20
models/tf.py
20
models/tf.py
|
@ -28,7 +28,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from tensorflow import keras
|
||||
|
||||
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, autopad, C3
|
||||
from models.common import Bottleneck, BottleneckCSP, Concat, Conv, C3, DWConv, Focus, SPP, SPPF, autopad
|
||||
from models.experimental import CrossConv, MixConv2d, attempt_load
|
||||
from models.yolo import Detect
|
||||
from utils.general import make_divisible, print_args, set_logging
|
||||
|
@ -183,6 +183,22 @@ class TFSPP(keras.layers.Layer):
|
|||
return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
|
||||
|
||||
|
||||
class TFSPPF(keras.layers.Layer):
|
||||
# Spatial pyramid pooling-Fast layer
|
||||
def __init__(self, c1, c2, k=5, w=None):
|
||||
super(TFSPPF, self).__init__()
|
||||
c_ = c1 // 2 # hidden channels
|
||||
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
|
||||
self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
|
||||
self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')
|
||||
|
||||
def call(self, inputs):
|
||||
x = self.cv1(inputs)
|
||||
y1 = self.m(x)
|
||||
y2 = self.m(y1)
|
||||
return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
|
||||
|
||||
|
||||
class TFDetect(keras.layers.Layer):
|
||||
def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
|
||||
super(TFDetect, self).__init__()
|
||||
|
@ -272,7 +288,7 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
|
|||
pass
|
||||
|
||||
n = max(round(n * gd), 1) if n > 1 else n # depth gain
|
||||
if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
|
||||
if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
|
||||
c1, c2 = ch[f], args[0]
|
||||
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
|
||||
|
||||
|
|
Loading…
Reference in New Issue