Skip to content

Commit

Permalink
clip
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Jan 2, 2024
1 parent b26e04a commit 41fa13c
Show file tree
Hide file tree
Showing 19 changed files with 44 additions and 55 deletions.
2 changes: 1 addition & 1 deletion cli/src/dump.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ fn annotate_with_onnx_model(
if let Some(id) = model
.node_id_by_name(&gnode.name)
.ok()
.or_else(|| gnode.output.get(0).and_then(|n| model.node_id_by_name(n).ok()))
.or_else(|| gnode.output.first().and_then(|n| model.node_id_by_name(n).ok()))
{
let mut v = vec![];
for a in gnode.attribute.iter() {
Expand Down
2 changes: 1 addition & 1 deletion cli/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ impl Parameters {
if let Some(value) = tensors_values
.by_name(konst)
.and_then(|tv| tv.values.as_ref())
.and_then(|v| v.get(0))
.and_then(|v| v.first())
{
let value = value.clone().into_arc_tensor();
let id = raw_model.node_id_by_name(konst)?;
Expand Down
22 changes: 7 additions & 15 deletions core/src/model/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub trait SpecialOps<F, O> {
pub struct Graph<F, O>
where
F: Fact + Clone + 'static,
O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static
O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
{
/// all nodes in the model
pub nodes: Vec<Node<F, O>>,
Expand All @@ -44,7 +44,7 @@ where
impl<F, O> Default for Graph<F, O>
where
F: Fact + Clone + 'static,
O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static
O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
{
fn default() -> Graph<F, O> {
Graph {
Expand Down Expand Up @@ -76,7 +76,7 @@ where
impl<F, O> Graph<F, O>
where
F: Fact + Clone + 'static,
O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static
O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
{
pub fn add_node(
&mut self,
Expand Down Expand Up @@ -548,7 +548,6 @@ where
+ AsRef<dyn Op>
+ AsMut<dyn Op>
+ Clone

+ 'static,
{
pub fn add_const(
Expand All @@ -570,19 +569,12 @@ where
{
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
for i in 0..self.nodes.len() {
let input_1 = self.nodes[i]
.inputs
.get(0)
.map(|o| format!("{o:?}"))
.unwrap_or_default();
let input_2 = self.nodes[i]
.inputs
.get(1)
.map(|o| format!("{o:?}"))
.unwrap_or_default();
let input_1 =
self.nodes[i].inputs.first().map(|o| format!("{o:?}")).unwrap_or_default();
let input_2 = self.nodes[i].inputs.get(1).map(|o| format!("{o:?}")).unwrap_or_default();
let output_1 = self
.outlet_successors(OutletId::new(i, 0))
.get(0)
.first()
.map(|o| format!("{o:?}"))
.unwrap_or_default();
let output_2 = self
Expand Down
2 changes: 1 addition & 1 deletion core/src/ops/einsum/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ pub(super) fn ensure_mkn_axes<'a>(
// TODO: handle case where multiple consecutive k in the same order in both input.
bail!("Multiple k-axis candidate found");
} else {
non_trivial_k_axis.get(0).copied().or_else(|| candidate_k_axes.get(0)).copied()
non_trivial_k_axis.first().copied().or_else(|| candidate_k_axes.first()).copied()
};
let Some(k_axis) = k_axis else {
return Ok(AxesOrPatch::Patch(inject_k_axis(op, model, node)?));
Expand Down
2 changes: 1 addition & 1 deletion harness/core-proptest-pulse/src/conv_plus_conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl ConvPlusConvProblem {
pub fn model(ops: &[ConvOp]) -> TypedModel {
let mut model = TypedModel::default();
let s = model.symbol_table.sym("S");
let wire = model.add_source("a", f32::fact(dims!(1, 1, s)).into()).unwrap();
let wire = model.add_source("a", f32::fact(dims!(1, 1, s))).unwrap();
let mut wire = tvec!(wire);
for (ix, cv) in ops.iter().enumerate() {
wire = cv.chain(&format!("conv{ix}"), &mut model, &wire);
Expand Down
2 changes: 1 addition & 1 deletion harness/core-proptest-pulse/src/delay_plus_downsample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl DelayPlusDownsampleProblem {
pub fn run(&self) -> TestCaseResult {
let mut model = TypedModel::default();
let s = model.symbol_table.sym("S");
let a = model.add_source("a", f32::fact(dims!(1, s, 1)).into()).unwrap();
let a = model.add_source("a", f32::fact(dims!(1, s, 1))).unwrap();
let crop =
// model.wire_node("delay", expand(array::Crop::new(1, self.delay, 0)), &[a]).unwrap();
model.wire_node("delay", Slice::new(1, self.delay, s), &[a]).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion harness/core-proptest-pulse/src/delay_plus_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl DelayPlusPoolProblem {
pub fn run(&self) -> TestCaseResult {
let mut model = TypedModel::default();
let s = model.symbol_table.sym("S");
let a = model.add_source("a", f32::fact(dims!(1, s, 1)).into()).unwrap();
let a = model.add_source("a", f32::fact(dims!(1, s, 1))).unwrap();
let crop = model.wire_node("delay", Slice::new(1, self.delay, s), &[a]).unwrap();
let pool_spec = PoolSpec::new(
DataFormat::NHWC,
Expand Down
10 changes: 5 additions & 5 deletions harness/core-proptest-pulse/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ proptest! {
let full_len = input_len + begin + end;
let mut model = TypedModel::default();
let s = model.symbol_table.sym("S");
let a = model.add_source("a", f32::fact(&[s]).into()).unwrap();
let a = model.add_source("a", f32::fact(&[s])).unwrap();
let slice = model.wire_node("slice", Slice::new(0, begin as usize, (input_len + begin) as usize), &[a]).unwrap();
model.set_output_outlets(&slice).unwrap();

Expand All @@ -151,7 +151,7 @@ proptest! {
fn proptest_pad(pulse in 1i32..3, input_len in 0i32..10, begin in 0i32..3, end in 0i32..3) {
let mut model = TypedModel::default();
let s = model.symbol_table.sym("S");
let a = model.add_source("a", f32::fact(&[s]).into()).unwrap();
let a = model.add_source("a", f32::fact(&[s])).unwrap();
let pad = model.wire_node("pad", Pad::new(vec![(begin as _, end as _)],
PadMode::Constant(Arc::new(Tensor::from(-1f32)))), &[a]).unwrap();
model.set_output_outlets(&pad).unwrap();
Expand All @@ -171,7 +171,7 @@ fn test_simple_conv() {
let mut model = TypedModel::default();
let kernel = rctensor3(&[[[0.5f32, 1.0, -0.1]]]);
let s = model.symbol_table.sym("S");
let a = model.add_source("a", f32::fact(dims!(1, 1, s)).into()).unwrap();
let a = model.add_source("a", f32::fact(dims!(1, 1, s))).unwrap();
let kernel = model.add_const("kernel", kernel).unwrap();
let bias = model.add_const("bias", tensor0(0f32)).unwrap();

Expand Down Expand Up @@ -205,7 +205,7 @@ fn test_simple_conv() {
fn test_pad_before_1() {
let mut model = TypedModel::default();
let s = model.symbol_table.sym("S");
let a = model.add_source("a", f32::fact(&[s]).into()).unwrap();
let a = model.add_source("a", f32::fact(&[s])).unwrap();
model
.wire_node(
"pad",
Expand All @@ -223,7 +223,7 @@ fn test_pad_before_1() {
fn test_pad_before_2() {
let mut model = TypedModel::default();
let s = model.symbol_table.sym("S");
let a = model.add_source("a", f32::fact(&[s]).into()).unwrap();
let a = model.add_source("a", f32::fact(&[s])).unwrap();
model
.wire_node(
"pad",
Expand Down
2 changes: 1 addition & 1 deletion harness/core-proptest-pulse/src/pad_plus_conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl PadPlusConvProblem {
pub fn run(&self) -> TestCaseResult {
let mut model = TypedModel::default();
let s = model.symbol_table.sym("S");
let mut wire = model.add_source("a", f32::fact(dims!(1, 1, s)).into()).unwrap();
let mut wire = model.add_source("a", f32::fact(dims!(1, 1, s))).unwrap();
if self.pad_before > 0 || self.pad_after > 0 {
wire = model
.wire_node(
Expand Down
4 changes: 2 additions & 2 deletions harness/tfl-mobilenet-v2-q/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ mod tests {
fn declutter() -> TractResult<()> {
let tfd = tract_tflite::tflite()
.model_for_path(mobilenet_v2())?
.with_input_fact(0, input_dt().fact([1, 224, 224, 3]).into())?
.with_input_fact(0, input_dt().fact([1, 224, 224, 3]))?
.into_decluttered()?
.into_runnable()?;
run(tfd)
Expand All @@ -97,7 +97,7 @@ mod tests {
fn optimized() -> TractResult<()> {
let tfd = tract_tflite::tflite()
.model_for_path(mobilenet_v2())?
.with_input_fact(0, input_dt().fact([1, 224, 224, 3]).into())?
.with_input_fact(0, input_dt().fact([1, 224, 224, 3]))?
.into_optimized()?
.into_runnable()?;
run(tfd)
Expand Down
4 changes: 2 additions & 2 deletions linalg/src/frame/leaky_relu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ pub mod test {
f32: AsPrimitive<T>,
T: AsPrimitive<f32>,
{
let data = tract_data::prelude::tensor1(&values);
let data = tract_data::prelude::tensor1(values);
let data = data.cast_to::<T>().unwrap();
let data = data.as_slice::<T>().unwrap();
let alpha: T = tract_data::prelude::tensor0(alpha).cast_to_scalar::<T>().unwrap();
crate::frame::element_wise::test::test_element_wise_params::<K, T, _, T>(
&data,
data,
|x: T| {
if x > T::zero() {
x
Expand Down
8 changes: 4 additions & 4 deletions nnef/src/deser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl<'mb> ModelBuilder<'mb> {
.iter()
.find(|f| &f.decl.id.0 == "tract_core_properties")
.and_then(|f| f.body.as_ref())
.and_then(|body| body.get(0))
.and_then(|body| body.first())
{
let properties: TVec<(String, Arc<Tensor>)> =
properties.right.resolve(self, &[])?.to(self)?;
Expand Down Expand Up @@ -438,7 +438,7 @@ impl RValue {
let out_dt = builder.model.node(outlet_id.node).outputs[outlet_id.slot]
.fact
.datum_type;
if let Some(Some(dt)) = dt.get(0) {
if let Some(Some(dt)) = dt.first() {
if out_dt.unquantized() != dt.unquantized() {
return Err(format_err!(
"Mismatched types expected {:?}, got {:?}",
Expand Down Expand Up @@ -493,7 +493,7 @@ impl RValue {
RValue::Array(array) => Ok(Value::Array(
array
.iter()
.zip(std::iter::repeat(&dt.get(0).copied().flatten()))
.zip(std::iter::repeat(&dt.first().copied().flatten()))
.map(|(i, dt)| i.resolve(builder, &[*dt]))
.collect::<TractResult<_>>()?,
)),
Expand Down Expand Up @@ -538,7 +538,7 @@ impl RValue {
RValue::Literal(Literal::Array(array)) => Ok(Value::Array(
array
.iter()
.zip(std::iter::repeat(&dt.get(0).copied().flatten()))
.zip(std::iter::repeat(&dt.first().copied().flatten()))
.map(|(i, dt)| RValue::Literal(i.clone()).resolve(builder, &[*dt]))
.collect::<TractResult<_>>()?,
)),
Expand Down
2 changes: 1 addition & 1 deletion nnef/src/ops/core/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ fn cast_dump(ast: &mut IntoAst, node: &TypedNode, op: &Cast) -> TractResult<Opti

fn cast_load(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
let input = invocation.named_arg_as(builder, "input")?;
let invocation_dt = invocation.dt_from_quant_file.get(0).copied().flatten();
let invocation_dt = invocation.dt_from_quant_file.first().copied().flatten();
let to = if let Ok(s) = invocation.named_arg_as::<String>(builder, "to") {
let dt: DatumType = s.parse()?;
if let Some(invocation_dt) = invocation_dt {
Expand Down
2 changes: 1 addition & 1 deletion nnef/src/ops/core/qmatmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ fn qmatmul_load(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) ->
let bias: OutletId = invocation.named_arg_as(builder, "bias")?;
let qparams = qparams_as_outlets(builder, invocation)?;
let inputs: Vec<OutletId> = [a, b, bias].into_iter().chain(qparams).collect();
let c_dt = if let Some(c) = invocation.dt_from_quant_file.get(0).cloned().flatten() {
let c_dt = if let Some(c) = invocation.dt_from_quant_file.first().cloned().flatten() {
c
} else {
DatumType::from_str(&invocation.named_arg_as::<String>(builder, "output_type")?)?
Expand Down
2 changes: 1 addition & 1 deletion nnef/src/ops/core/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ fn external_load(
let shape: TVec<TDim> =
builder.allowing_new_symbols(|builder| invocation.named_arg_as(builder, "shape"))?;
let mut dt: DatumType = invocation.named_arg_as::<String>(builder, "datum_type")?.parse()?;
if let Some(Some(qdt)) = invocation.dt_from_quant_file.get(0) {
if let Some(Some(qdt)) = invocation.dt_from_quant_file.first() {
dt = *qdt;
}
Ok(Value::Wire(builder.model.add_source("", dt.fact(&*shape))?))
Expand Down
14 changes: 7 additions & 7 deletions nnef/src/ops/nnef/deser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::deser::{ModelBuilder, ResolvedInvocation};
// fragment external<? = scalar>( shape: integer[] ) -> ( output: tensor<?> );
pub fn external(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
let type_name = invocation.invocation.generic_type_name.unwrap_or(TypeName::Scalar);
let dt = if let Some(Some(dt)) = invocation.dt_from_quant_file.get(0) {
let dt = if let Some(Some(dt)) = invocation.dt_from_quant_file.first() {
*dt
} else if type_name == TypeName::Scalar {
f32::datum_type()
Expand All @@ -45,7 +45,7 @@ pub fn variable(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) ->
.or_else(|| tensors.get(&Identifier(label.0.trim_start_matches('/').to_owned())))
.ok_or_else(|| format_err!("No data for tensor {:?}", label))?,
);
if let Some(Some(dt)) = invocation.dt_from_quant_file.get(0) {
if let Some(Some(dt)) = invocation.dt_from_quant_file.first() {
if dt.size_of() != tensor.datum_type().size_of() {
bail!(
"Mismatched tensor type for tensor {}: expected {:?}, got {:?}",
Expand Down Expand Up @@ -120,7 +120,7 @@ pub fn transpose(
pub fn concat(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
let axis: usize = invocation.named_arg_as(builder, "axis")?;
let mut values: TVec<OutletId> = invocation.named_arg_as(builder, "values")?;
if let Some(Some(dt)) = invocation.dt_from_quant_file.get(0) {
if let Some(Some(dt)) = invocation.dt_from_quant_file.first() {
for value in &mut values {
if builder.model.node(value.node).outputs[value.slot].fact.datum_type != *dt {
*value = builder.wire_as_outlets(ops::cast::cast(*dt), &[*value])?[0];
Expand Down Expand Up @@ -349,7 +349,7 @@ pub fn conv_or_deconv(

let output_dt: Option<DatumType> = if input_fact.datum_type.is_float() {
None
} else if let Some(dt) = invocation.dt_from_quant_file.get(0).cloned().flatten() {
} else if let Some(dt) = invocation.dt_from_quant_file.first().cloned().flatten() {
Some(dt)
} else {
Some(DatumType::I32)
Expand Down Expand Up @@ -563,7 +563,7 @@ pub fn matmul(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> Tr
scale: a_dt.zp_scale().1 * b_dt.zp_scale().1,
zero_point: 0,
});
let c_dt = invocation.dt_from_quant_file.get(0).cloned().flatten().unwrap_or(accum_dt);
let c_dt = invocation.dt_from_quant_file.first().cloned().flatten().unwrap_or(accum_dt);

let a_qp = a_dt.qparams().unwrap_or_default().zp_scale();
let b_qp = b_dt.qparams().unwrap_or_default().zp_scale();
Expand Down Expand Up @@ -625,7 +625,7 @@ pub fn leaky_relu(
pub fn stack(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
let axis: usize = invocation.named_arg_as(builder, "axis")?;
let mut values: TVec<OutletId> = invocation.named_arg_as(builder, "values")?;
if let Some(Some(dt)) = invocation.dt_from_quant_file.get(0) {
if let Some(Some(dt)) = invocation.dt_from_quant_file.first() {
for value in &mut values {
if builder.model.node(value.node).outputs[value.slot].fact.datum_type != *dt {
*value = builder.wire_as_outlets(ops::cast::cast(*dt), &[*value])?[0];
Expand Down Expand Up @@ -683,7 +683,7 @@ pub fn softmax(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> T
let quant_output_dt = if input_fact.datum_type.is_float() {
None
} else {
invocation.dt_from_quant_file.get(0).cloned().flatten()
invocation.dt_from_quant_file.first().cloned().flatten()
};

builder.wire(ops::nn::Softmax { axes, quant_output_dt }, &[x])
Expand Down
4 changes: 2 additions & 2 deletions nnef/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ impl Registry {
tract_core::ops::element_wise::ElementWiseOp(ew.1.clone()),
&[input],
)?;
if let Some(Some(assumed_out_dt)) = dt.get(0) {
if let Some(Some(assumed_out_dt)) = dt.first() {
let out_dt = builder.model.outlet_fact(outlet[0])?.datum_type;
if out_dt != *assumed_out_dt {
return Ok(Some(
Expand Down Expand Up @@ -227,7 +227,7 @@ impl Registry {
let inputs = multicast(builder, &[a, b])?;
let mut wire = builder
.wire_as_outlets(tract_core::ops::binary::TypedBinOp(bin.1.clone()), &inputs)?[0];
if let Some(Some(out_dt)) = dt.get(0) {
if let Some(Some(out_dt)) = dt.first() {
if out_dt != &a_dt {
wire =
builder.wire_as_outlets(tract_core::ops::cast::cast(*out_dt), &[wire])?[0];
Expand Down
9 changes: 3 additions & 6 deletions onnx-opl/src/ml/category_mapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ impl ReverseLookup {
let keys = keys.as_slice_unchecked::<T>();
let mut hashmap = HashMap::<u64, SmallVec<[i32; 1]>>::default();
for (ix, k) in keys.iter().enumerate() {
let mut hasher = hashmap.hasher().build_hasher();
k.hash(&mut hasher);
let u = hasher.finish();
let u = hashmap.hasher().hash_one(k);
hashmap.entry(u).or_default().push(ix as i32);
}
hashmap
Expand All @@ -120,9 +118,8 @@ impl ReverseLookup {

unsafe fn search_t<T: Datum + Hash>(&self, needle: &T) -> Option<i32> {
let keys = self.keys.as_slice_unchecked::<T>();
let mut hasher = self.index.hasher().build_hasher();
needle.hash(&mut hasher);
let u = hasher.finish();

let u = self.index.hasher().hash_one(needle);
if let Some(candidates) = self.index.get(&u) {
for candidate in candidates {
if &keys[*candidate as usize] == needle {
Expand Down
4 changes: 2 additions & 2 deletions test-rt/test-tflite/src/tflite_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl State for TfliteState {
_ => bail!("unknown type in tract tflitec test Runtime"),
};
let tensor = unsafe {
Tensor::from_raw_dt(dt, &output_tensor.shape().dimensions(), output_tensor.data())?
Tensor::from_raw_dt(dt, output_tensor.shape().dimensions(), output_tensor.data())?
};
outputs.push(tensor.into_tvalue());
}
Expand All @@ -110,7 +110,7 @@ mod tests {
#[test]
fn test_trivial() -> TractResult<()> {
let mut model = TypedModel::default();
let wire = model.add_source("x", f32::fact(&[1]))?;
let wire = model.add_source("x", f32::fact([1]))?;
model.set_output_outlets(&[wire])?;
let out = runtime().prepare(model)?.run(tvec!(tensor1(&[0f32]).into_tvalue()))?.remove(0);
assert_eq!(out, tensor1(&[0f32]).into_tvalue());
Expand Down

0 comments on commit 41fa13c

Please sign in to comment.