Skip to content

Commit

Permalink
Merge pull request #406 from robertknight/prefer-tensor-from
Browse files Browse the repository at this point in the history
Prefer `Tensor::from` for creating vectors from array literals
  • Loading branch information
robertknight authored Nov 13, 2024
2 parents 19d2147 + bbbe0b7 commit 2faf5e3
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 29 deletions.
14 changes: 7 additions & 7 deletions rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2921,7 +2921,7 @@ mod tests {

#[test]
fn test_into_data() {
let tensor = NdTensor::from_data([2], vec![2., 3.]);
let tensor = NdTensor::from([2., 3.]);
assert_eq!(tensor.into_data(), vec![2., 3.]);

let mut tensor = NdTensor::from_data([2, 2], vec![1., 2., 3., 4.]);
Expand Down Expand Up @@ -3043,18 +3043,18 @@ mod tests {

#[test]
fn test_item() {
let tensor = NdTensor::from_data([], vec![5.]);
let tensor = NdTensor::from(5.);
assert_eq!(tensor.item(), Some(&5.));
let tensor = NdTensor::from_data([1], vec![6.]);
let tensor = NdTensor::from([6.]);
assert_eq!(tensor.item(), Some(&6.));
let tensor = NdTensor::from_data([2], vec![2., 3.]);
let tensor = NdTensor::from([2., 3.]);
assert_eq!(tensor.item(), None);

let tensor = Tensor::from_data(&[], vec![5.]);
let tensor = Tensor::from(5.);
assert_eq!(tensor.item(), Some(&5.));
let tensor = Tensor::from_data(&[1], vec![6.]);
let tensor = Tensor::from([6.]);
assert_eq!(tensor.item(), Some(&6.));
let tensor = Tensor::from_data(&[2], vec![2., 3.]);
let tensor = Tensor::from([2., 3.]);
assert_eq!(tensor.item(), None);
}

Expand Down
16 changes: 8 additions & 8 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1715,7 +1715,7 @@ mod tests {
fn test_graph_node_debug_names() {
let mut g = Graph::new();

let weights = Tensor::from_data(&[1], vec![0.3230]);
let weights = Tensor::from([0.3230]);
let weights_id = g.add_constant(Some("weights"), weights.clone());
let input_id = g.add_value(Some("input"), None);
let relu_out_id = g.add_value(Some("relu_out"), None);
Expand Down Expand Up @@ -1820,18 +1820,18 @@ mod tests {
// op_d is the same as op_c, but input order is reversed
let (_, op_d_out) = g.add_simple_op("op_d", Concat { axis: 0 }, &[op_b_out, op_a_out]);

let input = Tensor::from_data(&[1], vec![1.]);
let input = Tensor::from([1.]);

let results = g
.run(vec![(input_id, input.view().into())], &[op_c_out], None)
.unwrap();
let expected = Tensor::from_data(&[2], vec![2., 3.]);
let expected = Tensor::from([2., 3.]);
expect_equal(&results[0].as_tensor_view().unwrap(), &expected.view())?;

let results = g
.run(vec![(input_id, input.into())], &[op_d_out], None)
.unwrap();
let expected = Tensor::from_data(&[2], vec![3., 2.]);
let expected = Tensor::from([3., 2.]);
expect_equal(&results[0].as_tensor_view().unwrap(), &expected.view())?;

Ok(())
Expand Down Expand Up @@ -1865,7 +1865,7 @@ mod tests {
fn test_graph_many_steps() -> Result<(), Box<dyn Error>> {
let mut g = Graph::new();

let input = Tensor::from_data(&[5], vec![1., 2., 3., 4., 5.]);
let input = Tensor::from([1., 2., 3., 4., 5.]);
let input_id = g.add_value(Some("input"), None);

let mut prev_output = input_id;
Expand All @@ -1884,7 +1884,7 @@ mod tests {
.run(vec![(input_id, input.into())], &[prev_output], None)
.unwrap();

let expected = Tensor::from_data(&[5], vec![101., 102., 103., 104., 105.]);
let expected = Tensor::from([101., 102., 103., 104., 105.]);
expect_equal(&results[0].as_tensor_view().unwrap(), &expected.view())?;

Ok(())
Expand All @@ -1894,7 +1894,7 @@ mod tests {
fn test_noop_graph() -> Result<(), Box<dyn Error>> {
let mut g = Graph::new();

let input = Tensor::from_data(&[5], vec![1., 2., 3., 4., 5.]);
let input = Tensor::from([1., 2., 3., 4., 5.]);
let input_id = g.add_value(Some("input"), None);

let results = g
Expand All @@ -1910,7 +1910,7 @@ mod tests {
fn test_constant_graph() -> Result<(), Box<dyn Error>> {
let mut g = Graph::new();

let value = Tensor::from_data(&[5], vec![1., 2., 3., 4., 5.]);
let value = Tensor::from([1., 2., 3., 4., 5.]);
let const_id = g.add_constant(Some("weight"), value.clone());

let results = g.run(vec![], &[const_id], None).unwrap();
Expand Down
4 changes: 2 additions & 2 deletions src/ops/binary_elementwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,7 @@ mod tests {
// Simple case where comparing ordering of tensor shapes tells us
// target shape.
let a = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let b = Tensor::from_data(&[1], vec![10.]);
let b = Tensor::from([10.]);
let expected = Tensor::from_data(&[2, 2], vec![11., 12., 13., 14.]);
let result = add(&pool, a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
Expand All @@ -1096,7 +1096,7 @@ mod tests {

// Case where the length of tensor shapes needs to be compared before
// the ordering, since ([5] > [1,5]).
let a = Tensor::from_data(&[5], vec![1., 2., 3., 4., 5.]);
let a = Tensor::from([1., 2., 3., 4., 5.]);
let b = Tensor::from_data(&[1, 5], vec![1., 2., 3., 4., 5.]);
let expected = Tensor::from_data(&[1, 5], vec![2., 4., 6., 8., 10.]);

Expand Down
6 changes: 3 additions & 3 deletions src/ops/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,7 @@ mod tests {
expect_eq_1e4(&result, &expected_with_no_padding)?;

let expected_with_bias = Tensor::from_data(&[1, 1, 1, 1], vec![3.6358]);
let bias = Tensor::from_data(&[1], vec![1.0]);
let bias = Tensor::from([1.0]);
let result = check_conv(
input.view(),
kernel.view(),
Expand Down Expand Up @@ -981,7 +981,7 @@ mod tests {
0.4273, 0.4180, 0.4338,
],
);
let bias = Tensor::from_data(&[3], vec![0.1, 0.2, 0.3]);
let bias = Tensor::from([0.1, 0.2, 0.3]);
let expected = Tensor::from_data(
&[1, 3, 1, 1],
vec![
Expand Down Expand Up @@ -1341,7 +1341,7 @@ mod tests {
for eb in expected_with_bias.iter_mut() {
*eb += 1.234;
}
let bias = Tensor::from_data(&[1], vec![1.234]);
let bias = Tensor::from([1.234]);
let result = conv_transpose(
&pool,
input.view(),
Expand Down
8 changes: 4 additions & 4 deletions src/ops/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ mod tests {
expect_equal(&result, &expected)?;

// Case where copied input dim is also zero.
let input = Tensor::<f32>::from_data(&[0], vec![]);
let input = Tensor::from([0.; 0]);
let shape = NdTensor::from([0]);
let expected = input.to_shape([0].as_slice());
let result = reshape(
Expand All @@ -791,7 +791,7 @@ mod tests {
expect_equal(&result, &expected)?;

// Case where there is no corresponding input dim.
let input = Tensor::from_data(&[1], vec![5.]);
let input = Tensor::from([5.]);
let shape = NdTensor::from([1, 0]);
let result = reshape(
&pool,
Expand Down Expand Up @@ -859,7 +859,7 @@ mod tests {
assert_eq!(result.err(), expected_err);

// Case when allow_zero is true
let input = Tensor::from_data(&[1], vec![1]);
let input = Tensor::from([1]);
let shape = NdTensor::from([0, -1]);
let result = reshape(
&pool,
Expand Down Expand Up @@ -890,7 +890,7 @@ mod tests {
fn test_reshape_op() -> Result<(), Box<dyn Error>> {
let pool = new_pool();
let input = Tensor::from_data(&[2, 2], vec![-0.5, 0.5, 3.0, -5.5]);
let shape = Tensor::from_data(&[1], vec![4]);
let shape = Tensor::from([4]);
let expected = input.to_shape([4].as_slice());

let op = Reshape { allow_zero: false };
Expand Down
2 changes: 1 addition & 1 deletion src/ops/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ fn select_max_index<T, Cmp: Fn(&T, &T) -> std::cmp::Ordering>(

if !keep_dims {
let axes = &[resolved_axis as i32];
let axes = NdTensorView::from_data([1], axes);
let axes = NdTensorView::from(axes);
squeeze_in_place(&mut reduced, Some(axes)).expect("Invalid axis");
}

Expand Down
6 changes: 2 additions & 4 deletions src/ops/unary_elementwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1141,10 +1141,8 @@ mod tests {
#[test]
fn test_sigmoid() -> Result<(), Box<dyn Error>> {
let pool = new_pool();
let input: Tensor<f32> = Tensor::from_data(
&[9],
vec![-500.0, -3.0, -1.0, -0.5, 0.0, 0.5, 1.0, 3.0, 500.0],
);
let input: Tensor<f32> =
Tensor::from([-500.0, -3.0, -1.0, -0.5, 0.0, 0.5, 1.0, 3.0, 500.0]);
let expected = input.map(|x| reference_sigmoid(*x));

let result = sigmoid(&pool, input.view());
Expand Down

0 comments on commit 2faf5e3

Please sign in to comment.