Skip to content

Commit

Permalink
wrap up multi hop rag
Browse files Browse the repository at this point in the history
  • Loading branch information
liyin2015 committed Dec 16, 2024
1 parent 0439354 commit 6c5841a
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 28 deletions.
79 changes: 53 additions & 26 deletions adalflow/adalflow/optim/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class Trainer(Component):
max_error_samples: Optional[int] = 2
max_correct_samples: Optional[int] = 2
debug: bool = False
sequential_order: List[str] = ["text", "demo"]

def __init__(
self,
Expand All @@ -119,6 +120,7 @@ def __init__(
exclude_input_fields_from_bootstrap_demos: bool = False,
debug: bool = False,
save_traces: bool = False, # save traces in the few-shto demos
sequential_order: List[str] = ["text", "demo"],
*args,
**kwargs,
) -> None:
Expand Down Expand Up @@ -161,6 +163,7 @@ def __init__(
self.exclude_input_fields_from_bootstrap_demos = (
exclude_input_fields_from_bootstrap_demos
)
self.sequential_order = sequential_order

# TODO: need to support checkpoint resume too!
def diagnose(self, dataset: Any, split: str = "train"):
Expand Down Expand Up @@ -503,7 +506,6 @@ def fit(
and len(self.text_optimizers) > 0
):
if self.strategy == "random":

self._fit_text_grad_demo_mix_random(
train_loader,
train_dataset,
Expand All @@ -525,37 +527,62 @@ def fit(
raise ValueError(f"Strategy {self.strategy} not supported")

else: # sequential, text first and demo second
if len(self.text_optimizers) > 0:
if self.strategy == "random":
trainer_results = self._fit_text_grad_random(
train_loader,
val_dataset,
test_dataset,
trainer_results,
starting_step=starting_step,
)
starting_step += self.max_steps
elif self.strategy == "constrained":
trainer_results = self._fit_text_grad_constraint(

def run_text_optimizers(starting_step: int, trainer_results: TrainerResult):
if len(self.text_optimizers) > 0:
if self.strategy == "random":
trainer_results = self._fit_text_grad_random(
train_loader,
val_dataset,
test_dataset,
trainer_results,
starting_step=starting_step,
)
starting_step += self.max_steps
elif self.strategy == "constrained":
trainer_results = self._fit_text_grad_constraint(
train_loader,
val_dataset,
test_dataset,
trainer_results=trainer_results,
starting_step=starting_step,
)
starting_step += self.max_steps
else:
raise ValueError(f"Strategy {self.strategy} not supported")

def run_demo_optimizers(starting_step: int, trainer_results: TrainerResult):
if len(self.demo_optimizers) > 0:
self.adaltask.configure_teacher_generator()
self._fit_demos_random(
train_loader,
train_dataset,
val_dataset,
test_dataset,
trainer_results=trainer_results,
starting_step=starting_step,
)
starting_step += self.max_steps
else:
raise ValueError(f"Strategy {self.strategy} not supported")
if len(self.demo_optimizers) > 0:
self.adaltask.configure_teacher_generator() # attemp to use the newest teacher as
self._fit_demos_random(
train_loader,
train_dataset,
val_dataset,
test_dataset,
trainer_results=trainer_results,
starting_step=starting_step,
)

if self.sequential_order == ["text", "demo"]:
run_text_optimizers(starting_step, trainer_results)
run_demo_optimizers(starting_step, trainer_results)
else:
run_demo_optimizers(starting_step, trainer_results)
run_text_optimizers(starting_step, trainer_results)
# if len(self.text_optimizers) > 0:
# run_text_optimizers(starting_step, trainer_results)

# if len(self.demo_optimizers) > 0:
# run_demo_optimizers(starting_step, trainer_results)
# self.adaltask.configure_teacher_generator() # attemp to use the newest teacher as
# self._fit_demos_random(
# train_loader,
# train_dataset,
# val_dataset,
# test_dataset,
# trainer_results=trainer_results,
# starting_step=starting_step,
# )

end_time = time.time()
print(f"Training time: {end_time - start_time}s")
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2):
name=f"few_shot_demos_{i}",
data=None,
role_desc="To provide few shot demos to the language model",
requires_opt=False,
requires_opt=True,
param_type=ParameterType.DEMOS,
),
"task_desc_str": Parameter(
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/hotpot_qa/adal_exp/build_vanilla_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __init__(self, passages_per_hop=3, model_client=None, model_kwargs=None):
),
"few_shot_demos": adal.Parameter(
data=None,
requires_opt=False,
requires_opt=True,
role_desc="To provide few shot demos to the language model",
param_type=adal.ParameterType.DEMOS,
),
Expand Down
1 change: 1 addition & 0 deletions benchmarks/hotpot_qa/adal_exp/train_multi_hop_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def train(
weighted_sampling=True,
optimization_order=optimization_order,
exclude_input_fields_from_bootstrap_demos=exclude_input_fields_from_bootstrap_demos,
sequential_order=["text", "demo"],
)
print(trainer)

Expand Down

0 comments on commit 6c5841a

Please sign in to comment.