mirror of
https://github.com/ultralytics/yolov5.git
synced 2025-06-03 14:49:29 +08:00
Update tf.py
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
05e1da6105
commit
d9569c5f01
63
models/tf.py
63
models/tf.py
@ -1769,46 +1769,31 @@ def run(
|
|||||||
batch_size=1, # batch size
|
batch_size=1, # batch size
|
||||||
dynamic=False, # dynamic batch size
|
dynamic=False, # dynamic batch size
|
||||||
):
|
):
|
||||||
# PyTorch model
|
|
||||||
"""Def run(weights=ROOT / "yolov5s.pt", imgsz=(640, 640), batch_size=1, dynamic=False):"""
|
|
||||||
Exports YOLOv5 model from PyTorch to TensorFlow/Keras formats and performs inference for validation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
weights (str | Path): Path to the pre-trained YOLOv5 weights file (typically a .pt file).
|
|
||||||
imgsz (tuple[int, int]): Tuple specifying the inference size (height, width) of the input images.
|
|
||||||
batch_size (int): Number of images to process in a batch.
|
|
||||||
dynamic (bool): Specifies dynamic batch size when set to True.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None: The function does not return any value. It displays model summaries and performs inference.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
run(weights="yolov5s.pt", imgsz=(640, 640), batch_size=1, dynamic=False)
|
|
||||||
```
|
|
||||||
|
|
||||||
Note:
|
|
||||||
- Ensure that the specified weight file path points to a valid YOLOv5 weight file, and the weights are compatible
|
|
||||||
with the model configuration.
|
|
||||||
- The function will load the PyTorch model, convert it to TensorFlow/Keras formats, and display the model
|
|
||||||
summaries. It will also perform a dummy inference to validate the export.
|
|
||||||
"""
|
|
||||||
# PyTorch model
|
|
||||||
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
|
|
||||||
model = attempt_load(weights, device=torch.device("cpu"), inplace=True, fuse=False)
|
|
||||||
_ = model(im) # inference
|
|
||||||
model.info()
|
|
||||||
|
|
||||||
# TensorFlow model
|
|
||||||
im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
|
|
||||||
tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
|
|
||||||
_ = tf_model.predict(im) # inference
|
|
||||||
|
|
||||||
# Keras model
|
|
||||||
im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
|
|
||||||
keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
|
|
||||||
keras_model.summary()
|
|
||||||
"""
|
"""
|
||||||
|
Exports YOLOv5 model from PyTorch to TensorFlow/Keras formats and performs inference for validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weights (str | Path): Path to the pre-trained YOLOv5 weights file (typically a .pt file).
|
||||||
|
imgsz (tuple[int, int]): Tuple specifying the inference size (height, width) of the input images.
|
||||||
|
batch_size (int): Number of images to process in a batch.
|
||||||
|
dynamic (bool): Specifies dynamic batch size when set to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None: The function does not return any value. It displays model summaries and performs inference.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
run(weights="yolov5s.pt", imgsz=(640, 640), batch_size=1, dynamic=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Ensure that the specified weight file path points to a valid YOLOv5 weight file, and the weights are compatible
|
||||||
|
with the model configuration.
|
||||||
|
- The function will load the PyTorch model, convert it to TensorFlow/Keras formats, and display the model
|
||||||
|
summaries. It will also perform a dummy inference to validate the export.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# PyTorch model
|
||||||
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
|
im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
|
||||||
model = attempt_load(weights, device=torch.device("cpu"), inplace=True, fuse=False)
|
model = attempt_load(weights, device=torch.device("cpu"), inplace=True, fuse=False)
|
||||||
_ = model(im) # inference
|
_ = model(im) # inference
|
||||||
|
Loading…
x
Reference in New Issue
Block a user