mirror of https://github.com/open-mmlab/mmcv.git
fix parrots op bug (#1289)
parent
f022d57702
commit
ea3e9789bf
|
@ -1,6 +1,8 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def is_custom_op_loaded():
|
def is_custom_op_loaded():
|
||||||
flag = False
|
flag = False
|
||||||
|
@ -16,4 +18,4 @@ def is_custom_op_loaded():
|
||||||
flag = os.path.exists(ort_lib_path)
|
flag = os.path.exists(ort_lib_path)
|
||||||
except (ImportError, ModuleNotFoundError):
|
except (ImportError, ModuleNotFoundError):
|
||||||
pass
|
pass
|
||||||
return flag
|
return flag or torch.__version__ == 'parrots'
|
||||||
|
|
|
@ -45,7 +45,7 @@ void sync_bn_forward_output_cuda_parrots(CudaContext& ctx,
|
||||||
auto running_var = buildATensor(ctx, outs[1]);
|
auto running_var = buildATensor(ctx, outs[1]);
|
||||||
auto norm = buildATensor(ctx, outs[2]);
|
auto norm = buildATensor(ctx, outs[2]);
|
||||||
auto std = buildATensor(ctx, outs[3]);
|
auto std = buildATensor(ctx, outs[3]);
|
||||||
auto output = buildATensor(ctx, outs[3]);
|
auto output = buildATensor(ctx, outs[4]);
|
||||||
sync_bn_forward_output_cuda(input, mean, var, running_mean, running_var,
|
sync_bn_forward_output_cuda(input, mean, var, running_mean, running_var,
|
||||||
weight, bias, norm, std, output, eps, momentum,
|
weight, bias, norm, std, output, eps, momentum,
|
||||||
group_size);
|
group_size);
|
||||||
|
|
Loading…
Reference in New Issue