Update tf.py

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/13194/head
Glenn Jocher 2024-07-17 17:40:05 +02:00 committed by GitHub
parent 05e1da6105
commit d9569c5f01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 24 additions and 39 deletions

View File

@ -1769,8 +1769,7 @@ def run(
batch_size=1, # batch size
dynamic=False, # dynamic batch size
):
# PyTorch model
"""Def run(weights=ROOT / "yolov5s.pt", imgsz=(640, 640), batch_size=1, dynamic=False):"""
"""
Exports YOLOv5 model from PyTorch to TensorFlow/Keras formats and performs inference for validation.
Args:
@ -1793,6 +1792,7 @@ def run(
- The function will load the PyTorch model, convert it to TensorFlow/Keras formats, and display the model
summaries. It will also perform a dummy inference to validate the export.
"""
# PyTorch model
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
model = attempt_load(weights, device=torch.device("cpu"), inplace=True, fuse=False)
@ -1808,21 +1808,6 @@ def run(
im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
keras_model.summary()
"""
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
model = attempt_load(weights, device=torch.device("cpu"), inplace=True, fuse=False)
_ = model(im) # inference
model.info()
# TensorFlow model
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
_ = tf_model.predict(im) # inference
# Keras model
im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
keras_model.summary()
LOGGER.info("PyTorch, TensorFlow and Keras models successfully verified.\nUse export.py for TF model export.")