diff --git a/pyvene/data_generators/causal_model.py b/pyvene/data_generators/causal_model.py index 8b89e77e..6573edab 100644 --- a/pyvene/data_generators/causal_model.py +++ b/pyvene/data_generators/causal_model.py @@ -275,7 +275,7 @@ def check_path(total_setting): return check_path - def inputToTensor(self, setting): + def input_to_tensor(self, setting): result = [] for input in self.inputs: temp = torch.tensor(setting[input]).float() @@ -284,7 +284,7 @@ def inputToTensor(self, setting): result.append(temp) return torch.cat(result) - def outputToTensor(self, setting): + def output_to_tensor(self, setting): result = [] for output in self.outputs: temp = torch.tensor(float(setting[output])) @@ -293,7 +293,19 @@ def outputToTensor(self, setting): result.append(temp) return torch.cat(result) - def generate_factual_dataset(self, size, sampler=None, filter=None, device="cpu"): + def generate_factual_dataset( + self, + size, + sampler=None, + filter=None, + device="cpu", + inputFunction=None, + outputFunction=None + ): + if inputFunction is None: + inputFunction = self.input_to_tensor + if outputFunction is None: + outputFunction = self.output_to_tensor if sampler is None: sampler = self.sample_input X, y = [], [] @@ -301,8 +313,8 @@ def generate_factual_dataset(self, size, sampler=None, filter=None, device="cpu" while count < size: input = sampler() if filter is None or filter(input): - X.append(self.inputToTensor(input)) - y.append(self.outputToTensor(self.run_forward(input))) + X.append(inputFunction(input)) + y.append(outputFunction(self.run_forward(input))) count += 1 return torch.stack(X).to(device), torch.stack(y).to(device) @@ -315,6 +327,8 @@ def generate_counterfactual_dataset( intervention_sampler=None, filter=None, device="cpu", + inputFunction=None, + outputFunction=None ): maxlength = len( [ @@ -323,6 +337,10 @@ def generate_counterfactual_dataset( if var not in self.inputs and var not in self.outputs ] ) + if inputFunction is None: + inputFunction = self.input_to_tensor + if outputFunction is None: + outputFunction = self.output_to_tensor if sampler is None: sampler = self.sample_input if intervention_sampler is None: @@ -341,17 +359,17 @@ def generate_counterfactual_dataset( if var not in intervention: continue source = sampler() - sources.append(self.inputToTensor(source)) + sources.append(inputFunction(source)) source_dic[var] = source for _ in range(maxlength - len(sources)): - sources.append(torch.zeros(self.inputToTensor(sampler()).shape)) - example["labels"] = self.outputToTensor( + sources.append(torch.zeros(self.input_to_tensor(sampler()).shape)) + example["labels"] = outputFunction( self.run_interchange(base, source_dic) ).to(device) - example["base_labels"] = self.outputToTensor( + example["base_labels"] = outputFunction( self.run_forward(base) ).to(device) - example["input_ids"] = self.inputToTensor(base).to(device) + example["input_ids"] = inputFunction(base).to(device) example["source_input_ids"] = torch.stack(sources).to(device) example["intervention_id"] = torch.tensor( [intervention_id(intervention)]