Update tf.py

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/13194/head
Glenn Jocher 2024-07-17 17:40:05 +02:00 committed by GitHub
parent 05e1da6105
commit d9569c5f01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 24 additions and 39 deletions

View File

@ -1769,46 +1769,31 @@ def run(
batch_size=1, # batch size
dynamic=False, # dynamic batch size
):
# PyTorch model
"""Def run(weights=ROOT / "yolov5s.pt", imgsz=(640, 640), batch_size=1, dynamic=False):"""
Exports YOLOv5 model from PyTorch to TensorFlow/Keras formats and performs inference for validation.
Args:
weights (str | Path): Path to the pre-trained YOLOv5 weights file (typically a .pt file).
imgsz (tuple[int, int]): Tuple specifying the inference size (height, width) of the input images.
batch_size (int): Number of images to process in a batch.
dynamic (bool): Specifies dynamic batch size when set to True.
Returns:
None: The function does not return any value. It displays model summaries and performs inference.
Example:
```python
run(weights="yolov5s.pt", imgsz=(640, 640), batch_size=1, dynamic=False)
```
Note:
- Ensure that the specified weight file path points to a valid YOLOv5 weight file, and the weights are compatible
with the model configuration.
- The function will load the PyTorch model, convert it to TensorFlow/Keras formats, and display the model
summaries. It will also perform a dummy inference to validate the export.
"""
# PyTorch model
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
model = attempt_load(weights, device=torch.device("cpu"), inplace=True, fuse=False)
_ = model(im) # inference
model.info()
# TensorFlow model
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
_ = tf_model.predict(im) # inference
# Keras model
im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
keras_model.summary()
"""
Exports YOLOv5 model from PyTorch to TensorFlow/Keras formats and performs inference for validation.
Args:
weights (str | Path): Path to the pre-trained YOLOv5 weights file (typically a .pt file).
imgsz (tuple[int, int]): Tuple specifying the inference size (height, width) of the input images.
batch_size (int): Number of images to process in a batch.
dynamic (bool): Specifies dynamic batch size when set to True.
Returns:
None: The function does not return any value. It displays model summaries and performs inference.
Example:
```python
run(weights="yolov5s.pt", imgsz=(640, 640), batch_size=1, dynamic=False)
```
Note:
- Ensure that the specified weight file path points to a valid YOLOv5 weight file, and the weights are compatible
with the model configuration.
- The function will load the PyTorch model, convert it to TensorFlow/Keras formats, and display the model
summaries. It will also perform a dummy inference to validate the export.
"""
# PyTorch model
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
model = attempt_load(weights, device=torch.device("cpu"), inplace=True, fuse=False)
_ = model(im) # inference