Skip to content

Commit

Permalink
Python DefFun now creates functions that include NodeDefs
Browse files Browse the repository at this point in the history
(in addition to the older FunctionDef::Node format).  Add an
optional out_names argument to Defun and names to Declare's
arguments, so that signatures for forward-declared DefFuns
can have signatures in the name (required for this change).
Change: 139919259
  • Loading branch information
tensorflower-gardener committed Nov 22, 2016
1 parent be3e778 commit 5591ca5
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 41 deletions.
191 changes: 152 additions & 39 deletions tensorflow/python/framework/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,18 @@ def _make_argname_from_tensor_name(name):
return re.sub(":0$", "", name).replace(":", "_o")


def _tensor_to_argdef(t):
def _tensor_to_argdef(t, name=None):
arg = op_def_pb2.OpDef.ArgDef()
arg.name = _make_argname_from_tensor_name(t.name)
if name is None:
arg.name = _make_argname_from_tensor_name(t.name)
else:
arg.name = name
arg.type = t.dtype.as_datatype_enum
return arg


def _get_node_def_attr(op):
# pylint: disable=protected-access
return op._node_def.attr
# pylint: enable=protected-access
def _get_node_def(op):
return op._node_def # pylint: disable=protected-access


def _add_input_array(op, start, limit, dtype, func):
Expand Down Expand Up @@ -122,17 +123,66 @@ def _add_output_list(op, start, limit, dtype_lst, func):
return ret_name


def _add_op_node(op, func):
"""Converts an op to a function def node and add it to `func`."""
node = function_pb2.FunctionDef.Node()
node.op = op.type
def _get_op_def(op):
# pylint: disable=protected-access
if hasattr(op, "_sig"):
op_def = getattr(op, "_sig")
return getattr(op, "_sig")
else:
op_def = op_def_registry.get_registered_ops()[op.type]
return op_def_registry.get_registered_ops()[op.type]
# pylint: enable=protected-access
attrs = _get_node_def_attr(op)


def _is_in_placeholders(op, func_arg_placeholders):
return op.values() and (op.values()[0].name in func_arg_placeholders)


def _create_input_dict(function_graph, func_arg_placeholders):
"""Create a mapping from graph tensor names to function tensor names."""
input_dict = {}
for op in function_graph.get_operations():
if _is_in_placeholders(op, func_arg_placeholders):
input_dict[op.values()[0].name] = op.values()[0].name
input_dict[op.name] = op.name
else:
op_def = _get_op_def(op)
attrs = _get_node_def(op).attr
o = 0
for arg_def in op_def.output_arg:
if arg_def.number_attr:
num = attrs[arg_def.number_attr].i
elif arg_def.type_list_attr:
num = len(attrs[arg_def.type_list_attr].list.type)
else:
num = 1
for i in range(num):
result = "%s:%s:%d" % (op.name, arg_def.name, i)
input_dict[op.values()[o].name] = result
if o == 0:
input_dict[op.name] = result
o += 1
return input_dict


def _add_op_node(op, func, input_dict):
"""Converts an op to a function def node and add it to `func`."""
# Add an entry in func.node_def

# Note that extend() makes a copy in this case, see:
# https://developers.google.com/protocol-buffers/docs/reference/python-generated#repeated-message-fields
func.node_def.extend([_get_node_def(op)])
node_def = func.node_def[-1]
for i in range(len(node_def.input)):
if not node_def.input[i].startswith("^"):
assert node_def.input[i] in input_dict, (
"%s missing from %s" % (node_def.input[i], input_dict.items()))
node_def.input[i] = input_dict[node_def.input[i]]

# To support legacy consumers, add an entry in func.node.
# TODO(josh11b): Delete this.
node = function_pb2.FunctionDef.Node()
node.op = op.type
op_def = _get_op_def(op)
attrs = node_def.attr
if not op_def.output_arg:
node.ret.append(_make_argname_from_tensor_name(op.name))
else:
Expand Down Expand Up @@ -174,12 +224,31 @@ def _add_op_node(op, func):
inp_index += 1
node.dep.extend(
[_make_argname_from_tensor_name(x.name) for x in op.control_inputs])
for k, v in _get_node_def_attr(op).items():
for k, v in attrs.items():
node.attr[k].CopyFrom(v)
func.node.extend([node])


def _graph_to_function_def(graph, inputs, outputs):
def _replace_ret(func, original, replacement):
for n in func.node:
for i, r in enumerate(n.ret):
if r == original:
n.ret[i] = replacement
return
raise ValueError("Could not find ret == '%s'" % original)


def _replace_arg(func, original, replacement):
for n in func.node:
for i, a in enumerate(n.arg):
if a == original:
n.arg[i] = replacement
for i, d in enumerate(n.dep):
if d == original:
n.dep[i] = replacement


