diff --git a/models/tf.py b/models/tf.py index 626d0f4bb..2fa1c25ac 100644 --- a/models/tf.py +++ b/models/tf.py @@ -1769,46 +1769,31 @@ def run( batch_size=1, # batch size dynamic=False, # dynamic batch size ): - # PyTorch model - """Def run(weights=ROOT / "yolov5s.pt", imgsz=(640, 640), batch_size=1, dynamic=False):""" - Exports YOLOv5 model from PyTorch to TensorFlow/Keras formats and performs inference for validation. - - Args: - weights (str | Path): Path to the pre-trained YOLOv5 weights file (typically a .pt file). - imgsz (tuple[int, int]): Tuple specifying the inference size (height, width) of the input images. - batch_size (int): Number of images to process in a batch. - dynamic (bool): Specifies dynamic batch size when set to True. - - Returns: - None: The function does not return any value. It displays model summaries and performs inference. - - Example: - ```python - run(weights="yolov5s.pt", imgsz=(640, 640), batch_size=1, dynamic=False) - ``` - - Note: - - Ensure that the specified weight file path points to a valid YOLOv5 weight file, and the weights are compatible - with the model configuration. - - The function will load the PyTorch model, convert it to TensorFlow/Keras formats, and display the model - summaries. It will also perform a dummy inference to validate the export. - """ - # PyTorch model - im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image - model = attempt_load(weights, device=torch.device("cpu"), inplace=True, fuse=False) - _ = model(im) # inference - model.info() - - # TensorFlow model - im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image - tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) - _ = tf_model.predict(im) # inference - - # Keras model - im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size) - keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im)) - keras_model.summary() """ + Exports YOLOv5 model from PyTorch to TensorFlow/Keras formats and performs inference for validation. + + Args: + weights (str | Path): Path to the pre-trained YOLOv5 weights file (typically a .pt file). + imgsz (tuple[int, int]): Tuple specifying the inference size (height, width) of the input images. + batch_size (int): Number of images to process in a batch. + dynamic (bool): Specifies dynamic batch size when set to True. + + Returns: + None: The function does not return any value. It displays model summaries and performs inference. + + Example: + ```python + run(weights="yolov5s.pt", imgsz=(640, 640), batch_size=1, dynamic=False) + ``` + + Note: + - Ensure that the specified weight file path points to a valid YOLOv5 weight file, and the weights are compatible + with the model configuration. + - The function will load the PyTorch model, convert it to TensorFlow/Keras formats, and display the model + summaries. It will also perform a dummy inference to validate the export. + """ + + # PyTorch model im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image model = attempt_load(weights, device=torch.device("cpu"), inplace=True, fuse=False) _ = model(im) # inference