From 5a7632aff8056fe8908296d04628d0067c2e0d8b Mon Sep 17 00:00:00 2001 From: Chris Zimmerman Date: Sun, 19 Jul 2020 08:46:36 -0400 Subject: [PATCH] tensordot and reshape bug (#37) --- src/tensor/linalg.cr | 94 ++++++++++++++++++++++++++++++++++++++++++++ src/tensor/tensor.cr | 67 +++++++++++-------------------- 2 files changed, 116 insertions(+), 45 deletions(-) diff --git a/src/tensor/linalg.cr b/src/tensor/linalg.cr index 34ea8a88..7f1ae04e 100644 --- a/src/tensor/linalg.cr +++ b/src/tensor/linalg.cr @@ -569,6 +569,100 @@ class Tensor(T) dest end + # Compute tensor dot product along specified axes. + # + # Given two tensors, a and b, and an array_like object containing two + # array_like objects, (a_axes, b_axes), sum the products of a’s and b’s + # elements (components) over the axes specified by a_axes and b_axes. + # The third argument can be a single non-negative integer_like scalar, + # N; if it is such, then the last N dimensions of a and the first N + # dimensions of b are summed over. + # + # Arguments + # --------- + # *b* : Tensor + # Right hand side of dot products + # *axes* : Array(Array(Int)) | Array(Int) | Int + # Axes of summation + # + # Examples + # -------- + # ``` + # a = Tensor.range(60.0).reshape(3, 4, 5) + # b = Tensor.range(24.0).reshape(4, 3, 2) + # puts a.tensordot(b, axes: [[1, 0], [0, 1]]) + # + # # [[4400, 4730], + # # [4532, 4874], + # # [4664, 5018], + # # [4796, 5162], + # # [4928, 5306]] + # ``` + def tensordot(b : Tensor(T), axes : Array(Array(Int))) + axes_a, axes_b = axes + na = axes_a.size + nb = axes_b.size + as_ = self.shape + nda = self.rank + bs = b.shape + ndb = b.rank + equal = na == nb + na.times do |k| + if as_[axes_a[k]] != bs[axes_b[k]] + equal = false + break + end + if axes_a[k] < 0 + axes_a[k] += nda + end + if axes_b[k] < 0 + axes_b[k] += ndb + end + end + unless equal + raise Num::Internal::ShapeError.new("Shape mismatch for sum") + end + notin = (0...nda).select do |k| + !axes_a.includes?(k) + end + newaxes_a = notin + axes_a + n2 = 1 + axes_a.each do |axis| + n2 *= as_[axis] + end + newshape_a = [(notin.map { |ax| as_[ax] }).product, n2] + olda = notin.map { |ax| as_[ax] } + + notin = (0...ndb).select do |k| + !axes_b.includes?(k) + end + newaxes_b = axes_b + notin + n2 = 1 + axes_b.each do |axis| + n2 *= bs[axis] + end + newshape_b = [n2, (notin.map { |ax| bs[ax] }).product] + oldb = notin.map { |ax| bs[ax] } + + at = self.transpose(newaxes_a).reshape(newshape_a) + bt = b.transpose(newaxes_b).reshape(newshape_b) + res = at.matmul(bt) + res.reshape(olda + oldb) + end + + # :ditto: + def tensordot(b : Tensor(T), axes : Int) + axes_a = (-axes...0).to_a + axes_b = (0...axes).to_a + self.tensordot(b, [axes_a, axes_b]) + end + + # :ditto: + def tensordot(b : Tensor(T), axes : Array(Int)) + axes_a, axes_b = axes + self.tensordot(b, [[axes_a], [axes_b]]) + end + # :nodoc: def is_matrix unless self.rank == 2 diff --git a/src/tensor/tensor.cr b/src/tensor/tensor.cr index 5286a16d..71d91b13 100644 --- a/src/tensor/tensor.cr +++ b/src/tensor/tensor.cr @@ -1217,64 +1217,41 @@ class Tensor(T) # # [3, 4]] # ``` def reshape(new_shape : Array(Int)) - result_shape = new_shape.map &.to_i - - if result_shape == @shape + newshape = new_shape.map &.to_i + if newshape == shape return self.view end - - n = 1 - c = @size - auto = -1 - - result_shape.each_with_index do |v, i| - if v < 0 - if auto >= 0 - raise Num::Internal::ValueError.new( - "Only a single dimension can be inferred" - ) + newsize = 1 + cur_size = size + autosize = -1 + newshape.each_with_index do |val, i| + if val < 0 + if autosize >= 0 + raise Num::Internal::ValueError.new("Only shape dimension can be automatic") end - auto = i + autosize = i else - n *= v + newsize *= val end end - if auto >= 0 - result_shape = result_shape.dup - result_shape[auto] = c // n - n *= result_shape[auto] + if autosize >= 0 + newshape = newshape.dup + newshape[autosize] = cur_size // newsize + newsize *= newshape[autosize] end - if n != c - raise Num::Internal::ShapeError.new( - "Shape #{@shape} cannot be reshaped to #{result_shape}" - ) + if newsize != cur_size + raise Num::Internal::ShapeError.new "Shapes #{shape} cannot be reshaped to #{newshape}" end + newstrides = Num::Internal.shape_to_strides(newshape, Num::RowMajor) + if @flags.contiguous? - new_strides = Num::Internal.shape_to_strides( - result_shape, - Num::RowMajor - ) - t = Tensor(T).new(@buffer, result_shape, new_strides) - t.flags &= ~Num::ArrayFlags::OwnData - t - elsif @flags.fortran? - new_strides = Num::Internal.shape_to_strides( - result_shape, - Num::ColMajor - ) - t = Tensor(T).new(@buffer, result_shape, new_strides) - t.flags &= ~Num::ArrayFlags::OwnData - t + self.class.new(@buffer, newshape, newstrides) else - t = dup(Num::ColMajor) - new_strides = Num::Internal.shape_to_strides( - result_shape, - Num::ColMajor - ) - t + tmp = self.dup(Num::RowMajor) + self.class.new(tmp.to_unsafe, newshape, newstrides) end end