Skip to content

Commit

Permalink
clips
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Nov 22, 2024
1 parent f82ff88 commit 92ed0a2
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 10 deletions.
8 changes: 4 additions & 4 deletions core/src/ops/array/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,16 @@ impl Gather {
ensure!(self.axis == 0);
ensure!(data.fact.shape.rank() == 2);
let data_shape = data.fact.shape.as_concrete().unwrap();
let output_shape = &*self.compute_output_shape(&data_shape, indices.shape())?;
let output_shape = &*self.compute_output_shape(data_shape, indices.shape())?;
let mut output = unsafe { Tensor::uninitialized::<f16>(output_shape)? };
let indices_slice = indices.as_slice::<i64>()?;
let vector_len = data_shape[1];
let output_slice = output.as_slice_mut::<f16>()?;
for (pos, ix) in indices_slice.iter().enumerate() {
let slice = &mut output_slice[pos * vector_len..][..vector_len];
for i in 0..vector_len {
for (i, slot) in slice.iter_mut().enumerate() {
let offset = data_shape[1] * *ix as usize + i;
slice[i] = data.fact.format.extract_at_offset_f16(&data.value, offset)
*slot = data.fact.format.extract_at_offset_f16(&data.value, offset)
}
}
Ok(output)
Expand All @@ -76,7 +76,7 @@ impl TypedOp for Gather {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
ensure!(inputs[1].datum_type == i64::datum_type());
if inputs[0].datum_type.is_opaque() {
let data_shape = block_quant_aware_input_shape(&inputs[0])?;
let data_shape = block_quant_aware_input_shape(inputs[0])?;
Ok(tvec!(f16::fact(&*self.compute_output_shape(&data_shape, &inputs[1].shape)?)))
} else {
Ok(tvec!(inputs[0]
Expand Down
1 change: 1 addition & 0 deletions linalg/src/frame/block_quant/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ dyn_clone::clone_trait_object!(BlockQuant);
dyn_hash::hash_trait_object!(BlockQuant);
impl_downcast!(BlockQuant);

#[allow(clippy::derived_hash_with_manual_eq)]
#[derive(Clone, Hash)]
pub struct PackedBlockQuantFormat {
pub bq: Box<dyn BlockQuant>,
Expand Down
7 changes: 1 addition & 6 deletions linalg/src/frame/block_quant/q4_0.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,7 @@ impl<const QK: usize> BlockQuant for BaseQ4_0<QK> {
}
}
Ok(EagerPackedInput {
format: Box::new(PackedBlockQuantFormat {
bq: Box::new(self.clone()),
r,
zip,
scales_at_end,
}),
format: Box::new(PackedBlockQuantFormat { bq: Box::new(*self), r, zip, scales_at_end }),
packed: blob.into(),
mn: m,
k,
Expand Down

0 comments on commit 92ed0a2

Please sign in to comment.