Skip to content

Commit

Permalink
bug: fixed transpose function and fixed reciprocal-cell calculation
Browse files Browse the repository at this point in the history
Now the transpose function works for 2D and 3D arrays.

The reciprocal cell calculation was wrong for skewed lattices.

Signed-off-by: Nick Papior <[email protected]>
  • Loading branch information
zerothi committed Jan 23, 2018
1 parent c721842 commit 756672a
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 49 deletions.
107 changes: 63 additions & 44 deletions flos/num/array.lua
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ end


function Array:initialize(...)

local sh = nil

-- One may also initialize an Array by passing a
Expand All @@ -75,7 +75,7 @@ function Array:initialize(...)
if sh == nil then
sh = shape.Shape(...)
end

-- Create the shape container
rawset(self, "shape", sh)

Expand Down Expand Up @@ -200,7 +200,7 @@ function Array.range(i1, i2, step)
elseif is < 0 and i1 < i2 then
error("flos.Array range with negative step-length and i1 < i2 is not allowed.")
end

local new = Array.empty( 1 )
local j = 0
for i = i1, i2, is do
Expand Down Expand Up @@ -266,7 +266,7 @@ function Array:set_linear(i, v)

-- If we are at the last dimension, return immediately.
if #self.shape == 1 then

self[i] = v

return
Expand All @@ -278,7 +278,7 @@ function Array:set_linear(i, v)
local j = m.tointeger( m.ceil(i / n_dim) )
-- Transform i into the linear index in the underlying array
self[j]:set_linear( m.tointeger(i - (j-1) * n_dim), v)

end


Expand Down Expand Up @@ -323,7 +323,7 @@ function Array:reshape(...)
if #arg == 0 then
arg[1] = 0
end

-- In case a shape is passed
local sh
if shape.isShape(arg[1]) then
Expand All @@ -334,7 +334,7 @@ function Array:reshape(...)
if sh == nil then
error("flos.Array cannot align shapes, incompatible dimensions")
end

-- Create the new array
local new = Array( sh )

Expand Down Expand Up @@ -522,9 +522,9 @@ function Array:norm(axis)
norm = norm + self[i]:norm(0) ^ 2
end
norm = m.sqrt(norm)

elseif #self.shape == 1 then

norm = 0.
for i = 1, #self do
norm = norm + self[i] * self[i]
Expand All @@ -539,9 +539,9 @@ function Array:norm(axis)
for i = 1, #norm do
norm[i] = self[i]:norm()
end

end

return norm
end

Expand All @@ -560,14 +560,14 @@ function Array:scalar_project(P, axis)
local ax = ax_(axis)

if ax == 0 then

-- Calculate norm of the projection vector
return self:flatten():dot( P:flatten() ) / P:norm(0)

else

error("flos.Array could not project on anything but flattened array")

end

end
Expand All @@ -587,15 +587,15 @@ function Array:project(P, axis)
local ax = ax_(axis)

if ax == 0 then

-- Calculate norm of the projection vector
local dnorm2 = P:norm(0) ^ 2
return self:flatten():dot( P:flatten() ) / dnorm2 * P

else

error("flos.Array could not project on anything but flattened array")

end

end
Expand Down Expand Up @@ -631,7 +631,7 @@ function Array:min(axis)

error("NotimplementedYet")
end

return min
end

Expand Down Expand Up @@ -666,7 +666,7 @@ function Array:max(axis)

error("NotimplementedYet")
end

return max
end

Expand All @@ -679,7 +679,7 @@ function Array:sum(axis)

local sum
if ax == 0 then

-- Special case for the 1D case
if #self.shape == 1 then
sum = self[1]
Expand Down Expand Up @@ -753,7 +753,7 @@ function Array.cross(lhs, rhs)
local cross = Array( sh )

if #cross.shape == 1 then

