Skip to content

Commit

Permalink
Fix yield iterations for Tensors with negative strides (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
christopherzimmerman authored Nov 25, 2020
1 parent 6ba53ea commit 9d2b864
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 18 deletions.
4 changes: 2 additions & 2 deletions shard.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
name: num
version: 0.4.2
version: 0.4.4

authors:
- Chris Zimmerman <[email protected]>

crystal: 0.34.0
crystal: 0.35.1

license: MIT

Expand Down
35 changes: 19 additions & 16 deletions src/tensor/internal/yield_iterators.cr
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
require "../tensor"

macro init_strided_iteration(coord, backstrides, t_shape, t_strides, t_rank)
macro init_strided_iteration(coord, backstrides, t_shape, t_strides, t_rank, t_data)
{{ coord.id }} = Pointer(Int32).malloc({{ t_rank }}, 0)
{{ backstrides.id }} = Pointer(Int32).malloc({{ t_rank }})
{{ t_rank }}.times do |i|
{{ backstrides.id }}[i] = {{ t_strides }}[i] * ({{ t_shape }}[i] - 1)
if {{ t_strides }}[i] < 0
{{ t_data }} += ({{ t_shape }}[i] - 1) * {{ t_strides }}[i].abs
end
end
end

Expand All @@ -31,7 +34,7 @@ def strided_iteration(t : Tensor)
end
else
t_shape, t_strides, t_rank = t.iter_attrs
init_strided_iteration(:coord, :backstrides, t_shape, t_strides, t_rank)
init_strided_iteration(:coord, :backstrides, t_shape, t_strides, t_rank, data)
t.size.times do |i|
yield i, data
advance_strided_iteration(:coord, :backstrides, t_shape, t_strides, t_rank, data)
Expand Down Expand Up @@ -59,22 +62,22 @@ def dual_strided_iteration(t1 : Tensor, t2 : Tensor)
t2data += 1
end
elsif t1_contiguous
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank)
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank, t2data)
n.times do |i|
yield i, t1data, t2data
t1data += 1
advance_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank, t2data)
end
elsif t2_contiguous
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank)
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank, t1data)
n.times do |i|
yield i, t1data, t2data
advance_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank, t1data)
t2data += 1
end
else
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank)
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank)
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank, t1data)
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank, t2data)
n.times do |i|
yield i, t1data, t2data
advance_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank, t1data)
Expand Down Expand Up @@ -107,26 +110,26 @@ def tri_strided_iteration(t1 : Tensor, t2 : Tensor, t3 : Tensor)
t3data += 1
end
elsif t1_contiguous && t2_contiguous
init_strided_iteration(:t3_coord, :t3_backstrides, t3_shape, t3_strides, t3_rank)
init_strided_iteration(:t3_coord, :t3_backstrides, t3_shape, t3_strides, t3_rank, t3data)
n.times do |i|
yield i, t1data, t2data, t3data
t1data += 1
t2data += 1
advance_strided_iteration(:t3_coord, :t3_backstrides, t3_shape, t3_strides, t3_rank, t3data)
end
elsif t1_contiguous
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank)
init_strided_iteration(:t3_coord, :t3_backstrides, t3_shape, t3_strides, t3_rank)
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank, t2data)
init_strided_iteration(:t3_coord, :t3_backstrides, t3_shape, t3_strides, t3_rank, t3data)
n.times do |i|
yield i, t1data, t2data, t3data
t1data += 1
advance_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank, t2data)
advance_strided_iteration(:t3_coord, :t3_backstrides, t3_shape, t3_strides, t3_rank, t3data)
end
else
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank)
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank)
init_strided_iteration(:t3_coord, :t3_backstrides, t3_shape, t3_strides, t3_rank)
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank, t1data)
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank, t2data)
init_strided_iteration(:t3_coord, :t3_backstrides, t3_shape, t3_strides, t3_rank, t3data)
n.times do |i|
yield i, t1data, t2data, t3data
advance_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank, t1data)
Expand Down Expand Up @@ -162,7 +165,7 @@ def outer_strided_iteration(t1 : Tensor, t2 : Tensor)
t1data += 1
end
elsif t1_contiguous
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank)
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank, t2data)
n.times do
m.times do
yield index, t1data, t2data
Expand All @@ -172,7 +175,7 @@ def outer_strided_iteration(t1 : Tensor, t2 : Tensor)
t1data += 1
end
elsif t2_contiguous
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank)
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank, t1data)
n.times do
m.times do
yield index, t1data, t2data
Expand All @@ -182,8 +185,8 @@ def outer_strided_iteration(t1 : Tensor, t2 : Tensor)
advance_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank, t1data)
end
else
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank)
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank)
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank, t1data)
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank, t2data)
n.times do
m.times do
yield index, t1data, t2data
Expand Down

0 comments on commit 9d2b864

Please sign in to comment.