Skip to content

Commit

Permalink
Merge pull request #407 from robertknight/capture-by-value
Browse files Browse the repository at this point in the history
Support passing captures by-value to subgraphs
  • Loading branch information
robertknight authored Nov 14, 2024
2 parents 24da48c + bfbb642 commit 773f728
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 58 deletions.
251 changes: 197 additions & 54 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,11 +490,16 @@ impl CachedPlan {
///
/// Subgraphs used by control flow operators (`If`, `Loop` etc.) may contain
/// value nodes that capture their values from parent graphs, like a captured
/// value in a Rust closure.
/// value in a Rust closure. A `CaptureEnv` is passed to the subgraph when
/// it is executed and used to resolve these values.
///
/// `CaptureEnv`s are arranged in a hierarchy. Value lookups will attempt to
/// look up the value in the environment's associated graph. If no such node
/// exists, the value will be looked up in the parent environment and so on.
///
/// Values can be captured either by reference or
/// by value. Values that are captured by-value can potentially be used as
/// [`in-place inputs`](Operator::run_in_place).
#[derive(Clone)]
pub struct CaptureEnv<'a> {
// The parent environment to search if a node name is not found in this
Expand All @@ -504,13 +509,16 @@ pub struct CaptureEnv<'a> {
// The "local" graph for this environment. Node names are looked up in
// this graph first and if found, values are resolved from `inputs` or
// `temp_values`.
graph: &'a Graph,
graph: Option<&'a Graph>,

// Values passed as inputs to the graph run.
inputs: Option<&'a FxHashMap<NodeId, InputOrOutput<'a>>>,

// Temporary values computed during the graph run.
temp_values: Option<&'a FxHashMap<NodeId, Output>>,
// Values computed during the graph run, captured by reference.
temp_values_by_ref: Option<&'a FxHashMap<NodeId, Output>>,

// Values computed during the graph run, captured by value.
temp_values: Option<FxHashMap<NodeId, Output>>,
}

