forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add BatchBucketizeOp in caffe2 (pytorch#9385)
Summary: Pull Request resolved: pytorch#9385 The operator transform dense features to sparse features by bucketizing. Only the feature in indices tensor will be transformed and output. Reviewed By: bddppq Differential Revision: D8820351 fbshipit-source-id: a66cae546b870c6b2982ac20641f198334f2e853
- Loading branch information
1 parent
099a6d5
commit 5ac8a80
Showing
3 changed files
with
245 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
#include "batch_bucketize_op.h" | ||
|
||
#include "caffe2/core/context.h" | ||
#include "caffe2/core/tensor.h" | ||
|
||
namespace caffe2 { | ||
|
||
template <> | ||
bool BatchBucketizeOp<CPUContext>::RunOnDevice() { | ||
auto& feature = Input(FEATURE); | ||
auto& indices = Input(INDICES); | ||
auto& boundaries = Input(BOUNDARIES); | ||
auto& lengths = Input(LENGTHS); | ||
auto* output = Output(O); | ||
CAFFE_ENFORCE_EQ(lengths.ndim(), 1); | ||
CAFFE_ENFORCE_EQ(indices.ndim(), 1); | ||
CAFFE_ENFORCE_EQ(boundaries.ndim(), 1); | ||
CAFFE_ENFORCE_EQ(feature.ndim(), 2); | ||
CAFFE_ENFORCE_EQ(lengths.size(), indices.size()); | ||
|
||
const auto* lengths_data = lengths.template data<int32_t>(); | ||
const auto* indices_data = indices.template data<int32_t>(); | ||
const auto* boundaries_data = boundaries.template data<float>(); | ||
const auto* feature_data = feature.template data<float>(); | ||
auto batch_size = feature.dim(0); | ||
auto feature_dim = feature.dim(1); | ||
auto output_dim = indices.size(); | ||
|
||
TIndex length_sum = 0; | ||
for (TIndex i = 0; i < lengths.size(); i++) { | ||
CAFFE_ENFORCE_GE(feature_dim, indices_data[i]); | ||
length_sum += lengths_data[i]; | ||
} | ||
CAFFE_ENFORCE_EQ(length_sum, boundaries.size()); | ||
|
||
TIndex lower_bound = 0; | ||
output->Resize(batch_size, output_dim); | ||
auto* output_data = output->template mutable_data<int32_t>(); | ||
|
||
for (TIndex i = 0; i < batch_size; i++) { | ||
lower_bound = 0; | ||
for (TIndex j = 0; j < output_dim; j++) { | ||
for (TIndex k = 0; k <= lengths_data[j]; k++) { | ||
if (k == lengths_data[j] || | ||
feature_data[i * feature_dim + indices_data[j]] <= | ||
boundaries_data[lower_bound + k]) { | ||
output_data[i * output_dim + j] = k; | ||
break; | ||
} else { | ||
continue; | ||
} | ||
} | ||
lower_bound += lengths_data[j]; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
REGISTER_CPU_OPERATOR(BatchBucketize, BatchBucketizeOp<CPUContext>); | ||
|
||
OPERATOR_SCHEMA(BatchBucketize) | ||
.NumInputs(4) | ||
.NumOutputs(1) | ||
.SetDoc(R"DOC( | ||
Bucketize the float_features into sparse features. | ||
The float_features is a N * D tensor where N is the batch_size, and D is the feature_dim. | ||
The indices is a 1D tensor containing the indices of the features that need to be bucketized. | ||
The lengths is a 1D tensor that splits the following 'boundaries' argument. | ||
The boundaries is a 1D tensor containing the border list for each feature. | ||
With in each batch, `indices` should not have duplicate number, | ||
and the number of elements in `indices` should be less than or euqal to `D`. | ||
Each element in `lengths` vector (lengths[`i`]) represents | ||
the number of boundaries in the sub border list. | ||
The sum of all elements in `lengths` must be equal to the size of `boundaries`. | ||
If lengths[0] = 2, the first sub border list is [0.5, 1.0], which separate the | ||
value to (-inf, 0.5], (0,5, 1.0], (1.0, inf). The bucketized feature will have | ||
three possible values (i.e. 0, 1, 2). | ||
For example, with input: | ||
float_features = [[1.42, 2.07, 3.19, 0.55, 4.32], | ||
[4.57, 2.30, 0.84, 4.48, 3.09], | ||
[0.89, 0.26, 2.41, 0.47, 1.05], | ||
[0.03, 2.97, 2.43, 4.36, 3.11], | ||
[2.74, 5.77, 0.90, 2.63, 0.38]] | ||
indices = [0, 1, 4] | ||
lengths = [2, 3, 1] | ||
boundaries = [0.5, 1.0, 1.5, 2.5, 3.5, 2.5] | ||
The output is: | ||
output =[[2, 1, 1], | ||
[2, 1, 1], | ||
[1, 0, 0], | ||
[0, 2, 1], | ||
[2, 3, 0]] | ||
after running this operator. | ||
)DOC") | ||
.Input( | ||
0, | ||
"float_features", | ||
"2-D dense tensor, the second dimension must be greater or equal to the indices dimension") | ||
.Input( | ||
1, | ||
"indices", | ||
"Flatten tensor, containing the indices of `float_features` to be bucketized. The datatype must be int32.") | ||
.Input( | ||
2, | ||
"lengths", | ||
"Flatten tensor, the size must be equal to that of `indices`. The datatype must be int32.") | ||
.Input( | ||
3, | ||
"boundaries", | ||
"Flatten tensor, dimension has to match the sum of lengths") | ||
.Output( | ||
0, | ||
"bucktized_feat", | ||
"2-D dense tensor, with 1st dim = float_features.dim(0), 2nd dim = size(indices)" | ||
"in the arg list, the tensor is of the same data type as `feature`."); | ||
|
||
NO_GRADIENT(BatchBucketize); | ||
|
||
} // namespace caffe2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// Copyright 2004-present Facebook. All Rights Reserved. | ||
|
||
#ifndef CAFFE2_OPERATORS_BATCH_BUCKETIZE_OP_H_ | ||
#define CAFFE2_OPERATORS_BATCH_BUCKETIZE_OP_H_ | ||
|
||
#include "caffe2/core/context.h" | ||
#include "caffe2/core/operator.h" | ||
#include "caffe2/utils/math.h" | ||
|
||
namespace caffe2 { | ||
|
||
template <class Context> | ||
class BatchBucketizeOp final : public Operator<Context> { | ||
public: | ||
USE_OPERATOR_CONTEXT_FUNCTIONS; | ||
|
||
BatchBucketizeOp(const OperatorDef& operator_def, Workspace* ws) | ||
: Operator<Context>(operator_def, ws) {} | ||
|
||
bool RunOnDevice() override; | ||
|
||
protected: | ||
INPUT_TAGS(FEATURE, INDICES, BOUNDARIES, LENGTHS); | ||
OUTPUT_TAGS(O); | ||
}; | ||
|
||
} // namespace caffe2 | ||
|
||
#endif // CAFFE2_OPERATORS_BATCH_BUCKETIZE_OP_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
from __future__ import unicode_literals | ||
|
||
import numpy as np | ||
|
||
from caffe2.python import core | ||
from hypothesis import given | ||
import hypothesis.strategies as st | ||
import caffe2.python.hypothesis_test_util as hu | ||
|
||
|
||
class TestBatchBucketize(hu.HypothesisTestCase): | ||
@given(**hu.gcs_cpu_only) | ||
def test_batch_bucketize_example(self, gc, dc): | ||
op = core.CreateOperator('BatchBucketize', | ||
["FEATURE", "INDICES", "BOUNDARIES", "LENGTHS"], | ||
["O"]) | ||
float_feature = np.array([[1.42, 2.07, 3.19, 0.55, 4.32], | ||
[4.57, 2.30, 0.84, 4.48, 3.09], | ||
[0.89, 0.26, 2.41, 0.47, 1.05], | ||
[0.03, 2.97, 2.43, 4.36, 3.11], | ||
[2.74, 5.77, 0.90, 2.63, 0.38]], dtype=np.float32) | ||
indices = np.array([0, 1, 4], dtype=np.int32) | ||
lengths = np.array([2, 3, 1], dtype=np.int32) | ||
boundaries = np.array([0.5, 1.0, 1.5, 2.5, 3.5, 2.5], dtype=np.float32) | ||
|
||
def ref(float_feature, indices, boundaries, lengths): | ||
output = np.array([[2, 1, 1], | ||
[2, 1, 1], | ||
[1, 0, 0], | ||
[0, 2, 1], | ||
[2, 3, 0]], dtype=np.int32) | ||
return (output,) | ||
|
||
self.assertReferenceChecks(gc, op, | ||
[float_feature, indices, boundaries, lengths], | ||
ref) | ||
|
||
@given( | ||
x=hu.tensor( | ||
min_dim=2, max_dim=2, dtype=np.float32, | ||
elements=st.floats(min_value=0, max_value=5), | ||
min_value=5), | ||
seed=st.integers(min_value=2, max_value=1000), | ||
**hu.gcs_cpu_only) | ||
def test_batch_bucketize(self, x, seed, gc, dc): | ||
op = core.CreateOperator('BatchBucketize', | ||
["FEATURE", "INDICES", "BOUNDARIES", "LENGTHS"], | ||
['O']) | ||
np.random.seed(seed) | ||
d = x.shape[1] | ||
lens = np.random.randint(low=1, high=3, size=d - 3) | ||
indices = np.random.choice(range(d), d - 3, replace=False) | ||
indices.sort() | ||
boundaries = [] | ||
for i in range(d - 3): | ||
# add [0, 0] as duplicated bounary for duplicated bucketization | ||
if lens[i] > 2: | ||
cur_boundary = np.append( | ||
np.random.randn(lens[i] - 2) * 5, [0, 0]) | ||
else: | ||
cur_boundary = np.random.randn(lens[i]) * 5 | ||
cur_boundary.sort() | ||
boundaries += cur_boundary.tolist() | ||
|
||
lens = np.array(lens, dtype=np.int32) | ||
boundaries = np.array(boundaries, dtype=np.float32) | ||
indices = np.array(indices, dtype=np.int32) | ||
|
||
def ref(x, indices, boundaries, lens): | ||
output_dim = indices.shape[0] | ||
ret = np.zeros((x.shape[0], output_dim)).astype(np.int32) | ||
boundary_offset = 0 | ||
for i, l in enumerate(indices): | ||
temp_bound = boundaries[boundary_offset : lens[i] + boundary_offset] | ||
for j in range(x.shape[0]): | ||
for k, bound_val in enumerate(temp_bound): | ||
if k == len(temp_bound) - 1 and x[j, l] > bound_val: | ||
ret[j, i] = k + 1 | ||
elif x[j, l] > bound_val: | ||
continue | ||
else: | ||
ret[j, i] = k | ||
break | ||
boundary_offset += lens[i] | ||
return (ret,) | ||
|
||
self.assertReferenceChecks(gc, op, [x, indices, boundaries, lens], ref) |