Skip to content

Commit

Permalink
add more robustness 2
Browse files Browse the repository at this point in the history
  • Loading branch information
mieskolainen committed Oct 19, 2024
1 parent 798b113 commit 98d8b32
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion icedqcd/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def process_data(args):
## 1. Impute data
if args['imputation_param']['active']:

fmodel = f'{args["modeldir"]}/imputer.pkl'
fmodel = os.path.join(args["modeldir"], 'imputer.pkl')
imputer = pickle.load(open(fmodel, 'rb'))
data['data'], _ = process.impute_datasets(data=data['data'], features=None, args=args['imputation_param'], imputer=imputer)

Expand Down
10 changes: 5 additions & 5 deletions icenet/tools/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,9 +857,9 @@ def process_data(args, data, func_factor, mvavars, runmode):
if args['reweight']:

if args["reweight_file"] is None:
fmodel = f'{args["datadir"]}/reweighter__{args["__hash_post_genesis__"]}.pkl'
fmodel = os.path.join(args["datadir"], f'reweighter__{args["__hash_genesis__"]}.pkl')
else:
fmodel = f'{args["datadir"]}/{args["reweight_file"]}'
fmodel = os.path.join(args["datadir"], args["reweight_file"])

if 'load' in args['reweight_mode']:
print(f'Loading reweighting model from: {fmodel} [runmode = {runmode}]', 'green')
Expand Down Expand Up @@ -898,7 +898,7 @@ def process_data(args, data, func_factor, mvavars, runmode):
output['trn']['data'], imputer = impute_datasets(data=output['trn']['data'], features=impute_vars, args=args['imputation_param'], imputer=None)
output['val']['data'], imputer = impute_datasets(data=output['val']['data'], features=impute_vars, args=args['imputation_param'], imputer=imputer)

fmodel = os.path.join(args["modeldir"], f'imputer__{args["__hash_post_genesis__"]}.pkl')
fmodel = os.path.join(args["modeldir"], 'imputer.pkl')

print(f'Saving imputer to: {fmodel}', 'green')
pickle.dump(imputer, open(fmodel, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)
Expand All @@ -909,7 +909,7 @@ def process_data(args, data, func_factor, mvavars, runmode):
if args['reweight']:

if args["reweight_file"] is None:
fmodel = os.path.join(args["datadir"], f'reweighter__{args["__hash_post_genesis__"]}.pkl')
fmodel = os.path.join(args["datadir"], f'reweighter__{args["__hash_genesis__"]}.pkl')
else:
fmodel = os.path.join(args["datadir"], args["reweight_file"])

Expand Down Expand Up @@ -941,7 +941,7 @@ def process_data(args, data, func_factor, mvavars, runmode):
## Imputate
if args['imputation_param']['active']:

fmodel = os.path.join(args["modeldir"], f'imputer__{args["__hash_post_genesis__"]}.pkl')
fmodel = os.path.join(args["modeldir"], f'imputer.pkl')

print(f'Loading imputer from: {fmodel}', 'green')
imputer = pickle.load(open(fmodel, 'rb'))
Expand Down

0 comments on commit 98d8b32

Please sign in to comment.