diff --git a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp index 5030fed0e7..3f3bc5a047 100644 --- a/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp +++ b/mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp @@ -4,6 +4,21 @@ using namespace std; void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { + at::Tensor input_y = input; + at::Tensor output_y = output; + bool is_half = input.scalar_type() == at::kHalf; + if (is_half) { + input_y = input.to(at::kFloat); + output_y = output.to(at::kFloat); + } + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input_y); + if (weight_size > 0) { + weight_y = at::broadcast_to(weight, input.sizes()); + if (is_half) { + weight_y = weight_y.to(at::kFloat); + } + } int64_t n_class = input.size(1); at::Tensor target_y = at::ones_like(input); if (n_class == 1) { @@ -12,24 +27,26 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, target_y = at::add(target_y, 1.0); } else { target_y = at::one_hot(target, n_class); + weight_y = at::mul(weight_y, target_y); + weight_y = at::sum(weight_y, 1, true); + weight_y = at::broadcast_to(weight_y, input.sizes()); } target_y = target_y.to(at::kInt); - int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); - if (weight_size > 0) { - weight_y = at::broadcast_to(weight, input.sizes()); - } OpCommand cmd; string reduction = "none"; cmd.Name("SigmoidFocalLoss") - .Input(input) + .Input(input_y) .Input(target_y) .Input(weight_y) - .Output(output) + .Output(output_y) .Attr("gamma", gamma) .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); + if (is_half) { + output_y = output_y.to(at::kHalf); + } + output.copy_(output_y); } void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, @@ -38,34 +55,51 @@ void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor grad_input, float gamma, float alpha) { + at::Tensor input_y = input; + at::Tensor grad_input_y = grad_input; + bool is_half = input.scalar_type() == at::kHalf; + if (is_half) { + input_y = input.to(at::kFloat); + grad_input_y = grad_input.to(at::kFloat); + } + int64_t weight_size = weight.size(0); + at::Tensor weight_y = at::ones_like(input_y); + if (weight_size > 0) { + weight_y = at::broadcast_to(weight, input.sizes()); + if (is_half) { + weight_y = weight_y.to(at::kFloat); + } + } int64_t n_class = input.size(1); at::Tensor target_y = at::ones_like(input); if (n_class == 1) { target_y = at::reshape(target, input.sizes()); } else { target_y = at::one_hot(target, n_class); + weight_y = at::mul(weight_y, target_y); + weight_y = at::sum(weight_y, 1, true); + weight_y = at::broadcast_to(weight_y, input.sizes()); target_y = at::mul(target_y, -1.0); target_y = at::add(target_y, 1.0); } target_y = target_y.to(at::kInt); at::Tensor grad_up = at::ones_like(input); - int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); - if (weight_size > 0) { - weight_y = at::broadcast_to(weight, input.sizes()); - } OpCommand cmd; string reduction = "none"; cmd.Name("SigmoidFocalLossGrad") - .Input(input) + .Input(input_y) .Input(target_y) .Input(grad_up) .Input(weight_y) - .Output(grad_input) + .Output(grad_input_y) .Attr("gamma", gamma) .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); + if (is_half) { + grad_input_y = grad_input_y.to(at::kHalf); + } + grad_input.copy_(grad_input_y); } void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, @@ -74,19 +108,30 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, Tensor output, float gamma, float alpha) { + at::Tensor input_y = input; + bool is_half = input.scalar_type() == at::kHalf; + if (is_half) { + input_y = input.to(at::kFloat); + } int64_t n_class = input.size(1); at::Tensor target_y = at::one_hot(target, n_class); target_y = target_y.to(at::kInt); int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); + at::Tensor weight_y = at::ones_like(input_y); if (weight_size > 0) { weight_y = at::broadcast_to(weight, input.sizes()); + if (is_half) { + weight_y = weight_y.to(at::kFloat); + } + weight_y = at::mul(weight_y, target_y); + weight_y = at::sum(weight_y, 1, true); + weight_y = at::broadcast_to(weight_y, input.sizes()); } - at::Tensor op_output = at::ones_like(input); + at::Tensor op_output = at::ones_like(input_y); OpCommand cmd; string reduction = "none"; cmd.Name("SoftmaxFocalLoss") - .Input(input) + .Input(input_y) .Input(target_y) .Input(weight_y) .Output(op_output) @@ -94,6 +139,9 @@ void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight, .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); + if (is_half) { + op_output = op_output.to(at::kHalf); + } int64_t n_batch = input.size(0); c10::SmallVector offsets = {0, 0}; c10::SmallVector sizes = {n_batch, 1}; @@ -124,27 +172,44 @@ void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight, Tensor buff, Tensor grad_input, float gamma, float alpha) { + at::Tensor input_y = input; + at::Tensor grad_input_y = grad_input; + bool is_half = input.scalar_type() == at::kHalf; + if (is_half) { + input_y = input.to(at::kFloat); + grad_input_y = grad_input.to(at::kFloat); + } int64_t n_class = input.size(1); at::Tensor target_y = at::one_hot(target, n_class); target_y = target_y.to(at::kInt); at::Tensor grad_up = at::ones_like(input); int64_t weight_size = weight.size(0); - at::Tensor weight_y = at::ones_like(input); + at::Tensor weight_y = at::ones_like(input_y); if (weight_size > 0) { weight_y = at::broadcast_to(weight, input.sizes()); + if (is_half) { + weight_y = weight_y.to(at::kFloat); + } + weight_y = at::mul(weight_y, target_y); + weight_y = at::sum(weight_y, 1, true); + weight_y = at::broadcast_to(weight_y, input.sizes()); } OpCommand cmd; string reduction = "none"; cmd.Name("SoftmaxFocalLossGrad") - .Input(input) + .Input(input_y) .Input(target_y) .Input(grad_up) .Input(weight_y) - .Output(grad_input) + .Output(grad_input_y) .Attr("gamma", gamma) .Attr("alpha", alpha) .Attr("reduction", reduction) .Run(); + if (is_half) { + grad_input_y = grad_input_y.to(at::kHalf); + } + grad_input.copy_(grad_input_y); } void softmax_focal_loss_backward_impl(Tensor input, Tensor target,