From 949786161ffbdb2d735625078478d67f684b5999 Mon Sep 17 00:00:00 2001 From: Mikael Mieskolainen Date: Mon, 18 Nov 2024 11:09:01 +0000 Subject: [PATCH] bootstrap resampling for train --- .github/workflows/icenet-install-test.yml | 32 +-- configs/zee/models.yml | 2 + icedqcd/common.py | 12 +- icefit/icepeak.py | 5 +- icehgcal/common.py | 22 +- icehnl/common.py | 7 +- iceid/common.py | 13 +- icenet/__init__.py | 4 +- icenet/tools/io.py | 6 +- icenet/tools/process.py | 285 +++++++++++++--------- icetrg/common.py | 7 +- icezee/common.py | 7 +- 12 files changed, 246 insertions(+), 156 deletions(-) diff --git a/.github/workflows/icenet-install-test.yml b/.github/workflows/icenet-install-test.yml index 38283732..00841b8e 100644 --- a/.github/workflows/icenet-install-test.yml +++ b/.github/workflows/icenet-install-test.yml @@ -101,6 +101,21 @@ jobs: run: | source setenv-github-actions.sh && python icefit/peakfit.py --analyze --group --test_mode --fit_type dual-unitary-II --output_name dual-unitary-II + # + - name: Deep Learning system (runme_eid) + run: | + source setenv-github-actions.sh && maxevents=10000; source tests/runme_eid.sh + + # + - name: Deep Learning system (runme_eid_deep) + run: | + source setenv-github-actions.sh && maxevents=10000; source tests/runme_eid_deep.sh + + # + - name: Deep Learning system (runme_eid_visual) + run: | + source setenv-github-actions.sh && maxevents=10000; source tests/runme_eid_visual.sh + # (This is run twice to test cache files) - name: Deep Learning system (runme_brem) run: | @@ -127,7 +142,7 @@ jobs: source tests/runme_zee_gridtune.sh echo "yes" | source superclean.sh - + # - name: Deep Learning system (runme_zee) run: | @@ -151,21 +166,6 @@ jobs: run: | source setenv-github-actions.sh && maxevents=10000; source tests/runme_trg.sh echo "yes" | source superclean.sh - - # - - name: Deep Learning system (runme_eid) - run: | - source setenv-github-actions.sh && maxevents=10000; source tests/runme_eid.sh - - # - - name: Deep Learning system (runme_eid_deep) - run: | - source setenv-github-actions.sh && maxevents=10000; source tests/runme_eid_deep.sh - - # - - name: Deep Learning system (runme_eid_visual) - run: | - source setenv-github-actions.sh && maxevents=10000; source tests/runme_eid_visual.sh ## source setenv-github-actions.sh && maxevents=10000; source tests/runme_brk.sh ## source setenv-github-actions.sh && maxevents=10000; source tests/runme_dqcd_vector_train.sh diff --git a/configs/zee/models.yml b/configs/zee/models.yml index a15436f8..9d84e237 100644 --- a/configs/zee/models.yml +++ b/configs/zee/models.yml @@ -193,6 +193,8 @@ iceboost_swd: <<: *ICEBOOST0 label: 'ICEBOOST-SWD' + + bootstrap: 3 # BCE loss domains [use with custom:binary_cross_entropy] BCE_param: diff --git a/icedqcd/common.py b/icedqcd/common.py index 35c5cf93..2950fa27 100644 --- a/icedqcd/common.py +++ b/icedqcd/common.py @@ -353,10 +353,12 @@ def splitfactor(x, y, w, ids, args, skip_graph=True, use_dequantize=True): data_graph += sum(ray.get(graph_futures), []) # Join split array results ray.shutdown() - + print(f'ray_results: {time.time() - start_time:0.1f} sec') io.showmem() + data_graph = np.array(data_graph, dtype=object) # ! + # ------------------------------------------------------------------------- ## Tensor representation data_tensor = None @@ -392,5 +394,9 @@ def splitfactor(x, y, w, ids, args, skip_graph=True, use_dequantize=True): """ # -------------------------------------------------------------------------- - - return {'data': data, 'data_MI': data_MI, 'data_kin': data_kin, 'data_deps': data_deps, 'data_tensor': data_tensor, 'data_graph': data_graph} + return {'data': data, + 'data_MI': data_MI, + 'data_kin': data_kin, + 'data_deps': data_deps, + 'data_tensor': data_tensor, + 'data_graph': data_graph} diff --git a/icefit/icepeak.py b/icefit/icepeak.py index 5afb64e2..a0df15fc 100644 --- a/icefit/icepeak.py +++ b/icefit/icepeak.py @@ -366,7 +366,7 @@ def TH1_to_numpy(hist, dtype=np.float64): #for n, v in hist.__dict__.items(): # class generated on the fly # print(f'{n} {v}') - + hh = hist.to_numpy() counts = np.array(hist.values(), dtype=dtype) errors = np.array(hist.errors(), dtype=dtype) @@ -1430,7 +1430,8 @@ def integral_wrapper(lambdafunc, x, edges, norm=False, N_int: int=128, EPS=1E-8, if norm: # Normalization based on a numerical integral over edge bounds x_fine = np.linspace(edges[0], edges[-1], N_int) - I = max(np.trapz(y=lambdafunc(x_fine), x=x_fine), EPS) + y_fine = lambdafunc(x_fine) + I = max(np.trapz(x=x_fine, y=y_fine), EPS) return f / I * edges2binwidth(edges) else: diff --git a/icehgcal/common.py b/icehgcal/common.py index 5159081a..f0cd0b52 100644 --- a/icehgcal/common.py +++ b/icehgcal/common.py @@ -242,23 +242,31 @@ def splitfactor(x, y, w, ids, args): ### DeepSets representation data_deps = None + # ------------------------------------------------------------------------- + ### Mutual Information + data_MI = None + # ------------------------------------------------------------------------- ### Tensor representation data_tensor = None # ------------------------------------------------------------------------- ## Graph representation - data_graph = None + data_graph = None - data_graph = graphio.parse_graph_data_candidate(X=data.x, Y=data.y, weights=data.w, ids=data.ids, + data_graph = graphio.parse_graph_data_candidate(X=data.x, Y=data.y, weights=data.w, ids=data.ids, features=scalar_vars, graph_param=args['graph_param']) + data_graph = np.array(data_graph, dtype=object) # ! + # -------------------------------------------------------------------- ### Finally pick active scalar variables out - + data.x = None # To protect other routines (TBD see global impact --> comment this line) - return {'data': data, 'data_kin': data_kin, 'data_deps': data_deps, 'data_tensor': data_tensor, 'data_graph': data_graph} - -# ======================================================================== -# ======================================================================== + return {'data': data, + 'data_MI': data_MI, + 'data_kin': data_kin, + 'data_deps': data_deps, + 'data_tensor': data_tensor, + 'data_graph': data_graph} diff --git a/icehnl/common.py b/icehnl/common.py index 742fa14c..bb5af531 100644 --- a/icehnl/common.py +++ b/icehnl/common.py @@ -146,4 +146,9 @@ def splitfactor(x, y, w, ids, args): data = data[vars] data.x = data.x.astype(np.float32) - return {'data': data, 'data_MI': data_MI, 'data_kin': data_kin, 'data_deps': data_deps, 'data_tensor': data_tensor, 'data_graph': data_graph} + return {'data': data, + 'data_MI': data_MI, + 'data_kin': data_kin, + 'data_deps': data_deps, + 'data_tensor': data_tensor, + 'data_graph': data_graph} diff --git a/iceid/common.py b/iceid/common.py index dcbc4fbb..9af82684 100644 --- a/iceid/common.py +++ b/iceid/common.py @@ -150,6 +150,10 @@ def splitfactor(x, y, w, ids, args): data_kin = data[vars] data_kin.x = data_kin.x.astype(np.float32) + # ------------------------------------------------------------------------- + ### MI variables + data_MI = None + # ------------------------------------------------------------------------- ### DeepSets representation data_deps = None @@ -206,6 +210,8 @@ def splitfactor(x, y, w, ids, args): print(f'ray_results: {time.time() - start_time:0.1f} sec') io.showmem() + data_graph = np.array(data_graph, dtype=object) # ! + # -------------------------------------------------------------------- ### Finally pick active scalar variables out @@ -213,4 +219,9 @@ def splitfactor(x, y, w, ids, args): data = data[vars] data.x = data.x.astype(np.float32) - return {'data': data, 'data_kin': data_kin, 'data_deps': data_deps, 'data_tensor': data_tensor, 'data_graph': data_graph} + return {'data': data, + 'data_MI': data_MI, + 'data_kin': data_kin, + 'data_deps': data_deps, + 'data_tensor': data_tensor, + 'data_graph': data_graph} diff --git a/icenet/__init__.py b/icenet/__init__.py index ca28bed5..4b1badaa 100644 --- a/icenet/__init__.py +++ b/icenet/__init__.py @@ -3,9 +3,9 @@ import os import psutil -__version__ = '0.1.3.6' +__version__ = '0.1.3.7' __release__ = 'alpha' -__date__ = '04/11/2024' +__date__ = '18/11/2024' __author__ = 'm.mieskolainen@imperial.ac.uk' __repository__ = 'github.com/mieskolainen/icenet' __asciiart__ = \ diff --git a/icenet/tools/io.py b/icenet/tools/io.py index bca776a6..63841218 100644 --- a/icenet/tools/io.py +++ b/icenet/tools/io.py @@ -391,7 +391,11 @@ def __getitem__(self, key): return IceXYW(x=self.x[..., col], y=self.y, w=self.w, ids=ids) else: return IceXYW(x=self.x[col], y=self.y, w=self.w, ids=ids) - + + # length operator + def __len__(self): + return len(self.x) + # + operator def __add__(self, other): diff --git a/icenet/tools/process.py b/icenet/tools/process.py index 815288e8..0476b0d4 100644 --- a/icenet/tools/process.py +++ b/icenet/tools/process.py @@ -1158,18 +1158,18 @@ def train_models(data_trn, data_val, args=None): prints.print_variables(data_trn['data'].x, data_trn['data'].ids, W=data_trn['data'].w, output_file=output_file) # ------------------------------------------------------------- - - def set_distillation_drain(ID, param, inputs, dtype='torch'): + + def set_distillation_drain(ID, param, inputs, idx, dtype='torch'): if 'distillation' in args and args['distillation']['drains'] is not None: if ID in args['distillation']['drains']: print(f'Creating soft distillation drain for the model <{ID}>', 'yellow') # By default to torch - inputs['y_soft'] = torch.tensor(y_soft, dtype=torch.float) + inputs['y_soft'] = torch.tensor(y_soft, dtype=torch.float)[idx] if dtype == 'numpy': - inputs['y_soft'] = inputs['y_soft'].detach().cpu().numpy() - + inputs['y_soft'] = inputs['y_soft'].detach().cpu().numpy()[idx] + # ------------------------------------------------------------- print(f'Training models:', 'magenta') @@ -1188,149 +1188,192 @@ def set_distillation_drain(ID, param, inputs, dtype='torch'): ID = args['active_models'][i] param = args['models'][ID] print(f'Training <{ID}> | {param} \n') - - try: + + if 'bootstrap' not in param: + num_bootstrap = 1 + elif param['bootstrap'] is None: + num_bootstrap = 1 + else: + num_bootstrap = param['bootstrap'] + 1 + + ORIG_LABEL = copy.deepcopy(param['label']) + + for b in range(num_bootstrap): + + # Bootstrap sample - ## Different model - if param['train'] == 'torch_graph': + if b > 0: + param['label'] = f'{ORIG_LABEL}__bs_{b}' - inputs = {'data_trn': data_trn['data_graph'], - 'data_val': data_val['data_graph'], - 'args': args, - 'param': param} + idx_trn = np.random.choice(len(data_trn['data']), size=len(data_trn['data']), replace=True) + idx_val = np.arange(len(data_val['data'])) # Orig - set_distillation_drain(ID=ID, param=param, inputs=inputs) + print(f'Bootstrap training sample: {b} / {num_bootstrap}', 'green') - if ID in args['raytune']['param']['active']: - model = train.raytune_main(inputs=inputs, train_func=train.train_torch_graph) - else: - model = train.train_torch_graph(**inputs) + else: + idx_trn = np.arange(len(data_trn['data'])) # Orig + idx_val = np.arange(len(data_val['data'])) # Orig + + print(f'Original training sample', 'green') - elif param['train'] == 'xgb': + try: + + ## Different model + if param['train'] == 'torch_graph': + + inputs = { + 'data_trn': [data_trn['data_graph'][i] for i in idx_trn], + 'data_val': [data_val['data_graph'][i] for i in idx_val], + 'args': args, + 'param': param} + + set_distillation_drain(ID=ID, param=param, inputs=inputs, idx=idx_trn) - inputs = {'data_trn': data_trn['data'], - 'data_val': data_val['data'], + if ID in args['raytune']['param']['active']: + model = train.raytune_main(inputs=inputs, train_func=train.train_torch_graph) + else: + model = train.train_torch_graph(**inputs) + + elif param['train'] == 'xgb': + + inputs = { + 'data_trn': data_trn['data'][idx_trn], + 'data_val': data_val['data'][idx_val], 'args': args, - 'data_trn_MI': data_trn['data_MI'] if 'data_MI' in data_trn else None, - 'data_val_MI': data_val['data_MI'] if 'data_MI' in data_val else None, + 'data_trn_MI': data_trn['data_MI'][idx_trn] if ('data_MI' in data_trn and data_trn['data_MI'] is not None) else None, + 'data_val_MI': data_val['data_MI'][idx_val] if ('data_MI' in data_val and data_val['data_MI'] is not None) else None, 'param': param} - - set_distillation_drain(ID=ID, param=param, inputs=inputs, dtype='numpy') - - if ID in args['raytune']['param']['active']: - model = train.raytune_main(inputs=inputs, train_func=iceboost.train_xgb) - else: - model = iceboost.train_xgb(**inputs) + + set_distillation_drain(ID=ID, param=param, inputs=inputs, idx=idx_trn, dtype='numpy') + + if ID in args['raytune']['param']['active']: + model = train.raytune_main(inputs=inputs, train_func=iceboost.train_xgb) + else: + model = iceboost.train_xgb(**inputs) - elif param['train'] == 'torch_deps': - - inputs = {'X_trn': torch.tensor(data_trn['data_deps'].x, dtype=torch.float), - 'Y_trn': torch.tensor(data_trn['data'].y, dtype=torch.long), - 'X_val': torch.tensor(data_val['data_deps'].x, dtype=torch.float), - 'Y_val': torch.tensor(data_val['data'].y, dtype=torch.long), + elif param['train'] == 'torch_deps': + + inputs = { + 'X_trn': torch.tensor(data_trn['data_deps'].x[idx_trn], dtype=torch.float), + 'Y_trn': torch.tensor(data_trn['data'].y[idx_trn], dtype=torch.long), + 'X_val': torch.tensor(data_val['data_deps'].x[idx_val], dtype=torch.float), + 'Y_val': torch.tensor(data_val['data'].y[idx_val], dtype=torch.long), 'X_trn_2D': None, 'X_val_2D': None, - 'trn_weights': torch.tensor(data_trn['data'].w, dtype=torch.float), - 'val_weights': torch.tensor(data_val['data'].w, dtype=torch.float), - 'data_trn_MI': data_trn['data_MI'] if 'data_MI' in data_trn else None, - 'data_val_MI': data_val['data_MI'] if 'data_MI' in data_val else None, + 'trn_weights': torch.tensor(data_trn['data'].w[idx_trn], dtype=torch.float), + 'val_weights': torch.tensor(data_val['data'].w[idx_val], dtype=torch.float), + 'data_trn_MI': data_trn['data_MI'][idx_trn] if ('data_MI' in data_trn and data_trn['data_MI'] is not None) else None, + 'data_val_MI': data_val['data_MI'][idx_val] if ('data_MI' in data_val and data_val['data_MI'] is not None) else None, 'args': args, 'param': param, 'ids': data_trn['data_deps'].ids} - - set_distillation_drain(ID=ID, param=param, inputs=inputs) + + set_distillation_drain(ID=ID, param=param, idx=idx_trn, inputs=inputs) - if ID in args['raytune']['param']['active']: - model = train.raytune_main(inputs=inputs, train_func=train.train_torch_generic) - else: - model = train.train_torch_generic(**inputs) + if ID in args['raytune']['param']['active']: + model = train.raytune_main(inputs=inputs, train_func=train.train_torch_generic) + else: + model = train.train_torch_generic(**inputs) - elif param['train'] == 'torch_generic': - - inputs = {'X_trn': torch.tensor(aux.red(data_trn['data'].x, data_trn['data'].ids, param, 'X'), dtype=torch.float), - 'Y_trn': torch.tensor(data_trn['data'].y, dtype=torch.long), - 'X_val': torch.tensor(aux.red(data_val['data'].x, data_val['data'].ids, param, 'X'), dtype=torch.float), - 'Y_val': torch.tensor(data_val['data'].y, dtype=torch.long), - 'X_trn_2D': None if data_trn['data_tensor'] is None else torch.tensor(data_trn['data_tensor'], dtype=torch.float), - 'X_val_2D': None if data_val['data_tensor'] is None else torch.tensor(data_val['data_tensor'], dtype=torch.float), - 'trn_weights': torch.tensor(data_trn['data'].w, dtype=torch.float), - 'val_weights': torch.tensor(data_val['data'].w, dtype=torch.float), - 'data_trn_MI': data_trn['data_MI'] if 'data_MI' in data_trn else None, - 'data_val_MI': data_val['data_MI'] if 'data_MI' in data_val else None, + elif param['train'] == 'torch_generic': + + inputs = { + 'X_trn': torch.tensor(aux.red(data_trn['data'].x[idx_trn], data_trn['data'].ids, param, 'X'), dtype=torch.float), + 'Y_trn': torch.tensor(data_trn['data'].y[idx_trn], dtype=torch.long), + 'X_val': torch.tensor(aux.red(data_val['data'].x[idx_val], data_val['data'].ids, param, 'X'), dtype=torch.float), + 'Y_val': torch.tensor(data_val['data'].y[idx_val], dtype=torch.long), + 'X_trn_2D': None if data_trn['data_tensor'] is None else torch.tensor(data_trn['data_tensor'][idx_trn], dtype=torch.float), + 'X_val_2D': None if data_val['data_tensor'] is None else torch.tensor(data_val['data_tensor'][idx_val], dtype=torch.float), + 'trn_weights': torch.tensor(data_trn['data'].w[idx_trn], dtype=torch.float), + 'val_weights': torch.tensor(data_val['data'].w[idx_val], dtype=torch.float), + 'data_trn_MI': data_trn['data_MI'][idx_trn] if ('data_MI' in data_trn and data_trn['data_MI'] is not None) else None, + 'data_val_MI': data_val['data_MI'][idx_val] if ('data_MI' in data_val and data_val['data_MI'] is not None) else None, 'args': args, 'param': param, 'ids': data_trn['data'].ids} + + set_distillation_drain(ID=ID, param=param, idx=idx_trn, inputs=inputs) - set_distillation_drain(ID=ID, param=param, inputs=inputs) - - if ID in args['raytune']['param']['active']: - model = train.raytune_main(inputs=inputs, train_func=train.train_torch_generic) - else: - model = train.train_torch_generic(**inputs) - - elif param['train'] == 'graph_xgb': - - inputs = {'y_soft': None} - set_distillation_drain(ID=ID, param=param, inputs=inputs) + if ID in args['raytune']['param']['active']: + model = train.raytune_main(inputs=inputs, train_func=train.train_torch_generic) + else: + model = train.train_torch_generic(**inputs) - train.train_graph_xgb(data_trn=data_trn['data_graph'], data_val=data_val['data_graph'], - trn_weights=data_trn['data'].w, val_weights=data_val['data'].w, args=args, param=param, y_soft=inputs['y_soft'], - feature_names=data_trn['data'].ids) - - elif param['train'] == 'flr': - train.train_flr(data_trn=data_trn['data'], args=args, param=param) - - elif param['train'] == 'flow': - train.train_flow(data_trn=data_trn['data'], data_val=data_val['data'], args=args, param=param) + elif param['train'] == 'graph_xgb': - elif param['train'] == 'cut': - None - - elif param['train'] == 'cutset': - - inputs = {'data_trn': data_trn['data'], - 'data_val': data_val['data'], - 'args': args, - 'param': param} + inputs = {'y_soft': None} + set_distillation_drain(ID=ID, param=param, idx=idx_trn, inputs=inputs) + + train.train_graph_xgb( + data_trn = [data_trn['data_graph'][i] for i in idx_trn], + data_val = [data_val['data_graph'][i] for i in idx_val], + trn_weights = data_trn['data'].w[idx_trn], + val_weights = data_val['data'].w[idx_val], + args = args, + param = param, + y_soft = inputs['y_soft'][idx_trn], + feature_names = data_trn['data'].ids) - if ID in args['raytune']['param']['active']: - model = train.raytune_main(inputs=inputs, train_func=train.train_cutset) - else: - model = train.train_cutset(**inputs) - - else: - raise Exception(__name__ + f'.Unknown param["train"] = {param["train"]} for ID = {ID}') - - # -------------------------------------------------------- - # If distillation - if 'distillation' in args and ID == args['distillation']['source']: + elif param['train'] == 'flr': + train.train_flr(data_trn=data_trn['data'][idx_trn], args=args, param=param) + + elif param['train'] == 'flow': + train.train_flow(data_trn = data_trn['data'][idx_trn], + data_val = data_val['data'][idx_val], + args = args, + param = param) - if len(args['primary_classes']) != 2: - raise Exception(__name__ + f'.train_models: Distillation supported now only for 2-class classification') + elif param['train'] == 'cut': + None - print(f'Computing distillation soft targets from the source <{ID}> ', 'yellow') + elif param['train'] == 'cutset': + + inputs = { + 'data_trn': data_trn['data'][idx_trn], + 'data_val': data_val['data'][idx_val], + 'args': args, + 'param': param} + + if ID in args['raytune']['param']['active']: + model = train.raytune_main(inputs=inputs, train_func=train.train_cutset) + else: + model = train.train_cutset(**inputs) - if param['train'] == 'xgb': - XX, XX_ids = aux.red(data_trn['data'].x, data_trn['data'].ids, param) - y_soft = model.predict(xgboost.DMatrix(data=XX, feature_names=XX_ids)) - if len(y_soft.shape) > 1: - y_soft = y_soft[:, args['signalclass']] - - elif param['train'] == 'torch_graph': - y_soft = model.softpredict(data_trn['data_graph'])[:, args['signalclass']] else: - raise Exception(__name__ + f".train_models: Unsupported distillation source <{param['train']}>") - # -------------------------------------------------------- - - except KeyboardInterrupt: - print(f'CTRL+C catched -- continue with the next model', 'red') - - except Exception as e: - prints.printbar('*') - print(f'Exception occured: \n {e} \n', 'red') - print(f"Check the model '{ID}' definition: training failed -- continue!", 'red') - prints.printbar('*') - exceptions += 1 + raise Exception(__name__ + f'.Unknown param["train"] = {param["train"]} for ID = {ID}') + + # -------------------------------------------------------- + # If distillation + if 'distillation' in args and ID == args['distillation']['source']: + + if len(args['primary_classes']) != 2: + raise Exception(__name__ + f'.train_models: Distillation supported now only for 2-class classification') + + print(f'Computing distillation soft targets from the source <{ID}> ', 'yellow') + + if param['train'] == 'xgb': + + XX, XX_ids = aux.red(data_trn['data'].x, data_trn['data'].ids, param) + y_soft = model.predict(xgboost.DMatrix(data=XX, feature_names=XX_ids)) + + if len(y_soft.shape) > 1: + y_soft = y_soft[:, args['signalclass']] + + elif param['train'] == 'torch_graph': + y_soft = model.softpredict(data_trn['data_graph'])[:, args['signalclass']] + else: + raise Exception(__name__ + f".train_models: Unsupported distillation source <{param['train']}>") + # -------------------------------------------------------- + + except KeyboardInterrupt: + print(f'CTRL+C catched -- continue with the next model', 'red') + + except Exception as e: + prints.printbar('*') + print(f'Exception occured: \n {e} \n', 'red') + print(f"Check the model '{ID}' definition: training failed -- continue!", 'red') + prints.printbar('*') + exceptions += 1 print(f'[done]', 'yellow') diff --git a/icetrg/common.py b/icetrg/common.py index 8fc1247d..52ae69a8 100644 --- a/icetrg/common.py +++ b/icetrg/common.py @@ -176,4 +176,9 @@ def splitfactor(x, y, w, ids, args): data = data[vars] data.x = data.x.astype(np.float32) - return {'data': data, 'data_MI': data_MI, 'data_kin': data_kin, 'data_deps': data_deps, 'data_tensor': data_tensor, 'data_graph': data_graph} + return {'data': data, + 'data_MI': data_MI, + 'data_kin': data_kin, + 'data_deps': data_deps, + 'data_tensor': data_tensor, + 'data_graph': data_graph} diff --git a/icezee/common.py b/icezee/common.py index 812044d8..aa646aa1 100644 --- a/icezee/common.py +++ b/icezee/common.py @@ -293,4 +293,9 @@ def splitfactor(x, y, w, ids, args): # Change the variable name [+ have the same original variables in data_kin] data.ids[ind] = f'TRF__{v}' - return {'data': data, 'data_MI': data_MI, 'data_kin': data_kin, 'data_deps': data_deps, 'data_tensor': data_tensor, 'data_graph': data_graph} + return {'data': data, + 'data_MI': data_MI, + 'data_kin': data_kin, + 'data_deps': data_deps, + 'data_tensor': data_tensor, + 'data_graph': data_graph}