diff --git a/data/frostt/reader.py b/data/frostt/reader.py index 721bb50..355514a 100644 --- a/data/frostt/reader.py +++ b/data/frostt/reader.py @@ -1,23 +1,37 @@ import gzip import shutil import ctf +glob_comm = ctf.comm() +import os modify = False def read_from_frostt(file_name, I, J, K): unzipped_file_name = file_name + '.tns' + exists = os.path.isfile(unzipped_file_name) - with gzip.open(file_name + '.tns.gz', 'r') as f_in: - with open(unzipped_file_name, 'w') as f_out: - shutil.copyfileobj(f_in, f_out) - + if not exists: + if glob_comm.rank() == 0: + print('Creating ' + unzipped_file_name) + with gzip.open(file_name + '.tns.gz', 'r') as f_in: + with open(unzipped_file_name, 'w') as f_out: + shutil.copyfileobj(f_in, f_out) + T_start = ctf.tensor((I+1, J+1, K+1), sp=True) + if glob_comm.rank() == 0: + print('T_start initialized') T_start.read_from_file(unzipped_file_name) + if glob_comm.rank() == 0: + print('T_start read in') T = ctf.tensor((I,J,K), sp=True) + if glob_comm.rank() == 0: + print('T initialized') T[:,:,:] = T_start[1:,1:,1:] + if glob_comm.rank() == 0: + print('T filled') if modify: T.write_to_file(unzipped_file_name) - return T \ No newline at end of file + return T diff --git a/data/frostt/test.py b/data/frostt/test.py index 683baa3..a117095 100644 --- a/data/frostt/test.py +++ b/data/frostt/test.py @@ -10,5 +10,7 @@ T = reader.read_from_frostt(file_name, I, J, K) if file_name == 'nell-2': assert(T[0,182,606] == 1.0) +elif file_name == 'nell-1': + assert(T[0,17350,8011251] == 1.0) elif file_name == 'amazon-reviews': assert(T[0,305245,32024] == 1.0)