Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
Change: 138451572
  • Loading branch information
Dan Smilkov authored and tensorflower-gardener committed Nov 9, 2016
1 parent 62c94ea commit aac685b
Show file tree
Hide file tree
Showing 32 changed files with 784 additions and 2,453 deletions.
8 changes: 7 additions & 1 deletion tensorflow/contrib/distributions/python/ops/bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def __init__(self,
self._parameters = parameters


@kullback_leibler.RegisterKL(Bernoulli, Bernoulli)
def _kl_bernoulli_bernoulli(a, b, name=None):
"""Calculate the batched KL divergence KL(a || b) with a and b Bernoulli.
Expand All @@ -200,3 +199,10 @@ def _kl_bernoulli_bernoulli(a, b, name=None):
nn.softplus(-b.logits)) +
math_ops.sigmoid(-a.logits) * (-nn.softplus(a.logits) +
nn.softplus(b.logits)))


kl_classes = [
Bernoulli,
BernoulliWithSigmoidP,
]
kullback_leibler.register_pairwise_kls(kl_classes, _kl_bernoulli_bernoulli)
5 changes: 1 addition & 4 deletions tensorflow/contrib/distributions/python/ops/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,4 @@ def _kl_beta_beta(d1, d2, name=None):
Beta,
BetaWithSoftplusAB,
]

for beta_aa in kl_classes:
for beta_bb in kl_classes:
kullback_leibler.RegisterKL(beta_aa, beta_bb)(_kl_beta_beta)
kullback_leibler.register_pairwise_kls(kl_classes, _kl_beta_beta)
17 changes: 17 additions & 0 deletions tensorflow/contrib/distributions/python/ops/kullback_leibler.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,20 @@ def __call__(self, kl_fn):
_DIVERGENCES[self._key]))
_DIVERGENCES[self._key] = kl_fn
return kl_fn


def register_pairwise_kls(kl_classes, kl_fn):
"""Registers `kl_fn` for each pair of classes in `kl_classes`.
Args:
kl_classes: classes for which to register KL implementation
kl_fn: The function to use for the KL divergence.
Returns:
None
"""
for cls_a in kl_classes:
RegisterKL(cls_a, cls_a)(kl_fn)
for cls_b in kl_classes:
if cls_a != cls_b:
RegisterKL(cls_a, cls_b)(kl_fn)
10 changes: 1 addition & 9 deletions tensorflow/contrib/distributions/python/ops/mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,12 +780,4 @@ def _kl_mvn_mvn_brute_force(mvn_a, mvn_b, name=None):
MultivariateNormalDiag,
MultivariateNormalDiagPlusVDVT,
]


for mvn_aa in kl_classes:
# Register when they are the same here, and do not register when they are the
# same below because that would result in a repeated registration.
kullback_leibler.RegisterKL(mvn_aa, mvn_aa)(_kl_mvn_mvn_brute_force)
for mvn_bb in kl_classes:
if mvn_bb != mvn_aa:
kullback_leibler.RegisterKL(mvn_aa, mvn_bb)(_kl_mvn_mvn_brute_force)
kullback_leibler.register_pairwise_kls(kl_classes, _kl_mvn_mvn_brute_force)
8 changes: 7 additions & 1 deletion tensorflow/contrib/distributions/python/ops/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ def __init__(self,
self._parameters = parameters


@kullback_leibler.RegisterKL(Normal, Normal)
def _kl_normal_normal(n_a, n_b, name=None):
"""Calculate the batched KL divergence KL(n_a || n_b) with n_a and n_b Normal.
Expand All @@ -247,3 +246,10 @@ def _kl_normal_normal(n_a, n_b, name=None):
ratio = s_a_squared / s_b_squared
return (math_ops.square(n_a.mu - n_b.mu) / (two * s_b_squared) +
half * (ratio - one - math_ops.log(ratio)))


kl_classes = [
Normal,
NormalWithSoftplusSigma,
]
kullback_leibler.register_pairwise_kls(kl_classes, _kl_normal_normal)
56 changes: 14 additions & 42 deletions tensorflow/core/common_runtime/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,17 @@ limitations under the License.
namespace tensorflow {

// A few string constant used throughout this module.
static const char* const kArgOp = "_Arg";
static const char* const kRetOp = "_Retval";
static const char* const kGradientOp = "SymbolicGradient";
static const char* const kNodeLabel = "Func";
static const char* const kFuncAttr = "f";
static const char* const kNoInlineAttr = "_noinline";
//
// TODO(zhifengc): Dedup some of these constants into
// framework/function.h
static constexpr const char* const kArgOp = "_Arg";
static constexpr const char* const kRetOp = "_Retval";
static constexpr const char* const kGradientOp =
FunctionLibraryDefinition::kGradientOp;
static constexpr const char* const kNodeLabel = "Func";
static constexpr const char* const kFuncAttr =
FunctionLibraryDefinition::kFuncAttr;
static constexpr const char* const kNoInlineAttr = "_noinline";

// Represents the index-th output of a node.
struct Endpoint {
Expand Down Expand Up @@ -926,46 +931,13 @@ static void InlineFunctionBody(Graph* g, Node* caller,
g->RemoveNode(caller); // 'caller' is replaced with inlined nodes.
}

// Given a node's NodeDef, returns false iff the node explicitly
// specified _noinline. This gives ExpandInlineFunctions a heuristic
// to decide whether to inline the function.
bool ShouldInline(const NodeDef& ndef) {
bool noinline = false;
if (GetNodeAttr(ndef, kNoInlineAttr, &noinline).ok()) {
// If the node specifies attribute '_noinline', returns accordingly.
return !noinline;
}
if (ndef.op() != kGradientOp) {
// If the op is not SymbolicGradient, we should be free to decide
// whether to inline or not.
return true;
}
// If the node is a SymbolicGradient, we use the forward
// function's attribute '_noinline' instead.
const NameAttrList* forward_func_attrs;
Status s =
GetNodeAttr(AttrSlice(&ndef.attr()), kFuncAttr, &forward_func_attrs);
if (!s.ok()) {
// The node def is malformed (missing attribute 'f'), we'll just
// continue and the runtime will error out.
return false;
}
s = GetNodeAttr(AttrSlice(&forward_func_attrs->attr()), kNoInlineAttr,
&noinline);
if (!s.ok()) {
// The forward function doesn't specify '_noinline' attr, we should
// be free to decide.
return true;
}
// Otherwise, make inline decision according to the attr.
return !noinline;
}

bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
std::vector<std::pair<Node*, const FunctionBody*>> candidates;
const FunctionLibraryDefinition* fld = lib->GetFunctionLibraryDefinition();
for (Node* node : graph->nodes()) {
VLOG(3) << "Expanding " << node->DebugString();
if (!ShouldInline(node->def())) {
bool noinline;
if (fld->GetAttr(node->def(), kNoInlineAttr, &noinline).ok() && noinline) {
VLOG(3) << "noinline: " << node->DebugString();
continue;
}
Expand Down
26 changes: 26 additions & 0 deletions tensorflow/core/framework/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,32 @@ Status FunctionLibraryDefinition::LookUp(
return default_registry_->LookUp(op, op_reg_data);
}

const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
const NodeDef& ndef) const {
if (ndef.op() != kGradientOp) {
// If 'ndef' calls a function and the function's def has the attr,
// returns it.
return Find(ndef.op());
}

// If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or
// Foo's attributes.
const NameAttrList* forward_func_attrs;
if (!GetNodeAttr(AttrSlice(&ndef.attr()), kFuncAttr, &forward_func_attrs)
.ok()) {
return nullptr;
}
const string& func_name = forward_func_attrs->name();
const string& grad_name = FindGradient(func_name);
// If 'func' has a user-defined gradient function, uses the grad
// function's attrs to see if noinline is specified. Otherwise,
// uses func's attrs.
if (!grad_name.empty()) {
return Find(grad_name);
}
return Find(func_name);
}

FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
FunctionDefLibrary lib;
for (const auto& f : function_defs_) {
Expand Down
33 changes: 29 additions & 4 deletions tensorflow/core/framework/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_FRAMEWORK_FUNCTION_H_
#define TENSORFLOW_FRAMEWORK_FUNCTION_H_

#include <unordered_map>

#include <vector>
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/function.pb.h"
Expand All @@ -26,6 +24,8 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
Expand Down Expand Up @@ -308,6 +308,15 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
Status LookUp(const string& op_type_name,
const OpRegistrationData** op_reg_data) const override;

static constexpr const char* const kGradientOp = "SymbolicGradient";
static constexpr const char* const kFuncAttr = "f";

// Given a node def 'ndef', inspects attributes of the callee
// function to derive the attribute 'value' for 'attr'. Returns OK
// iff the attribute is given by the function's definition.
template <typename T>
Status GetAttr(const NodeDef& ndef, const string& attr, T* value) const;

// Returns a proto representation of the state of this function library.
FunctionDefLibrary ToProto() const;

Expand All @@ -322,9 +331,13 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
};

const OpRegistryInterface* const default_registry_;
std::unordered_map<string, std::unique_ptr<FunctionDefAndOpRegistration>>
gtl::FlatMap<string, std::unique_ptr<FunctionDefAndOpRegistration>, HashStr>
function_defs_;
std::unordered_map<string, string> func_grad_;
gtl::FlatMap<string, string, HashStr> func_grad_;

// Helper function for GetAttr. Returns the FunctionDef* to get the
// attr from.
const FunctionDef* GetAttrImpl(const NodeDef& ndef) const;
};

// Forward declare. Defined in common_runtime/function.h
Expand Down Expand Up @@ -473,6 +486,18 @@ bool RegisterOp(const string& op, Creator func);
Status GetOpGradientCreator(const string& op, Creator* creator);
};

// Implementation details.

template <typename T>
Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef,
const string& attr, T* value) const {
const FunctionDef* fdef = GetAttrImpl(ndef);
if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) {
return Status::OK();
}
return errors::InvalidArgument("Attr ", attr, " is not defined.");
}

} // end namespace tensorflow

