Skip to content

Commit

Permalink
Add BatchBucketizeOp in caffe2 (pytorch#9385)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#9385

The operator transform dense features to sparse features by bucketizing. Only the feature in indices tensor will be transformed and output.

Reviewed By: bddppq

Differential Revision: D8820351

fbshipit-source-id: a66cae546b870c6b2982ac20641f198334f2e853
  • Loading branch information
nateanl authored and facebook-github-bot committed Jul 14, 2018
1 parent 099a6d5 commit 5ac8a80
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 0 deletions.
126 changes: 126 additions & 0 deletions caffe2/operators/batch_bucketize_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#include "batch_bucketize_op.h"

#include "caffe2/core/context.h"
#include "caffe2/core/tensor.h"

namespace caffe2 {

template <>
bool BatchBucketizeOp<CPUContext>::RunOnDevice() {
auto& feature = Input(FEATURE);
auto& indices = Input(INDICES);
auto& boundaries = Input(BOUNDARIES);
auto& lengths = Input(LENGTHS);
auto* output = Output(O);
CAFFE_ENFORCE_EQ(lengths.ndim(), 1);
CAFFE_ENFORCE_EQ(indices.ndim(), 1);
CAFFE_ENFORCE_EQ(boundaries.ndim(), 1);
CAFFE_ENFORCE_EQ(feature.ndim(), 2);
CAFFE_ENFORCE_EQ(lengths.size(), indices.size());

const auto* lengths_data = lengths.template data<int32_t>();
const auto* indices_data = indices.template data<int32_t>();
const auto* boundaries_data = boundaries.template data<float>();
const auto* feature_data = feature.template data<float>();
auto batch_size = feature.dim(0);
auto feature_dim = feature.dim(1);
auto output_dim = indices.size();

TIndex length_sum = 0;
for (TIndex i = 0; i < lengths.size(); i++) {
CAFFE_ENFORCE_GE(feature_dim, indices_data[i]);
length_sum += lengths_data[i];
}
CAFFE_ENFORCE_EQ(length_sum, boundaries.size());

TIndex lower_bound = 0;
output->Resize(batch_size, output_dim);
auto* output_data = output->template mutable_data<int32_t>();

for (TIndex i = 0; i < batch_size; i++) {
lower_bound = 0;
for (TIndex j = 0; j < output_dim; j++) {
for (TIndex k = 0; k <= lengths_data[j]; k++) {
if (k == lengths_data[j] ||
feature_data[i * feature_dim + indices_data[j]] <=
boundaries_data[lower_bound + k]) {
output_data[i * output_dim + j] = k;
break;
} else {
continue;
}
}
lower_bound += lengths_data[j];
}
}
return true;
}

REGISTER_CPU_OPERATOR(BatchBucketize, BatchBucketizeOp<CPUContext>);

OPERATOR_SCHEMA(BatchBucketize)
.NumInputs(4)
.NumOutputs(1)
.SetDoc(R"DOC(
Bucketize the float_features into sparse features.
The float_features is a N * D tensor where N is the batch_size, and D is the feature_dim.
The indices is a 1D tensor containing the indices of the features that need to be bucketized.
The lengths is a 1D tensor that splits the following 'boundaries' argument.
The boundaries is a 1D tensor containing the border list for each feature.
With in each batch, `indices` should not have duplicate number,
and the number of elements in `indices` should be less than or euqal to `D`.
Each element in `lengths` vector (lengths[`i`]) represents
the number of boundaries in the sub border list.
The sum of all elements in `lengths` must be equal to the size of `boundaries`.
If lengths[0] = 2, the first sub border list is [0.5, 1.0], which separate the
value to (-inf, 0.5], (0,5, 1.0], (1.0, inf). The bucketized feature will have
three possible values (i.e. 0, 1, 2).
For example, with input:
float_features = [[1.42, 2.07, 3.19, 0.55, 4.32],
[4.57, 2.30, 0.84, 4.48, 3.09],
[0.89, 0.26, 2.41, 0.47, 1.05],
[0.03, 2.97, 2.43, 4.36, 3.11],
[2.74, 5.77, 0.90, 2.63, 0.38]]
indices = [0, 1, 4]
lengths = [2, 3, 1]
boundaries = [0.5, 1.0, 1.5, 2.5, 3.5, 2.5]
The output is:
output =[[2, 1, 1],
[2, 1, 1],
[1, 0, 0],
[0, 2, 1],
[2, 3, 0]]
after running this operator.
)DOC")
.Input(
0,
"float_features",
"2-D dense tensor, the second dimension must be greater or equal to the indices dimension")
.Input(
1,
"indices",
"Flatten tensor, containing the indices of `float_features` to be bucketized. The datatype must be int32.")
.Input(
2,
"lengths",
"Flatten tensor, the size must be equal to that of `indices`. The datatype must be int32.")
.Input(
3,
"boundaries",
"Flatten tensor, dimension has to match the sum of lengths")
.Output(
0,
"bucktized_feat",
"2-D dense tensor, with 1st dim = float_features.dim(0), 2nd dim = size(indices)"
"in the arg list, the tensor is of the same data type as `feature`.");

NO_GRADIENT(BatchBucketize);

} // namespace caffe2
29 changes: 29 additions & 0 deletions caffe2/operators/batch_bucketize_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright 2004-present Facebook. All Rights Reserved.

#ifndef CAFFE2_OPERATORS_BATCH_BUCKETIZE_OP_H_
#define CAFFE2_OPERATORS_BATCH_BUCKETIZE_OP_H_

#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"

namespace caffe2 {

template <class Context>
class BatchBucketizeOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;

BatchBucketizeOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {}

bool RunOnDevice() override;

protected:
INPUT_TAGS(FEATURE, INDICES, BOUNDARIES, LENGTHS);
OUTPUT_TAGS(O);
};

} // namespace caffe2

