diff --git a/mmcv/onnx/info.py b/mmcv/onnx/info.py index 060efafe2..e59997368 100644 --- a/mmcv/onnx/info.py +++ b/mmcv/onnx/info.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import os +import torch + def is_custom_op_loaded(): flag = False @@ -16,4 +18,4 @@ def is_custom_op_loaded(): flag = os.path.exists(ort_lib_path) except (ImportError, ModuleNotFoundError): pass - return flag + return flag or torch.__version__ == 'parrots' diff --git a/mmcv/ops/csrc/parrots/sync_bn_parrots.cpp b/mmcv/ops/csrc/parrots/sync_bn_parrots.cpp index 6130c3ff5..0b1855abd 100644 --- a/mmcv/ops/csrc/parrots/sync_bn_parrots.cpp +++ b/mmcv/ops/csrc/parrots/sync_bn_parrots.cpp @@ -45,7 +45,7 @@ void sync_bn_forward_output_cuda_parrots(CudaContext& ctx, auto running_var = buildATensor(ctx, outs[1]); auto norm = buildATensor(ctx, outs[2]); auto std = buildATensor(ctx, outs[3]); - auto output = buildATensor(ctx, outs[3]); + auto output = buildATensor(ctx, outs[4]); sync_bn_forward_output_cuda(input, mean, var, running_mean, running_var, weight, bias, norm, std, output, eps, momentum, group_size);