Skip to content

Commit

Permalink
feat: penalties for increasing loss by alot , softmax for weights
Browse files Browse the repository at this point in the history
  • Loading branch information
distributedstatemachine committed Dec 27, 2024
1 parent 3de6150 commit 4d362fb
Show file tree
Hide file tree
Showing 4 changed files with 608 additions and 104 deletions.
135 changes: 57 additions & 78 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ async def run(self):
eval_start = tplr.T()
self.model.zero_grad()
total_loss = 0.0
loss_after = 0.0
step_loss_after = 0.0
full_steps = 0
total_steps = 0
exhausted_window = False
Expand Down Expand Up @@ -514,96 +514,81 @@ async def run(self):

# Perform forward pass with updated model (no gradients needed)
with torch.no_grad(), torch.amp.autocast(device_type=self.model.device.type, dtype=torch.bfloat16):
outputs2 = self.model(input_ids=input_ids, labels=labels)
outputs_after = self.model(input_ids=input_ids, labels=labels)
step_loss_after += outputs_after.loss.item()

# Restore original parameters
for name_i, param_i in self.model.named_parameters():
param_i.data.copy_(original_params[name_i])

# Perform forward pass and compute loss with gradients
with torch.enable_grad(), torch.amp.autocast(device_type=self.model.device.type, dtype=torch.bfloat16):
outputs = self.model(input_ids=input_ids, labels=labels)
loss = outputs.loss
outputs_before = self.model(input_ids=input_ids, labels=labels)
loss = outputs_before.loss
loss.backward()

total_loss += outputs.loss.item()
loss_after += outputs2.loss.item()
total_loss += loss.item()

if self.current_window - offset != window:
exhausted_window = True
continue

self.optimizer.step()
self.scheduler.step()
step_loss = total_loss/(full_steps+1)
step_loss_after = loss_after/(full_steps+1)

if loss_after <= step_loss:
# Reward for loss reduction
loss_score = 1 - (step_loss_after / step_loss)
else:
# Penalize for loss increase, capped at -1
loss_score = -min(1, (step_loss_after - step_loss) / step_loss)
step_loss = total_loss / (full_steps + 1)
step_loss_after = step_loss_after / (full_steps + 1)

eval_duration = tplr.T() - eval_start
tokens_per_step = self.hparams.sequence_length * self.config.actual_batch_size * (full_steps + 1)

tokens_per_second = tokens_per_step / eval_duration

tplr.logger.info(f"{tplr.P(window, eval_duration)}: Accumulated gradients:")
tplr.logger.info(f"{tplr.P(window, eval_duration)}: \tTotal steps: [tan]{full_steps}/{total_steps}[/tan], Rate: [tan]{(full_steps/total_steps):.2f}[/tan], Target: [tan]{self.sample_rate:.2f}[/tan]")
tplr.logger.info(f"{tplr.P(window, eval_duration)}: \tTotal tokens: [tan]{tokens_per_step}[/tan], Tokens per second: [tan]{tokens_per_second:.2f}[/tan]")
tplr.logger.info(f"{tplr.P(window, eval_duration)}: \tLoss: [tan]{step_loss}[tan]")
tplr.logger.info(f"{tplr.P(window, eval_duration)}: \tLoss before applying delta: [tan]{step_loss:.4f}[/tan]")
tplr.logger.info(f"{tplr.P(window, eval_duration)}: \tLoss after applying delta: [tan]{step_loss_after:.4f}[/tan]")

if exhausted_window:
self.sample_rate = max(0.0001, self.sample_rate * 0.95)
else:
self.sample_rate = min(1, self.sample_rate * 1.05)

# Compute the score for this slice.
st = tplr.T()
score = 0.0

# Check if we have any gradients
has_grads = any(param.grad is not None for name, param in self.model.named_parameters())

if not has_grads:
tplr.logger.warning("No gradients found - setting score to 0.0")
score = 0.0
else:
# Collect all delta_i and grad_i into larger vectors
all_delta = []
all_grad = []