#endif // CAFFE2_OPERATORS_BATCH_BUCKETIZE_OP_H_
90 changes: 90 additions & 0 deletions caffe2/python/operator_test/batch_bucketize_op_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import numpy as np

from caffe2.python import core
from hypothesis import given
import hypothesis.strategies as st
import caffe2.python.hypothesis_test_util as hu


class TestBatchBucketize(hu.HypothesisTestCase):
@given(**hu.gcs_cpu_only)
def test_batch_bucketize_example(self, gc, dc):
op = core.CreateOperator('BatchBucketize',
["FEATURE", "INDICES", "BOUNDARIES", "LENGTHS"],
["O"])
float_feature = np.array([[1.42, 2.07, 3.19, 0.55, 4.32],
[4.57, 2.30, 0.84, 4.48, 3.09],
[0.89, 0.26, 2.41, 0.47, 1.05],
[0.03, 2.97, 2.43, 4.36, 3.11],
[2.74, 5.77, 0.90, 2.63, 0.38]], dtype=np.float32)
indices = np.array([0, 1, 4], dtype=np.int32)
lengths = np.array([2, 3, 1], dtype=np.int32)
boundaries = np.array([0.5, 1.0, 1.5, 2.5, 3.5, 2.5], dtype=np.float32)

def ref(float_feature, indices, boundaries, lengths):
output = np.array([[2, 1, 1],
[2, 1, 1],
[1, 0, 0],
[0, 2, 1],
[2, 3, 0]], dtype=np.int32)
return (output,)

self.assertReferenceChecks(gc, op,
[float_feature, indices, boundaries, lengths],
ref)

@given(
x=hu.tensor(
min_dim=2, max_dim=2, dtype=np.float32,
elements=st.floats(min_value=0, max_value=5),
min_value=5),
seed=st.integers(min_value=2, max_value=1000),
**hu.gcs_cpu_only)
def test_batch_bucketize(self, x, seed, gc, dc):
op = core.CreateOperator('BatchBucketize',
["FEATURE", "INDICES", "BOUNDARIES", "LENGTHS"],
['O'])
np.random.seed(seed)
d = x.shape[1]
lens = np.random.randint(low=1, high=3, size=d - 3)
indices = np.random.choice(range(d), d - 3, replace=False)
indices.sort()
boundaries = []
for i in range(d - 3):
# add [0, 0] as duplicated bounary for duplicated bucketization
if lens[i] > 2:
cur_boundary = np.append(
np.random.randn(lens[i] - 2) * 5, [0, 0])
else:
cur_boundary = np.random.randn(lens[i]) * 5
cur_boundary.sort()
boundaries += cur_boundary.tolist()

lens = np.array(lens, dtype=np.int32)
boundaries = np.array(boundaries, dtype=np.float32)
indices = np.array(indices, dtype=np.int32)

def ref(x, indices, boundaries, lens):
output_dim = indices.shape[0]
ret = np.zeros((x.shape[0], output_dim)).astype(np.int32)
boundary_offset = 0
for i, l in enumerate(indices):
temp_bound = boundaries[boundary_offset : lens[i] + boundary_offset]
for j in range(x.shape[0]):
for k, bound_val in enumerate(temp_bound):
if k == len(temp_bound) - 1 and x[j, l] > bound_val:
ret[j, i] = k + 1
elif x[j, l] > bound_val:
continue
else:
ret[j, i] = k
break
boundary_offset += lens[i]
return (ret,)

self.assertReferenceChecks(gc, op, [x, indices, boundaries, lens], ref)

0 comments on commit 5ac8a80

Please sign in to comment.