def _graph_to_function_def(graph, inputs, outputs, out_names=None):
"""Returns `graph` as a `FunctionDef` protocol buffer.
This method creates a [`FunctionDef`](
Expand All @@ -195,19 +264,47 @@ def _graph_to_function_def(graph, inputs, outputs):
graph: Graph.
inputs: List of tensors. Inputs to the function.
outputs: List of tensors. Outputs of the function.
out_names: Optional list of string names for the outputs.
Returns:
A FunctionDef protocol buffer.
Raises:
ValueError: if out_names is specified and the wrong length.
"""
func = function_pb2.FunctionDef()
func.signature.name = "_"
func.signature.input_arg.extend([_tensor_to_argdef(i) for i in inputs])
func.signature.output_arg.extend([_tensor_to_argdef(o) for o in outputs])
if out_names is None:
func.signature.output_arg.extend([_tensor_to_argdef(o) for o in outputs])
elif len(outputs) != len(out_names):
raise ValueError(
"Length of out_names (%d) does not match number of outputs (%d): %s" %
(len(out_names), len(outputs), ", ".join(out_names)))
else:
func.signature.output_arg.extend([
_tensor_to_argdef(o, n) for o, n in zip(outputs, out_names)])
func_arg_placeholders = set([i.name for i in inputs])
input_dict = _create_input_dict(graph, func_arg_placeholders)

for op in graph.get_operations():
if op.values() and (op.values()[0].name in func_arg_placeholders):
if _is_in_placeholders(op, func_arg_placeholders):
continue
_add_op_node(op, func)
_add_op_node(op, func, input_dict)

if out_names is None:
for o in outputs:
k = _make_argname_from_tensor_name(o.name)
func.ret[k] = input_dict[o.name]
else:
for o, n in zip(outputs, out_names):
func.ret[n] = input_dict[o.name]
# TODO(josh11b): Delete this once we switch fully to NodeDefs for
# function bodies.
k = _make_argname_from_tensor_name(o.name)
_replace_ret(func, k, n)
_replace_arg(func, k, n)

return func


