Skip to content

Commit

Permalink
ENH: add load_dictionary to load Xy #13
Browse files Browse the repository at this point in the history
  • Loading branch information
JinpengLI committed Aug 26, 2013
1 parent d1a90eb commit e95d73a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 10 deletions.
1 change: 1 addition & 0 deletions epac/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class conf:
BEST_PARAMS = "best_params"
RESULT_SET = "result_set"
ML_CLASSIFICATION_MODE = None # Set to True to force classification mode
DICT_INDEX_FILE = "dict_index.txt"

@classmethod
def init_ml(cls, **Xy):
Expand Down
49 changes: 39 additions & 10 deletions epac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,32 +328,61 @@ def train_test_merge(Xy_train, Xy_test):
return Xy_train


def save_datasets(dataset_dir, **Xy):
def save_dictionary(dataset_dir, **Xy):
'''Save a dictionary to a directory
Save a dictionary to a directory. This dictionary may contain
numpy array, numpy.memmap
Example
-------
from sklearn import datasets
from epac.utils import save_datasets
from epac.utils import save_dictionary
X, y = datasets.make_classification(n_samples=50,
n_features=10000,
n_informative=2,
random_state=1)
Xy = dict(X=X, y=y)
save_datasets("/tmp/save_datasets_data", **Xy)
save_dictionary("/tmp/save_datasets_data", **Xy)
'''
if not os.path.exists(dataset_dir):
os.makedirs(dataset_dir)
index_filepath = os.path.join(dataset_dir, "db_index.txt")
file_db_index = open(index_filepath, "w+")
file_db_index.write(repr(len(Xy)) + "\n")
index_filepath = os.path.join(dataset_dir, conf.DICT_INDEX_FILE)
file_dict_index = open(index_filepath, "w+")
file_dict_index.write(repr(len(Xy)) + "\n")
for key in Xy:
filepath = os.path.join(dataset_dir, key + ".npy")
file_db_index.write(filepath)
file_db_index.write("\n")
file_db_index.close()
file_dict_index.write(key)
file_dict_index.write("\n")
file_dict_index.write(filepath)
file_dict_index.write("\n")
file_dict_index.close()
for key in Xy:
filepath = os.path.join(dataset_dir, key + ".npy")
np.save(filepath, Xy[key])
np.save(filepath, Xy[key])


def load_dictionary(dataset_dir):
'''Load a dictionary
Load a dictionary from save_dictionary
Example
-------
from epac.utils import load_dictionary
Xy = load_dictionary("/tmp/save_datasets_data")
'''
if not os.path.exists(dataset_dir):
return None
index_filepath = os.path.join(dataset_dir, conf.DICT_INDEX_FILE)
if not os.path.isfile(index_filepath):
return None
file_dict_index = open(index_filepath, "r")
len_dict = file_dict_index.readline()
res = {}
for i in range(int(len_dict)):
key = file_dict_index.readline()
key = key.strip("\n")
filepath = file_dict_index.readline()
filepath = filepath.strip("\n")
data = np.load(filepath)
res[key] = data
return res

0 comments on commit e95d73a

Please sign in to comment.