diff --git a/mmcv/ops/csrc/common/mlu/focal_loss_sigmoid_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/focal_loss_sigmoid_mlu_kernel.mlu index 028f6c0c9..7624379b6 100644 --- a/mmcv/ops/csrc/common/mlu/focal_loss_sigmoid_mlu_kernel.mlu +++ b/mmcv/ops/csrc/common/mlu/focal_loss_sigmoid_mlu_kernel.mlu @@ -24,8 +24,21 @@ __mlu_func__ void loadInput(char *nram_input, T *dram_input, const int32_t size, const int32_t dst_stride = 0, const int32_t src_stride = 0, const int32_t count = 1) { - __memcpy_async(nram_input, dram_input, size, GDRAM2NRAM, dst_stride, - src_stride, count - 1); + if (dst_stride == src_stride) { + __memcpy_async(nram_input, dram_input, size * count, GDRAM2NRAM); + } else { + __memcpy_async(nram_input, dram_input, size, GDRAM2NRAM, dst_stride, + src_stride, count - 1); + } +} + +template +__mlu_func__ void loadWeight(char *nram_input, T *dram_input, const int32_t t, + const int32_t c, const int32_t has_weight, + const int32_t partition_nc) { + if (has_weight && partition_nc && t >= 0 && t < c) { + __memcpy_async(nram_input, (T *)dram_input + t, sizeof(T), GDRAM2NRAM); + } } template @@ -33,152 +46,117 @@ __mlu_func__ void storeOutput(T *dram_output, char *nram_output, const int32_t size, const int32_t dst_stride = 0, const int32_t src_stride = 0, const int32_t count = 1) { - __memcpy_async(dram_output, nram_output, size, NRAM2GDRAM, dst_stride, - src_stride, count - 1); + if (dst_stride == src_stride) { + __memcpy_async(dram_output, nram_output, size * count, NRAM2GDRAM); + } else { + __memcpy_async(dram_output, nram_output, size, NRAM2GDRAM, dst_stride, + src_stride, count - 1); + } } template __mlu_func__ void compute(T *input, const int32_t *target, const T *weight, - const int32_t has_weight, const int32_t deal_num, - const int32_t n_seg, const int32_t C, float alpha, - float gamma, T *scalar_temp, T *tensor_max, - T *tensor_temp, T *output) { - const int32_t scalar_elem_num = NFU_ALIGN_SIZE / sizeof(T); + const int32_t has_weight, const int32_t partition_nc, + const int32_t deal_num, const int32_t n_seg, + const int32_t c, const int32_t c_seg, + const int32_t c_start_index, const float alpha, + const float gamma, T *compute_a, T *compute_b, + T *output) { + // set params + const int32_t c_num = + has_weight ? PAD_UP(c_seg, NFU_ALIGN_SIZE / sizeof(T)) : c_seg; + const int32_t c_end_index = c_start_index + c_seg; + const int32_t half_epsilon = 0x0400; + const T epsilon_f = + sizeof(T) == sizeof(float) ? FLT_MIN : *((half *)&half_epsilon); - // 0. n_max = max(0, x) - __nramset((T *)tensor_max, deal_num, (T)0); - __bang_cycle_maxequal((T *)tensor_max, (T *)tensor_max, (T *)input, deal_num, - deal_num); - - // 1. ln(1+e^x) = ln(e^(-max) + e^(x-max)) + max - __nramset((T *)scalar_temp, scalar_elem_num, (T)-1); - __bang_cycle_mul((T *)tensor_temp, (T *)tensor_max, (T *)scalar_temp, - deal_num, scalar_elem_num); - __bang_cycle_add((T *)output, (T *)input, (T *)tensor_temp, deal_num, - deal_num); - __bang_active_exphp((T *)output, (T *)output, deal_num); - __bang_active_exphp((T *)tensor_temp, (T *)tensor_temp, deal_num); - __bang_cycle_add((T *)output, (T *)output, (T *)tensor_temp, deal_num, - deal_num); - __bang_active_loghp((T *)output, (T *)output, deal_num); - __bang_cycle_add((T *)output, (T *)output, (T *)tensor_max, deal_num, - deal_num); - - // 2. temp = [1 + e^(-x)] ^ (-r) - __nramset((T *)scalar_temp, scalar_elem_num, (T)-1); - __bang_cycle_mul((T *)tensor_temp, (T *)input, (T *)scalar_temp, deal_num, - scalar_elem_num); - __bang_active_exphp((T *)tensor_temp, (T *)tensor_temp, deal_num); - - __nramset((T *)scalar_temp, scalar_elem_num, (T)1); - __bang_cycle_add((T *)tensor_temp, (T *)tensor_temp, (T *)scalar_temp, - deal_num, scalar_elem_num); - __bang_active_loghp((T *)tensor_temp, (T *)tensor_temp, deal_num); - - __nramset((T *)scalar_temp, scalar_elem_num, (T)(-gamma)); - __bang_cycle_mul((T *)tensor_temp, (T *)tensor_temp, (T *)scalar_temp, - deal_num, scalar_elem_num); - __bang_active_exphp((T *)tensor_temp, (T *)tensor_temp, deal_num); - - // 3.1 output: target != j - __nramset((T *)scalar_temp, scalar_elem_num, (T)(1 - alpha)); - __bang_cycle_mul((T *)output, (T *)output, (T *)scalar_temp, deal_num, - scalar_elem_num); - __bang_cycle_mul((T *)output, (T *)output, (T *)tensor_temp, deal_num, - deal_num); - - // 3.2 output: target == j - const int32_t c_align_size = PAD_UP((sizeof(T) * C), NFU_ALIGN_SIZE); + // 0. alpha_t * p_t^r = alpha * (1 - p) ^ gamma if t == c_i + // = (1 - alpha) * p ^ gamma if t != c_i + __nramset((T *)output, deal_num, (T)(1 - alpha)); + __bang_active_sigmoid((T *)compute_b, (T *)input, deal_num); for (int32_t i = 0; i < n_seg; ++i) { - const int32_t target_value = *((int32_t *)target + i); - if (target_value >= 0 && target_value < C) { - const int32_t offset = i * c_align_size + target_value * sizeof(T); - char *addr_input = (char *)input + offset; - char *addr_output = (char *)output + offset; - const float x = *(T *)addr_input; - const float p = 1. / (1. + exp(-x)); - *(T *)addr_output = -alpha * pow(1. - p, gamma) * log(fmax(p, FLT_MIN)); + const int32_t t = *((uint32_t *)target + i); + if (t >= c_start_index && t < c_end_index) { + const uint32_t index = i * c_num + t - c_start_index; + *((T *)input + index) = -1.0 * (*((T *)input + index)); + *((T *)compute_b + index) = 1.0 - (*((T *)compute_b + index)) + epsilon_f; + *((T *)output + index) = alpha; } } + if (sizeof(T) == sizeof(half)) { + __bang_half2float((float *)compute_a, (half *)compute_b, deal_num); + __bang_active_loghp((float *)compute_a, (float *)compute_a, deal_num); + __bang_mul_const((float *)compute_a, (float *)compute_a, (float)gamma, + deal_num); + __bang_active_exphp((float *)compute_a, (float *)compute_a, deal_num); + __bang_float2half_rd((half *)compute_a, (float *)compute_a, deal_num); + } else { + __bang_active_loghp((T *)compute_a, (T *)compute_b, deal_num); + __bang_mul_const((T *)compute_a, (T *)compute_a, (T)gamma, deal_num); + __bang_active_exphp((T *)compute_a, (T *)compute_a, deal_num); + } + __bang_mul((T *)output, (T *)compute_a, (T *)output, deal_num); - // with weight - if (has_weight > 0) { - int32_t row_num_elem = deal_num / n_seg; + // 1. max = max(0, -x) if t == c_i + // = max(0, x) if t != c_i + __nramset((T *)compute_b, deal_num, (T)0); + __bang_maxequal((T *)compute_b, (T *)compute_b, (T *)input, deal_num); + + // 2. -log(p_t) = ln(e^(-max)+ e^(-max-x) + max if t == c_i + // = ln(e^(-max)+ e^(-max+x) + max if t != c_i + __bang_mul_const((T *)compute_a, (T *)compute_b, (T)-1.0, deal_num); + __bang_add((T *)input, (T *)compute_a, (T *)input, deal_num); + + __bang_active_exphp((T *)compute_a, (T *)compute_a, deal_num); + __bang_active_exphp((T *)input, (T *)input, deal_num); + __bang_add((T *)compute_a, (T *)compute_a, (T *)input, deal_num); + __bang_active_loghp((T *)compute_a, (T *)compute_a, deal_num); + __bang_add((T *)input, (T *)compute_a, (T *)compute_b, deal_num); + + // 3. output = alpha_t * p_t^r * [-log(p_t)] + __bang_mul((T *)output, (T *)output, (T *)input, deal_num); + + // 4. with weight + if (has_weight) { for (int32_t i = 0; i < n_seg; ++i) { - const int32_t t = *((int32_t *)target + i); - __nramset((T *)scalar_temp, scalar_elem_num, *((T *)weight + t)); - __bang_cycle_mul((T *)output + i * row_num_elem, - (T *)output + i * row_num_elem, (T *)scalar_temp, - row_num_elem, scalar_elem_num); + int32_t t = *((int32_t *)target + i); + if (t >= 0 && t < c) { + t = partition_nc ? 0 : t; + __bang_mul_const((T *)output + i * c_num, (T *)output + i * c_num, + *((T *)weight + t), c_num); + } } } } template -__mlu_func__ void focalLossSigmoidForwardBlock( +__mlu_func__ void startPipeline( const T *input, const int32_t *target, const T *weight, - const int32_t row_num, const int32_t C, const float alpha, - const float gamma, T *output) { - /* - * NRAM partition - * |-----------------------------------------------------------------------| - * | scalar | - * |-----------------------------------------------------------------------| - * | weight | - * |------------------------------- COMPUTE -------------------------------| - * | | | - * | computeA | computeB | - * | | | - * |------------- PING ------------------------------- PONG ---------------| - * | | | - * | input | input | - * | | | - * |-----------------------------------|-----------------------------------| - * | | | - * | output | output | - * | | | - * |-----------------------------------|-----------------------------------| - * | target | target | - * |-----------------------------------|-----------------------------------| - * - * split_pipeline_num is 6: COMPUTE(computeA,computeB), PING(input,output), - * PONG(input,output). - * split_target_num is 2: PING(target), PONG(target). - */ - const int32_t c_align = PAD_UP(C, NFU_ALIGN_SIZE / sizeof(T)); - const int32_t c_align_size = c_align * sizeof(T); - const int32_t scalar_size = NFU_ALIGN_SIZE; - const int32_t weight_size = (weight != NULL) * c_align_size; - const int32_t split_pipeline_num = 6; - const int32_t split_target_num = 2; + char *nram_compute_a, char *nram_compute_b, char *nram_input, + char *nram_target, char *nram_weight, char *nram_output, + const int32_t has_weight, const int32_t partition_nc, + const int32_t pingpong_offset, const int32_t pingpong_weight_offset, + const int32_t c_offset_num, const int32_t n, const int32_t n_seg, + const int32_t c, const int32_t c_seg, const float alpha, const float gamma, + T *output) { + // with offset + input = (T *)((char *)input + c_offset_num * sizeof(T)); + output = (T *)((char *)output + c_offset_num * sizeof(T)); - const int32_t remain_size = MAX_NRAM_SIZE - scalar_size - weight_size; - const int32_t n_seg = remain_size / (split_pipeline_num * c_align_size + - split_target_num * sizeof(int32_t)); - const int32_t deal_num = n_seg * c_align_size / sizeof(T); - const int32_t target_size = n_seg * sizeof(int32_t); + const int32_t c_seg_align_num = PAD_UP(c_seg, NFU_ALIGN_SIZE / sizeof(T)); + const int32_t c_num = has_weight ? c_seg_align_num : c_seg; + const int32_t deal_num = PAD_UP(n_seg * c_num, NFU_ALIGN_SIZE / sizeof(T)); + const int32_t load_size = c_seg * sizeof(T); + const int32_t dram_stride = c * sizeof(T); + const int32_t nram_stride = c_num * sizeof(T); - // nram scalar,weight - char *nram_scalar = (char *)nram_buffer; - char *nram_weight = (char *)nram_scalar + scalar_size; - if (weight_size > 0) { - loadInput(nram_weight, (T *)weight, C * sizeof(T)); - __asm__ volatile("sync;"); + if (has_weight && !partition_nc) { + loadInput(nram_weight, (T *)weight, load_size, nram_stride, dram_stride, + 1); + __asm__ volatile("sync;\n\t"); } - - // nram COMPUTE - const int32_t compute_size = 2 * c_align_size * n_seg; - char *nram_compute_a = (char *)nram_weight + weight_size; - char *nram_compute_b = (char *)nram_compute_a + c_align_size * n_seg; - - // nram PING/PONG - const int32_t pingpong_offset = (remain_size - compute_size) / 2; - char *nram_input = (char *)nram_compute_a + 2 * c_align_size * n_seg; - char *nram_output = (char *)nram_compute_a + 3 * c_align_size * n_seg; - char *nram_target = (char *)nram_compute_a + 4 * c_align_size * n_seg; - - const int32_t repeat = row_num / n_seg; - const int32_t remain = row_num % n_seg; + const int32_t repeat = n / n_seg; + const int32_t remain = n % n_seg; /* * Pipeline: The pipeline is processed in three stages: Load, Compute, Store. @@ -206,80 +184,214 @@ __mlu_func__ void focalLossSigmoidForwardBlock( // diagram of PINGPONG: L0 if (repeat > 0) { - loadInput(nram_input, (T *)input, C * sizeof(T), c_align * sizeof(T), - C * sizeof(T), n_seg); - loadInput(nram_target, (int32_t *)target, target_size); - __asm__ volatile("sync;"); + loadInput(nram_input, (T *)input, load_size, nram_stride, dram_stride, + n_seg); + loadInput(nram_target, (int32_t *)target, n_seg * sizeof(int32_t)); + loadWeight(nram_weight, (T *)weight, *((int32_t *)target), c, has_weight, + partition_nc); + __asm__ volatile("sync;\n\t"); } // diagram of PINGPONG: C0 and L1 if (repeat > 1) { - loadInput(nram_input + pingpong_offset, (T *)input + C * n_seg, - C * sizeof(T), c_align * sizeof(T), C * sizeof(T), n_seg); - loadInput(nram_target + pingpong_offset, (int32_t *)target + n_seg, - target_size); compute((T *)nram_input, (int32_t *)nram_target, (T *)nram_weight, - weight_size, deal_num, n_seg, C, alpha, gamma, (T *)nram_scalar, - (T *)nram_compute_a, (T *)nram_compute_b, (T *)nram_output); - __asm__ volatile("sync;"); + has_weight, partition_nc, deal_num, n_seg, c, c_seg, c_offset_num, + alpha, gamma, (T *)nram_compute_a, (T *)nram_compute_b, + (T *)nram_output); + loadInput((char *)nram_input + pingpong_offset, (T *)input + c * n_seg, + load_size, nram_stride, dram_stride, n_seg); + loadInput((char *)nram_target + pingpong_offset, + (int32_t *)target + n_seg, n_seg * sizeof(int32_t)); + loadWeight((char *)nram_weight + pingpong_weight_offset, (T *)weight, + *((int32_t *)target + n_seg), c, has_weight, partition_nc); + __asm__ volatile("sync;\n\t"); } for (int32_t i = 0; i < repeat - 2; ++i) { - storeOutput((T *)output + i * C * n_seg, - nram_output + (i % 2) * pingpong_offset, C * sizeof(T), - C * sizeof(T), c_align * sizeof(T), n_seg); - loadInput(nram_input + (i % 2) * pingpong_offset, - (T *)input + (i + 2) * C * n_seg, C * sizeof(T), - c_align * sizeof(T), C * sizeof(T), n_seg); - loadInput(nram_target + (i % 2) * pingpong_offset, - (int32_t *)target + (i + 2) * n_seg, target_size); + storeOutput((T *)output + i * c * n_seg, + nram_output + (i % 2) * pingpong_offset, load_size, + dram_stride, nram_stride, n_seg); + loadInput((char *)nram_input + (i % 2) * pingpong_offset, + (T *)(input) + (i + 2) * c * n_seg, load_size, nram_stride, + dram_stride, n_seg); + loadInput((char *)nram_target + (i % 2) * pingpong_offset, + (int32_t *)target + (i + 2) * n_seg, + n_seg * sizeof(int32_t)); + loadWeight((char *)nram_weight + (i % 2) * pingpong_weight_offset, + (T *)weight, *((int32_t *)target + (i + 2) * n_seg), c, + has_weight, partition_nc); compute((T *)(nram_input + ((i + 1) % 2) * pingpong_offset), (int32_t *)(nram_target + ((i + 1) % 2) * pingpong_offset), - (T *)nram_weight, weight_size, deal_num, n_seg, C, alpha, gamma, - (T *)nram_scalar, (T *)nram_compute_a, (T *)nram_compute_b, + (T *)(nram_weight + + partition_nc * ((i + 1) % 2) * pingpong_weight_offset), + has_weight, partition_nc, deal_num, n_seg, c, c_seg, c_offset_num, + alpha, gamma, (T *)nram_compute_a, (T *)nram_compute_b, (T *)(nram_output + ((i + 1) % 2) * pingpong_offset)); - __asm__ volatile("sync;"); + __asm__ volatile("sync;\n\t"); } if (repeat > 1) { - storeOutput((T *)output + (repeat - 2) * C * n_seg, - nram_output + (repeat % 2) * pingpong_offset, C * sizeof(T), - C * sizeof(T), c_align * sizeof(T), n_seg); + storeOutput((T *)output + (repeat - 2) * c * n_seg, + (char *)nram_output + (repeat % 2) * pingpong_offset, + load_size, dram_stride, nram_stride, n_seg); } + if (remain > 0) { - loadInput(nram_input + (repeat % 2) * pingpong_offset, - (T *)input + repeat * C * n_seg, C * sizeof(T), - c_align * sizeof(T), C * sizeof(T), remain); - loadInput(nram_target + (repeat % 2) * pingpong_offset, + loadInput((char *)nram_input + (repeat % 2) * pingpong_offset, + (T *)input + repeat * c * n_seg, load_size, nram_stride, + dram_stride, remain); + loadInput((char *)nram_target + (repeat % 2) * pingpong_offset, (int32_t *)target + repeat * n_seg, remain * sizeof(int32_t)); + loadWeight((char *)nram_weight + (repeat % 2) * pingpong_weight_offset, + (T *)weight, *((int32_t *)target + repeat * n_seg), c, + has_weight, partition_nc); } + if (repeat > 0) { compute((T *)(nram_input + ((repeat - 1) % 2) * pingpong_offset), (int32_t *)(nram_target + ((repeat - 1) % 2) * pingpong_offset), - (T *)nram_weight, weight_size, deal_num, n_seg, C, alpha, gamma, - (T *)nram_scalar, (T *)nram_compute_a, (T *)nram_compute_b, + (T *)(nram_weight + + partition_nc * ((repeat - 1) % 2) * pingpong_weight_offset), + has_weight, partition_nc, deal_num, n_seg, c, c_seg, c_offset_num, + alpha, gamma, (T *)nram_compute_a, (T *)nram_compute_b, (T *)(nram_output + ((repeat - 1) % 2) * pingpong_offset)); } - __asm__ volatile("sync;"); + __asm__ volatile("sync;\n\t"); if (repeat > 0) { - storeOutput((T *)output + (repeat - 1) * C * n_seg, - nram_output + ((repeat - 1) % 2) * pingpong_offset, - C * sizeof(T), C * sizeof(T), c_align * sizeof(T), n_seg); + storeOutput((T *)output + (repeat - 1) * c * n_seg, + (char *)nram_output + ((repeat - 1) % 2) * pingpong_offset, + load_size, dram_stride, nram_stride, n_seg); } + if (remain > 0) { - int rem_deal_num = remain * c_align_size / sizeof(T); + int32_t rem_num = PAD_UP(remain * c_num, NFU_ALIGN_SIZE / sizeof(T)); compute((T *)(nram_input + (repeat % 2) * pingpong_offset), (int32_t *)(nram_target + (repeat % 2) * pingpong_offset), - (T *)nram_weight, weight_size, rem_deal_num, remain, C, alpha, - gamma, (T *)nram_scalar, (T *)nram_compute_a, (T *)nram_compute_b, + (T *)(nram_weight + + partition_nc * (repeat % 2) * pingpong_weight_offset), + has_weight, partition_nc, rem_num, remain, c, c_seg, c_offset_num, + alpha, gamma, (T *)nram_compute_a, (T *)nram_compute_b, (T *)(nram_output + (repeat % 2) * pingpong_offset)); - __asm__ volatile("sync;"); + __asm__ volatile("sync;\n\t"); - storeOutput((T *)output + repeat * C * n_seg, - nram_output + (repeat % 2) * pingpong_offset, C * sizeof(T), - C * sizeof(T), c_align * sizeof(T), remain); + storeOutput((T *)output + repeat * c * n_seg, + (char *)nram_output + (repeat % 2) * pingpong_offset, + load_size, dram_stride, nram_stride, remain); + } + __asm__ volatile("sync;\n\t"); +} + +template +__mlu_func__ void focalLossSigmoidForwardBlock( + const T *input, const int32_t *target, const T *weight, const int32_t n, + const int32_t c, const float alpha, const float gamma, T *output) { + /* + * NRAM partition + * |-----------------------------------------------------------------------| + * | weight | + * |------------------------------- COMPUTE -------------------------------| + * | | | + * | computeA | computeB | + * | | | + * |------------- PING ------------------------------- PONG ---------------| + * | | | + * | input | input | + * | | | + * |-----------------------------------|-----------------------------------| + * | | | + * | output | output | + * | | | + * |-----------------------------------|-----------------------------------| + * | target | target | + * |-----------------------------------|-----------------------------------| + * + * split_pipeline_num is 6: COMPUTE(computeA,computeB), PING(input,output), + * PONG(input,output). + * split_target_num is 2: PING(target), PONG(target). + * weight is not NULL: + * The nram-size of weight is equal to c_align_size when partition input-N. + * The nram-size of weight is equal to NFU_ALIGN_SIZE when partition + * input-NC. + */ + + // calculate threshold of c + const int32_t split_pipeline_num = 6; + const int32_t split_target_num = 2; + const int32_t has_weight = weight != NULL; + const int32_t threshold_c = + PAD_DOWN((MAX_NRAM_SIZE - split_target_num * sizeof(int32_t)) / + (split_pipeline_num + has_weight), + NFU_ALIGN_SIZE) / + sizeof(T); + const int32_t c_align = PAD_UP(c, NFU_ALIGN_SIZE / sizeof(T)); + const int32_t c_align_size = c_align * sizeof(T); + + if (c <= threshold_c) { + // partition inputN + int32_t c_num = c; + int32_t reservered_align_size = + (split_target_num + split_pipeline_num) * NFU_ALIGN_SIZE; + int32_t weight_size = 0; + if (has_weight) { + c_num = c_align; + reservered_align_size = split_target_num * NFU_ALIGN_SIZE; + weight_size = c_align_size; + } + + const int32_t remain_size = + MAX_NRAM_SIZE - weight_size - reservered_align_size; + const int32_t n_seg = + remain_size / (split_pipeline_num * c_num * sizeof(T) + + split_target_num * sizeof(int32_t)); + const int32_t split_pipeline_size = + PAD_UP(c_num * n_seg * sizeof(T), NFU_ALIGN_SIZE); + const int32_t compute_size = 2 * split_pipeline_size; + const int32_t pingpong_offset = (MAX_NRAM_SIZE - weight_size - compute_size) / 2; + + char *nram_weight = (char *)nram_buffer; + char *nram_compute_a = nram_weight + has_weight * c_align_size; + char *nram_compute_b = nram_compute_a + split_pipeline_size; + char *nram_input = nram_compute_b + split_pipeline_size; + char *nram_output = nram_input + split_pipeline_size; + char *nram_target = nram_output + split_pipeline_size; + + startPipeline(input, target, weight, nram_compute_a, nram_compute_b, + nram_input, nram_target, nram_weight, nram_output, + has_weight, 0, pingpong_offset, 0, 0, n, n_seg, c, c, + alpha, gamma, output); + } else { + // partition inputNC + const int32_t weight_size = has_weight * NFU_ALIGN_SIZE; + const int32_t remain_size = MAX_NRAM_SIZE - weight_size; + const int32_t split_pipeline_size = PAD_DOWN( + (remain_size - split_target_num * NFU_ALIGN_SIZE) / split_pipeline_num, + NFU_ALIGN_SIZE); + const int32_t c_seg = split_pipeline_size / sizeof(T); + const int32_t n_seg = 1; + const int32_t compute_size = 2 * split_pipeline_size; + const int32_t pingpong_offset = (MAX_NRAM_SIZE - weight_size - compute_size) / 2; + const int32_t pingpong_weight_offset = weight_size / 2; + + char *nram_weight = (char *)nram_buffer; + char *nram_compute_a = nram_weight + weight_size; + char *nram_compute_b = nram_compute_a + split_pipeline_size; + char *nram_input = nram_compute_b + split_pipeline_size; + char *nram_output = nram_input + split_pipeline_size; + char *nram_target = nram_output + split_pipeline_size; + + const int32_t loop_num = (c + c_seg - 1) / c_seg; + const int32_t partition_nc = 1; + for (int32_t i = 0; i < loop_num; ++i) { + const int32_t c_index = i * c_seg; + const int32_t c_seg_curr = i == (loop_num - 1) ? c - c_index : c_seg; + startPipeline(input, target, weight, nram_compute_a, nram_compute_b, + nram_input, nram_target, nram_weight, nram_output, + has_weight, partition_nc, pingpong_offset, + pingpong_weight_offset, c_index, n, n_seg, c, c_seg_curr, + alpha, gamma, output); + } } } diff --git a/mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu index 1095da870..7cb16bb10 100644 --- a/mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu +++ b/mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (C) 2021 by Cambricon. + * Copyright (C) 2021 Cambricon. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF @@ -15,6 +15,8 @@ #define COORD_DIM (4) #define MEMORY_CORE (0x80) #define INFO_NUM (5) // 5 means x1, x2, y1, y2 and score +#define REDUCE_NUM \ + (7) // score, x1, y1, x2, y2, max_index (reserve 2 num for half-type input) #define SIZE_NRAM_BUF (MAX_NRAM_SIZE + REM_FOR_STACK - 62 * 1024) #define SIZE_SRAM_BUF (MAX_SRAM_SIZE) @@ -551,7 +553,7 @@ __mlu_func__ void nms_detection( } // for keepNum } -__mlu_global__ void MLUKernelNMS( +__mlu_global__ void MLUUnion1KernelNMS( const void *input_boxes, const void *input_confidence, const int input_num_boxes, const int input_stride, const int max_output_size, const float iou_threshold, @@ -635,15 +637,525 @@ __mlu_global__ void MLUKernelNMS( } } +template +__mlu_func__ void nms_detection_ux( + int32_t *loop_end_flag, uint32_t &output_box_num, OUT_DT *output_dram, + IN_DT *score_data, const IN_DT *boxes_data, const Addr input_ram, + const int input_layout, const int input_num_boxes, const int input_stride, + const int max_output_size, const float thresh_iou, const float thresh_score, + const float offset, const int output_mode, const int algo) { + loop_end_flag[0] = 0; + IN_DT *sram = (IN_DT *)sram_buffer; + + // score, x1, y1, x2, y2, inter_x1, inter_y1, inter_x2, inter_y2 + int nms_buffer_count1 = 9; + // temp nram buffer to store selected target. + int nram_save_limit_count = 256; + float div_thresh_iou = 1.0 / thresh_iou; + + // input data ptr + IN_DT *input_score_ptr; + const IN_DT *input_x1_ptr; + const IN_DT *input_y1_ptr; + const IN_DT *input_x2_ptr; + const IN_DT *input_y2_ptr; + input_score_ptr = score_data; + input_x1_ptr = boxes_data; + input_y1_ptr = input_x1_ptr + input_stride; + input_x2_ptr = input_y1_ptr + input_stride; + input_y2_ptr = input_x2_ptr + input_stride; + + int limit = 0; // find limit when GDRAM or SRAM + int max_seg_pad = 0; // the max length every repeat + int repeat = 0; + int remain = 0; + int remain_pad = 0; + int nram_save_count = 0; + + if (output_mode == 0) { + limit = (SIZE_NRAM_BUF - NFU_ALIGN_SIZE /*for max_box*/ * sizeof(IN_DT) - + nram_save_limit_count * sizeof(OUT_DT)) / + (nms_buffer_count1 * sizeof(IN_DT)); + } else { + limit = (SIZE_NRAM_BUF - NFU_ALIGN_SIZE /*for max_box*/ * sizeof(IN_DT) - + nram_save_limit_count * INFO_NUM * sizeof(OUT_DT)) / + (nms_buffer_count1 * sizeof(IN_DT)); + } + + // data split + int avg_cluster = input_num_boxes / clusterDim; + int rem_cluster = input_num_boxes % clusterDim; + int len_cluster = avg_cluster + (clusterId < rem_cluster ? 1 : 0); + int cluster_offset = avg_cluster * clusterId + + (clusterId <= rem_cluster ? clusterId : rem_cluster); + + int avg_core = len_cluster / coreDim; + int rem_core = len_cluster % coreDim; + int len_core = avg_core + (coreId < rem_core ? 1 : 0); + int core_offset = + avg_core * coreId + (coreId <= rem_core ? coreId : rem_core); + int input_offset = cluster_offset + core_offset; + + max_seg_pad = PAD_DOWN(limit, NMS_SIZE); + + // core 0 of each cluster calculate the max score index + int max_index_avg_core = input_num_boxes / clusterDim; + int max_index_rem_core = input_num_boxes % clusterDim; + int max_index_len_core = + max_index_avg_core + (clusterId < max_index_rem_core ? 1 : 0); + int max_index_input_offset = + max_index_avg_core * clusterId + + (clusterId <= max_index_rem_core ? clusterId : max_index_rem_core); + repeat = max_index_len_core / max_seg_pad; + remain = max_index_len_core % max_seg_pad; + remain_pad = PAD_UP(remain, NMS_SIZE); + + // if datatype is fp16, we should cvt to fp32 when compute iou + int max_seg_iou_compute = + PAD_DOWN(max_seg_pad / (sizeof(float) / sizeof(IN_DT)), NMS_SIZE); + int repeat_iou_compute = len_core / max_seg_iou_compute; + int remain_iou_compute = len_core % max_seg_iou_compute; + int remain_pad_iou_compute = PAD_UP(remain_iou_compute, NMS_SIZE); + + // init the nram ptr + IN_DT *score = (IN_DT *)nram_buffer; + IN_DT *x1 = score + max_seg_pad; + IN_DT *y1 = x1 + max_seg_pad; + IN_DT *x2 = y1 + max_seg_pad; + IN_DT *y2 = x2 + max_seg_pad; + IN_DT *inter_x1 = y2 + max_seg_pad; + IN_DT *inter_y1 = inter_x1 + max_seg_pad; + IN_DT *inter_x2 = inter_y1 + max_seg_pad; + IN_DT *inter_y2 = inter_x2 + max_seg_pad; + IN_DT *max_box = inter_y2 + max_seg_pad; // the max score, x1, y1, x2, y2 + OUT_DT *nram_save = + (OUT_DT *)((char *)max_box + + NFU_ALIGN_SIZE); // offset two line from max_box + + mluMemcpyDirection_t input_load_dir = SRAM2NRAM; + mluMemcpyDirection_t input_store_dir = NRAM2SRAM; + input_load_dir = (input_ram == SRAM) ? SRAM2NRAM : GDRAM2NRAM; + input_store_dir = (input_ram == SRAM) ? NRAM2SRAM : NRAM2GDRAM; + + for (int keep = 0; keep < max_output_size; + keep++) { // loop until the max_score <= 0 + __sync_all(); + + /******FIND MAX START******/ + int max_index = 0; + int global_max_index = 0; // for Ux + float max_area = 0; // the max socre area + max_box[0] = 0; // init 0 + + if (coreId == 0) { + for (int i = 0; i <= repeat; i++) { + if (i == repeat && remain == 0) { + break; + } + + int seg_len = (i == repeat) + ? remain_pad + : max_seg_pad; // the length every nms compute + // check seg_len exceeds the limit of fp16 or not. 65536 is the largest + // num + // that fp16 could express. + if (sizeof(IN_DT) == sizeof(half) && seg_len > 65536) { + return; + } + int cpy_len = (i == repeat) + ? remain + : max_seg_pad; // the length every nms memcpy + + /******NMS LOAD START******/ + __bang_write_zero(score, seg_len); + __memcpy(score, + input_score_ptr + max_index_input_offset + i * max_seg_pad, + cpy_len * sizeof(IN_DT), input_load_dir, + cpy_len * sizeof(IN_DT), cpy_len * sizeof(IN_DT), 0); + + /******NMS LOAD END******/ + + __bang_max(inter_x1, score, seg_len); + if (inter_x1[0] > max_box[0]) { + max_box[0] = inter_x1[0]; + if (sizeof(IN_DT) == sizeof(half)) { + max_index = + ((uint16_t *)inter_x1)[1] + max_index_input_offset + + i * max_seg_pad; // offset start from head of input_data + } else if (sizeof(IN_DT) == sizeof(float)) { + max_index = + ((uint32_t *)inter_x1)[1] + max_index_input_offset + + i * max_seg_pad; // offset start from head of input_data + } + } + } // for repeat + + // the max box's x1, y1, x2, y2 on every cluster + max_box[1] = input_x1_ptr[max_index]; + max_box[2] = input_y1_ptr[max_index]; + max_box[3] = input_x2_ptr[max_index]; + max_box[4] = input_y2_ptr[max_index]; + ((uint32_t *)(max_box + 5))[0] = max_index; + // copy max box info to sram + __memcpy(sram, max_box, REDUCE_NUM * sizeof(IN_DT), NRAM2SRAM); + } + __sync_all(); + // copy all partial max to the sram of cluster 0 + if (clusterId != 0) { + __memcpy(sram + REDUCE_NUM * clusterId, sram, REDUCE_NUM * sizeof(IN_DT), + SRAM2SRAM, 0); + } + __sync_all(); + + // reduce between clusters to get the global max box + if (clusterId == 0) { + if (coreId == 0) { + __bang_write_zero(inter_x1, NMS_SIZE); + __memcpy(inter_x1, sram, sizeof(IN_DT), SRAM2NRAM, sizeof(IN_DT), + REDUCE_NUM * sizeof(IN_DT), clusterDim - 1); + __bang_max(max_box, inter_x1, NMS_SIZE); + int max_cluster = (sizeof(IN_DT) == sizeof(half)) + ? ((uint16_t *)max_box)[1] + : ((uint32_t *)max_box)[1]; + __memcpy(max_box, sram + max_cluster * REDUCE_NUM, + REDUCE_NUM * sizeof(IN_DT), SRAM2NRAM); + __memcpy(sram, max_box, REDUCE_NUM * sizeof(IN_DT), NRAM2SRAM); + } + __sync_cluster(); + if (coreId == 0x80 && clusterDim > 1) { + // broadcast global max box to each cluster's sram + for (int cluster_idx = 1; cluster_idx < clusterDim; ++cluster_idx) { + __memcpy(sram, sram, REDUCE_NUM * sizeof(IN_DT), SRAM2SRAM, + cluster_idx); + } + } + __sync_cluster(); + } + __sync_all(); + + // copy the global max box to max_box + __memcpy(max_box, sram, REDUCE_NUM * sizeof(IN_DT), SRAM2NRAM); + if (algo == 0 || offset == 0.0) { + max_area = ((float)max_box[3] - (float)max_box[1]) * + ((float)max_box[4] - (float)max_box[2]); + } else { + max_area = ((float)max_box[3] - (float)max_box[1] + offset) * + ((float)max_box[4] - (float)max_box[2] + offset); + } + global_max_index = ((uint32_t *)(max_box + 5))[0]; + if (coreId != 0x80) { + input_score_ptr[global_max_index] = 0; + } + // by now, we get: max_score|max_index|max_box|max_area + /******FIND MAX END******/ + + /******NMS STORE START******/ + // store to nram + if (float(max_box[0]) > thresh_score) { + OUT_DT *save_ptr; + int save_offset = 0; + int save_str_num = 0; + save_ptr = nram_save; + save_offset = nram_save_count; + save_str_num = nram_save_limit_count; + if (clusterId == 0 && coreId == 0) { + if (output_mode == 0) { // index1, index2, ... + save_ptr[save_offset] = ((uint32_t *)(max_box + INFO_NUM))[0]; + } else if (output_mode == 1) { // score, x1, y1, x2, y2 + __memcpy(save_ptr + save_offset * INFO_NUM, max_box, + INFO_NUM * sizeof(IN_DT), NRAM2NRAM, + INFO_NUM * sizeof(IN_DT), INFO_NUM * sizeof(IN_DT), 0); + } else if (output_mode == 2) { // score---, x1---, y1---, x2---, y2--- + __memcpy(save_ptr + save_offset, max_box, 1 * sizeof(IN_DT), + NRAM2NRAM, save_str_num * sizeof(IN_DT), 1 * sizeof(IN_DT), + 4); + } + } + nram_save_count++; + output_box_num++; + } + + // store to sram/gdram + if (output_box_num != 0) { + if ((nram_save_count == nram_save_limit_count) || + (float(max_box[0]) <= thresh_score) || keep == max_output_size - 1) { + if (nram_save_count != 0) { + if (clusterId == 0 && coreId == 0) { + if (output_mode == 0) { // index1, index2, ... + pvLock(); + __memcpy(output_dram, nram_save, + nram_save_count * sizeof(uint32_t), NRAM2GDRAM); + pvUnlock(); + output_dram += nram_save_count; + } else if (output_mode == 1) { // score, x1, y1, x2, y2 + pvLock(); + __memcpy(output_dram, nram_save, + nram_save_count * INFO_NUM * sizeof(IN_DT), NRAM2GDRAM); + pvUnlock(); + output_dram += nram_save_count * INFO_NUM; + } else if (output_mode == + 2) { // score---, x1---, y1---, x2---, y2--- + pvLock(); + __memcpy(output_dram, nram_save, nram_save_count * sizeof(IN_DT), + NRAM2GDRAM, max_output_size * sizeof(IN_DT), + nram_save_limit_count * sizeof(IN_DT), 4); + pvUnlock(); + output_dram += nram_save_count; + } + nram_save_count = 0; + } + } + } // if move data nram->sram/gdram + } // if dst + + if (float(max_box[0]) <= thresh_score) { + if (clusterId == 0 && coreId == 0) { + loop_end_flag[0] = 1; // dram + } + } + __sync_all(); + if (loop_end_flag[0] == 1) { + break; + } + /******NMS STORE END******/ + + // To solve fp16 accuracy, we convert fp16 to fp32 to calculate IoU. + for (int i = 0; i <= repeat_iou_compute; i++) { + if (i == repeat_iou_compute && remain_iou_compute == 0) { + break; + } + int seg_len = (i == repeat_iou_compute) ? remain_pad_iou_compute + : max_seg_iou_compute; + int cpy_len = + (i == repeat_iou_compute) ? remain_iou_compute : max_seg_iou_compute; + + /******NMS LOAD START******/ + __nramset((float *)score, seg_len, 0.0f); + int dt_offset = 0; + if (sizeof(IN_DT) == sizeof(float)) { + __memcpy(score, input_score_ptr + input_offset + i * max_seg_pad, + cpy_len * sizeof(IN_DT), input_load_dir, + cpy_len * sizeof(IN_DT), cpy_len * sizeof(IN_DT), 0); + dt_offset = 0; + } else if (sizeof(IN_DT) == sizeof(half)) { + __nramset(x1, seg_len, half(0)); + __memcpy(x1, input_score_ptr + input_offset + i * max_seg_iou_compute, + cpy_len * sizeof(IN_DT), input_load_dir, + cpy_len * sizeof(IN_DT), cpy_len * sizeof(IN_DT), 0); + __bang_half2float((float *)score, (half *)x1, seg_len); + dt_offset = max_seg_iou_compute; + } + + __memcpy(x1 + dt_offset, + input_x1_ptr + input_offset + i * max_seg_iou_compute, + cpy_len * sizeof(IN_DT), input_load_dir, + max_seg_pad * sizeof(IN_DT), input_num_boxes * sizeof(IN_DT), 3); + /******NMS LOAD END******/ + + /******NMS COMPUTE START******/ + if (sizeof(IN_DT) == sizeof(half)) { + __bang_half2float((float *)x1, (half *)x1 + max_seg_iou_compute, + seg_len); + __bang_half2float((float *)y1, (half *)y1 + max_seg_iou_compute, + seg_len); + __bang_half2float((float *)x2, (half *)x2 + max_seg_iou_compute, + seg_len); + __bang_half2float((float *)y2, (half *)y2 + max_seg_iou_compute, + seg_len); + } + // 1、 compute IOU + // get the area_I + __nramset((float *)inter_y1, seg_len, float(max_box[1])); // max_x1 + __bang_maxequal((float *)inter_x1, (float *)x1, (float *)inter_y1, + seg_len); // inter_x1 + __nramset((float *)inter_y2, seg_len, float(max_box[3])); // max_x2 + __bang_minequal((float *)inter_x2, (float *)x2, (float *)inter_y2, + seg_len); // inter_x2 + __bang_sub((float *)inter_x1, (float *)inter_x2, (float *)inter_x1, + seg_len); + if (algo == 1 && offset != 0.0) { + __bang_add_const((float *)inter_x1, (float *)inter_x1, offset, seg_len); + } + __bang_active_relu((float *)inter_x1, (float *)inter_x1, + seg_len); // inter_w + __nramset((float *)inter_x2, seg_len, float(max_box[2])); // max_y1 + __bang_maxequal((float *)inter_y1, (float *)y1, (float *)inter_x2, + seg_len); // inter_y1 + __nramset((float *)inter_x2, seg_len, float(max_box[4])); // max_y2 + __bang_minequal((float *)inter_y2, (float *)y2, (float *)inter_x2, + seg_len); // inter_y2 + __bang_sub((float *)inter_y1, (float *)inter_y2, (float *)inter_y1, + seg_len); + if (algo == 1 && offset != 0.0) { + __bang_add_const((float *)inter_y1, (float *)inter_y1, offset, seg_len); + } + __bang_active_relu((float *)inter_y1, (float *)inter_y1, + seg_len); // inter_h + __bang_mul((float *)inter_x1, (float *)inter_x1, (float *)inter_y1, + seg_len); // area_I + // get the area of input_box: area = (x2 - x1) * (y2 - y1); + __bang_sub((float *)inter_y1, (float *)x2, (float *)x1, seg_len); + __bang_sub((float *)inter_y2, (float *)y2, (float *)y1, seg_len); + if (algo == 1 && offset != 0.0) { + __bang_add_const((float *)inter_y1, (float *)inter_y1, offset, seg_len); + __bang_add_const((float *)inter_y2, (float *)inter_y2, offset, seg_len); + } + __bang_mul((float *)inter_x2, (float *)inter_y1, (float *)inter_y2, + seg_len); // area + // get the area_U: area + max_area - area_I + __bang_add_const((float *)inter_x2, (float *)inter_x2, float(max_area), + seg_len); + __bang_sub((float *)inter_x2, (float *)inter_x2, (float *)inter_x1, + seg_len); // area_U + // 2、 select the box + // if IOU greater than thres, set the score to zero, abort it: area_U > + // area_I * (1 / thresh)? + if (thresh_iou > 0.0) { + __bang_mul_const((float *)inter_x1, (float *)inter_x1, div_thresh_iou, + seg_len); + } else { + __bang_mul_const((float *)inter_x2, (float *)inter_x2, thresh_iou, + seg_len); + } + __bang_ge((float *)inter_x1, (float *)inter_x2, (float *)inter_x1, + seg_len); + __bang_mul((float *)score, (float *)score, (float *)inter_x1, seg_len); + /******NMS COMPUTE END******/ + + if (sizeof(IN_DT) == 2) { + __bang_float2half_rd((half *)score, (float *)score, seg_len); + } + pvLock(); + __memcpy(input_score_ptr + input_offset + i * max_seg_iou_compute, score, + cpy_len * sizeof(IN_DT), input_store_dir, + cpy_len * sizeof(IN_DT), cpy_len * sizeof(IN_DT), 0); + pvUnlock(); + } // for repeat + } // for max_output_size +} + +__mlu_global__ void MLUUionXKernelNMS( + const void *input_boxes, const void *input_confidence, + const int input_num_boxes, const int input_layout, const int input_stride, + const int max_output_size, const float iou_threshold, + const float confidence_threshold, const float offset, + const cnrtDataType_t data_type_input, const int output_mode, const int algo, + void *workspace, void *result_num, void *output) { + int input_dwidth = (data_type_input == CNRT_FLOAT32) ? 4 : 2; + int32_t *loop_end_flag = + (int32_t *)((char *)workspace + + INFO_NUM * input_num_boxes * input_dwidth); + int reduce_sram_size = NFU_ALIGN_SIZE * REDUCE_NUM * input_dwidth; + int availbale_sram_size = SIZE_SRAM_BUF - reduce_sram_size; + + int cluster_score_size = input_num_boxes * input_dwidth; + int cluster_boxes_size = input_num_boxes * 4 * input_dwidth; + char *sram_score = (char *)sram_buffer + reduce_sram_size; + char *sram_boxes = + (char *)sram_buffer + reduce_sram_size + cluster_score_size; + Addr input_ram = GDRAM; + if ((cluster_score_size + cluster_boxes_size) < availbale_sram_size) { + input_ram = SRAM; + __memcpy(sram_score, input_confidence, cluster_score_size, GDRAM2SRAM); + __memcpy(sram_boxes, input_boxes, cluster_boxes_size, GDRAM2SRAM); + } else { + __memcpy(workspace, input_confidence, cluster_score_size, GDRAM2GDRAM); + } + __sync_cluster(); + uint32_t output_box_num = 0; + if (output_mode == 0) { + uint32_t *output_dram = (uint32_t *)output; + switch (data_type_input) { + default: { return; } + case CNRT_FLOAT16: { + half *score_data; + half *boxes_data; + score_data = + (input_ram == SRAM) ? (half *)sram_score : (half *)workspace; + boxes_data = + (input_ram == SRAM) ? (half *)sram_boxes : (half *)input_boxes; + nms_detection_ux(loop_end_flag, output_box_num, output_dram, score_data, + boxes_data, input_ram, input_layout, input_num_boxes, + input_stride, max_output_size, iou_threshold, + confidence_threshold, offset, output_mode, algo); + ((uint32_t *)result_num)[0] = output_box_num; + }; break; + case CNRT_FLOAT32: { + float *score_data; + float *boxes_data; + score_data = + (input_ram == SRAM) ? (float *)sram_score : (float *)workspace; + boxes_data = + (input_ram == SRAM) ? (float *)sram_boxes : (float *)input_boxes; + nms_detection_ux(loop_end_flag, output_box_num, output_dram, score_data, + boxes_data, input_ram, input_layout, input_num_boxes, + input_stride, max_output_size, iou_threshold, + confidence_threshold, offset, output_mode, algo); + ((uint32_t *)result_num)[0] = output_box_num; + }; break; + } + } else { + switch (data_type_input) { + default: { return; } + case CNRT_FLOAT16: { + half *output_dram = (half *)output; + half *score_data; + half *boxes_data; + score_data = + (input_ram == SRAM) ? (half *)sram_score : (half *)workspace; + boxes_data = + (input_ram == SRAM) ? (half *)sram_boxes : (half *)input_boxes; + nms_detection_ux(loop_end_flag, output_box_num, output_dram, score_data, + boxes_data, input_ram, input_layout, input_num_boxes, + input_stride, max_output_size, iou_threshold, + confidence_threshold, offset, output_mode, algo); + ((uint32_t *)result_num)[0] = output_box_num; + }; break; + case CNRT_FLOAT32: { + float *output_dram = (float *)output; + float *score_data; + float *boxes_data; + score_data = + (input_ram == SRAM) ? (float *)sram_score : (float *)workspace; + boxes_data = + (input_ram == SRAM) ? (float *)sram_boxes : (float *)input_boxes; + nms_detection_ux(loop_end_flag, output_box_num, output_dram, score_data, + boxes_data, input_ram, input_layout, input_num_boxes, + input_stride, max_output_size, iou_threshold, + confidence_threshold, offset, output_mode, algo); + ((uint32_t *)result_num)[0] = output_box_num; + }; break; + } + } +} + void KernelNms(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, const cnrtDataType_t data_type_input, const void *boxes_ptr, const void *scores_ptr, const int input_num_boxes, const int input_stride, const int max_output_boxes, const float iou_threshold, const float offset, void *workspace_ptr, void *output_size_ptr, void *output_ptr) { - MLUKernelNMS<<>>( - boxes_ptr, scores_ptr, input_num_boxes, input_stride, max_output_boxes, - iou_threshold, /*confidence_threshold=*/0.0, /*output_mode=*/0, - /*input_layout=*/0, workspace_ptr, output_size_ptr, output_ptr, - data_type_input, offset, /*algo=*/1); + switch (k_type) { + default: { return; } + case CNRT_FUNC_TYPE_BLOCK: + case CNRT_FUNC_TYPE_UNION1: { + MLUUnion1KernelNMS<<>>( + boxes_ptr, scores_ptr, input_num_boxes, input_stride, + max_output_boxes, iou_threshold, /*confidence_threshold=*/0.0, + /*output_mode=*/0, + /*input_layout=*/1, workspace_ptr, output_size_ptr, output_ptr, + data_type_input, offset, /*algo=*/1); + }; break; + case CNRT_FUNC_TYPE_UNION2: + case CNRT_FUNC_TYPE_UNION4: + case CNRT_FUNC_TYPE_UNION8: + case CNRT_FUNC_TYPE_UNION16: { + MLUUionXKernelNMS<<>>( + boxes_ptr, scores_ptr, input_num_boxes, /*input_layout=*/1, + input_stride, max_output_boxes, iou_threshold, + /*confidence_threshold=*/0.0, offset, data_type_input, + /*output_mode=*/0, /*algo=*/1, workspace_ptr, output_size_ptr, + output_ptr); + }; break; + } } diff --git a/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu index e11aa4c57..55df914ab 100644 --- a/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu +++ b/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu @@ -14,8 +14,7 @@ __nram__ char buffer[MAX_NRAM_SIZE]; #define ALIGN_SIZE 64 -#define MAX_ELEMENTS_FLOAT (50 * 1024) -#define MAX_ELEMENTS_HALF (100 * 1024) +#define BUFFER_SIZE (MAX_NRAM_SIZE * 480 / 512) #define ROI_OFFSET 5 #define SAMPLING_NUM 4 @@ -24,59 +23,101 @@ __nram__ char buffer[MAX_NRAM_SIZE]; namespace forward { template -__mlu_func__ void bilinearInterpolate(int input_height, int input_width, - float y, float x, T *w1, T *w2, T *w3, - T *w4, int *x_low, int *x_high, - int *y_low, int *y_high, int *empty, - T zero_sign) { - // deal with cases that inverse elements are of feature map boundary - if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) { - *empty = 1; - return; +__mlu_func__ void bilinearInterpolate( + T *tmp_sum, T *nram_in, T *offset_bottom_data, const int roi_bin_grid_h, + const int roi_bin_grid_w, const T bin_size_h, const T bin_size_w, + const int input_height, const int input_width, const int channels, + const int channel_align, const int cyc_channel, T y_pre, T x_pre, + T zero_sign_tmp, bool is_normal_c, int index) { + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + T y = (y_pre + ((T)(iy + 0.5) * bin_size_h) / (T)(roi_bin_grid_h)) <= 0.0 + ? 0.0 + : (y_pre + ((T)(iy + 0.5) * bin_size_h) / (T)(roi_bin_grid_h)); + int y_low = int(y); + int y_high; + if (y_low >= input_height - 1) { + y_high = y_low = input_height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + T ly = y - y_low; + T hy = 1.0 - ly; + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + T x = (x_pre + ((ix + 0.5) * bin_size_w) / (T)(roi_bin_grid_w)) <= 0.0 + ? 0.0 + : (x_pre + ((ix + 0.5) * bin_size_w) / (roi_bin_grid_w)); + T zero_sign = + (T)(x >= -1.0 && x <= input_width && y >= -1.0 && y <= input_height) * + zero_sign_tmp; + int x_low = int(x); + int x_high; + if (x_low >= input_width - 1) { + x_high = x_low = input_width - 1; + x = T(x_low); + } else { + x_high = x_low + 1; + } + T lx = x - x_low; + T hx = 1.0 - lx; + + T w1 = hy * hx * zero_sign; + T w2 = hy * lx * zero_sign; + T w3 = ly * hx * zero_sign; + T w4 = ly * lx * zero_sign; + + // load + int cpy_len = (x_high - x_low) * channels; + int temp_size = cyc_channel < (channels - index * cyc_channel) + ? cyc_channel + : channels - index * cyc_channel; + int cpy_size = is_normal_c ? channels * sizeof(T) : temp_size * sizeof(T); + + int32_t offset1 = (y_low * input_width + x_low) * channels; + int32_t offset2 = (y_high * input_width + x_low) * channels; + + T *tmp1 = is_normal_c + ? offset_bottom_data + offset1 + : offset_bottom_data + offset1 + cyc_channel * index; + T *tmp2 = is_normal_c + ? offset_bottom_data + offset2 + : offset_bottom_data + offset2 + cyc_channel * index; + + T *tmp_cyc1 = nram_in; + T *tmp_cyc2 = nram_in + cyc_channel; + T *tmp_cyc3 = nram_in + cyc_channel * 2; + T *tmp_cyc4 = nram_in + cyc_channel * 3; + + __asm__ volatile("sync;"); + if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) { + __nramset(nram_in, channel_align, T(0)); + } else { + __memcpy_async(tmp_cyc1, tmp1, cpy_size, GDRAM2NRAM); + __memcpy_async(tmp_cyc2, tmp1 + cpy_len, cpy_size, GDRAM2NRAM); + __memcpy_async(tmp_cyc3, tmp2, cpy_size, GDRAM2NRAM); + __memcpy_async(tmp_cyc4, tmp2 + cpy_len, cpy_size, GDRAM2NRAM); + __asm__ volatile("sync;"); + __bang_mul_const(tmp_cyc1, tmp_cyc1, w1, channel_align); + __bang_mul_const(tmp_cyc2, tmp_cyc2, w2, channel_align); + __bang_mul_const(tmp_cyc3, tmp_cyc3, w3, channel_align); + __bang_mul_const(tmp_cyc4, tmp_cyc4, w4, channel_align); + __bang_sumpool(nram_in, nram_in, cyc_channel, 1, SAMPLING_NUM, 1, + SAMPLING_NUM, 1, 1); + } + __bang_add(tmp_sum, tmp_sum, nram_in, channel_align); + } } - - if (y <= 0) y = 0; - if (x <= 0) x = 0; - - *y_low = int(y); - *x_low = int(x); - - if (*y_low >= input_height - 1) { - *y_high = *y_low = input_height - 1; - y = (T)(*y_low); - } else { - *y_high = *y_low + 1; - } - - if (*x_low >= input_width - 1) { - *x_high = *x_low = input_width - 1; - x = (T)(*x_low); - } else { - *x_high = *x_low + 1; - } - - T ly = y - *y_low; - T lx = x - *x_low; - T hy = 1.0 - ly; - T hx = 1.0 - lx; - - *w1 = hy * hx * zero_sign; - *w2 = hy * lx * zero_sign; - *w3 = ly * hx * zero_sign; - *w4 = ly * lx * zero_sign; - - return; } template -__mlu_func__ void roialignForwardKernel( - T *input, T *rois, T *output, const bool aligned, const int channels, - const int pooled_height, const int pooled_width, const int input_height, - const int input_width, const int sampling_ratio, const float spatial_scale, - const int num_rois, const int max_elements) { +__mlu_func__ void roialignForwardNpartKernel( + T *input, T *rois, T *output, T *nram_buffer, const bool aligned, + const int channels, const int pooled_height, const int pooled_width, + const int input_height, const int input_width, const int sampling_ratio, + const float spatial_scale, const int num_rois, const int max_elements) { /* * NRAM partition - * |----------------------NRAM------ -----------------| + * |----------------------NRAM------------------------| * | | * | output | * |--------------------------------------------------| @@ -92,8 +133,6 @@ __mlu_func__ void roialignForwardKernel( */ int channel_align = PAD_UP(channels, ALIGN_SIZE); - int height = 0; - int width = 0; int samp_channel_align = channel_align * SAMPLING_NUM; int samp_channel = channels * SAMPLING_NUM; @@ -103,7 +142,7 @@ __mlu_func__ void roialignForwardKernel( int offset_length; int task_length; - // the length dealt by every core and the offset of taskid + // the length dealt by every core and the offset of taskId if (taskId < rem_num) { task_length = inter_num + 1; offset_length = taskId * (inter_num + 1); @@ -112,215 +151,95 @@ __mlu_func__ void roialignForwardKernel( offset_length = rem_num * (inter_num + 1) + (taskId - rem_num) * inter_num; } - int max_size = max_elements >> 1; - T *nram_out = (T *)buffer; - T *nram_in = nram_out + max_size; - T *nram_rois = nram_in + max_elements; + int max_size = max_elements; + T *nram_out = nram_buffer; + T *nram_in = nram_out + max_size * 2; int pooled_size = pooled_height * pooled_width; - // output and roi data ptr T *top_data = output + offset_length * pooled_size * channels; T *task_rois = rois + offset_length * ROI_OFFSET; for (int roi_id = 0; roi_id < task_length; roi_id++) { // For each roi, find the corresponding feature map which it belongs to, // and compute the scaling_factor to map it to that feature map. - height = input_height; - width = input_width; T offset = aligned ? (T)0.5 : (T)0; - T *roi_id_tmp = task_rois + roi_id * ROI_OFFSET; - __bang_write_zero(nram_rois, ALIGN_SIZE); - __memcpy((void *)nram_rois, (void *)roi_id_tmp, ROI_OFFSET * sizeof(T), - GDRAM2NRAM); - int batch_id = nram_rois[0]; - T roi_xmin = nram_rois[1]; - T roi_ymin = nram_rois[2]; - T roi_xmax = nram_rois[3]; - T roi_ymax = nram_rois[4]; + int batch_id = roi_id_tmp[0]; + T roi_xmin = roi_id_tmp[1]; + T roi_ymin = roi_id_tmp[2]; + T roi_xmax = roi_id_tmp[3]; + T roi_ymax = roi_id_tmp[4]; - roi_xmin = roi_xmin * spatial_scale - offset; - roi_ymin = roi_ymin * spatial_scale - offset; - roi_xmax = roi_xmax * spatial_scale - offset; - roi_ymax = roi_ymax * spatial_scale - offset; + roi_xmin = roi_xmin * (T)spatial_scale - offset; + roi_ymin = roi_ymin * (T)spatial_scale - offset; + roi_xmax = roi_xmax * (T)spatial_scale - offset; + roi_ymax = roi_ymax * (T)spatial_scale - offset; - float roi_width = roi_xmax - roi_xmin; - float roi_height = roi_ymax - roi_ymin; + T roi_width = roi_xmax - roi_xmin; + T roi_height = roi_ymax - roi_ymin; if (!aligned) { - // Force malformed ROIs to be 1x1 - roi_width = roi_width > 1 ? roi_width : 1.0; - roi_height = roi_height > 1 ? roi_height : 1.0; + roi_width = roi_width > 1.0 ? roi_width : 1.0; + roi_height = roi_height > 1.0 ? roi_height : 1.0; } - float bin_size_h = (float)roi_height / pooled_height; - float bin_size_w = (float)roi_width / pooled_width; + T bin_size_h = roi_height / (T)pooled_height; + T bin_size_w = roi_width / (T)pooled_width; + T *offset_bottom_data = + input + batch_id * channels * input_width * input_height; - // input data ptr - T *offset_bottom_data = input + batch_id * channels * width * height; T *tmp_sum = nram_out; __bang_write_zero(nram_out, max_size); // We use roi_bin_grid to sample the grid, and perform average pooling // inside a bin. When the grid is empty, then output zeros. - int roi_bin_grid_h = - (sampling_ratio > 0) ? sampling_ratio : __float2int_up(bin_size_h); - int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : __float2int_up(bin_size_w); - float count = roi_bin_grid_h * roi_bin_grid_w; - float zero_sign_tmp = 1.0f / count; + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_h)); + int roi_bin_grid_w = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_w)); + T count = roi_bin_grid_h * roi_bin_grid_w; + T zero_sign_tmp = 1.0f / count; + bool is_normal_c = samp_channel_align < (max_elements * SAMPLING_NUM); for (int ph = 0; ph < pooled_height; ph++) { - float y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid + T y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid for (int pw = 0; pw < pooled_width; pw++) { - float x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid + T x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid // Bilinear interpolatation - if (samp_channel_align < max_elements) { - // One aligned channel data can be computed at one time - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - float y = - (y_pre + ((iy + 0.5) * bin_size_h) / (roi_bin_grid_h)) <= 0 - ? 0 - : (y_pre + - ((iy + 0.5) * bin_size_h) / - (roi_bin_grid_h)); // center_point y - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - float x = - (x_pre + ((ix + 0.5) * bin_size_w) / (roi_bin_grid_w)) <= 0 - ? 0 - : (x_pre + - ((ix + 0.5) * bin_size_w) / - (roi_bin_grid_w)); // center_point x - T zero_sign = - (T)(x >= -1.0 && x <= width && y >= -1.0 && y <= height) * - zero_sign_tmp; - - int empty = 0; - T w1, w2, w3, w4; - int x_low, x_high, y_low, y_high; - - bilinearInterpolate(input_height, input_width, y, x, &w1, &w2, - &w3, &w4, &x_low, &x_high, &y_low, &y_high, - &empty, zero_sign); - - // load - int cpy_len = (x_high - x_low) * channels; - int cpy_size = channels * sizeof(T); - - int offset1 = (y_low * width + x_low) * channels; - int offset2 = (y_high * width + x_low) * channels; - - T *tmp1 = offset_bottom_data + offset1; - T *tmp2 = offset_bottom_data + offset2; - - T *tmp_cyc1 = nram_in; - T *tmp_cyc2 = nram_in + channel_align; - T *tmp_cyc3 = nram_in + channel_align * 2; - T *tmp_cyc4 = nram_in + channel_align * 3; - __asm__ volatile("sync;"); - if (empty == 1) { - __nramset(nram_in, channel_align, T(0)); - } else { - // load gdram to nram - __memcpy_async(tmp_cyc1, tmp1, cpy_size, GDRAM2NRAM); - __memcpy_async(tmp_cyc2, tmp1 + cpy_len, cpy_size, GDRAM2NRAM); - __memcpy_async(tmp_cyc3, tmp2, cpy_size, GDRAM2NRAM); - __memcpy_async(tmp_cyc4, tmp2 + cpy_len, cpy_size, GDRAM2NRAM); - __asm__ volatile("sync;"); - // roialign_forward compute - __bang_mul_const(tmp_cyc1, tmp_cyc1, w1, channel_align); - __bang_mul_const(tmp_cyc2, tmp_cyc2, w2, channel_align); - __bang_mul_const(tmp_cyc3, tmp_cyc3, w3, channel_align); - __bang_mul_const(tmp_cyc4, tmp_cyc4, w4, channel_align); - __bang_sumpool(nram_in, nram_in, channel_align, 1, SAMPLING_NUM, - 1, SAMPLING_NUM, 1, 1); - } - __bang_add(tmp_sum, tmp_sum, nram_in, channel_align); - } - } + if (is_normal_c) { + bilinearInterpolate( + (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, + roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, + input_height, input_width, channels, channel_align, channel_align, + y_pre, x_pre, zero_sign_tmp, is_normal_c, 0); } else { // One aligned channel data cannot be computed at one time - int cyc_num = samp_channel / max_elements + - (int)(samp_channel % max_elements != 0); - int cyc_channel = max_elements / SAMPLING_NUM; + int cyc_num = + samp_channel / (max_elements * SAMPLING_NUM) + + (int)(samp_channel % (max_elements * SAMPLING_NUM) != 0); + int cyc_channel = max_elements; for (int i = 0; i < cyc_num; ++i) { - int real_channel = - (i == cyc_num - 1) ? channels - i * cyc_channel : cyc_channel; - int align_channel = - (i == cyc_num - 1) - ? PAD_UP((channel_align - i * cyc_channel), ALIGN_SIZE) - : cyc_channel; - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - float y = - (y_pre + ((iy + 0.5) * bin_size_h) / (roi_bin_grid_h)) <= 0 - ? 0 - : (y_pre + - ((iy + 0.5) * bin_size_h) / - (roi_bin_grid_h)); // center_point y - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - float x = - (x_pre + ((ix + 0.5) * bin_size_w) / (roi_bin_grid_w)) <= 0 - ? 0 - : (x_pre + - ((ix + 0.5) * bin_size_w) / - (roi_bin_grid_w)); // center_point x - - T zero_sign = - (T)(x >= -1.0 && x <= width && y >= -1.0 && y <= height) * - zero_sign_tmp; - - int empty = 0; - T w1, w2, w3, w4; - int x_low, x_high, y_low, y_high; - - bilinearInterpolate(input_height, input_width, y, x, &w1, &w2, - &w3, &w4, &x_low, &x_high, &y_low, &y_high, - &empty, zero_sign); - - // load - int cpy_len = (x_high - x_low) * channels; - - int offset1 = (y_low * width + x_low) * channels; - int offset2 = (y_high * width + x_low) * channels; - - T *tmp1 = offset_bottom_data + offset1 + cyc_channel * i; - T *tmp2 = offset_bottom_data + offset2 + cyc_channel * i; - - T *tmp_cyc1 = nram_in; - T *tmp_cyc2 = nram_in + cyc_channel; - T *tmp_cyc3 = nram_in + cyc_channel * 2; - T *tmp_cyc4 = nram_in + cyc_channel * 3; - __asm__ volatile("sync;"); - if (empty == 1) { // exits abnormal values - __nramset(nram_in, align_channel, T(0)); - } else { - __memcpy_async(tmp_cyc1, tmp1, align_channel * sizeof(T), - GDRAM2NRAM); - __memcpy_async(tmp_cyc2, tmp1 + cpy_len, - align_channel * sizeof(T), GDRAM2NRAM); - __memcpy_async(tmp_cyc3, tmp2, align_channel * sizeof(T), - GDRAM2NRAM); - __memcpy_async(tmp_cyc4, tmp2 + cpy_len, - align_channel * sizeof(T), GDRAM2NRAM); - __asm__ volatile("sync;"); - __bang_mul_const(tmp_cyc1, tmp_cyc1, w1, align_channel); - __bang_mul_const(tmp_cyc2, tmp_cyc2, w2, align_channel); - __bang_mul_const(tmp_cyc3, tmp_cyc3, w3, align_channel); - __bang_mul_const(tmp_cyc4, tmp_cyc4, w4, align_channel); - __bang_sumpool(nram_in, nram_in, cyc_channel, 1, SAMPLING_NUM, - 1, SAMPLING_NUM, 1, 1); - } - __bang_add(tmp_sum, tmp_sum, nram_in, align_channel); - } - } + int real_channel = cyc_channel < (channels - i * cyc_channel) + ? cyc_channel + : channels - i * cyc_channel; + int align_channel = (i == cyc_num - 1) + ? PAD_UP(real_channel, ALIGN_SIZE) + : cyc_channel; + bilinearInterpolate( + (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, + roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, + input_height, input_width, channels, align_channel, cyc_channel, + y_pre, x_pre, zero_sign_tmp, is_normal_c, i); __memcpy(top_data + cyc_channel * i, tmp_sum, real_channel * sizeof(T), NRAM2GDRAM); __bang_write_zero(nram_out, max_size); } } // copy output data to ddr when channel num is not aligned with 64 - if (samp_channel_align < max_elements) { + if (is_normal_c) { __memcpy(top_data, nram_out, channels * sizeof(T), NRAM2GDRAM); __bang_write_zero(nram_out, max_size); } @@ -330,25 +249,665 @@ __mlu_func__ void roialignForwardKernel( } // loop for num_roi } +template +__mlu_func__ void roialignForwardHpartKernel( + T *input, T *rois, T *output, T *nram_buffer, const bool aligned, + const int channels, const int pooled_height, const int pooled_width, + const int input_height, const int input_width, const int sampling_ratio, + const float spatial_scale, const int num_rois, const int max_elements) { + int channel_align = PAD_UP(channels, ALIGN_SIZE); + int samp_channel_align = channel_align * SAMPLING_NUM; + int samp_channel = channels * SAMPLING_NUM; + int taskdim_cyc = taskDim / num_rois > 1 ? taskDim / num_rois : 1; + int roi_id = taskId / taskdim_cyc; + if (taskId >= taskdim_cyc * num_rois) { + return; + } + + // multi-core params + int inter_num = pooled_height / taskdim_cyc; + int rem_num = pooled_height % taskdim_cyc; + int offset_length; + int task_length; + + if ((taskId % taskdim_cyc) < rem_num) { + task_length = inter_num + 1; + offset_length = (taskId % taskdim_cyc) * (inter_num + 1); + } else { + task_length = inter_num; + offset_length = rem_num * (inter_num + 1) + + ((taskId % taskdim_cyc) - rem_num) * inter_num; + } + + int max_size = max_elements * 2; + T *nram_out = nram_buffer; + T *nram_in = nram_out + max_size; + + int pooled_size = pooled_height * pooled_width; + T *top_data = + output + (roi_id * pooled_size + offset_length * pooled_width) * channels; + T offset = aligned ? (T)0.5 : (T)0; + T *roi_id_tmp = rois + roi_id * ROI_OFFSET; + + int batch_id = roi_id_tmp[0]; + T roi_xmin = roi_id_tmp[1]; + T roi_ymin = roi_id_tmp[2]; + T roi_xmax = roi_id_tmp[3]; + T roi_ymax = roi_id_tmp[4]; + + roi_xmin = roi_xmin * (T)spatial_scale - offset; + roi_ymin = roi_ymin * (T)spatial_scale - offset; + roi_xmax = roi_xmax * (T)spatial_scale - offset; + roi_ymax = roi_ymax * (T)spatial_scale - offset; + + T roi_width = roi_xmax - roi_xmin; + T roi_height = roi_ymax - roi_ymin; + + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = roi_width > 1 ? roi_width : 1.0; + roi_height = roi_height > 1 ? roi_height : 1.0; + } + + T bin_size_h = roi_height / (T)pooled_height; + T bin_size_w = roi_width / (T)pooled_width; + T *offset_bottom_data = + input + batch_id * channels * input_width * input_height; + + T *tmp_sum = nram_out; + __bang_write_zero(nram_out, max_size); + + // We use roi_bin_grid to sample the grid, and perform average pooling + // inside a bin. When the grid is empty, then output zeros. + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_h)); + int roi_bin_grid_w = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_w)); + T count = roi_bin_grid_h * roi_bin_grid_w; + T zero_sign_tmp = 1.0f / count; + bool is_normal_c = samp_channel_align < (max_elements * SAMPLING_NUM); + + for (int ph = offset_length; ph < (offset_length + task_length); ph++) { + T y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid + for (int pw = 0; pw < pooled_width; pw++) { + T x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid + // Bilinear interpolatation + if (is_normal_c) { + bilinearInterpolate((T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, + roi_bin_grid_h, roi_bin_grid_w, bin_size_h, + bin_size_w, input_height, input_width, channels, + channel_align, channel_align, y_pre, x_pre, + zero_sign_tmp, is_normal_c, 0); + } else { + // One aligned channel data cannot be computed at one time + int cyc_num = samp_channel / (max_elements * SAMPLING_NUM) + + (int)(samp_channel % (max_elements * SAMPLING_NUM) != 0); + int cyc_channel = max_elements; + for (int i = 0; i < cyc_num; ++i) { + int real_channel = cyc_channel < (channels - i * cyc_channel) + ? cyc_channel + : channels - i * cyc_channel; + int align_channel = (i == cyc_num - 1) + ? PAD_UP(real_channel, ALIGN_SIZE) + : cyc_channel; + bilinearInterpolate( + (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, + roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, + input_height, input_width, channels, align_channel, cyc_channel, + y_pre, x_pre, zero_sign_tmp, is_normal_c, i); + + __memcpy(top_data + cyc_channel * i, tmp_sum, + real_channel * sizeof(T), NRAM2GDRAM); + __bang_write_zero(nram_out, max_size); + } + } + // copy output data to ddr when channel num is not aligned with 64 + if (is_normal_c) { + __memcpy(top_data, nram_out, channels * sizeof(T), NRAM2GDRAM); + __bang_write_zero(nram_out, max_size); + } + top_data += channels; + } // loop for pw + } // loop for ph +} + __mlu_global__ void MLUUnion1KernelRoialign( const void *input, const void *rois, const int channels, const bool aligned, const int pooled_height, const int pooled_width, const int input_height, const int input_width, const int sampling_ratio, const float spatial_scale, const int num_rois, const cnrtDataType_t data_type, void *output) { - int max_elements = - (data_type == CNRT_FLOAT32) ? MAX_ELEMENTS_FLOAT : MAX_ELEMENTS_HALF; + size_t data_type_size = + (data_type == CNRT_FLOAT32) ? sizeof(float) : sizeof(half); + int max_elements = PAD_DOWN( + (BUFFER_SIZE / (int)data_type_size) / (ROI_OFFSET + 1), ALIGN_SIZE); + + if (taskDim < num_rois || (num_rois * pooled_height < taskDim)) { + switch (data_type) { + case CNRT_FLOAT16: { + half *nram_buffer = (half *)buffer; + roialignForwardNpartKernel( + (half *)input, (half *)rois, (half *)output, (half *)nram_buffer, + aligned, channels, pooled_height, pooled_width, input_height, + input_width, sampling_ratio, spatial_scale, num_rois, max_elements); + }; break; + case CNRT_FLOAT32: { + float *nram_buffer = (float *)buffer; + roialignForwardNpartKernel( + (float *)input, (float *)rois, (float *)output, + (float *)nram_buffer, aligned, channels, pooled_height, + pooled_width, input_height, input_width, sampling_ratio, + spatial_scale, num_rois, max_elements); + }; break; + default: + break; + } + } else { + switch (data_type) { + case CNRT_FLOAT16: { + half *nram_buffer = (half *)buffer; + roialignForwardHpartKernel( + (half *)input, (half *)rois, (half *)output, (half *)nram_buffer, + aligned, channels, pooled_height, pooled_width, input_height, + input_width, sampling_ratio, spatial_scale, num_rois, max_elements); + }; break; + case CNRT_FLOAT32: { + float *nram_buffer = (float *)buffer; + roialignForwardHpartKernel( + (float *)input, (float *)rois, (float *)output, + (float *)nram_buffer, aligned, channels, pooled_height, + pooled_width, input_height, input_width, sampling_ratio, + spatial_scale, num_rois, max_elements); + }; break; + default: + break; + } + } + return; +} + +template +__mlu_func__ void buSelection(T *rois_count, T *nram_temp, const int num_rois) { + for (int i = 0; i < num_rois; ++i) { + for (int j = 1; j < num_rois; ++j) { + if (rois_count[(j - 1) * 2] < rois_count[j * 2]) { + nram_temp[0] = rois_count[(j - 1) * 2]; + rois_count[(j - 1) * 2] = rois_count[j * 2]; + rois_count[j * 2] = nram_temp[0]; + nram_temp[1] = rois_count[(j - 1) * 2 + 1]; + rois_count[(j - 1) * 2 + 1] = rois_count[j * 2 + 1]; + rois_count[j * 2 + 1] = nram_temp[1]; + } + } + } +} + +template +__mlu_func__ void getPatitionList(T *h_nram, T *n_nram, T *roi_count, + int pooled_height, int num_rois, T sum, + int split_num, int &h_flag, int &n_flag) { + T avg_sum = sum / split_num; + T *h_nram_temp = h_nram; + T *n_nram_temp = n_nram; + + int n_index = 0; + T n_sum = 0; + h_flag = 0; + n_flag = 0; + int list_align = PAD_UP(ALIGN_SIZE * 5, ALIGN_SIZE); + __bang_write_zero(h_nram, list_align); + for (int i = 0; i < num_rois; i++) { + if (roi_count[2 * i] >= avg_sum) { + int h_num = std::ceil(roi_count[2 * i] / avg_sum); + int h_split = pooled_height / h_num; + int h_rem = pooled_height % h_num; + T h_sum = 0.0; + + for (int j = 0; j < h_num; j++) { + h_nram_temp[0] = i; + h_nram_temp[1] = h_sum; + h_nram_temp[2] = (j < h_rem) ? (h_split + 1) : h_split; + h_sum += h_nram_temp[2]; + h_nram_temp += 3; + n_nram_temp += 2; + h_flag++; + } + } else { + if (roi_count[2 * i] + n_sum > avg_sum) { + n_nram_temp[0] = i - n_index; + n_nram_temp[1] = i - 1; + n_sum = 0.0; + n_index = 0; + n_nram_temp += 2; + i--; + n_flag++; + } else { + n_index++; + n_sum += roi_count[2 * i]; + } + } + } + if (n_flag == 0 && n_index != 0) { + n_flag = 1; + n_nram[(h_flag + n_flag - 1) * 2] = num_rois - 1; + } + + n_nram[(h_flag + n_flag) * 2 - 1] = num_rois - 1; + + if (h_flag + n_flag > taskDim) { + getPatitionList(h_nram, n_nram, roi_count, pooled_height, num_rois, sum, + split_num - 1, h_flag, n_flag); + } + return; +} + +template +__mlu_func__ void mergeAndSplitQuantity( + T *rois, T *rois_sort, T *split_list, T *roi_count, T *nram_rois, + const bool aligned, const int pooled_height, const int pooled_width, + const int sampling_ratio, const float spatial_scale, const int num_rois, + int &h_split_num, int &n_split_num) { + /* take the coordinates out of ROIS and actually calculate the actual + * calculation size. The sorted calculation scale is partition, large scale + * is split H, small is N. + */ + T *h_tem = split_list; + T *n_tem = split_list + 3 * ALIGN_SIZE; + int num_rois_align = PAD_UP(num_rois * (ROI_OFFSET + 1), ALIGN_SIZE); + int count_align = PAD_UP(num_rois * 2, ALIGN_SIZE); + __bang_write_zero(nram_rois, num_rois_align); + T sum = 0.0; + int temp_offset = 0; + __memcpy((void *)(nram_rois + 1), (void *)rois, ROI_OFFSET * sizeof(T), + GDRAM2NRAM, (ROI_OFFSET + 1) * sizeof(T), ROI_OFFSET * sizeof(T), + (num_rois - 1)); + T *nram_temp = roi_count + count_align; + for (int roi_id = 0; roi_id < num_rois; roi_id++) { + T offset = aligned ? (T)0.5 : (T)0; + + T roi_xmin = nram_rois[temp_offset + 2]; + T roi_ymin = nram_rois[temp_offset + 3]; + T roi_xmax = nram_rois[temp_offset + 4]; + T roi_ymax = nram_rois[temp_offset + 5]; + + roi_xmin = roi_xmin * (T)spatial_scale - offset; + roi_ymin = roi_ymin * (T)spatial_scale - offset; + roi_xmax = roi_xmax * (T)spatial_scale - offset; + roi_ymax = roi_ymax * (T)spatial_scale - offset; + + T roi_width = roi_xmax - roi_xmin; + T roi_height = roi_ymax - roi_ymin; + + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = roi_width > 1 ? roi_width : 1.0; + roi_height = roi_height > 1 ? roi_height : 1.0; + } + + T bin_size_h = roi_height / (T)pooled_height; + T bin_size_w = roi_width / (T)pooled_width; + + // We use roi_bin_grid to sample the grid, and perform average pooling + // inside a bin. When the grid is empty, then output zeros. + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_h)); + int roi_bin_grid_w = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_w)); + T count = roi_bin_grid_h * roi_bin_grid_w; + sum += count; + *(roi_count + 2 * roi_id) = count; + *(roi_count + 2 * roi_id + 1) = roi_id; + + *(nram_rois + roi_id * (ROI_OFFSET + 1)) = count; + temp_offset += (ROI_OFFSET + 1); + } + + buSelection(roi_count, nram_temp, num_rois); + + temp_offset = 0; + for (int i = 0; i < num_rois; i++) { + for (int j = 0; j < num_rois; j++) { + if (roi_count[2 * i] == nram_rois[j * (ROI_OFFSET + 1)]) { + rois_sort[temp_offset] = nram_rois[j * (ROI_OFFSET + 1)]; + rois_sort[temp_offset + 1] = nram_rois[j * (ROI_OFFSET + 1) + 1]; + rois_sort[temp_offset + 2] = nram_rois[j * (ROI_OFFSET + 1) + 2]; + rois_sort[temp_offset + 3] = nram_rois[j * (ROI_OFFSET + 1) + 3]; + rois_sort[temp_offset + 4] = nram_rois[j * (ROI_OFFSET + 1) + 4]; + rois_sort[temp_offset + 5] = nram_rois[j * (ROI_OFFSET + 1) + 5]; + nram_rois[j * (ROI_OFFSET + 1)] = -1.0; + break; + } + } + temp_offset += (ROI_OFFSET + 1); + } + getPatitionList(h_tem, n_tem, roi_count, pooled_height, num_rois, sum, + taskDim, h_split_num, n_split_num); +} + +template +__mlu_func__ void roialignForwardNpartKernelForBinPart( + T *input, T *rois, T *output, T *rois_sort, T *split_list, T *rois_count, + T *nram_buffer, const bool aligned, const int channels, + const int pooled_height, const int pooled_width, const int input_height, + const int input_width, const int sampling_ratio, const float spatial_scale, + const int num_rois, const int max_size) { + int channel_align = PAD_UP(channels, ALIGN_SIZE); + int samp_channel_align = channel_align * SAMPLING_NUM; + int samp_channel = channels * SAMPLING_NUM; + int max_elements = max_size * SAMPLING_NUM; + int offset_length; + int task_length; + + T *n_split_nram = split_list + 3 * ALIGN_SIZE + 2 * taskId; + T *nram_out = nram_buffer; + T *nram_in = nram_out + max_size; + T *task_rois = rois_sort + (int)n_split_nram[0] * (ROI_OFFSET + 1); + + offset_length = (int)n_split_nram[0]; + task_length = n_split_nram[1] - n_split_nram[0] + 1; + int pooled_size = pooled_height * pooled_width; + + for (int roi_id = offset_length; roi_id < offset_length + task_length; + roi_id++) { + // For each roi, find the corresponding feature map which it belongs to, + // and compute the scaling_factor to map it to that feature map. + T offset = aligned ? (T)0.5 : (T)0; + int rea_out_id = rois_count[roi_id * 2 + 1]; + T *top_data = output + rea_out_id * pooled_size * channels; + T *nram_rois = task_rois + (roi_id - offset_length) * (ROI_OFFSET + 1); + + int batch_id = nram_rois[1]; + T roi_xmin = nram_rois[2]; + T roi_ymin = nram_rois[3]; + T roi_xmax = nram_rois[4]; + T roi_ymax = nram_rois[5]; + + roi_xmin = roi_xmin * (T)spatial_scale - offset; + roi_ymin = roi_ymin * (T)spatial_scale - offset; + roi_xmax = roi_xmax * (T)spatial_scale - offset; + roi_ymax = roi_ymax * (T)spatial_scale - offset; + + T roi_width = roi_xmax - roi_xmin; + T roi_height = roi_ymax - roi_ymin; + + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = roi_width > 1.0 ? roi_width : 1.0; + roi_height = roi_height > 1.0 ? roi_height : 1.0; + } + + T bin_size_h = roi_height / (T)pooled_height; + T bin_size_w = roi_width / (T)pooled_width; + T *offset_bottom_data = + input + batch_id * channels * input_width * input_height; + + T *tmp_sum = nram_out; + __bang_write_zero(nram_in, max_elements); + __bang_write_zero(nram_out, max_size); + + // We use roi_bin_grid to sample the grid, and perform average pooling + // inside a bin. When the grid is empty, then output zeros. + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_h)); + int roi_bin_grid_w = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_w)); + T count = roi_bin_grid_h * roi_bin_grid_w; + T zero_sign_tmp = 1.0f / count; + bool is_normal_c = samp_channel_align < max_elements; + + for (int ph = 0; ph < pooled_height; ph++) { + T y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid + for (int pw = 0; pw < pooled_width; pw++) { + T x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid + // Bilinear interpolatation + if (is_normal_c) { + bilinearInterpolate( + (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, + roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, + input_height, input_width, channels, channel_align, channel_align, + y_pre, x_pre, zero_sign_tmp, is_normal_c, 0); + } else { + // One aligned channel data cannot be computed at one time + int cyc_num = samp_channel / max_elements + + (int)(samp_channel % max_elements != 0); + int cyc_channel = max_elements / SAMPLING_NUM; + for (int i = 0; i < cyc_num; ++i) { + int real_channel = cyc_channel < (channels - i * cyc_channel) + ? cyc_channel + : channels - i * cyc_channel; + int align_channel = (i == cyc_num - 1) + ? PAD_UP(real_channel, ALIGN_SIZE) + : cyc_channel; + + bilinearInterpolate( + (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, + roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, + input_height, input_width, channels, align_channel, cyc_channel, + y_pre, x_pre, zero_sign_tmp, is_normal_c, i); + + __memcpy(top_data + cyc_channel * i, tmp_sum, + real_channel * sizeof(T), NRAM2GDRAM); + __bang_write_zero(nram_out, max_size); + } + } + // copy output data to ddr when channel num is not aligned with 64 + if (is_normal_c) { + __memcpy(top_data, nram_out, channels * sizeof(T), NRAM2GDRAM); + __bang_write_zero(nram_out, max_size); + } + top_data += channels; + } // loop for pw + } // loop for ph + } // loop for num_roi +} + +template +__mlu_func__ void roialignForwardHpartKernelForBinPart( + T *input, T *rois, T *output, T *rois_sort, T *split_list, T *rois_count, + T *nram_buffer, const bool aligned, const int channels, + const int pooled_height, const int pooled_width, const int input_height, + const int input_width, const int sampling_ratio, const float spatial_scale, + const int num_rois, const int max_size) { + int channel_align = PAD_UP(channels, ALIGN_SIZE); + int samp_channel_align = channel_align * SAMPLING_NUM; + int samp_channel = channels * SAMPLING_NUM; + int max_elements = max_size * SAMPLING_NUM; + + T *h_split_nram = split_list; + T *nram_out = nram_buffer; + T *nram_in = nram_out + max_size; + T *nram_rois = rois_sort + (int)h_split_nram[taskId * 3] * (ROI_OFFSET + 1); + + int offset_length = (int)h_split_nram[taskId * 3 + 1]; + int task_length = (int)h_split_nram[taskId * 3 + 2]; + int rea_out_id = (int)h_split_nram[taskId * 3]; + + rea_out_id = rois_count[rea_out_id * 2 + 1]; + int pooled_size = pooled_height * pooled_width; + T *top_data = + output + + (rea_out_id * pooled_size + offset_length * pooled_width) * channels; + + T offset = aligned ? (T)0.5 : (T)0; + + int batch_id = nram_rois[1]; + T roi_xmin = nram_rois[2]; + T roi_ymin = nram_rois[3]; + T roi_xmax = nram_rois[4]; + T roi_ymax = nram_rois[5]; + + roi_xmin = roi_xmin * (T)spatial_scale - offset; + roi_ymin = roi_ymin * (T)spatial_scale - offset; + roi_xmax = roi_xmax * (T)spatial_scale - offset; + roi_ymax = roi_ymax * (T)spatial_scale - offset; + + T roi_width = roi_xmax - roi_xmin; + T roi_height = roi_ymax - roi_ymin; + + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = roi_width > 1 ? roi_width : 1.0; + roi_height = roi_height > 1 ? roi_height : 1.0; + } + + T bin_size_h = roi_height / (T)pooled_height; + T bin_size_w = roi_width / (T)pooled_width; + T *offset_bottom_data = + input + batch_id * channels * input_width * input_height; + + T *tmp_sum = nram_out; + __bang_write_zero(nram_in, max_elements); + __bang_write_zero(nram_out, max_size); + + // We use roi_bin_grid to sample the grid, and perform average pooling + // inside a bin. When the grid is empty, then output zeros. + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_h)); + int roi_bin_grid_w = (sampling_ratio > 0) + ? sampling_ratio + : (int)std::ceil((float)(bin_size_w)); + T count = roi_bin_grid_h * roi_bin_grid_w; + T zero_sign_tmp = 1.0f / count; + bool is_normal_c = samp_channel_align < max_elements; + + for (int ph = offset_length; ph < (offset_length + task_length); ph++) { + T y_pre = roi_ymin + ph * bin_size_h; // ymin in each grid + for (int pw = 0; pw < pooled_width; pw++) { + T x_pre = roi_xmin + pw * bin_size_w; // xmin in each grid + // Bilinear interpolatation + if (is_normal_c) { + bilinearInterpolate((T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, + roi_bin_grid_h, roi_bin_grid_w, bin_size_h, + bin_size_w, input_height, input_width, channels, + channel_align, channel_align, y_pre, x_pre, + zero_sign_tmp, is_normal_c, 0); + } else { + // One aligned channel data cannot be computed at one time + int cyc_num = samp_channel / max_elements + + (int)(samp_channel % max_elements != 0); + int cyc_channel = max_elements / SAMPLING_NUM; + for (int i = 0; i < cyc_num; ++i) { + int real_channel = cyc_channel < (channels - i * cyc_channel) + ? cyc_channel + : channels - i * cyc_channel; + int align_channel = (i == cyc_num - 1) + ? PAD_UP(real_channel, ALIGN_SIZE) + : cyc_channel; + bilinearInterpolate( + (T *)tmp_sum, (T *)nram_in, (T *)offset_bottom_data, + roi_bin_grid_h, roi_bin_grid_w, bin_size_h, bin_size_w, + input_height, input_width, channels, align_channel, cyc_channel, + y_pre, x_pre, zero_sign_tmp, is_normal_c, i); + + __memcpy(top_data + cyc_channel * i, tmp_sum, + real_channel * sizeof(T), NRAM2GDRAM); + __bang_write_zero(nram_out, max_size); + } + } + + // copy output data to ddr when channel num is not aligned with 64 + if (is_normal_c) { + __memcpy(top_data, nram_out, channels * sizeof(T), NRAM2GDRAM); + __bang_write_zero(nram_out, max_size); + } + top_data += channels; + } // loop for pw + } // loop for ph +} + +__mlu_global__ void MLUUnion1KernelBinPartRoialign( + const void *input, const void *rois, const int channels, const bool aligned, + const int pooled_height, const int pooled_width, const int input_height, + const int input_width, const int sampling_ratio, const float spatial_scale, + const int num_rois, const cnrtDataType_t data_type, void *output) { + int h_split_num = 0; + int n_split_num = 0; + int num_rois_align = PAD_UP(num_rois * (ROI_OFFSET + 4), ALIGN_SIZE); + int count_align = PAD_UP(num_rois * 2, ALIGN_SIZE); + int list_align = ALIGN_SIZE * 5; + int sum_size = num_rois_align + count_align + list_align; + + if (coreId == 0x80) { + return; + } + switch (data_type) { case CNRT_FLOAT16: { - roialignForwardKernel((half *)input, (half *)rois, (half *)output, - aligned, channels, pooled_height, pooled_width, - input_height, input_width, sampling_ratio, - spatial_scale, num_rois, max_elements); + int max_channel = + PAD_DOWN((BUFFER_SIZE / sizeof(half) - sum_size) / (ROI_OFFSET + 1), + ALIGN_SIZE); + half *rois_sort = (half *)buffer; + __bang_write_zero(rois_sort, sum_size); + half *rois_count = (half *)(rois_sort + num_rois_align); + half *split_list = (half *)(rois_count + count_align); + half *nram_rois = (half *)(split_list + list_align); + mergeAndSplitQuantity((half *)rois, (half *)rois_sort, (half *)split_list, + (half *)rois_count, (half *)nram_rois, aligned, + pooled_height, pooled_width, sampling_ratio, + spatial_scale, num_rois, h_split_num, n_split_num); + half *nram_buffer = (half *)nram_rois; + __bang_write_zero(nram_rois, num_rois_align); + + if (taskId < h_split_num) { + roialignForwardHpartKernelForBinPart( + (half *)input, (half *)rois, (half *)output, (half *)rois_sort, + (half *)split_list, (half *)rois_count, (half *)nram_buffer, + aligned, channels, pooled_height, pooled_width, input_height, + input_width, sampling_ratio, spatial_scale, num_rois, max_channel); + } else { + if (n_split_num > 0 && (n_split_num + h_split_num) > taskId) { + roialignForwardNpartKernelForBinPart( + (half *)input, (half *)rois, (half *)output, (half *)rois_sort, + (half *)split_list, (half *)rois_count, (half *)nram_buffer, + aligned, channels, pooled_height, pooled_width, input_height, + input_width, sampling_ratio, spatial_scale, num_rois, + max_channel); + } else { + return; + } + } }; break; case CNRT_FLOAT32: { - roialignForwardKernel((float *)input, (float *)rois, (float *)output, - aligned, channels, pooled_height, pooled_width, - input_height, input_width, sampling_ratio, - spatial_scale, num_rois, max_elements); + int max_channel = + PAD_DOWN((BUFFER_SIZE / sizeof(float) - sum_size) / (ROI_OFFSET + 1), + ALIGN_SIZE); + float *rois_sort = (float *)buffer; + __bang_write_zero(rois_sort, sum_size); + float *rois_count = (float *)(rois_sort + num_rois_align); + float *split_list = (float *)(rois_count + count_align); + float *nram_rois = (float *)(split_list + list_align); + mergeAndSplitQuantity((float *)rois, (float *)rois_sort, + (float *)split_list, (float *)rois_count, + (float *)nram_rois, aligned, pooled_height, + pooled_width, sampling_ratio, spatial_scale, + num_rois, h_split_num, n_split_num); + float *nram_buffer = (float *)nram_rois; + __bang_write_zero(nram_rois, num_rois_align); + + if (taskId < h_split_num) { + roialignForwardHpartKernelForBinPart( + (float *)input, (float *)rois, (float *)output, (float *)rois_sort, + (float *)split_list, (float *)rois_count, (float *)nram_buffer, + aligned, channels, pooled_height, pooled_width, input_height, + input_width, sampling_ratio, spatial_scale, num_rois, max_channel); + } else { + if (n_split_num > 0 && (n_split_num + h_split_num) > taskId) { + roialignForwardNpartKernelForBinPart( + (float *)input, (float *)rois, (float *)output, + (float *)rois_sort, (float *)split_list, (float *)rois_count, + (float *)nram_buffer, aligned, channels, pooled_height, + pooled_width, input_height, input_width, sampling_ratio, + spatial_scale, num_rois, max_channel); + } else { + return; + } + } }; break; default: break; @@ -400,14 +959,17 @@ __mlu_func__ void unionRoiAlignBp( const int wi, const int c, const int no, const int ho, const int wo, const float spatial_scale, const int sampling_ratio, const bool aligned) { int c_align = PAD_UP(c, NFU_ALIGN_SIZE / sizeof(T)); - int deal_this_core = - boxes_num / taskDim + (int)(taskId < boxes_num % taskDim); + int deal_all = boxes_num * hi * wi; + int deal_this_core = deal_all / taskDim + (int)(taskId < deal_all % taskDim); for (int i = 0; i < deal_this_core; ++i) { - int box_id = i * taskDim + taskId; - T *box = boxes + box_id * DIM_BOX; - T *grads_offset = grads + box_id * hi * wi * c; + int bhw_id = i * taskDim + taskId; + int box_id = bhw_id / (hi * wi); + int ih = (bhw_id / wi) % hi; + int iw = bhw_id % wi; + T *box = boxes + box_id * 5; int image_id = (int)box[0]; T *image_offset = grads_image + image_id * ho * wo * c; + T *grads_ = grads + box_id * hi * wi * c + ih * wi * c + iw * c; float offset = aligned ? 0.5 : 0.0; float x1 = box[1] * spatial_scale - offset; @@ -427,108 +989,113 @@ __mlu_func__ void unionRoiAlignBp( (sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_height / hi); int roi_grid_w = (sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_width / wi); - const int count = roi_grid_h * roi_grid_w; - if (c_align * sizeof(T) * BLOCK_INPUT_OUTPUT <= MAX_NRAM_SIZE) { - for (int ih = 0; ih < hi; ++ih) { - for (int iw = 0; iw < wi; ++iw) { - T *grads_ = grads_offset + ih * wi * c + iw * c; - for (int iy = 0; iy < roi_grid_h; ++iy) { - const float y = - y1 + ih * bin_size_h + (iy + 0.5) * bin_size_h / roi_grid_h; - for (int ix = 0; ix < roi_grid_w; ++ix) { - const float x = - x1 + iw * bin_size_w + (ix + 0.5) * bin_size_w / roi_grid_w; - float w1, w2, w3, w4; - int x_low, x_high, y_low, y_high; - bilinearInterpolateGradient(ho, wo, y, x, &w1, &w2, &w3, &w4, - &x_low, &x_high, &y_low, &y_high); - if (x_low >= 0 && y_low >= 0) { - __memcpy(buffer, grads_, c * sizeof(T), GDRAM2NRAM); - __bang_mul_const((T *)buffer + c_align, (T *)buffer, - (T)(w1 / count), c_align); - __bang_atomic_add((T *)buffer + c_align, - image_offset + y_low * wo * c + x_low * c, - (T *)buffer + c_align, c); - __bang_mul_const((T *)buffer + c_align, (T *)buffer, - (T)(w2 / count), c_align); - __bang_atomic_add((T *)buffer + c_align, - image_offset + y_low * wo * c + x_high * c, - (T *)buffer + c_align, c); - __bang_mul_const((T *)buffer + c_align, (T *)buffer, - (T)(w3 / count), c_align); - __bang_atomic_add((T *)buffer + c_align, - image_offset + y_high * wo * c + x_low * c, - (T *)buffer + c_align, c); - __bang_mul_const((T *)buffer + c_align, (T *)buffer, - (T)(w4 / count), c_align); - __bang_atomic_add((T *)buffer + c_align, - image_offset + y_high * wo * c + x_high * c, - (T *)buffer + c_align, c); - } // x_low && y_low - } // ix - } // iy - } // iw - } // ih + const T count = roi_grid_h * roi_grid_w; + if (c_align * sizeof(T) * 2 <= MAX_NRAM_SIZE) { + for (int iy = 0; iy < roi_grid_h; ++iy) { + const float y = + y1 + ih * bin_size_h + (iy + 0.5) * bin_size_h / roi_grid_h; + for (int ix = 0; ix < roi_grid_w; ++ix) { + const float x = + x1 + iw * bin_size_w + (ix + 0.5) * bin_size_w / roi_grid_w; + float w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + bilinearInterpolateGradient(ho, wo, y, x, &w1, &w2, &w3, &w4, &x_low, + &x_high, &y_low, &y_high); + if (x_low >= 0 && y_low >= 0) { + __memcpy(buffer, grads_, c * sizeof(T), GDRAM2NRAM); + __bang_mul_const((T *)buffer + c_align, (T *)buffer, (T)w1, + c_align); + __bang_mul_const((T *)buffer + c_align, (T *)buffer + c_align, + 1 / count, c_align); + __bang_atomic_add((T *)buffer + c_align, + image_offset + y_low * wo * c + x_low * c, + (T *)buffer + c_align, c); + __bang_mul_const((T *)buffer + c_align, (T *)buffer, (T)w2, + c_align); + __bang_mul_const((T *)buffer + c_align, (T *)buffer + c_align, + 1 / count, c_align); + __bang_atomic_add((T *)buffer + c_align, + image_offset + y_low * wo * c + x_high * c, + (T *)buffer + c_align, c); + __bang_mul_const((T *)buffer + c_align, (T *)buffer, (T)w3, + c_align); + __bang_mul_const((T *)buffer + c_align, (T *)buffer + c_align, + 1 / count, c_align); + __bang_atomic_add((T *)buffer + c_align, + image_offset + y_high * wo * c + x_low * c, + (T *)buffer + c_align, c); + __bang_mul_const((T *)buffer + c_align, (T *)buffer, (T)w4, + c_align); + __bang_mul_const((T *)buffer + c_align, (T *)buffer + c_align, + 1 / count, c_align); + __bang_atomic_add((T *)buffer + c_align, + image_offset + y_high * wo * c + x_high * c, + (T *)buffer + c_align, c); + } // x_low && y_low + } // ix + } // iy } else { - for (int ih = 0; ih < hi; ++ih) { - for (int iw = 0; iw < wi; ++iw) { - T *grads_ = grads_offset + ih * wi * c + iw * c; - for (int iy = 0; iy < roi_grid_h; ++iy) { - const float y = - y1 + ih * bin_size_h + (iy + 0.5) * bin_size_h / roi_grid_h; - for (int ix = 0; ix < roi_grid_w; ++ix) { - const float x = - x1 + iw * bin_size_w + (ix + 0.5) * bin_size_w / roi_grid_w; - float w1, w2, w3, w4; - int x_low, x_high, y_low, y_high; - bilinearInterpolateGradient(ho, wo, y, x, &w1, &w2, &w3, &w4, - &x_low, &x_high, &y_low, &y_high); - if (x_low >= 0 && y_low >= 0) { - int deal_once = PAD_DOWN(MAX_NRAM_SIZE / BLOCK_INPUT_OUTPUT, - NFU_ALIGN_SIZE) / - sizeof(T); - int c_repeat = c / deal_once + (int)(c % deal_once != 0); - for (int i = 0; i < c_repeat; ++i) { - int deal_c = deal_once; - int align_c = deal_once; - if (i == c_repeat - 1) { - deal_c = c - i * deal_once; - align_c = c_align - i * deal_once; - } - __memcpy(buffer, grads_ + i * deal_once, deal_c * sizeof(T), - GDRAM2NRAM); - __bang_mul_const((T *)buffer + align_c, (T *)buffer, - (T)(w1 / count), align_c); - __bang_atomic_add( - (T *)buffer + align_c, - image_offset + y_low * wo * c + x_low * c + i * deal_once, - (T *)buffer + align_c, deal_c); - __bang_mul_const((T *)buffer + align_c, (T *)buffer, - (T)(w2 / count), align_c); - __bang_atomic_add((T *)buffer + align_c, - image_offset + y_low * wo * c + x_high * c + - i * deal_once, - (T *)buffer + align_c, deal_c); - __bang_mul_const((T *)buffer + align_c, (T *)buffer, - (T)(w3 / count), align_c); - __bang_atomic_add((T *)buffer + align_c, - image_offset + y_high * wo * c + x_low * c + - i * deal_once, - (T *)buffer + align_c, deal_c); - __bang_mul_const((T *)buffer + align_c, (T *)buffer, - (T)(w4 / count), align_c); - __bang_atomic_add((T *)buffer + align_c, - image_offset + y_high * wo * c + - x_high * c + i * deal_once, - (T *)buffer + align_c, deal_c); - } // for c_repeat - } // x_low >= 0 && y_low >= 0 - } // ix - } // iy - } // iw - } // ih - } // if c - } // i + for (int iy = 0; iy < roi_grid_h; ++iy) { + const float y = + y1 + ih * bin_size_h + (iy + 0.5) * bin_size_h / roi_grid_h; + for (int ix = 0; ix < roi_grid_w; ++ix) { + const float x = + x1 + iw * bin_size_w + (ix + 0.5) * bin_size_w / roi_grid_w; + float w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + bilinearInterpolateGradient(ho, wo, y, x, &w1, &w2, &w3, &w4, &x_low, + &x_high, &y_low, &y_high); + if (x_low >= 0 && y_low >= 0) { + int deal_once = + PAD_DOWN(MAX_NRAM_SIZE / 2, NFU_ALIGN_SIZE) / sizeof(T); + int c_repeat = c / deal_once + (int)(c % deal_once != 0); + for (int i = 0; i < c_repeat; ++i) { + int deal_c = deal_once; + int align_c = deal_once; + if (i == c_repeat - 1) { + deal_c = c - i * deal_once; + align_c = c_align - i * deal_once; + } + __memcpy(buffer, grads_ + i * deal_once, deal_c * sizeof(T), + GDRAM2NRAM); + __bang_mul_const((T *)buffer + align_c, (T *)buffer, (T)w1, + align_c); + __bang_mul_const((T *)buffer + align_c, (T *)buffer + align_c, + 1 / count, align_c); + __bang_atomic_add( + (T *)buffer + align_c, + image_offset + y_low * wo * c + x_low * c + i * deal_once, + (T *)buffer + align_c, deal_c); + __bang_mul_const((T *)buffer + align_c, (T *)buffer, (T)w2, + align_c); + __bang_mul_const((T *)buffer + align_c, (T *)buffer + align_c, + 1 / count, align_c); + __bang_atomic_add( + (T *)buffer + align_c, + image_offset + y_low * wo * c + x_high * c + i * deal_once, + (T *)buffer + align_c, deal_c); + __bang_mul_const((T *)buffer + align_c, (T *)buffer, (T)w3, + align_c); + __bang_mul_const((T *)buffer + align_c, (T *)buffer + align_c, + 1 / count, align_c); + __bang_atomic_add( + (T *)buffer + align_c, + image_offset + y_high * wo * c + x_low * c + i * deal_once, + (T *)buffer + align_c, deal_c); + __bang_mul_const((T *)buffer + align_c, (T *)buffer, (T)w4, + align_c); + __bang_mul_const((T *)buffer + align_c, (T *)buffer + align_c, + 1 / count, align_c); + __bang_atomic_add( + (T *)buffer + align_c, + image_offset + y_high * wo * c + x_high * c + i * deal_once, + (T *)buffer + align_c, deal_c); + } // for c_repeat + } // x_low >= 0 && y_low >= 0 + } // ix + } // iy + } // if c + } // i } __mlu_global__ void MLUUnion1KernelRoiAlignBackward( @@ -564,9 +1131,21 @@ void KernelRoiAlign(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, const int input_width, const int sampling_ratio, const float spatial_scale, const int num_rois, void *output) { - forward::MLUUnion1KernelRoialign<<>>( - input, rois, channels, aligned, pooled_height, pooled_width, input_height, - input_width, sampling_ratio, spatial_scale, num_rois, d_type, output); + // set thresholds for degradation caused by sorting + const int sort_border = 100; // threshold of num_rois + const int sort_cluster_num = 16; // threshold of cluster + + if (num_rois > sort_border || k_dim.y < sort_cluster_num) { + forward::MLUUnion1KernelRoialign<<>>( + input, rois, channels, aligned, pooled_height, pooled_width, + input_height, input_width, sampling_ratio, spatial_scale, num_rois, + d_type, output); + } else { + forward::MLUUnion1KernelBinPartRoialign<<>>( + input, rois, channels, aligned, pooled_height, pooled_width, + input_height, input_width, sampling_ratio, spatial_scale, num_rois, + d_type, output); + } } void KernelRoiAlignBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, diff --git a/mmcv/ops/csrc/pytorch/focal_loss.cpp b/mmcv/ops/csrc/pytorch/focal_loss.cpp index a0d878ff3..ea8249717 100644 --- a/mmcv/ops/csrc/pytorch/focal_loss.cpp +++ b/mmcv/ops/csrc/pytorch/focal_loss.cpp @@ -96,8 +96,10 @@ void sigmoid_focal_loss_forward(Tensor input, Tensor target, Tensor weight, #endif #ifdef MMCV_WITH_MLU } else if (input.device().type() == at::kMLU) { - CHECK_MLU(input); - CHECK_MLU(target); + CHECK_MLU_INPUT(input); + CHECK_MLU_INPUT(target); + CHECK_MLU_INPUT(weight); + CHECK_MLU_INPUT(output); sigmoid_focal_loss_forward_mlu(input, target, weight, output, gamma, alpha); #endif } else { @@ -121,10 +123,10 @@ void sigmoid_focal_loss_backward(Tensor input, Tensor target, Tensor weight, #endif #ifdef MMCV_WITH_MLU } else if (input.device().type() == at::kMLU) { - CHECK_MLU(input); - CHECK_MLU(target); - CHECK_MLU(weight); - CHECK_MLU(grad_input); + CHECK_MLU_INPUT(input); + CHECK_MLU_INPUT(target); + CHECK_MLU_INPUT(weight); + CHECK_MLU_INPUT(grad_input); sigmoid_focal_loss_backward_mlu(input, target, weight, grad_input, gamma, alpha); diff --git a/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp index 044e8dd01..b003e5150 100644 --- a/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp @@ -37,24 +37,38 @@ static void policyFuncForward(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type, auto N = input.size(0); auto C = input.size(1); - auto nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); - auto c_align_size = PAD_UP((C * input.itemsize()), NFU_ALIGN_SIZE); + const size_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); + const size_t c_align_size = PAD_UP((C * input.itemsize()), NFU_ALIGN_SIZE); const int split_target_num = 2; const int split_pipeline_num = 6; - auto scalar_size = NFU_ALIGN_SIZE; - auto weight_size = c_align_size; + const int has_weight = weight.data_ptr() != nullptr; const int target_data_width = target.scalar_type() == at::kLong ? target.itemsize() / 2 : target.itemsize(); + const int threshold_c = + PAD_DOWN((nram_size - split_target_num * sizeof(int)) / + (split_pipeline_num + has_weight), + NFU_ALIGN_SIZE) / + input.itemsize(); - // n_seg * c_align_size * split_pipeline_num + - // n_seg * target.itemsize() * split_target_num + - // weight_size + scalar_size <= nram_size - auto n_seg = (nram_size - weight_size - scalar_size) / - (c_align_size * split_pipeline_num + - target_data_width * split_target_num); - auto seg_num = (N + n_seg - 1) / n_seg; - + int n_seg = 1; + if (C <= threshold_c) { + int c_size = C * input.itemsize(); + int reservered_align_size = + (split_target_num + split_pipeline_num) * NFU_ALIGN_SIZE; + int wegiht_size = 0; + if (has_weight) { + c_size = c_align_size; + reservered_align_size = split_target_num * NFU_ALIGN_SIZE; + wegiht_size = c_align_size; + } + // n_seg * c_size * split_pipeline_num + n_seg * target.itemsize() * + // split_target_num + // + weight_size + reservered_align_size <= nram_size + n_seg = (nram_size - wegiht_size - reservered_align_size) / + (split_pipeline_num * c_size + split_target_num * sizeof(int32_t)); + } + auto seg_num = n_seg == 0 ? N : (N + n_seg - 1) / n_seg; auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); auto cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); auto core_num = core_dim * cluster_num; @@ -103,31 +117,8 @@ void SigmoidFocalLossForwardMLUKernelLauncher(Tensor input, Tensor target, CNLOG(INFO) << "weight is a empty tensor."; } - // check C - auto nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); - auto input_N = input.size(0); - auto input_C = input.size(1); - const int split_target_num = 2; - const int split_pipeline_num = 6; - const int has_weight = (int)(weight.data_ptr() != nullptr); - - // target supports only INT on MLU device while it keeps LONG on host side, - // so target.itemsize() / 2 - const int target_data_width = target.scalar_type() == at::kLong - ? target.itemsize() / 2 - : target.itemsize(); - auto threshold_C = PAD_DOWN((nram_size - NFU_ALIGN_SIZE - - split_target_num * target_data_width) / - (split_pipeline_num + has_weight), - NFU_ALIGN_SIZE) / - input.itemsize(); - - TORCH_CHECK(threshold_C >= input_C, - "input.size(1) should be in the range of [0, ", threshold_C, - "]. ", "But now input.size(1) is ", input_C, "."); - + // return if zero-element if (input.numel() == 0 || target.numel() == 0 || output.numel() == 0) { - // return if zero-element return; } @@ -158,8 +149,8 @@ void SigmoidFocalLossForwardMLUKernelLauncher(Tensor input, Tensor target, << k_dim.z << ">>>"; // launch kernel KernelFocalLossSigmoidForward(k_dim, k_type, queue, d_type, input_ptr, - target_ptr, weight_ptr, input_N, input_C, alpha, - gamma, output_ptr); + target_ptr, weight_ptr, input.size(0), + input.size(1), alpha, gamma, output_ptr); } void getDealNAndThresholdC(const int compute_data_bytes, diff --git a/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp index af193fce3..e26819999 100644 --- a/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp @@ -19,6 +19,16 @@ void KernelNms(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, const float iou_threshold, const float offset, void *workspace_ptr, void *output_size_ptr, void *output_ptr); +int selectUnionType(uint32_t use_job, int box_num_per_core) { + // the box_num_per_core should be at least 256, otherwise the real IO + // bandwidth would be very low + while (box_num_per_core < 256 && use_job >= 4) { + box_num_per_core *= 2; + use_job /= 2; + } + return use_job; +} + Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold, int offset) { // dimension parameters check @@ -42,32 +52,57 @@ Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold, } int input_num_boxes = boxes.size(0); - int input_stride = boxes.size(1); + int input_stride = boxes.size(0); int max_output_boxes = boxes.size(0); - cnrtJobType_t k_type = CNRT_FUNC_TYPE_UNION1; - int core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - uint32_t dim_x = core_dim; - cnrtDim3_t k_dim = {dim_x, 1, 1}; - cnrtDataType_t data_type_input = torch_mlu::toCnrtDtype(boxes.dtype()); + cnrtDataType_t data_type_input = torch_mlu::toCnrtDtype(boxes.dtype()); + cnrtDim3_t k_dim; + cnrtJobType_t k_type; + uint32_t union_number = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); + uint32_t core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + uint32_t job_limit = union_number * core_dim; + uint32_t core_number = union_number * core_dim; + int box_num_per_core = (input_num_boxes + core_number - 1) / core_number; + // initiate k_type as Union1 + k_dim.x = core_dim; + k_dim.y = 1; + k_dim.z = 1; + k_type = CNRT_FUNC_TYPE_UNION1; + int use_job = selectUnionType(job_limit, box_num_per_core); + if (use_job < 4) { + k_dim.x = 1; + k_type = CNRT_FUNC_TYPE_BLOCK; + } else if (use_job == 4) { + k_dim.x = core_dim; + k_type = CNRT_FUNC_TYPE_UNION1; + } else { + k_dim.x = use_job; + k_type = (cnrtFunctionType_t)use_job; + } + + // transpose boxes (n, 4) to (4, n) for better performance + auto boxes_t = boxes.transpose(0, 1); + auto boxes_ = torch_mlu::cnnl::ops::cnnl_contiguous(boxes_t); + auto scores_ = torch_mlu::cnnl::ops::cnnl_contiguous(scores); auto output = at::empty({max_output_boxes}, boxes.options().dtype(at::kLong)); auto output_size = at::empty({1}, scores.options().dtype(at::kInt)); // workspace + const int info_num = 5; // x1, x2, y1, y2 and score size_t space_size = 0; if (boxes.scalar_type() == at::kHalf) { - space_size = input_num_boxes * sizeof(int16_t); + space_size = input_num_boxes * sizeof(int16_t) * info_num + sizeof(float); } else { - space_size = input_num_boxes * sizeof(float); + space_size = input_num_boxes * sizeof(float) * info_num + sizeof(float); } auto workspace = at::empty(space_size, boxes.options().dtype(at::kByte)); // get compute queue auto queue = torch_mlu::getCurQueue(); - auto boxes_impl = torch_mlu::getMluTensorImpl(boxes); + auto boxes_impl = torch_mlu::getMluTensorImpl(boxes_); auto boxes_ptr = boxes_impl->cnnlMalloc(); - auto scores_impl = torch_mlu::getMluTensorImpl(scores); + auto scores_impl = torch_mlu::getMluTensorImpl(scores_); auto scores_ptr = scores_impl->cnnlMalloc(); auto workspace_impl = torch_mlu::getMluTensorImpl(workspace); auto workspace_ptr = workspace_impl->cnnlMalloc(); @@ -76,20 +111,11 @@ Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold, auto output_size_impl = torch_mlu::getMluTensorImpl(output_size); auto output_size_ptr = output_size_impl->cnnlMalloc(); - switch (k_type) { - default: { - TORCH_CHECK(false, "[nms_mlu]:Failed to choose kernel to launch"); - } - case CNRT_FUNC_TYPE_BLOCK: - case CNRT_FUNC_TYPE_UNION1: { - CNLOG(INFO) << "Launch Kernel MLUUnion1 or Block NMS<<>>"; - KernelNms(k_dim, k_type, queue, data_type_input, boxes_ptr, scores_ptr, - input_num_boxes, input_stride, max_output_boxes, iou_threshold, - offset, workspace_ptr, output_size_ptr, output_ptr); - }; break; - } + CNLOG(INFO) << "Launch Kernel MLUUnionX NMS<<>>"; + KernelNms(k_dim, k_type, queue, data_type_input, boxes_ptr, scores_ptr, + input_num_boxes, input_stride, max_output_boxes, iou_threshold, + offset, workspace_ptr, output_size_ptr, output_ptr); int output_num = *static_cast(output_size.cpu().data_ptr()); return output.slice(0, 0, output_num); diff --git a/mmcv/ops/csrc/pytorch/roi_align.cpp b/mmcv/ops/csrc/pytorch/roi_align.cpp index 5c3a1dd30..a9cdcdbcd 100644 --- a/mmcv/ops/csrc/pytorch/roi_align.cpp +++ b/mmcv/ops/csrc/pytorch/roi_align.cpp @@ -122,11 +122,11 @@ void roi_align_forward(Tensor input, Tensor rois, Tensor output, #endif #ifdef MMCV_WITH_MLU } else if (input.device().type() == at::kMLU) { - CHECK_MLU(input); - CHECK_MLU(rois); - CHECK_MLU(output); - CHECK_MLU(argmax_y); - CHECK_MLU(argmax_x); + CHECK_MLU_INPUT(input); + CHECK_MLU_INPUT(rois); + CHECK_MLU_INPUT(output); + CHECK_MLU_INPUT(argmax_y); + CHECK_MLU_INPUT(argmax_x); roi_align_forward_mlu(input, rois, output, argmax_y, argmax_x, aligned_height, aligned_width, spatial_scale, @@ -164,11 +164,11 @@ void roi_align_backward(Tensor grad_output, Tensor rois, Tensor argmax_y, #endif #ifdef MMCV_WITH_MLU } else if (grad_output.device().type() == at::kMLU) { - CHECK_MLU(grad_output); - CHECK_MLU(rois); - CHECK_MLU(argmax_y); - CHECK_MLU(argmax_x); - CHECK_MLU(grad_input); + CHECK_MLU_INPUT(grad_output); + CHECK_MLU_INPUT(rois); + CHECK_MLU_INPUT(argmax_y); + CHECK_MLU_INPUT(argmax_x); + CHECK_MLU_INPUT(grad_input); roi_align_backward_mlu(grad_output, rois, argmax_y, argmax_x, grad_input, aligned_height, aligned_width, spatial_scale,