Skip to content

Commit

Permalink
Improve code for memory schema
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos authored and kali committed Nov 26, 2024
1 parent 37348db commit a684770
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 125 deletions.
8 changes: 5 additions & 3 deletions metal/src/fact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ use tract_core::internal::*;
/// Origin of the metal tensor
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum MetalOrigin {
/// Metal tensor outputted by a GPU ops
/// Metal tensor outputted by a GPU operator
/// Can be either: Owned or ArenaView
/// Note: Tensors marked as FromGPU are from asynchronous operations.
FromGpu,
/// Metal tensor built from a CPU tensor
/// Can be only Owned Metal tensor
/// Metal tensor built from a CPU tensor (CPU op output or Const)
/// Can be only Owned Metal tensor.
/// Note: Tensors marked as FromCPU are from synchronous operations.
FromCpu,
}

Expand Down
4 changes: 2 additions & 2 deletions metal/src/memory/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ impl MetalMemoryPool {
dt: DatumType,
shape: &[usize],
) -> Result<MetalTensor> {
// unsafe { Tensor::uninitialized_dt(dt, shape)?.into_metal() }
// ensure!(!self.node_seen.borrow().contains(&node_id), "Tensor for node {:?} was already requested. Maybe the memory pool was not reset properly.", node_id);
ensure!(!self.node_seen.borrow().contains(&node_id), "Tensor for node {:?} was already requested. Maybe the memory pool was not reset properly.", node_id);
let alignment = dt.alignment();
(self.alignment % alignment == 0)
.then(|| self.resolved_schema.offsets_by_node[node_id])
.flatten()
.map(|offset| {
// self.node_seen.borrow_mut().insert(node_id);
Ok(MetalArenaView {
Expand Down
114 changes: 70 additions & 44 deletions metal/src/memory/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,22 @@ use std::fmt;
use std::fmt::Debug;
use tract_core::internal::*;

/// Requirement for node outputs from a memory perspective.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ScopedNodeMemory {
pub struct NodeMemReq {
pub node: usize,
pub scope: Scope,
pub lifetime: Lifetime,
pub mem_size: TDim,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Scope {
pub struct Lifetime {
pub start: usize,
pub end: usize,
}

impl Scope {
pub fn is_disjoint(&self, other: &Scope) -> bool {
impl Lifetime {
pub fn is_disjoint(&self, other: &Lifetime) -> bool {
self.start >= other.end || other.start >= self.end
}

Expand All @@ -34,10 +35,10 @@ impl Scope {
}
}

pub fn eval_metal_scope_node_mem(
pub fn eval_metal_mem_req_for_nodes(
model: &TypedModel,
order: &[usize],
) -> TractResult<TVec<ScopedNodeMemory>> {
) -> TractResult<TVec<NodeMemReq>> {
let outputs = model.output_outlets()?.to_vec();
let flush_lists = order::build_flush_list(model, order, &outputs, |node| {
let Ok(facts) = model.node_output_facts(node.id) else { return false };
Expand All @@ -47,14 +48,15 @@ pub fn eval_metal_scope_node_mem(
let mut scoped_nodes = tvec![];

for (step, n) in order.iter().enumerate() {
let scope_start = step;
let scope_end = flush_lists
let lifetime_start = step;
let lifetime_end = flush_lists
.iter()
.enumerate()
.find(|(_step, flush_list)| flush_list.contains(n))
.map(|it| usize::min(it.0 + 1, order.len()));

let Some(scope_end) = scope_end else {
// Ignore nodes that won't be flushed from gpu.
let Some(lifetime_end) = lifetime_end else {
continue;
};

Expand All @@ -69,19 +71,21 @@ pub fn eval_metal_scope_node_mem(
continue;
}

scoped_nodes.push(ScopedNodeMemory {
scoped_nodes.push(NodeMemReq {
node: *n,
scope: Scope { start: scope_start, end: scope_end },
lifetime: Lifetime { start: lifetime_start, end: lifetime_end },
mem_size: out_metal_tmp_facts.iter().map(|it| it.mem_size()).sum::<TDim>(),
})
}

Ok(scoped_nodes)
}

/// A partition is a list of node that have disjoint memory requirement from a lifetime
/// perspective.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Partition {
pub nodes: Vec<ScopedNodeMemory>,
pub nodes: Vec<NodeMemReq>,
}

impl Partition {
Expand All @@ -100,62 +104,77 @@ impl Partition {
TDim::Max(self.nodes.iter().map(|s| s.mem_size.clone()).collect())
}

pub fn is_disjoint(&self, scope: &Scope) -> bool {
self.nodes.iter().all(|n| n.scope.is_disjoint(scope))
pub fn has_no_conflict_with_lifetime(&self, lifetime: &Lifetime) -> bool {
self.nodes.iter().all(|n| n.lifetime.is_disjoint(lifetime))
}

pub fn find_node_alive_at_step(&self, step: usize) -> Option<&ScopedNodeMemory> {
self.nodes.iter().find(|it| it.scope.is_alive_at_step(step))
pub fn find_node_alive_at_step(&self, step: usize) -> Option<&NodeMemReq> {
self.nodes.iter().find(|it| it.lifetime.is_alive_at_step(step))
}
}

/// This struct represents a resolved memory schema for a model that contains
/// Metal operators. This schema is concrete.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MetalResolvedMemSchema {
pub offsets_by_node: Vec<usize>,
pub offsets_by_node: Vec<Option<usize>>,
pub memory_size: usize,
}

/// This struct represent a memory schema for node output memory that are handled
/// by Metal GPU.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct MetalMemSchema {
/// Total numbef in the model.
pub model_num_nodes: usize,
pub by_partition: Vec<Partition>,
pub by_steps: Vec<Vec<Option<ScopedNodeMemory>>>,
// vec![vec![Option<NodeMemReq>; num_partitions]; num_steps].
pub by_steps: Vec<Vec<Option<NodeMemReq>>>,
}

impl MetalMemSchema {
pub fn eval_size_by_partition(&self, symbols: &SymbolValues) -> TractResult<Vec<i64>> {
self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).collect()
}

/// Returns memory size of each inner partitions.
pub fn size_by_partition(&self) -> Vec<TDim> {
self.by_partition.iter().map(|it| it.size()).collect()
}

/// Evaluate memory size by partition for given symbol values.
pub fn eval_size_by_partition(&self, symbols: &SymbolValues) -> TractResult<Vec<i64>> {
self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).collect()
}

/// Returns total memory size required for the schema.
pub fn memory_size(&self) -> TDim {
self.by_partition.iter().map(|it| it.size()).sum()
}

/// Evaluate memory size required for the schema for given symbol values.
pub fn eval_memory_size(&self, symbols: &SymbolValues) -> TractResult<i64> {
self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).sum()
}

/// Compute offsets for each node for given symbols. Node ids
/// are indexes in the returned vector.
pub fn compute_offset_by_node(
&self,
num_nodes: usize,
symbols: &SymbolValues,
) -> TractResult<Vec<usize>> {
) -> TractResult<Vec<Option<usize>>> {
let mut cursor = 0;
let mut offset_by_node = vec![0; num_nodes];
let mut offset_by_node = vec![None; self.model_num_nodes];

for partition in self.by_partition.iter() {
for node_mem in partition.nodes.iter() {
offset_by_node[node_mem.node] = cursor;
offset_by_node[node_mem.node] = Some(cursor);
}
cursor += partition.eval_size_to_i64(symbols)? as usize;
}

Ok(offset_by_node)
}

/// Evaluate peak memory size for given symbols. The return value is lower or equal to the memory
/// size of the schema. The difference between peak memory size and memory size represents the
/// memory fragmentation introduced by the schema.
pub fn eval_peak_memory_size(&self, symbols: &SymbolValues) -> TractResult<i64> {
Ok(self
.by_steps
Expand All @@ -174,6 +193,9 @@ impl MetalMemSchema {
.unwrap_or(0))
}

/// Evaluate the usage for given symbols as the ratio between
/// schema memory size and peak memory size. A value of 1.0 means
/// that the schema doesn't introduce memory fragmentation.
pub fn eval_usage(&self, symbols: &SymbolValues) -> TractResult<f32> {
let memory_size = self.eval_memory_size(symbols)? as f32;
let peak_memory_size = self.eval_peak_memory_size(symbols)? as f32;
Expand Down Expand Up @@ -205,47 +227,47 @@ impl fmt::Display for MetalMemSchema {
}

impl MetalMemSchema {
pub fn resolve(
&self,
num_nodes: usize,
symbols: &SymbolValues,
) -> TractResult<MetalResolvedMemSchema> {
/// Resolve Memory schema with given symbols.
pub fn resolve(&self, symbols: &SymbolValues) -> TractResult<MetalResolvedMemSchema> {
Ok(MetalResolvedMemSchema {
offsets_by_node: self.compute_offset_by_node(num_nodes, symbols)?,
offsets_by_node: self.compute_offset_by_node(symbols)?,
memory_size: self.eval_memory_size(symbols)?.try_into()?,
})
}

/// Build a memory schema for given model and execution order. The hint is used to optimize
/// the memory schema because it is based on symbolic dimensions. That doesn't mean it will be
/// optimal for all possible values for symbolic dimensions.
pub fn build(
model: &TypedModel,
order: &[usize],
hint: &SymbolValues,
) -> TractResult<MetalMemSchema> {
let mut scoped_nodes_mem = eval_metal_scope_node_mem(model, order)?;
let mut nodes_mem_req = eval_metal_mem_req_for_nodes(model, order)?;

let hinted_mem_size = scoped_nodes_mem
let hinted_mem_size = nodes_mem_req
.iter()
.map(|node_mem| Ok((node_mem.node, node_mem.mem_size.eval_to_i64(hint)?)))
.collect::<TractResult<HashMap<usize, i64>>>()?;

scoped_nodes_mem.sort_by(|lhs, rhs| {
nodes_mem_req.sort_by(|lhs, rhs| {
let lhs_hint_mem_size = hinted_mem_size.get(&lhs.node);
let rhs_hint_mem_size = hinted_mem_size.get(&rhs.node);

lhs.scope
lhs.lifetime
.end
.cmp(&rhs.scope.end)
.cmp(&rhs.lifetime.end)
.reverse()
.then(lhs.scope.len().cmp(&rhs.scope.len()).reverse())
.then(lhs.lifetime.len().cmp(&rhs.lifetime.len()).reverse())
.then(lhs_hint_mem_size.cmp(&rhs_hint_mem_size).reverse())
});

let mut partitions: Vec<Partition> = vec![];
for node_mem in scoped_nodes_mem {
// Find partitions where node scope is disjoint from existing.
for node_mem in nodes_mem_req {
// Find partitions where node lifetime is disjoint from existing.
let mut available = partitions
.iter_mut()
.filter(|it| it.is_disjoint(&node_mem.scope))
.filter(|it| it.has_no_conflict_with_lifetime(&node_mem.lifetime))
.collect::<Vec<_>>();

available.sort_by_cached_key(|n| {
Expand All @@ -260,7 +282,7 @@ impl MetalMemSchema {
}
}

let by_steps: Vec<Vec<Option<ScopedNodeMemory>>> = (0..order.len())
let by_steps: Vec<Vec<Option<NodeMemReq>>> = (0..order.len())
.map(|step| {
let mem_step: Vec<_> =
partitions.iter().map(|p| p.find_node_alive_at_step(step).cloned()).collect();
Expand All @@ -269,6 +291,10 @@ impl MetalMemSchema {
})
.collect::<TractResult<Vec<_>>>()?;

Ok(MetalMemSchema { by_partition: partitions, by_steps })
Ok(MetalMemSchema {
model_num_nodes: model.nodes().len(),
by_partition: partitions,
by_steps,
})
}
}
9 changes: 2 additions & 7 deletions metal/src/ops/change_axes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,16 +207,11 @@ impl MetalEvalOp for MetalIntoShape {
let input = opaque.to_metal_tensor()?;
ensure!(input.len() == self.0.len);
let output =
crate::ops::make_tensor_for_node(session, node_id, input.datum_type(), input.shape())?;
crate::ops::make_tensor_for_node(session, node_id, input.datum_type(), &self.0.dims)?;

Memcpy.dispatch_eval(context, input, 0, &output)?;

unsafe {
Ok(tvec![output
.reshaped_with_geometry_unchecked(self.0.dims.clone(), self.0.strides.clone())
.into_opaque_tensor()
.into_tvalue()])
}
Ok(tvec![output.into_opaque_tensor().into_tvalue()])
}
}

Expand Down
3 changes: 1 addition & 2 deletions metal/src/ops/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ impl TypedOp for MetalSync {
.into_typed_fact()]),
MetalSyncKind::ToGpu => {
ensure!(input.datum_type != DatumType::Opaque, "Cannot sync Opaque Tensor to GPU");
Ok(tvec![TypedFact::dt_scalar(DatumType::Opaque)
.with_opaque_fact(MetalFact::from_cpu(input.clone())?)])
Ok(tvec![MetalFact::from_cpu(input.clone())?.into_opaque_fact()])
}
}
}
Expand Down
6 changes: 2 additions & 4 deletions metal/src/session_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use tract_core::internal::*;
#[derive(Debug, Clone)]
pub struct MetalSessionHandler {
mem_schema: MetalMemSchema,
num_nodes: usize,
}

impl MetalSessionHandler {
Expand All @@ -20,14 +19,13 @@ impl MetalSessionHandler {
plan.borrow().order_without_consts(),
memory_hint,
)?;
Ok(Self { num_nodes: plan.borrow().model().nodes().len(), mem_schema })
Ok(Self { mem_schema })
}
}

impl SessionStateHandler for MetalSessionHandler {
fn before_plan_eval(&self, session_state: &mut SessionState) -> TractResult<()> {
let resolved_mem_schema =
self.mem_schema.resolve(self.num_nodes, &session_state.resolved_symbols)?;
let resolved_mem_schema = self.mem_schema.resolve(&session_state.resolved_symbols)?;
let memory_pool = objc::rc::autoreleasepool(|| {
crate::METAL_CONTEXT
.with_borrow(|context| MetalMemoryPool::from_schema(context, resolved_mem_schema))
Expand Down
Loading

0 comments on commit a684770

Please sign in to comment.