Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deadlock pulsification register #1235

Merged
merged 3 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .travis/bundle-entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ net_bench speaker_id pulse8 $CACHEDIR/speaker-id-2019-03.onnx -i 1,S,40,f32 --ou
net_bench voicecom_fake_quant 2sec $CACHEDIR/snips-voice-commands-cnn-fake-quant.pb -i 200,10,f32
net_bench voicecom_float 2sec $CACHEDIR/snips-voice-commands-cnn-float.pb -i 200,10,f32

net_bench trunet pulse1_f32 $CACHEDIR/trunet_dummy.nnef.tgz --nnef-tract-core --pulse 1
net_bench trunet pulse1_f16 $CACHEDIR/trunet_dummy.nnef.tgz --nnef-tract-core --half-floats --pulse 1

. $PRIVATE

end=$(date +%s)
Expand Down
6 changes: 6 additions & 0 deletions cli/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,12 @@ impl Parameters {
dec.optimize(&mut m)?;
Ok(m)
});
#[cfg(not(feature = "pulse"))]
{
if matches.value_of("pulse").is_some() {
bail!("This build of tract has pulse disabled.")
}
}
#[cfg(feature = "pulse")]
{
if let Some(spec) = matches.value_of("pulse") {
Expand Down
8 changes: 6 additions & 2 deletions pulse/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Mutex;
use std::sync::RwLock;

use crate::fact::StreamInfo;
use crate::{internal::*, ops::sync_inputs};
Expand Down Expand Up @@ -109,7 +109,7 @@ impl SpecialOps<PulsedFact, Box<dyn PulsedOp>> for PulsedModel {
}
}

struct Pulsifier(Symbol, TDim, Arc<Mutex<HashMap<TypeId, crate::ops::OpPulsifier>>>);
struct Pulsifier(Symbol, TDim, Arc<RwLock<HashMap<TypeId, crate::ops::OpPulsifier>>>);

impl std::fmt::Debug for Pulsifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down Expand Up @@ -138,10 +138,12 @@ impl
)?
.unwrap());
}
log::debug!("Pulsifying node {node}");

if let Some(pulsified) =
OpPulsifier::pulsify(source, node, target, mapping, &self.0, &self.1)?
{
log::debug!("Pulsified node {node} with adhoc pulsifier");
return Ok(pulsified);
}

Expand All @@ -150,6 +152,7 @@ impl
if pulse_facts.iter().all(|pf| pf.stream.is_none()) {
let pulse_op = NonPulsingWrappingOp(node.op.clone());
let inputs: TVec<OutletId> = node.inputs.iter().map(|i| mapping[i]).collect();
log::debug!("Pulsified node {node} with NonPulsingWrappingOp");
return target.wire_node(&node.name, pulse_op, &inputs);
}

Expand All @@ -162,6 +165,7 @@ impl
if axis_info.outputs[0].len() == 1 {
let pulse_op = PulseWrappingOp(node.op.clone());
let inputs = sync_inputs(node, target, mapping)?;
log::debug!("Pulsified node {node} with PulsingWrappingOp");
return target.wire_node(&node.name, pulse_op, &inputs);
}

Expand Down
12 changes: 6 additions & 6 deletions pulse/src/ops/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::any::Any;
use std::sync::Mutex;
use std::sync::RwLock;

use crate::internal::*;
use lazy_static::lazy_static;
Expand Down Expand Up @@ -64,20 +64,20 @@ pub struct OpPulsifier {
}

impl OpPulsifier {
pub fn inventory() -> Arc<Mutex<HashMap<TypeId, OpPulsifier>>> {
pub fn inventory() -> Arc<RwLock<HashMap<TypeId, OpPulsifier>>> {
lazy_static! {
static ref INVENTORY: Arc<Mutex<HashMap<TypeId, OpPulsifier>>> = {
static ref INVENTORY: Arc<RwLock<HashMap<TypeId, OpPulsifier>>> = {
let mut it = HashMap::default();
register_all(&mut it);
Arc::new(Mutex::new(it))
Arc::new(RwLock::new(it))
};
};
(*INVENTORY).clone()
}

pub fn register<T: Any>(func: PulsifierFn) -> TractResult<()> {
let inv = Self::inventory();
let mut inv = inv.lock().map_err(|e| anyhow!("Fail to lock inventory {e}"))?;
let mut inv = inv.write().map_err(|e| anyhow!("Fail to lock inventory {e}"))?;
inv.insert(
std::any::TypeId::of::<T>(),
OpPulsifier {
Expand All @@ -98,7 +98,7 @@ impl OpPulsifier {
pulse: &TDim,
) -> TractResult<Option<TVec<OutletId>>> {
let inv = Self::inventory();
let inv = inv.lock().map_err(|e| anyhow!("Fail to lock inventory {e}"))?;
let inv = inv.read().map_err(|e| anyhow!("Fail to lock inventory {e}"))?;
if let Some(pulsifier) = inv.get(&(*node.op).type_id()) {
if let Some(pulsified) = (pulsifier.func)(source, node, target, mapping, symbol, pulse)?
{
Expand Down