for i, (name_i, param_i) in enumerate(self.model.named_parameters()):
if param_i.grad is None:
continue

if name_i not in indices or name_i not in eval_slice_data:
continue
# Compute cosine similarity between miner's delta and validator's gradients
cosine_similarity = torch.nn.functional.cosine_similarity(all_delta, all_grad, dim=0).item()

Check failure on line 559 in neurons/validator.py

View workflow job for this annotation

GitHub Actions / lint-and-test

Ruff (F821)

neurons/validator.py:559:79: F821 Undefined name `all_delta`

Check failure on line 559 in neurons/validator.py

View workflow job for this annotation

GitHub Actions / lint-and-test

Ruff (F821)

neurons/validator.py:559:90: F821 Undefined name `all_grad`

idxs_i = indices[name_i].to(self.model.device)
grad_i = param_i.grad.view(-1).clone()[idxs_i].to(self.model.device)
slice_i = eval_slice_data[name_i].view(-1).to(self.model.device)
theta_i = param_i.data.view(-1)[idxs_i]
delta_i = theta_i - slice_i
# Set initial score to 0.0
score = 0.0

all_delta.append(delta_i)
all_grad.append(grad_i)
# Check if cosine similarity is greater than zero
if cosine_similarity > 0.0:
# Base score from cosine similarity
base_score = 0.1

if len(all_delta) > 0:
#Concatenate all parts
all_delta = torch.cat(all_delta)
all_grad = torch.cat(all_grad)
# Compute the loss difference (percentage)
loss_difference = step_loss_after - step_loss # Positive if miner's loss is worse
percentage_loss_difference = loss_difference / step_loss # Fractional change

# Compute global cosine similarity
score = torch.nn.functional.cosine_similarity(all_delta, all_grad, dim=0).item()
if percentage_loss_difference < 0: # Miner improved the loss
# Miner improved the loss, add to base score
score = base_score + (-percentage_loss_difference) # Negative because loss decreased
elif percentage_loss_difference <= 0.25:
# Loss did not improve but is not worse by more than 25%
score = base_score # Only base score
else:
tplr.logger.warning("No valid parameter tensors found - setting score to 0.0")
# Loss is worse by more than 25%, zero out their moving average score
self.scores[eval_uid] = 0.0
score = 0.0
else:
tplr.logger.info(f"Cosine similarity ({cosine_similarity:.4f}) not positive. Setting score to 0.0")
score = 0.0

tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Computed score: [bold dark_sea_green]{score:.4f}[/bold dark_sea_green]")
self.optimizer.zero_grad()

tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Computed score for miner {eval_uid}: [bold dark_sea_green]{score:.4f}[/bold dark_sea_green]")
self.optimizer.zero_grad()

# Assign and log scores.

# Apply decay to miners who did not submit slices
all_uids = set(self.metagraph.uids.tolist())
non_submitted_uids = all_uids - submitted_uids
Expand All @@ -613,46 +598,39 @@ async def run(self):
self.scores[uid] *= decay_factor

# Update the score for the evaluated miner
self.step_scores[eval_uid] = score + loss_score
self.step_loss_scores[eval_uid] = loss_score
self.step_scores[eval_uid] = score
self.scores[eval_uid] = (
(1 - self.hparams.validator_moving_alpha) * self.step_scores[eval_uid] +
(1 - self.hparams.validator_moving_alpha) * self.step_scores[eval_uid] +
self.hparams.validator_moving_alpha * self.scores[eval_uid]
)

# Only consider positive scores for weights
positive_scores_indices = self.scores > 0
positive_scores = self.scores[positive_scores_indices]
# Prepare scores for softmax
scores_tensor = self.scores.clone()

total_positive_score = positive_scores.sum().item()
# Set scores <= 0 to a very negative value for softmax stability
scores_tensor[scores_tensor <= 0] = -float('inf')

if total_positive_score == 0.0:
tplr.logger.warning("Total positive score is zero; setting all weights to zero.")
self.weights = torch.zeros_like(self.scores)
else:
# Normalize positive scores to get weights
self.weights = torch.zeros_like(self.scores)
self.weights[positive_scores_indices] = positive_scores / total_positive_score
# Compute softmax over scores
self.weights = torch.nn.functional.softmax(scores_tensor, dim=0)

