Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

top k with dyn k #1205

Merged
merged 1 commit into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions core/src/ops/array/topk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::internal::*;
pub struct Topk {
pub axis: usize,
pub largest: bool,
pub k: usize,
pub fallback_k: TDim,
}

impl Op for Topk {
Expand All @@ -24,9 +24,10 @@ impl EvalOp for Topk {
}

fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let input = args_1!(inputs);
let (input, k) = args_2!(inputs);
let mut output_shape: TVec<usize> = input.shape().into();
output_shape[self.axis] = self.k;
let k = k.cast_to_scalar::<i64>()? as usize;
output_shape[self.axis] = k;
let mut output_values = Tensor::zero::<f32>(&output_shape)?;
let mut output_indices = Tensor::zero::<i64>(&output_shape)?;
let mut iterating_shape = output_shape.clone();
Expand All @@ -47,7 +48,7 @@ impl EvalOp for Topk {
.map(|x| if self.largest { -x } else { x })
.enumerate()
.sorted_by(|a, b| a.1.total_cmp(&b.1))
.take(self.k)
.take(k)
.map(|(pos, val)| if self.largest { (pos, -val) } else { (pos, val) })
.enumerate()
{
Expand All @@ -64,8 +65,13 @@ impl TypedOp for Topk {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let mut fact_values = inputs[0].without_value();
let mut fact_indices = inputs[0].without_value();
fact_values.shape.set(self.axis, self.k.to_dim());
fact_indices.shape.set(self.axis, self.k.to_dim());
let k: TDim = if let Some(k) = &inputs[1].konst {
k.cast_to_scalar::<i64>()?.into()
} else {
self.fallback_k.clone()
};
fact_values.shape.set(self.axis, k.clone());
fact_indices.shape.set(self.axis, k);
fact_indices.datum_type = i64::datum_type();
Ok(tvec!(fact_values, fact_indices))
}
Expand Down
10 changes: 6 additions & 4 deletions nnef/src/ops/core/topk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub fn register(registry: &mut Registry) {
"tract_core_topk",
&[
TypeName::Scalar.tensor().named("input"),
TypeName::Integer.named("k"),
TypeName::Integer.tensor().named("k"),
TypeName::Integer.named("axis"),
TypeName::Logical.named("largest"),
],
Expand All @@ -20,10 +20,11 @@ pub fn register(registry: &mut Registry) {
fn ser_topk(ast: &mut IntoAst, node: &TypedNode) -> TractResult<Option<Arc<RValue>>> {
let op = node.op().downcast_ref::<ops::array::Topk>().unwrap();
let input = ast.mapping[&node.inputs[0]].clone();
let k = ast.mapping[&node.inputs[1]].clone();
Ok(Some(invocation(
"tract_core_topk",
&[input],
&[("k", numeric(op.k)), ("largest", logical(op.largest)), ("axis", numeric(op.axis))],
&[input, k],
&[("largest", logical(op.largest)), ("axis", numeric(op.axis))],
)))
}

Expand All @@ -32,5 +33,6 @@ fn de_topk(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> Tract
let k = invocation.named_arg_as(builder, "k")?;
let axis = invocation.named_arg_as(builder, "axis")?;
let largest = invocation.named_arg_as(builder, "largest")?;
builder.wire(ops::array::Topk { largest, k, axis }, &[input])
let fallback_k = builder.model.symbol_table.new_with_prefix("k").into();
builder.wire(ops::array::Topk { largest, fallback_k, axis }, &[input, k])
}
17 changes: 8 additions & 9 deletions onnx/src/ops/array/topk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,14 @@ impl Expansion for Topk {
inputs: &[OutletId],
) -> TractResult<TVec<OutletId>> {
let input = model.outlet_fact(inputs[0])?;
let k = model.outlet_fact(inputs[1])?;
if let Some(k) = &k.konst {
let rank = input.rank();
let k = k.as_slice::<i64>()?[0] as usize;
let axis = if self.axis >= 0 { self.axis } else { self.axis + rank as i64 } as usize;
model.wire_node(prefix, tract_core::ops::array::Topk { axis, k, largest: self.largest }, &[inputs[0]])
} else {
bail!("tract only suports TopK with a known constant K");
}
let rank = input.rank();
let axis = if self.axis >= 0 { self.axis } else { self.axis + rank as i64 } as usize;
let fallback_k = model.symbol_table.new_with_prefix("k").into();
model.wire_node(
prefix,
tract_core::ops::array::Topk { axis, fallback_k, largest: self.largest },
&[inputs[0], inputs[1]],
)
}

fn nboutputs(&self) -> TractResult<usize> {
Expand Down
6 changes: 3 additions & 3 deletions test-rt/suite-onnx/node.txt
Original file line number Diff line number Diff line change
Expand Up @@ -787,9 +787,9 @@ test_thresholdedrelu_example_expanded_ver18
test_thresholdedrelu_expanded_ver18
test_tile input:x
test_tile_precomputed input:x
test_top_k input:x since:10
test_top_k_negative_axis input:x
test_top_k_smallest input:x
test_top_k since:10
test_top_k_negative_axis
test_top_k_smallest
test_transpose_all_permutations_0
test_transpose_all_permutations_1
test_transpose_all_permutations_2
Expand Down