Skip to content

Commit

Permalink
support max_guidance_steps=0 option (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelstanton authored Sep 23, 2024
1 parent f4918d8 commit c318f9e
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions cortex/optim/generative/_lambo.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,19 @@ def step(self) -> None:

delta = torch.nn.Parameter(torch.zeros_like(activations))
optimizer = torch.optim.Adam([delta], lr=self.guidance_step_size)
metrics = {"step": self._step_count}

# get initial solution before guidance
tgt_tok_idxs, tgt_obj_vals = self._update_solution(
trunk_outputs,
activations,
delta,
tgt_tok_idxs,
tgt_obj_vals,
is_corrupted,
self.tokenizer,
non_viable_idxs,
)

print("\n")
for lang_step in range(self.max_guidance_updates):
Expand Down Expand Up @@ -192,12 +205,14 @@ def step(self) -> None:
)

grad_norm = feature_grad.norm(dim=(-2, -1), keepdim=True)
print(
tgt_obj_vals.median().item(),
design_loss.item(),
obj_loss.item(),
kl_div.item(),
entropy.item(),
metrics.update(
{
"masked_design_loss": design_loss.item(),
"masked_design_loss_grad_norm": grad_norm.mean().item(),
"masked_token_loss": kl_div.item(),
"masked_obj_loss": obj_loss.item(),
"token_entropy": entropy.item(),
}
)

self._step_count += 1
Expand All @@ -209,14 +224,6 @@ def step(self) -> None:

self._buffer = pd.concat([self._buffer, df], ignore_index=True)

metrics = {
"step": self._step_count,
"masked_design_loss": design_loss.item(),
"masked_design_loss_grad_norm": grad_norm.mean().item(),
"masked_token_loss": kl_div.item(),
"masked_obj_loss": obj_loss.item(),
"token_entropy": entropy.item(),
}
return metrics

def _coordinate_selection(
Expand Down

0 comments on commit c318f9e

Please sign in to comment.