mirror of https://github.com/open-mmlab/mmcv.git
fix parrots op bug (#1289)
parent
f022d57702
commit
ea3e9789bf
mmcv
onnx
ops/csrc/parrots
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def is_custom_op_loaded():
|
||||
flag = False
|
||||
|
@ -16,4 +18,4 @@ def is_custom_op_loaded():
|
|||
flag = os.path.exists(ort_lib_path)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pass
|
||||
return flag
|
||||
return flag or torch.__version__ == 'parrots'
|
||||
|
|
|
@ -45,7 +45,7 @@ void sync_bn_forward_output_cuda_parrots(CudaContext& ctx,
|
|||
auto running_var = buildATensor(ctx, outs[1]);
|
||||
auto norm = buildATensor(ctx, outs[2]);
|
||||
auto std = buildATensor(ctx, outs[3]);
|
||||
auto output = buildATensor(ctx, outs[3]);
|
||||
auto output = buildATensor(ctx, outs[4]);
|
||||
sync_bn_forward_output_cuda(input, mean, var, running_mean, running_var,
|
||||
weight, bias, norm, std, output, eps, momentum,
|
||||
group_size);
|
||||
|
|
Loading…
Reference in New Issue