#endif // TENSORFLOW_FRAMEWORK_FUNCTION_H_
3 changes: 3 additions & 0 deletions tensorflow/core/framework/function.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ message FunctionDef {
// attrs etc.
OpDef signature = 1;

// Attributes specific to this function definition.
map<string, AttrValue> attr = 5;

// TO BE REPLACED

// The body of the function.
Expand Down
80 changes: 80 additions & 0 deletions tensorflow/core/framework/function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -995,4 +995,84 @@ TEST(FunctionLibraryDefinitionTest, ToProto) {
EXPECT_EQ(f3->DebugString(), f4->DebugString());
}

TEST(FunctionLibraryDefinitionTest, GetAttr_FuncNoAttr) {
FunctionDefLibrary proto;
*proto.add_function() = test::function::XTimesTwo();
FunctionLibraryDefinition lib(OpRegistry::Global(), proto);

NodeDef ndef;
bool annotation;

// Not a function.
ndef.set_op("Matmul");
EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok());

// A function. No attr defined.
ndef.set_op("XTimesTwo");
EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok());

// ndef defines the attr. But we don't care.
AddNodeAttr("annotation", true, &ndef);
EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok());
}

template <typename T>
void SetAttrValue(FunctionDef* fdef, const string& attr, const T& value) {
AttrValue attr_value;
SetAttrValue(value, &attr_value);
fdef->mutable_attr()->insert({attr, attr_value});
}