# Log updated scores and weights
valid_score_indices = torch.nonzero(self.scores != 0).squeeze().view(-1)
valid_score_indices = torch.nonzero(self.scores > 0).squeeze().view(-1)
for uid_i in valid_score_indices:
uid = uid_i.item()
moving_score = self.scores[uid].item()
weight = self.weights[uid].item()
step_score = self.step_scores[uid].item()
loss_score = self.step_loss_scores[uid].item()
tplr.logger.info(
f"\tuid: [dark_sea_green]{uid}[/dark_sea_green], "
f"step_score: [dark_sea_green]{step_score:.3f}[/dark_sea_green], "
f"moving_score: [dark_sea_green]{moving_score:.3f}[/dark_sea_green], "
f"weight: [dark_sea_green]{weight:.3f}[/dark_sea_green], "
f"loss_score: [dark_sea_green]{loss_score:.3f}[/dark_sea_green]"
f"weight: [dark_sea_green]{weight:.3f}[/dark_sea_green]"
)

# Apply all deltas to the model state.
st = tplr.T()
max_global_step, window_metric = await tplr.apply_slices_to_model(
model=self.model,
max_global_step, window_metric = await tplr.apply_slices_to_model(
model=self.model,
window=window,
seed=window,
compression=self.hparams.compression,
Expand All @@ -667,8 +645,8 @@ async def run(self):
st = tplr.T()
await tplr.delete_files_before_window(window_max=window - self.hparams.max_history, save_location=self.save_location, key='state')
await tplr.delete_files_before_window(window_max=window - self.hparams.max_history, save_location=self.save_location, key='delta')
await tplr.delete_files_from_bucket_before_window( bucket = tplr.config.BUCKET_SECRETS["bucket_name"], window_max = window - self.hparams.max_history, key = 'state' )
await tplr.delete_files_from_bucket_before_window( bucket = tplr.config.BUCKET_SECRETS["bucket_name"], window_max = window - self.hparams.max_history, key = 'delta' )
await tplr.delete_files_from_bucket_before_window(bucket=tplr.config.BUCKET_SECRETS["bucket_name"], window_max=window - self.hparams.max_history, key='state')
await tplr.delete_files_from_bucket_before_window(bucket=tplr.config.BUCKET_SECRETS["bucket_name"], window_max=window - self.hparams.max_history, key='delta')
tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Cleaned file history.")

# Finish step.
Expand All @@ -678,14 +656,15 @@ async def run(self):
window_time_delta = self.window_time - gs_end
window_delta_str = f"[red]{window_time_delta:.2f}[/red]" if window_time_delta < 0 else f"[green]+{window_time_delta:.2f}[/green]"
tplr.logger.info(f"{tplr.P(window, gs_end - gs_start)}[{window_delta_str}]: Finished step.")

# Log main metrics
wandb.log({
"validator/loss": step_loss,
"validator/tokens_per_step": sum([slice_metric['tokens_per_step'] for _, slice_metric in window_metric.items()]),
"validator/tokens_per_second": sum([slice_metric['tokens_per_second'] for _, slice_metric in window_metric.items()]),
"validator/tokens_per_step": sum(slice_metric['tokens_per_step'] for _, slice_metric in window_metric.items()),
"validator/tokens_per_second": sum(slice_metric['tokens_per_second'] for _, slice_metric in window_metric.items()),
"validator/sample_rate": self.sample_rate,
"validator/utilization": eval_duration / (gs_end - gs_start),
"validator/global_batch_size": sum([slice_metric['batch_size'] for _, slice_metric in window_metric.items()]),
"validator/global_batch_size": sum(slice_metric['batch_size'] for _, slice_metric in window_metric.items()),
}, step=self.global_step)

for hotkey, slice_metric in window_metric.items():
Expand Down
Loading

0 comments on commit 4d362fb

Please sign in to comment.