cross[1] = lhs[2] * rhs[3] - lhs[3] * rhs[2]
cross[2] = lhs[3] * rhs[1] - lhs[1] * rhs[3]
cross[3] = lhs[1] * rhs[2] - lhs[2] * rhs[1]
Expand Down Expand Up @@ -814,7 +814,7 @@ function Array.dot(lhs, rhs)
if lhs.shape ~= rhs.shape then
error("flos.Array dot dimensions for 1D dot product are not the same")
end

-- This is a element wise product and sum
dot = 0.
for i = 1, #lhs do
Expand Down Expand Up @@ -872,11 +872,11 @@ function Array.dot(lhs, rhs)
else

error("flos.Array dot for arrays with anything but 1 or 2 dimensions is not implemented yet")

end

return dot

end


Expand All @@ -885,24 +885,43 @@ end
function Array:transpose()

-- Check dimensions, we cannot transpose a 1D array
if #self.shape == 1 then
error("flos.Array cannot transpose a vector, reshape, then transpose")
end
local nd = #self.shape
local new = nil

-- First create the reversed shape
local sh = self.shape:reverse()
if nd == 1 then
return self:copy()

-- Create return array
local new = Array( sh )
elseif nd == 2 then

-- Create return array
new = Array( self.shape:reverse() )

for i = 1 , self.shape[1] do
for j = 1 , self.shape[2] do
new[j][i] = self[i][j]
end
end

elseif nd == 3 then

-- Create return array
new = Array( self.shape:reverse() )

for i = 1 , self.shape[1] do
for j = 1 , self.shape[2] do
for k = 1 , self.shape[3] do
new[k][j][i] = self[i][j][k]
end
end
end

else

error("flos.Array transpose only works up to 3D arrays")

-- Perform transpose
local size = self:size()
for i = 1, size do
new:set_linear(size-i+1, self:get_linear(i))
end

return new

end


Expand All @@ -925,14 +944,14 @@ function Array.__add(lhs, rhs)
if sh == nil then
error("flos.Array + requires the same shape for two different Arrays")
end

-- Create the return array
ret = Array( lhs.shape )
-- Element-wise additions
for i = 1, #lhs do
ret[i] = lhs[i] + rhs[i]
end

elseif isArray(lhs) then

ret = Array( lhs.shape )
Expand All @@ -943,7 +962,7 @@ function Array.__add(lhs, rhs)
end

elseif isArray(rhs) then

-- Add scalar to all RHS
ret = Array( rhs.shape )

Expand Down Expand Up @@ -1135,7 +1154,7 @@ function Array.__pow(lhs, rhs)
if sh == nil then
error("flos.Array ^ requires the same shape for two different Arrays")
end

ret = Array( lhs.shape )
for i = 1, #lhs do
ret[i] = lhs[i] ^ rhs[i]
Expand All @@ -1145,7 +1164,7 @@ function Array.__pow(lhs, rhs)
if type(rhs) == "string" and rhs == "T" then
return lhs:transpose()
end

ret = Array( lhs.shape )
for i = 1, #lhs do
ret[i] = lhs[i] ^ rhs
Expand All @@ -1155,7 +1174,7 @@ function Array.__pow(lhs, rhs)
if type(lhs) == "string" and lhs == "T" then
return rhs:transpose()
end

ret = Array( rhs.shape )
for i = 1, #rhs do
ret[i] = lhs ^ rhs[i]
Expand Down
6 changes: 1 addition & 5 deletions flos/optima/lattice.lua
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,7 @@ function Lattice:update_reciprocal()
self.rcell[3][3] = c[1][1]*c[2][2] - c[1][2]*c[2][1]

for i = 1, 3 do
local n = c[i][1]*self.rcell[i][1] +
c[i][2]*self.rcell[i][2] +
c[i][3]*self.rcell[i][3]

self.rcell[i] = self.rcell[i] / n
self.rcell[i] = self.rcell[i] / c[i]:dot(self.rcell[i])
end

end
Expand Down

0 comments on commit 756672a

Please sign in to comment.