Expand Down Expand Up @@ -251,7 +348,6 @@ def _call(sig, *inputs, **kwargs):
Raises:
ValueError: if the arguments are invalid.
"""
if len(inputs) != len(sig.input_arg):
raise ValueError("Expected number of arguments: %d, received: %d" %
Expand Down Expand Up @@ -301,7 +397,6 @@ class _FuncGraph(ops.Graph):
Each captured input's corresponding place holder is converted into a
function argument and the caller passes in the captured tensor.
"""

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -385,7 +480,6 @@ def get_extra_inputs():
returned list of tensors are those accessed inside the function body
but defined outside the function body so far. Otherwise, returns an
empty list.
"""
g = ops.get_default_graph()
if isinstance(g, _FuncGraph):
Expand All @@ -402,7 +496,6 @@ def get_extra_args():
returned list of place holders are those used inside the function
body corresponding those returned by get_extra_inputs(). Otherwise,
returns an empty list.
"""
g = ops.get_default_graph()
if isinstance(g, _FuncGraph):
Expand All @@ -429,6 +522,7 @@ def __init__(self,
func_name=None,
grad_func=None,
python_grad_func=None,
out_names=None,
**kwargs):
"""Creates _DefinedFunction.
Expand All @@ -443,6 +537,8 @@ def __init__(self,
to None.
python_grad_func: A python callable implementing the gradient of
the function python-side.
out_names: An optional list of strings for the function return value
names.
**kwargs: The keyword arguments. **kwargs is passed to every call
site of this function.
Expand All @@ -455,6 +551,7 @@ def __init__(self,
self._func_name = func_name
self._grad_func = grad_func
self._python_grad_func = python_grad_func
self._out_names = out_names
self._extra_kwargs = kwargs
self._definition = None # Constructed lazily.

Expand Down Expand Up @@ -531,7 +628,8 @@ def _create_definition_if_needed(self):
inputs.extend(temp_graph.extra_args)

# Build the FunctionDef
self._definition = _graph_to_function_def(temp_graph, inputs, outputs)
self._definition = _graph_to_function_def(
temp_graph, inputs, outputs, out_names=self._out_names)

# Extra kwargs are treated as attrs on the function def.
kwargs_attr = _parse_kwargs_as_attrs(**self._extra_kwargs)
Expand All @@ -556,6 +654,7 @@ def update_strs(slist):
for s in slist:
update_str(s)

# TODO(josh11b): Switch .node to .node_def
for n in sorted(self._definition.node, key=lambda n: n.ret[0]):
update_strs(n.ret)
update_str(n.op)
Expand Down Expand Up @@ -661,6 +760,7 @@ def __init__(self,
func_name=None,
grad_func=None,
python_grad_func=None,
out_names=None,
**kwargs):
"""Creates _DefinedFunction.
Expand All @@ -673,6 +773,7 @@ def __init__(self,
to None.
python_grad_func: A python callable implementing the gradient of
the function python-side.
out_names: A list of strings for the function return value names.
**kwargs: The keyword arguments. **kwargs is passed to every call
site of this function.
Expand All @@ -686,6 +787,7 @@ def __init__(self,
assert grad_func is None or isinstance(grad_func, _OverloadedFunction)
self._grad_func = grad_func
self._python_grad_func = python_grad_func
self._out_names = out_names
self._extra_kwargs = kwargs
self._overload = {}

Expand All @@ -709,6 +811,7 @@ def instantiate(self, input_types):
name = "_".join([name, key])
defined = _DefinedFunction(self._func, self._argnames, input_types, name,
None, self._python_grad_func,
out_names=self._out_names,
**self._extra_kwargs)
_ = defined.name # Fully instantiate the function definition.
if self._grad_func:
Expand Down Expand Up @@ -802,11 +905,15 @@ def __init__(self, *input_types, **kwargs):
This will be called by tf.gradients to add the gradient ops
to the graph. At most one of grad_func and python_grad_func
can be specified.
out_names = (optional). A list of strings, one per output
tensor.
"""
self._input_types = input_types
self._func_name = kwargs.pop("func_name", None)
self._grad_func = kwargs.pop("grad_func", None)
self._python_grad_func = kwargs.pop("python_grad_func", None)
self._out_names = kwargs.pop("out_names", None)
self._extra_kwargs = kwargs

def __call__(self, func):
Expand All @@ -833,25 +940,28 @@ def __call__(self, func):

if self._input_types:
# If Defun is given a list of types for the inputs, the number
# of of input types should be compatible with 'func'.
# of input types should be compatible with 'func'.
num = len(self._input_types)
if num < min_args or num > max_args:
raise ValueError(
"The function has fewer arguments than the number of specified "
"input types.")
return _DefinedFunction(func, argnames, self._input_types,
self._func_name, self._grad_func,
self._python_grad_func, **self._extra_kwargs)
self._python_grad_func,
out_names=self._out_names, **self._extra_kwargs)

# 'func' expects no arguments and input types is an empty list.
if min_args == 0 and max_args == 0:
return _DefinedFunction(func, [], [], self._func_name, self._grad_func,
self._python_grad_func, **self._extra_kwargs)
self._python_grad_func,
out_names=self._out_names, **self._extra_kwargs)

# Input types are unknown. It's an overloaded function and hence
# its definition needs to be deferred until it's called.
return _OverloadedFunction(func, argnames, self._func_name, self._grad_func,
self._python_grad_func, **self._extra_kwargs)
self._python_grad_func,
out_names=self._out_names, **self._extra_kwargs)


class Declare(object):
Expand All @@ -861,38 +971,41 @@ class Declare(object):
later during a graph construction.
For example,
# Declares a function Foo, which takes a tf.int32 and a
# tf.float32 as inputs and returns a tf.float32 as its output.
foo = Declare("Foo", [tf.int32, tf.float32], [tf.float32])
# Declares a function Foo, which takes a tf.int32 named "n" and a
# tf.float32 named "n" as inputs and returns a tf.float32 named "z"
# as its output.
foo = Declare("Foo", [("n", tf.int32), ("x", tf.float32)],
[("z", tf.float32)])
# Defines a function Bar calls Foo.
@tf.Defun(tf.float32)
def Bar(x):
return foo(6, x)
# Defines Foo.
@tf.Defun(tf.int32, tf.float32)
# Defines Foo, with output named "z".
@tf.Defun(tf.int32, tf.float32, out_names=["z"])
def Foo(n, x):
... # Calculation.
return result
"""

def __init__(self, func_name, input_types, output_types):
def __init__(self, func_name, inputs, outputs):
"""Creates a `Declare` object.
Args:
func_name: The name of the function.
input_types: A list of data types of function arguments.
output_types: A list of data types of function return values.
inputs: A list of (name, data type) pairs of function arguments.
outputs: A list of (name, data type) pairs of function return values.
"""
self._sig = op_def_pb2.OpDef()
self._sig.name = func_name

def _to_argdef_list(types):
return [op_def_pb2.OpDef.ArgDef(type=_.as_datatype_enum) for _ in types]
def _to_argdef_list(args):
return [op_def_pb2.OpDef.ArgDef(type=t.as_datatype_enum, name=n)
for n, t in args]

self._sig.input_arg.extend(_to_argdef_list(input_types))
self._sig.output_arg.extend(_to_argdef_list(output_types))
self._sig.input_arg.extend(_to_argdef_list(inputs))
self._sig.output_arg.extend(_to_argdef_list(outputs))

def __call__(self, *inputs, **kwargs):
inputs = [ops.convert_to_tensor(_) for _ in inputs]
Expand Down
Loading

0 comments on commit 5591ca5

Please sign in to comment.