fix parrots op bug ()

pull/1238/head
pc 2021-08-23 14:28:01 +08:00 committed by GitHub
parent f022d57702
commit ea3e9789bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 2 deletions
mmcv
ops/csrc/parrots

View File

@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import torch
def is_custom_op_loaded():
flag = False
@ -16,4 +18,4 @@ def is_custom_op_loaded():
flag = os.path.exists(ort_lib_path)
except (ImportError, ModuleNotFoundError):
pass
return flag
return flag or torch.__version__ == 'parrots'

View File

@ -45,7 +45,7 @@ void sync_bn_forward_output_cuda_parrots(CudaContext& ctx,
auto running_var = buildATensor(ctx, outs[1]);
auto norm = buildATensor(ctx, outs[2]);
auto std = buildATensor(ctx, outs[3]);
auto output = buildATensor(ctx, outs[3]);
auto output = buildATensor(ctx, outs[4]);
sync_bn_forward_output_cuda(input, mean, var, running_mean, running_var,
weight, bias, norm, std, output, eps, momentum,
group_size);