TEST(FunctionLibraryDefinitionTest, GetAttr_FuncWithAttr) {
FunctionDefLibrary proto;
auto fdef = proto.add_function();
*fdef = test::function::XTimesTwo();
SetAttrValue(fdef, "annotation", true);
SetAttrValue(fdef, "options", "some string data");
FunctionLibraryDefinition lib(OpRegistry::Global(), proto);

NodeDef ndef;
bool annotation;

// A function. No attr defined in ndef.
ndef.set_op("XTimesTwo");
TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation));
EXPECT_EQ(annotation, true);

string str;
TF_EXPECT_OK(lib.GetAttr(ndef, "options", &str));
EXPECT_EQ(str, "some string data");
}

TEST(FunctionLibraryDefinitionTest, GetAttr_Gradient) {
FunctionDefLibrary proto;
auto fdef = proto.add_function();
*fdef = test::function::XTimesTwo();
SetAttrValue(fdef, "annotation", true);
*fdef = test::function::WXPlusB();
SetAttrValue(fdef, "annotation", false);
auto func_grad = proto.add_gradient();
func_grad->set_function_name("XTimesTwo");
func_grad->set_gradient_func("WXPlusB");
FunctionLibraryDefinition lib(OpRegistry::Global(), proto);

NodeDef ndef;
ndef.set_op(FunctionLibraryDefinition::kGradientOp);

bool annotation;
EXPECT_FALSE(lib.GetAttr(ndef, "annotation", &annotation).ok());

NameAttrList nal;
nal.set_name("XTimesTwo");
AddNodeAttr(FunctionLibraryDefinition::kFuncAttr, nal, &ndef);
TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation));
EXPECT_EQ(annotation, false); // XTimesTwo's gradient is WXPlusB.

nal.set_name("WXPlusB");
ndef.clear_attr();
AddNodeAttr(FunctionLibraryDefinition::kFuncAttr, nal, &ndef);
TF_EXPECT_OK(lib.GetAttr(ndef, "annotation", &annotation));
EXPECT_EQ(annotation, false); // WXPlusB has no custom gradient.
}

} // end namespace tensorflow
Loading

0 comments on commit aac685b

Please sign in to comment.