Skip to content

Commit

Permalink
Merge pull request #404 from robertknight/nodeid-usize
Browse files Browse the repository at this point in the history
Streamline NodeId -> usize conversions
  • Loading branch information
robertknight authored Nov 11, 2024
2 parents aaaa8ce + 7d1f287 commit 19d2147
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
2 changes: 1 addition & 1 deletion rten-generate/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ mod tests {
}

fn node_info(&self, id: NodeId) -> Option<NodeInfo> {
self.nodes.get(id.as_u32() as usize).cloned()
self.nodes.get(id.as_usize()).cloned()
}

fn input_ids(&self) -> &[NodeId] {
Expand Down
32 changes: 17 additions & 15 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,11 @@ impl NodeId {
self.0.get() - 1
}

/// Return the underlying ID value as a usize, for slice indexing.
pub fn as_usize(self) -> usize {
self.as_u32() as usize
}

/// Construct a node ID from a u32 value.
///
/// Panics if the value exceeds `i32::MAX`.
Expand Down Expand Up @@ -406,14 +411,14 @@ impl NodeRefCount {
/// Increment ref count of node. If the refcount reaches `u8::MAX` it
/// will become "sticky" and never decrement.
fn inc(&mut self, id: NodeId) {
let rc = &mut self.rc[id.as_u32() as usize];
let rc = &mut self.rc[id.as_usize()];
*rc = rc.saturating_add(1);
}

/// Decrement ref count of node and return new count, or `None` if the
/// ref count was already zero.
fn dec(&mut self, id: NodeId) -> Option<usize> {
let rc = &mut self.rc[id.as_u32() as usize];
let rc = &mut self.rc[id.as_usize()];

// If the refcount reaches the max value, it becomes sticky.
if *rc == u8::MAX {
Expand All @@ -427,7 +432,7 @@ impl NodeRefCount {
}

fn count(&self, id: NodeId) -> usize {
self.rc[id.as_u32() as usize] as usize
self.rc[id.as_usize()] as usize
}
}

Expand Down Expand Up @@ -719,7 +724,7 @@ impl Graph {
let node_id = NodeId::from_u32(self.nodes.len() as u32);
self.nodes.push(node);

if let Some(name) = self.nodes[node_id.as_u32() as usize].name() {
if let Some(name) = self.nodes[node_id.as_usize()].name() {
self.node_id_from_name.insert(name.to_string(), node_id);
}

Expand Down Expand Up @@ -855,7 +860,7 @@ impl Graph {

/// Retrieve a node by ID
pub fn get_node(&self, id: NodeId) -> Option<&Node> {
self.nodes.get(id.as_u32() as usize)
self.nodes.get(id.as_usize())
}

/// Look up a node ID given its unique name
Expand All @@ -875,7 +880,7 @@ impl Graph {

/// Retrieve a node by ID
pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut Node> {
self.nodes.get_mut(id.as_u32() as usize)
self.nodes.get_mut(id.as_usize())
}

/// Return the total number of parameters in all constant nodes in this
Expand Down Expand Up @@ -995,7 +1000,7 @@ impl Graph {

let inputs_by_id: FxHashMap<NodeId, InputOrOutput> = inputs.iter().cloned().collect();
let get_value_from_constant_or_input = |node_id: NodeId| -> Option<Input> {
match self.nodes.get(node_id.as_u32() as usize) {
match self.nodes.get(node_id.as_usize()) {
Some(Node::Constant(constant)) => Some(constant.as_input()),
Some(Node::Value(_)) => inputs_by_id.get(&node_id).map(|input| input.as_input()),
_ => {
Expand All @@ -1005,24 +1010,21 @@ impl Graph {
};

let get_value_from_capture = |node_id: NodeId| -> Option<Input> {
let name = self
.nodes
.get(node_id.as_u32() as usize)
.and_then(|n| n.name())?;
let name = self.nodes.get(node_id.as_usize()).and_then(|n| n.name())?;
captures.as_ref().and_then(|cap| cap.get_input(name))
};

// Count how often each temporary output is used, so we can free them
// when no longer needed.
let mut temp_value_refcount = NodeRefCount::with_capacity(self.nodes.len());
for &op_node_id in plan.iter() {
let Some(Node::Operator(op_node)) = self.nodes.get(op_node_id.as_u32() as usize) else {
let Some(Node::Operator(op_node)) = self.nodes.get(op_node_id.as_usize()) else {
return Err(RunError::PlanningError(
"operator node not found".to_string(),
));
};
for node_id in self.operator_dependencies(op_node) {
if let Some(Node::Value(_)) = self.nodes.get(node_id.as_u32() as usize) {
if let Some(Node::Value(_)) = self.nodes.get(node_id.as_usize()) {
temp_value_refcount.inc(node_id);
}
}
Expand Down Expand Up @@ -1054,7 +1056,7 @@ impl Graph {
let mut op_start = Instant::now();

for (step, &op_node_id) in plan.iter().enumerate() {
let Some(Node::Operator(op_node)) = self.nodes.get(op_node_id.as_u32() as usize) else {
let Some(Node::Operator(op_node)) = self.nodes.get(op_node_id.as_usize()) else {
return Err(RunError::PlanningError(
"operator node not found".to_string(),
));
Expand Down Expand Up @@ -1375,7 +1377,7 @@ impl Graph {
// Walk forwards through the plan and prune away steps that cannot be
// computed due to missing inputs.
for &node_id in plan {
let Some(Node::Operator(op_node)) = self.nodes.get(node_id.as_u32() as usize) else {
let Some(Node::Operator(op_node)) = self.nodes.get(node_id.as_usize()) else {
continue;
};

Expand Down

0 comments on commit 19d2147

Please sign in to comment.