diff --git a/include/caffe/layers/softmax_layer.hpp b/include/caffe/layers/softmax_layer.hpp index 7478a02c..10a1b63c 100644 --- a/include/caffe/layers/softmax_layer.hpp +++ b/include/caffe/layers/softmax_layer.hpp @@ -45,6 +45,8 @@ class SoftmaxLayer : public Layer { Blob scale_; Dtype input_scale_; //CUSTOMIZATION Dtype output_scale_; //CUSTOMIZATION + int input_zero_point_; //CUSTOMIZATION + int output_zero_point_; //CUSTOMIZATION }; } // namespace caffe diff --git a/src/caffe/layers/softmax_layer.cpp b/src/caffe/layers/softmax_layer.cpp index 70ab6947..e4eb8cca 100644 --- a/src/caffe/layers/softmax_layer.cpp +++ b/src/caffe/layers/softmax_layer.cpp @@ -23,11 +23,19 @@ void SoftmaxLayer::Reshape(const vector*>& bottom, scale_.Reshape(scale_dims); input_scale_ = this->layer_param_.softmax_param().input_scale(); //CUSTOMIZATION output_scale_ = this->layer_param_.softmax_param().output_scale(); //CUSTOMIZATION + input_zero_point_ = this->layer_param_.softmax_param().input_zero_point(); //CUSTOMIZATION + output_zero_point_ = this->layer_param_.softmax_param().output_zero_point(); //CUSTOMIZATION } template void SoftmaxLayer::Forward_cpu(const vector*>& bottom, const vector*>& top) { + const bool quant_in = (input_scale_ != Dtype(1.0) || input_zero_point_ != 0); + const bool quant_out = (output_scale_ != Dtype(1.0) || output_zero_point_ != 0); + if (quant_in) { + caffe_cpu_dequantize(bottom[0]->count(), bottom[0]->mutable_cpu_data(), + input_scale_, input_zero_point_); + } const Dtype* bottom_data = bottom[0]->cpu_data(); Dtype* top_data = top[0]->mutable_cpu_data(); Dtype* scale_data = scale_.mutable_cpu_data(); @@ -59,6 +67,14 @@ void SoftmaxLayer::Forward_cpu(const vector*>& bottom, top_data += inner_num_; } } + if (quant_out) { + // do not reuse "top_data"; it is shifted during the computation + caffe_cpu_quantize(top[0]->count(), top[0]->mutable_cpu_data(), output_scale_, output_zero_point_); + } + if (quant_in) { + caffe_cpu_quantize(bottom[0]->count(), bottom[0]->mutable_cpu_data(), + input_scale_, input_zero_point_); + } } template