From a684770ea456001b5eff0107da2bb52433352930 Mon Sep 17 00:00:00 2001 From: Hubert de La Jonquiere Date: Fri, 22 Nov 2024 16:45:50 +0100 Subject: [PATCH] Improve code for memory schema --- metal/src/fact.rs | 8 ++- metal/src/memory/pool.rs | 4 +- metal/src/memory/schema.rs | 114 ++++++++++++++++++++------------- metal/src/ops/change_axes.rs | 9 +-- metal/src/ops/sync.rs | 3 +- metal/src/session_handler.rs | 6 +- metal/src/tensor/arena_view.rs | 35 ---------- metal/src/tensor/mod.rs | 29 +-------- 8 files changed, 83 insertions(+), 125 deletions(-) diff --git a/metal/src/fact.rs b/metal/src/fact.rs index 8f2302e16b..faca43346b 100644 --- a/metal/src/fact.rs +++ b/metal/src/fact.rs @@ -4,11 +4,13 @@ use tract_core::internal::*; /// Origin of the metal tensor #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum MetalOrigin { - /// Metal tensor outputted by a GPU ops + /// Metal tensor outputted by a GPU operator /// Can be either: Owned or ArenaView + /// Note: Tensors marked as FromGPU are from asynchronous operations. FromGpu, - /// Metal tensor built from a CPU tensor - /// Can be only Owned Metal tensor + /// Metal tensor built from a CPU tensor (CPU op output or Const) + /// Can be only Owned Metal tensor. + /// Note: Tensors marked as FromCPU are from synchronous operations. FromCpu, } diff --git a/metal/src/memory/pool.rs b/metal/src/memory/pool.rs index bd07e85715..c1895e19dc 100644 --- a/metal/src/memory/pool.rs +++ b/metal/src/memory/pool.rs @@ -35,11 +35,11 @@ impl MetalMemoryPool { dt: DatumType, shape: &[usize], ) -> Result { - // unsafe { Tensor::uninitialized_dt(dt, shape)?.into_metal() } - // ensure!(!self.node_seen.borrow().contains(&node_id), "Tensor for node {:?} was already requested. Maybe the memory pool was not reset properly.", node_id); + ensure!(!self.node_seen.borrow().contains(&node_id), "Tensor for node {:?} was already requested. Maybe the memory pool was not reset properly.", node_id); let alignment = dt.alignment(); (self.alignment % alignment == 0) .then(|| self.resolved_schema.offsets_by_node[node_id]) + .flatten() .map(|offset| { // self.node_seen.borrow_mut().insert(node_id); Ok(MetalArenaView { diff --git a/metal/src/memory/schema.rs b/metal/src/memory/schema.rs index baa4b4d6fb..3c5e68b6e6 100644 --- a/metal/src/memory/schema.rs +++ b/metal/src/memory/schema.rs @@ -3,21 +3,22 @@ use std::fmt; use std::fmt::Debug; use tract_core::internal::*; +/// Requirement for node outputs from a memory perspective. #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct ScopedNodeMemory { +pub struct NodeMemReq { pub node: usize, - pub scope: Scope, + pub lifetime: Lifetime, pub mem_size: TDim, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct Scope { +pub struct Lifetime { pub start: usize, pub end: usize, } -impl Scope { - pub fn is_disjoint(&self, other: &Scope) -> bool { +impl Lifetime { + pub fn is_disjoint(&self, other: &Lifetime) -> bool { self.start >= other.end || other.start >= self.end } @@ -34,10 +35,10 @@ impl Scope { } } -pub fn eval_metal_scope_node_mem( +pub fn eval_metal_mem_req_for_nodes( model: &TypedModel, order: &[usize], -) -> TractResult> { +) -> TractResult> { let outputs = model.output_outlets()?.to_vec(); let flush_lists = order::build_flush_list(model, order, &outputs, |node| { let Ok(facts) = model.node_output_facts(node.id) else { return false }; @@ -47,14 +48,15 @@ pub fn eval_metal_scope_node_mem( let mut scoped_nodes = tvec![]; for (step, n) in order.iter().enumerate() { - let scope_start = step; - let scope_end = flush_lists + let lifetime_start = step; + let lifetime_end = flush_lists .iter() .enumerate() .find(|(_step, flush_list)| flush_list.contains(n)) .map(|it| usize::min(it.0 + 1, order.len())); - let Some(scope_end) = scope_end else { + // Ignore nodes that won't be flushed from gpu. + let Some(lifetime_end) = lifetime_end else { continue; }; @@ -69,9 +71,9 @@ pub fn eval_metal_scope_node_mem( continue; } - scoped_nodes.push(ScopedNodeMemory { + scoped_nodes.push(NodeMemReq { node: *n, - scope: Scope { start: scope_start, end: scope_end }, + lifetime: Lifetime { start: lifetime_start, end: lifetime_end }, mem_size: out_metal_tmp_facts.iter().map(|it| it.mem_size()).sum::(), }) } @@ -79,9 +81,11 @@ pub fn eval_metal_scope_node_mem( Ok(scoped_nodes) } +/// A partition is a list of node that have disjoint memory requirement from a lifetime +/// perspective. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Partition { - pub nodes: Vec, + pub nodes: Vec, } impl Partition { @@ -100,55 +104,67 @@ impl Partition { TDim::Max(self.nodes.iter().map(|s| s.mem_size.clone()).collect()) } - pub fn is_disjoint(&self, scope: &Scope) -> bool { - self.nodes.iter().all(|n| n.scope.is_disjoint(scope)) + pub fn has_no_conflict_with_lifetime(&self, lifetime: &Lifetime) -> bool { + self.nodes.iter().all(|n| n.lifetime.is_disjoint(lifetime)) } - pub fn find_node_alive_at_step(&self, step: usize) -> Option<&ScopedNodeMemory> { - self.nodes.iter().find(|it| it.scope.is_alive_at_step(step)) + pub fn find_node_alive_at_step(&self, step: usize) -> Option<&NodeMemReq> { + self.nodes.iter().find(|it| it.lifetime.is_alive_at_step(step)) } } +/// This struct represents a resolved memory schema for a model that contains +/// Metal operators. This schema is concrete. #[derive(Debug, Clone, PartialEq, Eq)] pub struct MetalResolvedMemSchema { - pub offsets_by_node: Vec, + pub offsets_by_node: Vec>, pub memory_size: usize, } +/// This struct represent a memory schema for node output memory that are handled +/// by Metal GPU. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct MetalMemSchema { + /// Total numbef in the model. + pub model_num_nodes: usize, pub by_partition: Vec, - pub by_steps: Vec>>, + // vec![vec![Option; num_partitions]; num_steps]. + pub by_steps: Vec>>, } impl MetalMemSchema { - pub fn eval_size_by_partition(&self, symbols: &SymbolValues) -> TractResult> { - self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).collect() - } - + /// Returns memory size of each inner partitions. pub fn size_by_partition(&self) -> Vec { self.by_partition.iter().map(|it| it.size()).collect() } + /// Evaluate memory size by partition for given symbol values. + pub fn eval_size_by_partition(&self, symbols: &SymbolValues) -> TractResult> { + self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).collect() + } + + /// Returns total memory size required for the schema. pub fn memory_size(&self) -> TDim { self.by_partition.iter().map(|it| it.size()).sum() } + /// Evaluate memory size required for the schema for given symbol values. pub fn eval_memory_size(&self, symbols: &SymbolValues) -> TractResult { self.by_partition.iter().map(|it| it.eval_size_to_i64(symbols)).sum() } + /// Compute offsets for each node for given symbols. Node ids + /// are indexes in the returned vector. pub fn compute_offset_by_node( &self, - num_nodes: usize, symbols: &SymbolValues, - ) -> TractResult> { + ) -> TractResult>> { let mut cursor = 0; - let mut offset_by_node = vec![0; num_nodes]; + let mut offset_by_node = vec![None; self.model_num_nodes]; for partition in self.by_partition.iter() { for node_mem in partition.nodes.iter() { - offset_by_node[node_mem.node] = cursor; + offset_by_node[node_mem.node] = Some(cursor); } cursor += partition.eval_size_to_i64(symbols)? as usize; } @@ -156,6 +172,9 @@ impl MetalMemSchema { Ok(offset_by_node) } + /// Evaluate peak memory size for given symbols. The return value is lower or equal to the memory + /// size of the schema. The difference between peak memory size and memory size represents the + /// memory fragmentation introduced by the schema. pub fn eval_peak_memory_size(&self, symbols: &SymbolValues) -> TractResult { Ok(self .by_steps @@ -174,6 +193,9 @@ impl MetalMemSchema { .unwrap_or(0)) } + /// Evaluate the usage for given symbols as the ratio between + /// schema memory size and peak memory size. A value of 1.0 means + /// that the schema doesn't introduce memory fragmentation. pub fn eval_usage(&self, symbols: &SymbolValues) -> TractResult { let memory_size = self.eval_memory_size(symbols)? as f32; let peak_memory_size = self.eval_peak_memory_size(symbols)? as f32; @@ -205,47 +227,47 @@ impl fmt::Display for MetalMemSchema { } impl MetalMemSchema { - pub fn resolve( - &self, - num_nodes: usize, - symbols: &SymbolValues, - ) -> TractResult { + /// Resolve Memory schema with given symbols. + pub fn resolve(&self, symbols: &SymbolValues) -> TractResult { Ok(MetalResolvedMemSchema { - offsets_by_node: self.compute_offset_by_node(num_nodes, symbols)?, + offsets_by_node: self.compute_offset_by_node(symbols)?, memory_size: self.eval_memory_size(symbols)?.try_into()?, }) } + /// Build a memory schema for given model and execution order. The hint is used to optimize + /// the memory schema because it is based on symbolic dimensions. That doesn't mean it will be + /// optimal for all possible values for symbolic dimensions. pub fn build( model: &TypedModel, order: &[usize], hint: &SymbolValues, ) -> TractResult { - let mut scoped_nodes_mem = eval_metal_scope_node_mem(model, order)?; + let mut nodes_mem_req = eval_metal_mem_req_for_nodes(model, order)?; - let hinted_mem_size = scoped_nodes_mem + let hinted_mem_size = nodes_mem_req .iter() .map(|node_mem| Ok((node_mem.node, node_mem.mem_size.eval_to_i64(hint)?))) .collect::>>()?; - scoped_nodes_mem.sort_by(|lhs, rhs| { + nodes_mem_req.sort_by(|lhs, rhs| { let lhs_hint_mem_size = hinted_mem_size.get(&lhs.node); let rhs_hint_mem_size = hinted_mem_size.get(&rhs.node); - lhs.scope + lhs.lifetime .end - .cmp(&rhs.scope.end) + .cmp(&rhs.lifetime.end) .reverse() - .then(lhs.scope.len().cmp(&rhs.scope.len()).reverse()) + .then(lhs.lifetime.len().cmp(&rhs.lifetime.len()).reverse()) .then(lhs_hint_mem_size.cmp(&rhs_hint_mem_size).reverse()) }); let mut partitions: Vec = vec![]; - for node_mem in scoped_nodes_mem { - // Find partitions where node scope is disjoint from existing. + for node_mem in nodes_mem_req { + // Find partitions where node lifetime is disjoint from existing. let mut available = partitions .iter_mut() - .filter(|it| it.is_disjoint(&node_mem.scope)) + .filter(|it| it.has_no_conflict_with_lifetime(&node_mem.lifetime)) .collect::>(); available.sort_by_cached_key(|n| { @@ -260,7 +282,7 @@ impl MetalMemSchema { } } - let by_steps: Vec>> = (0..order.len()) + let by_steps: Vec>> = (0..order.len()) .map(|step| { let mem_step: Vec<_> = partitions.iter().map(|p| p.find_node_alive_at_step(step).cloned()).collect(); @@ -269,6 +291,10 @@ impl MetalMemSchema { }) .collect::>>()?; - Ok(MetalMemSchema { by_partition: partitions, by_steps }) + Ok(MetalMemSchema { + model_num_nodes: model.nodes().len(), + by_partition: partitions, + by_steps, + }) } } diff --git a/metal/src/ops/change_axes.rs b/metal/src/ops/change_axes.rs index 47670d0648..50cce23227 100644 --- a/metal/src/ops/change_axes.rs +++ b/metal/src/ops/change_axes.rs @@ -207,16 +207,11 @@ impl MetalEvalOp for MetalIntoShape { let input = opaque.to_metal_tensor()?; ensure!(input.len() == self.0.len); let output = - crate::ops::make_tensor_for_node(session, node_id, input.datum_type(), input.shape())?; + crate::ops::make_tensor_for_node(session, node_id, input.datum_type(), &self.0.dims)?; Memcpy.dispatch_eval(context, input, 0, &output)?; - unsafe { - Ok(tvec![output - .reshaped_with_geometry_unchecked(self.0.dims.clone(), self.0.strides.clone()) - .into_opaque_tensor() - .into_tvalue()]) - } + Ok(tvec![output.into_opaque_tensor().into_tvalue()]) } } diff --git a/metal/src/ops/sync.rs b/metal/src/ops/sync.rs index 7014dc3c11..a6d8660948 100644 --- a/metal/src/ops/sync.rs +++ b/metal/src/ops/sync.rs @@ -74,8 +74,7 @@ impl TypedOp for MetalSync { .into_typed_fact()]), MetalSyncKind::ToGpu => { ensure!(input.datum_type != DatumType::Opaque, "Cannot sync Opaque Tensor to GPU"); - Ok(tvec![TypedFact::dt_scalar(DatumType::Opaque) - .with_opaque_fact(MetalFact::from_cpu(input.clone())?)]) + Ok(tvec![MetalFact::from_cpu(input.clone())?.into_opaque_fact()]) } } } diff --git a/metal/src/session_handler.rs b/metal/src/session_handler.rs index c48bf19055..c226ccae8b 100644 --- a/metal/src/session_handler.rs +++ b/metal/src/session_handler.rs @@ -6,7 +6,6 @@ use tract_core::internal::*; #[derive(Debug, Clone)] pub struct MetalSessionHandler { mem_schema: MetalMemSchema, - num_nodes: usize, } impl MetalSessionHandler { @@ -20,14 +19,13 @@ impl MetalSessionHandler { plan.borrow().order_without_consts(), memory_hint, )?; - Ok(Self { num_nodes: plan.borrow().model().nodes().len(), mem_schema }) + Ok(Self { mem_schema }) } } impl SessionStateHandler for MetalSessionHandler { fn before_plan_eval(&self, session_state: &mut SessionState) -> TractResult<()> { - let resolved_mem_schema = - self.mem_schema.resolve(self.num_nodes, &session_state.resolved_symbols)?; + let resolved_mem_schema = self.mem_schema.resolve(&session_state.resolved_symbols)?; let memory_pool = objc::rc::autoreleasepool(|| { crate::METAL_CONTEXT .with_borrow(|context| MetalMemoryPool::from_schema(context, resolved_mem_schema)) diff --git a/metal/src/tensor/arena_view.rs b/metal/src/tensor/arena_view.rs index c0608c2c13..e205e2fb20 100644 --- a/metal/src/tensor/arena_view.rs +++ b/metal/src/tensor/arena_view.rs @@ -1,5 +1,4 @@ use crate::MetalContext; -use anyhow::Result; use metal::Buffer; use metal::MTLResourceOptions; use num_traits::AsPrimitive; @@ -97,40 +96,6 @@ impl MetalArenaView { self.shape().iter().product() } - /// Reshaped tensor with given shape. - pub fn reshaped(&self, shape: impl Into>) -> Result { - let shape = shape.into(); - if self.len() != shape.iter().product::() { - bail!("Invalid reshape {:?} to {:?}", self.shape(), shape); - } - if shape.as_slice() != self.shape() { - Ok(Self { - dt: self.dt, - arena: Arc::clone(&self.arena), - strides: Tensor::natural_strides(&shape), - shape, - offset_bytes: self.offset_bytes, - }) - } else { - Ok(self.clone()) - } - } - - /// Reshaped tensor with given shape and strides, no consistency check. - pub unsafe fn reshaped_with_geometry_unchecked( - &self, - shape: impl Into>, - strides: impl Into>, - ) -> Self { - Self { - dt: self.dt, - arena: Arc::clone(&self.arena), - strides: strides.into(), - shape: shape.into(), - offset_bytes: self.offset_bytes, - } - } - pub fn as_bytes(&self) -> &[u8] { &self.arena.tensor().as_bytes() [self.offset_bytes..self.offset_bytes + self.len() * self.dt.size_of()] diff --git a/metal/src/tensor/mod.rs b/metal/src/tensor/mod.rs index d1e6e78d87..b94d7a5c84 100644 --- a/metal/src/tensor/mod.rs +++ b/metal/src/tensor/mod.rs @@ -150,7 +150,7 @@ impl MetalTensor { pub fn reshaped(&self, shape: impl Into>) -> Result { match self { Self::Owned(t) => Ok(Self::Owned(t.reshaped(shape)?)), - Self::ArenaView(t) => Ok(Self::ArenaView(t.reshaped(shape)?)), + Self::ArenaView(_t) => bail!("Reshape a Metal Arena View is not supported"), } } @@ -165,33 +165,6 @@ impl MetalTensor { self } - /// Reshaped tensor with given shape and strides, no consistency check. - pub unsafe fn reshaped_with_geometry_unchecked( - &self, - shape: impl Into>, - strides: impl Into>, - ) -> Self { - match self { - Self::Owned(t) => Self::Owned(t.reshaped_with_geometry_unchecked(shape, strides)), - Self::ArenaView(t) => { - Self::ArenaView(t.reshaped_with_geometry_unchecked(shape, strides)) - } - } - } - - // pub fn assert_sane_floats(&self) -> Result<()> { - // if let Ok(floats) = self.inner.view().as_slice::() { - // if let Some(pos) = floats.iter().position(|f| !f.is_finite()) { - // bail!("Found {} in at position {:?}", floats[pos], pos); - // } - // } else if let Ok(floats) = self.inner.view().as_slice::() { - // if let Some(pos) = floats.iter().position(|f| !f.is_finite()) { - // bail!("Found {} in at position {:?}", floats[pos], pos); - // } - // } - // Ok(()) - // } - /// Convert Metal tensor to Opaque Tensor. pub fn into_opaque_tensor(self) -> Tensor { tensor0::(self.into())