Skip to content

Commit

Permalink
improve frostt reader
Browse files Browse the repository at this point in the history
  • Loading branch information
NoSegfault committed Apr 27, 2019
1 parent d9abe20 commit 3d33bdd
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
24 changes: 19 additions & 5 deletions data/frostt/reader.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,37 @@
import gzip
import shutil
import ctf
glob_comm = ctf.comm()
import os

modify = False

def read_from_frostt(file_name, I, J, K):

unzipped_file_name = file_name + '.tns'
exists = os.path.isfile(unzipped_file_name)

with gzip.open(file_name + '.tns.gz', 'r') as f_in:
with open(unzipped_file_name, 'w') as f_out:
shutil.copyfileobj(f_in, f_out)

if not exists:
if glob_comm.rank() == 0:
print('Creating ' + unzipped_file_name)
with gzip.open(file_name + '.tns.gz', 'r') as f_in:
with open(unzipped_file_name, 'w') as f_out:
shutil.copyfileobj(f_in, f_out)

T_start = ctf.tensor((I+1, J+1, K+1), sp=True)
if glob_comm.rank() == 0:
print('T_start initialized')
T_start.read_from_file(unzipped_file_name)
if glob_comm.rank() == 0:
print('T_start read in')
T = ctf.tensor((I,J,K), sp=True)
if glob_comm.rank() == 0:
print('T initialized')
T[:,:,:] = T_start[1:,1:,1:]
if glob_comm.rank() == 0:
print('T filled')

if modify:
T.write_to_file(unzipped_file_name)

return T
return T
2 changes: 2 additions & 0 deletions data/frostt/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@
T = reader.read_from_frostt(file_name, I, J, K)
if file_name == 'nell-2':
assert(T[0,182,606] == 1.0)
elif file_name == 'nell-1':
assert(T[0,17350,8011251] == 1.0)
elif file_name == 'amazon-reviews':
assert(T[0,305245,32024] == 1.0)

0 comments on commit 3d33bdd

Please sign in to comment.