impl<'a> CaptureEnv<'a> {
Expand All @@ -519,60 +527,99 @@ impl<'a> CaptureEnv<'a> {
/// Lookups will first match nodes in `graph` and then try the `parent`
/// environment if that fails. Lookups that match constant nodes will be
/// resolved from the node directly. Lookups that match value nodes will
/// be resolved from `temp_values` first and then `inputs` if there is no
/// match there.
/// be resolved from the captured values first or the captured inputs
/// otherwise.
pub fn new(
parent: Option<&'a CaptureEnv<'a>>,
graph: &'a Graph,
inputs: Option<&'a FxHashMap<NodeId, InputOrOutput<'a>>>,
temp_values: Option<&'a FxHashMap<NodeId, Output>>,
temp_values_by_ref: Option<&'a FxHashMap<NodeId, Output>>,
temp_values: Option<FxHashMap<NodeId, Output>>,
) -> CaptureEnv<'a> {
CaptureEnv {
parent,
graph,
graph: Some(graph),
inputs,
temp_values_by_ref,
temp_values,
}
}

/// Return a new capture environment which has `self` as a parent.
///
/// The child `CaptureEnv` will have no captures of its own. This is useful
/// in loop operators which need to create a new capture environment to pass
/// to each iteration of a loop.
pub fn child(&self) -> CaptureEnv {
CaptureEnv {
parent: Some(self),
graph: None,
inputs: None,
temp_values_by_ref: None,
temp_values: None,
}
}

/// Look up a node by name in this environment.
pub fn get_node(&self, name: &str) -> Option<&'a Node> {
if let Some(node_id) = self.graph.get_node_id(name) {
// If a node by this name exists in this graph, but is a placeholder
// for a value captured from a parent graph, then ignore it.
if !self.graph.captures().contains(&node_id) {
return self.graph.get_node(node_id);
if let Some(graph) = self.graph {
if let Some(node_id) = graph.get_node_id(name) {
// If a node by this name exists in this graph, but is a placeholder
// for a value captured from a parent graph, then ignore it.
if !graph.captures().contains(&node_id) {
return graph.get_node(node_id);
}
}
}

self.parent.and_then(|parent| parent.get_node(name))
}

/// Look up an operator input value by name in this environment.
pub fn get_input(&self, name: &str) -> Option<Input<'a>> {
if let Some(node_id) = self.graph.get_node_id(name) {
// If a node by this name exists in this graph, but is a placeholder
// for a value captured from a parent graph, then ignore it.
if !self.graph.captures().contains(&node_id) {
// Otherwise, get the value from this scope.
return match self.graph.get_node(node_id) {
Some(Node::Constant(c)) => Some(c.as_input()),
Some(Node::Value(_)) => self
.temp_values
.and_then(|tv| tv.get(&node_id))
.map(|i| i.as_input())
.or_else(|| {
self.inputs
.and_then(|i| i.get(&node_id))
.map(|i| i.as_input())
}),
_ => None,
};
pub fn get_input(&self, name: &str) -> Option<Input> {
if let Some(graph) = self.graph {
if let Some(node_id) = graph.get_node_id(name) {
// If a node by this name exists in this graph, but is a placeholder
// for a value captured from a parent graph, then ignore it.
if !graph.captures().contains(&node_id) {
// Otherwise, get the value from this scope.
return match graph.get_node(node_id) {
Some(Node::Constant(c)) => Some(c.as_input()),
Some(Node::Value(_)) => self
.temp_values_by_ref
.and_then(|tv| tv.get(&node_id))
.map(|i| i.as_input())
.or_else(|| {
self.temp_values
.as_ref()
.and_then(|tv| tv.get(&node_id))
.map(|o| o.as_input())
})
.or_else(|| {
self.inputs
.and_then(|i| i.get(&node_id))
.map(|i| i.as_input())
}),
_ => None,
};
}
}
}

self.parent.and_then(|parent| parent.get_input(name))
}

/// Remove and return a value from the capture environment's map of by-value
/// captures.
pub fn take_input(&mut self, name: &str) -> Option<Output> {
let node_id = self.graph.and_then(|g| g.get_node_id(name))?;
self.temp_values.as_mut()?.remove(&node_id)
}

/// Remove and return all by-value captures.
pub fn take_all_inputs(&mut self) -> Option<FxHashMap<NodeId, Output>> {
self.temp_values.take()
}
}

/// Options that control logging and other behaviors when executing a
Expand Down Expand Up @@ -946,7 +993,7 @@ impl Graph {
&self,
inputs: Vec<(NodeId, InputOrOutput)>,
outputs: &[NodeId],
captures: &CaptureEnv,
captures: CaptureEnv,
pool: Option<&TensorPool>,
opts: Option<RunOptions>,
) -> Result<Vec<Output>, RunError> {
Expand Down Expand Up @@ -990,15 +1037,15 @@ impl Graph {
mut inputs: Vec<(NodeId, InputOrOutput)>,
plan: &[NodeId],
outputs: &[NodeId],
captures: Option<&CaptureEnv>,
mut captures: Option<CaptureEnv>,
pool: Option<&TensorPool>,
opts: Option<RunOptions>,
) -> Result<Vec<Output>, RunError> {
let opts = opts.unwrap_or_default();

let mut temp_values: FxHashMap<NodeId, Output> = FxHashMap::default();

// Extract all the owned tensor inputs into the temp value map.
// Extract all owned tensor inputs into the owned value map.
//
// This enables these inputs to be used for in-place operations or
// returned directly as outputs.
Expand All @@ -1025,10 +1072,14 @@ impl Graph {
}
};

let get_value_from_capture = |node_id: NodeId| -> Option<Input> {
let name = self.nodes.get(node_id.as_usize()).and_then(|n| n.name())?;
captures.as_ref().and_then(|cap| cap.get_input(name))
};
fn get_value_from_capture<'a>(
nodes: &[Node],
captures: Option<&'a CaptureEnv>,
node_id: NodeId,
) -> Option<Input<'a>> {
let name = nodes.get(node_id.as_usize()).and_then(|n| n.name())?;
captures.and_then(|cap| cap.get_input(name))
}

// Count how often each temporary output is used, so we can free them
// when no longer needed.
Expand Down Expand Up @@ -1109,16 +1160,43 @@ impl Graph {
None
};

// Take a value for passing to an operator as an owned value, if
// it won't be needed by other operators in future.
let mut take_value = |node_id| {
if temp_value_refcount.count(node_id) == 1 {
if let Some(value) = temp_values.remove(&node_id) {
Some(value)
} else if self.captures.contains(&node_id) {
let name = self.nodes.get(node_id.as_usize()).and_then(|n| n.name())?;
captures.as_mut().and_then(|cap| cap.take_input(name))
} else {
None
}
} else {
None
}
};

// If the operator can run in place, check if we have a tensor
// that can be used as the output. This requires that the tensor
// is not a constant (eg. weights) and is not going to be used by
// other ops in future.
let in_place_input = in_place_input_id.and_then(|input| {
if temp_value_refcount.count(input) == 1 {
temp_values.remove(&input)
} else {
None
let in_place_input = in_place_input_id.and_then(&mut take_value);

// Extract values used by the operator's subgraphs which can be
// passed by value.
let has_subgraph = op_node.operator.has_subgraph();
let by_value_captures = has_subgraph.then(|| {
let mut by_value_captures = FxHashMap::default();
for node_id in self.operator_dependencies(op_node) {
if op_node.inputs.contains(&Some(node_id)) {
continue;
}
if let Some(tensor) = take_value(node_id) {
by_value_captures.insert(node_id, tensor);
}
}
by_value_captures
});

// Collect all or remaining inputs for the operator
Expand All @@ -1134,7 +1212,9 @@ impl Graph {
op_inputs.push(Some(value));
} else if let Some(value) = temp_values.get(node_id) {
op_inputs.push(Some(value.as_input()));
} else if let Some(value) = get_value_from_capture(*node_id) {
} else if let Some(value) =
get_value_from_capture(&self.nodes, captures.as_ref(), *node_id)
{
op_inputs.push(Some(value))
} else {
// If this is reached, there was a bug in plan creation.
Expand Down Expand Up @@ -1175,16 +1255,20 @@ impl Graph {
.run_in_place(pool, input, InputList::from_optional(&op_inputs))
.map(|out| [out].into())
.map_err(op_error_to_run_error)
} else if op_node.operator.has_subgraph() {
let capture_env =
CaptureEnv::new(captures, self, Some(&inputs_by_id), Some(&temp_values));
let result = op_node.operator.run_subgraph(
} else if has_subgraph {
let capture_env = CaptureEnv::new(
captures.as_ref(),
self,
Some(&inputs_by_id),
Some(&temp_values),
by_value_captures,
);
op_node.operator.run_subgraph(
pool,
InputList::from_optional(&op_inputs),
&capture_env,
capture_env,
Some(opts.clone()),
);
result
)
} else {
op_node
.operator
Expand Down Expand Up @@ -1252,7 +1336,9 @@ impl Graph {
.map(|output_id| {
if let Some(value) = get_value_from_constant_or_input(*output_id) {
value.to_output()
} else if let Some(value) = get_value_from_capture(*output_id) {
} else if let Some(value) =
get_value_from_capture(&self.nodes, captures.as_ref(), *output_id)
{
value.to_output()
} else {
// During execution planning we verified that each output
Expand All @@ -1261,6 +1347,15 @@ impl Graph {
}
})
.collect();

// Release any unused captured values back into the pool for use by
// parent graphs.
if let Some(values) = captures.and_then(|mut cap| cap.take_all_inputs()) {
for (_, value) in values {
value.add_to_pool(pool);
}
}

Ok(result)
}

Expand Down Expand Up @@ -2535,7 +2630,7 @@ mod tests {
&self,
pool: &TensorPool,
inputs: InputList,
captures: &CaptureEnv,
captures: CaptureEnv,
options: Option<RunOptions>,
) -> Result<OutputList, RunError> {
let inputs = self
Expand Down Expand Up @@ -2784,4 +2879,52 @@ mod tests {
let result: Tensor<f32> = result.remove(0).try_into().unwrap();
assert_eq!(result.item(), Some(&3.));
}

#[test]
fn test_captures_by_value_if_possible() {
// Set up a graph that runs a subgraph and passes captures by value,
// if the value is passed to the graph as an owned value.
let mut g = Graph::new();
let input_id = g.add_value(Some("input"), None);

let mut subgraph = Graph::new();
let sg_input = subgraph.add_value(Some("input"), None);
subgraph.set_captures(&[sg_input]);

let id_op = TrackUsage::new(Identity {});
let id_op_metrics = id_op.metrics();
let (_, id_out) = subgraph.add_simple_op("Id", id_op, &[sg_input]);
subgraph.set_output_ids(&[id_out]);
let (_, out) = g.add_simple_op("Subgraph", Subgraph { graph: subgraph }, &[]);

// Run graph with an owned value as input.
let input = Tensor::from(42.);
let mut result = g.run(vec![(input_id, input.into())], &[out], None).unwrap();

// Check result and that Identity operation was run in-place.
let result: Tensor<f32> = result.remove(0).try_into().unwrap();
assert_eq!(result.item(), Some(&42.));

{
let id_op_metrics = id_op_metrics.lock().unwrap();
assert_eq!(id_op_metrics.run_count, 0);
assert_eq!(id_op_metrics.run_in_place_count, 1);
}

// Run graph with view as input.
let input = Tensor::from(42.);
let mut result = g
.run(vec![(input_id, input.view().into())], &[out], None)
.unwrap();

// Check result and that Identity operation was not run in-place.
let result: Tensor<f32> = result.remove(0).try_into().unwrap();
assert_eq!(result.item(), Some(&42.));

{
let id_op_metrics = id_op_metrics.lock().unwrap();
assert_eq!(id_op_metrics.run_count, 1);
assert_eq!(id_op_metrics.run_in_place_count, 1);
}
}
}
2 changes: 1 addition & 1 deletion src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ impl Model {
optimize,
capture_env,
} = &subgraph_opts;
let capture_env = CaptureEnv::new(*capture_env, graph, None, None);
let capture_env = CaptureEnv::new(*capture_env, graph, None, None, None);
Self::load_graph(
g,
registry,
Expand Down
Loading

0 comments on commit 773f728

Please sign in to comment.