Update tf.py
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/13194/head
parent
05e1da6105
commit
d9569c5f01
19
models/tf.py
19
models/tf.py
|
@ -1769,8 +1769,7 @@ def run(
|
|||
batch_size=1, # batch size
|
||||
dynamic=False, # dynamic batch size
|
||||
):
|
||||
# PyTorch model
|
||||
"""Def run(weights=ROOT / "yolov5s.pt", imgsz=(640, 640), batch_size=1, dynamic=False):"""
|
||||
"""
|
||||
Exports YOLOv5 model from PyTorch to TensorFlow/Keras formats and performs inference for validation.
|
||||
|
||||
Args:
|
||||
|
@ -1793,6 +1792,7 @@ def run(
|
|||
- The function will load the PyTorch model, convert it to TensorFlow/Keras formats, and display the model
|
||||
summaries. It will also perform a dummy inference to validate the export.
|
||||
"""
|
||||
|
||||
# PyTorch model
|
||||
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
|
||||
model = attempt_load(weights, device=torch.device("cpu"), inplace=True, fuse=False)
|
||||
|
@ -1808,21 +1808,6 @@ def run(
|
|||
im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
|
||||
keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
|
||||
keras_model.summary()
|
||||
"""
|
||||
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
|
||||
model = attempt_load(weights, device=torch.device("cpu"), inplace=True, fuse=False)
|
||||
_ = model(im) # inference
|
||||
model.info()
|
||||
|
||||
# TensorFlow model
|
||||
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
|
||||
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
|
||||
_ = tf_model.predict(im) # inference
|
||||
|
||||
# Keras model
|
||||
im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
|
||||
keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
|
||||
keras_model.summary()
|
||||
|
||||
LOGGER.info("PyTorch, TensorFlow and Keras models successfully verified.\nUse export.py for TF model export.")
|
||||
|
||||
|
|
Loading…
Reference in New Issue