Skip to content

Commit

Permalink
extension point for pulsifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 5, 2023
1 parent 5d1f7eb commit 30530cc
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 19 deletions.
14 changes: 7 additions & 7 deletions pulse/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Mutex;

use crate::fact::StreamInfo;
use crate::{internal::*, ops::sync_inputs};
use tract_core::model::translator::Translate;
Expand Down Expand Up @@ -107,7 +109,7 @@ impl SpecialOps<PulsedFact, Box<dyn PulsedOp>> for PulsedModel {
}
}

struct Pulsifier(Symbol, TDim, HashMap<TypeId, crate::ops::OpPulsifier>);
struct Pulsifier(Symbol, TDim, Arc<Mutex<HashMap<TypeId, crate::ops::OpPulsifier>>>);

impl std::fmt::Debug for Pulsifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down Expand Up @@ -137,12 +139,10 @@ impl
.unwrap());
}

if let Some(pulsifier) = self.2.get(&node.op.type_id()) {
if let Some(pulsified) =
(pulsifier.func)(source, node, target, mapping, &self.0, &self.1)?
{
return Ok(pulsified);
}
if let Some(pulsified) =
OpPulsifier::pulsify(source, node, target, mapping, &self.0, &self.1)?
{
return Ok(pulsified);
}

let pulse_facts: TVec<PulsedFact> =
Expand Down
59 changes: 47 additions & 12 deletions pulse/src/ops/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use std::any::Any;
use std::sync::Mutex;

use crate::internal::*;
use lazy_static::lazy_static;
use tract_pulse_opl::ops::Delay;

pub mod array;
Expand Down Expand Up @@ -60,22 +64,53 @@ pub struct OpPulsifier {
}

impl OpPulsifier {
pub fn inventory() -> HashMap<TypeId, OpPulsifier> {
let mut inventory = HashMap::default();
register_all(&mut inventory);
inventory
pub fn inventory() -> Arc<Mutex<HashMap<TypeId, OpPulsifier>>> {
lazy_static! {
static ref INVENTORY: Arc<Mutex<HashMap<TypeId, OpPulsifier>>> = {
let mut it = HashMap::default();
register_all(&mut it);
Arc::new(Mutex::new(it))
};
};
(*INVENTORY).clone()
}

pub fn register<T: Any>(func: PulsifierFn) -> TractResult<()> {
let inv = Self::inventory();
let mut inv = inv.lock().map_err(|e| anyhow!("Fail to lock inventory {e}"))?;
inv.insert(
std::any::TypeId::of::<T>(),
OpPulsifier {
type_id: std::any::TypeId::of::<T>(),
name: std::any::type_name::<T>(),
func,
},
);
Ok(())
}

pub fn pulsify(
source: &TypedModel,
node: &TypedNode,
target: &mut PulsedModel,
mapping: &HashMap<OutletId, OutletId>,
symbol: &Symbol,
pulse: &TDim,
) -> TractResult<Option<TVec<OutletId>>> {
let inv = Self::inventory();
let inv = inv.lock().map_err(|e| anyhow!("Fail to lock inventory {e}"))?;
if let Some(pulsifier) = inv.get(&(*node.op).type_id()) {
if let Some(pulsified) = (pulsifier.func)(source, node, target, mapping, symbol, pulse)?
{
return Ok(Some(pulsified));
}
}
Ok(None)
}
}

pub trait PulsedOp:
Op
+ fmt::Debug
+ tract_core::dyn_clone::DynClone
+ Send
+ Sync
+ 'static
+ Downcast
+ EvalOp
Op + fmt::Debug + tract_core::dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp
{
/// Reinterpret the PulsedOp as an Op.
fn as_op(&self) -> &dyn Op;
Expand Down

0 comments on commit 30530cc

Please sign in to comment.