Skip to content

Commit

Permalink
add more robustness
Browse files Browse the repository at this point in the history
  • Loading branch information
mieskolainen committed Oct 19, 2024
1 parent 69765bf commit 798b113
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions icenet/tools/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from importlib import import_module
import os
import copy
import pickle
import xgboost

Expand Down Expand Up @@ -660,7 +659,7 @@ def combine_pickle_data(args):
with open(os.path.join(cache_directory, f'output_{i}.pkl'), 'rb') as handle:
X_, Y_, W_, ids, info, genesis_args = pickle.load(handle)
if i > 0:
X = np.concatenate((X, X_), axis=0) # awkward will cast numpy automatically
Y = np.concatenate((Y, Y_), axis=0)
Expand Down Expand Up @@ -899,7 +898,7 @@ def process_data(args, data, func_factor, mvavars, runmode):
output['trn']['data'], imputer = impute_datasets(data=output['trn']['data'], features=impute_vars, args=args['imputation_param'], imputer=None)
output['val']['data'], imputer = impute_datasets(data=output['val']['data'], features=impute_vars, args=args['imputation_param'], imputer=imputer)

fmodel = f'{args["modeldir"]}/imputer.pkl'
fmodel = os.path.join(args["modeldir"], f'imputer__{args["__hash_post_genesis__"]}.pkl')

print(f'Saving imputer to: {fmodel}', 'green')
pickle.dump(imputer, open(fmodel, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
Expand All @@ -910,9 +909,9 @@ def process_data(args, data, func_factor, mvavars, runmode):
if args['reweight']:

if args["reweight_file"] is None:
fmodel = f'{args["datadir"]}/reweighter_{args["__hash_genesis__"]}.pkl'
fmodel = os.path.join(args["datadir"], f'reweighter__{args["__hash_post_genesis__"]}.pkl')
else:
fmodel = f'{args["datadir"]}/{args["reweight_file"]}'
fmodel = os.path.join(args["datadir"], args["reweight_file"])

if 'load' in args['reweight_mode']:
print(f'Loading reweighting model from: {fmodel} [runmode = {runmode}]', 'green')
Expand Down Expand Up @@ -942,7 +941,7 @@ def process_data(args, data, func_factor, mvavars, runmode):
## Imputate
if args['imputation_param']['active']:

fmodel = f'{args["modeldir"]}/imputer.pkl'
fmodel = os.path.join(args["modeldir"], f'imputer__{args["__hash_post_genesis__"]}.pkl')

print(f'Loading imputer from: {fmodel}', 'green')
imputer = pickle.load(open(fmodel, 'rb'))
Expand Down

0 comments on commit 798b113

Please sign in to comment.