Update tf.py
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/13194/head
parent
05e1da6105
commit
d9569c5f01
63
models/tf.py
63
models/tf.py
|
@ -1769,46 +1769,31 @@ def run(
|
|||
batch_size=1, # batch size
|
||||
dynamic=False, # dynamic batch size
|
||||
):
|
||||
# PyTorch model
|
||||
"""Def run(weights=ROOT / "yolov5s.pt", imgsz=(640, 640), batch_size=1, dynamic=False):"""
|
||||
Exports YOLOv5 model from PyTorch to TensorFlow/Keras formats and performs inference for validation.
|
||||
|
||||
Args:
|
||||
weights (str | Path): Path to the pre-trained YOLOv5 weights file (typically a .pt file).
|
||||
imgsz (tuple[int, int]): Tuple specifying the inference size (height, width) of the input images.
|
||||
batch_size (int): Number of images to process in a batch.
|
||||
dynamic (bool): Specifies dynamic batch size when set to True.
|
||||
|
||||
Returns:
|
||||
None: The function does not return any value. It displays model summaries and performs inference.
|
||||
|
||||
Example:
|
||||
```python
|
||||
run(weights="yolov5s.pt", imgsz=(640, 640), batch_size=1, dynamic=False)
|
||||
```
|
||||
|
||||
Note:
|
||||
- Ensure that the specified weight file path points to a valid YOLOv5 weight file, and the weights are compatible
|
||||
with the model configuration.
|
||||
- The function will load the PyTorch model, convert it to TensorFlow/Keras formats, and display the model
|
||||
summaries. It will also perform a dummy inference to validate the export.
|
||||
"""
|
||||
# PyTorch model
|
||||
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
|
||||
model = attempt_load(weights, device=torch.device("cpu"), inplace=True, fuse=False)
|
||||
_ = model(im) # inference
|
||||
model.info()
|
||||
|
||||
# TensorFlow model
|
||||
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
|
||||
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
|
||||
_ = tf_model.predict(im) # inference
|
||||
|
||||
# Keras model
|
||||
im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
|
||||
keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
|
||||
keras_model.summary()
|
||||
"""
|
||||
Exports YOLOv5 model from PyTorch to TensorFlow/Keras formats and performs inference for validation.
|
||||
|
||||
Args:
|
||||
weights (str | Path): Path to the pre-trained YOLOv5 weights file (typically a .pt file).
|
||||
imgsz (tuple[int, int]): Tuple specifying the inference size (height, width) of the input images.
|
||||
batch_size (int): Number of images to process in a batch.
|
||||
dynamic (bool): Specifies dynamic batch size when set to True.
|
||||
|
||||
Returns:
|
||||
None: The function does not return any value. It displays model summaries and performs inference.
|
||||
|
||||
Example:
|
||||
```python
|
||||
run(weights="yolov5s.pt", imgsz=(640, 640), batch_size=1, dynamic=False)
|
||||
```
|
||||
|
||||
Note:
|
||||
- Ensure that the specified weight file path points to a valid YOLOv5 weight file, and the weights are compatible
|
||||
with the model configuration.
|
||||
- The function will load the PyTorch model, convert it to TensorFlow/Keras formats, and display the model
|
||||
summaries. It will also perform a dummy inference to validate the export.
|
||||
"""
|
||||
|
||||
# PyTorch model
|
||||
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
|
||||
model = attempt_load(weights, device=torch.device("cpu"), inplace=True, fuse=False)
|
||||
_ = model(im) # inference
|
||||
|
|
Loading…
Reference in New Issue