Skip to content

Commit

Permalink
fix: only eval uids that val receives gradients from , remove softmax…
Browse files Browse the repository at this point in the history
… scoring , improve validator logs
  • Loading branch information
distributedstatemachine committed Jan 1, 2025
1 parent c4be210 commit 00b64e5
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 58 deletions.
2 changes: 1 addition & 1 deletion hparams.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"spec_version": 5,
"project": "dough",
"sequence_length": 2048,
"pages_per_window": 2,
"pages_per_window": 5,
"batch_size": 6,
"learning_rate": 4e-4,
"blocks_per_window": 3,
Expand Down
11 changes: 11 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Default recipe to run when just is called without arguments
default:
@just --list

# Run ruff check with auto-fix and format
lint:
ruff check --fix .
ruff format .

# Run both check and format in a single command
fix: lint
10 changes: 2 additions & 8 deletions neurons/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,6 @@ async def run(self):
"miner/mean_momentum_norm": sum(momentum_norms) / len(momentum_norms),
}, step=self.global_step)

# Log per-peer metrics
for peer_uid in self.peers:
self.wandb.log({
f"miner/peer_stake/{peer_uid}": self.metagraph.S[peer_uid].item(),
}, step=self.global_step)

# Reduce gradient using DeMo.
gradient = {}
xshapes = {}
Expand Down Expand Up @@ -341,8 +335,8 @@ async def run(self):
for n, p in self.model.named_parameters():
idxs_key = n + 'idxs'
vals_key = n + 'vals'
idxs = gather_result.state_dict.get(idxs_key)
vals = gather_result.state_dict.get(vals_key)
idxs = getattr(gather_result.state_dict, idxs_key, None)
vals = getattr(gather_result.state_dict, vals_key, None)
if idxs is not None and vals is not None:
# Ensure idx and val are lists of tensors
if not isinstance(idxs, (list, tuple)):
Expand Down
88 changes: 40 additions & 48 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(self):

# Init scores
self.scores = torch.zeros(self.metagraph.n, dtype=torch.float32)
self.moving_avg_scores = torch.zeros(self.metagraph.n, dtype=torch.float32) # Add moving average tracking
self.moving_avg_scores = torch.zeros(self.metagraph.n, dtype=torch.float32)
self.ma_alpha = 0.95 # Moving average decay factor

# Add step tracking
Expand Down Expand Up @@ -226,7 +226,7 @@ async def run(self):
tplr.logger.error(f"Failed to create checkpoint: {e}")

# Log checkpoint creation
if self.current_window % 500 == 0:
if self.global_step % 500 == 0:
self.wandb.log({
"checkpoint_window": self.current_window,
"global_step": self.global_step,
Expand Down Expand Up @@ -293,7 +293,7 @@ async def run(self):
self.wandb.log({"lr": self.scheduler.get_last_lr()[0]}, step=self.global_step)

# Get a random peer to eval on their gradient at self.sync_window + 1
eval_uid = random.choice(self.peers)
eval_uid = random.choice(step_grads.uids)
# Get the pages for the window infront of the current sync window
pages = await tplr.dataset.DatasetLoader.next_pages(
offset=self.sync_window + 1,
Expand Down Expand Up @@ -385,7 +385,7 @@ async def run(self):

# Compute score
score = loss_before - loss_after
tplr.logger.info(f'score: {score}')
tplr.logger.info(f'score: {score}, loss_before: {loss_before_per_token:.4f}, loss_after: {loss_after_per_token:.4f}, loss_improvement: {loss_improvement:.4f}, improvement_percentage: {improvement_percentage:.2f}%, uid: {eval_uid}')

# Log comprehensive metrics
self.wandb.log({
Expand All @@ -396,51 +396,50 @@ async def run(self):
"validator/eval_count": self.eval_count,
"validator/tokens_evaluated": n_tokens,
"validator/learning_rate": self.scheduler.get_last_lr()[0],
"validator/window": self.current_window,
"validator/global_step": self.global_step,
"validator/current_score": score,
}, step=self.global_step)

# Update counters
self.global_step += 1
self.eval_count += 1

# Set weights if needed
if self.sync_window % self.hparams.windows_per_weights == 0:
# Update scores with new score
self.scores[eval_uid] = self.hparams.scores_alpha * score + (1 - self.hparams.scores_alpha) * self.scores[eval_uid]
# Update moving average scores
self.moving_avg_scores[eval_uid] = self.ma_alpha * self.moving_avg_scores[eval_uid] + (1 - self.ma_alpha) * score
# Compute weights from moving average scores
weights = torch.softmax(self.moving_avg_scores, dim=0)

# Log per-UID metrics
valid_score_indices = torch.nonzero(self.scores > 0).squeeze().view(-1)
for uid_i in valid_score_indices:
uid = uid_i.item()
self.wandb.log({
f"validator/scores/{uid}": self.scores[uid_i].item(),
f"validator/moving_avg_scores/{uid}": self.moving_avg_scores[uid_i].item(),
f"validator/weights/{uid}": weights[uid_i].item(),
f"validator/stakes/{uid}": self.metagraph.S[uid_i].item(),
f"validator/current_score/{uid}": score if uid == eval_uid else 0,
}, step=self.global_step)

# Log aggregate network statistics
# Update scores with new score
self.scores[eval_uid] = self.hparams.scores_alpha * score + (1 - self.hparams.scores_alpha) * self.scores[eval_uid]
# Update moving average scores
self.moving_avg_scores[eval_uid] = self.ma_alpha * self.moving_avg_scores[eval_uid] + (1 - self.ma_alpha) * score
# Compute weights from moving average scores
# Zero out negative scores and apply softmax only on positive scores
positive_scores = torch.where(self.moving_avg_scores > 0, self.moving_avg_scores, torch.zeros_like(self.moving_avg_scores))
weights = positive_scores / positive_scores.sum() if positive_scores.sum() > 0 else torch.zeros_like(positive_scores)

# Log per-UID metrics
valid_score_indices = torch.nonzero(self.scores > 0).squeeze().view(-1)
for uid_i in valid_score_indices:
uid = uid_i.item()
self.wandb.log({
"validator/active_miners": len(valid_score_indices),
"validator/mean_score": self.scores[valid_score_indices].mean().item(),
"validator/mean_moving_avg_score": self.moving_avg_scores[valid_score_indices].mean().item(),
"validator/max_score": self.scores.max().item(),
"validator/min_score": self.scores.min().item(),
"validator/max_moving_avg_score": self.moving_avg_scores.max().item(),
"validator/min_moving_avg_score": self.moving_avg_scores.min().item(),
"validator/mean_weight": weights[valid_score_indices].mean().item(),
"validator/weight_std": weights[valid_score_indices].std().item(),
"validator/score_std": self.scores[valid_score_indices].std().item(),
"validator/moving_avg_score_std": self.moving_avg_scores[valid_score_indices].std().item(),
f"validator/scores/{uid}": self.scores[uid_i].item(),
f"validator/moving_avg_scores/{uid}": self.moving_avg_scores[uid_i].item(),
f"validator/weights/{uid}": weights[uid_i].item(),
}, step=self.global_step)

# Log aggregate network statistics
self.wandb.log({
"validator/active_miners": len(valid_score_indices),
"validator/mean_score": self.scores[valid_score_indices].mean().item(),
"validator/mean_moving_avg_score": self.moving_avg_scores[valid_score_indices].mean().item(),
"validator/max_score": self.scores.max().item(),
"validator/min_score": self.scores.min().item(),
"validator/max_moving_avg_score": self.moving_avg_scores.max().item(),
"validator/min_moving_avg_score": self.moving_avg_scores.min().item(),
"validator/mean_weight": weights[valid_score_indices].mean().item(),
"validator/weight_std": weights[valid_score_indices].std().item(),
"validator/score_std": self.scores[valid_score_indices].std().item(),
"validator/moving_avg_score_std": self.moving_avg_scores[valid_score_indices].std().item(),
"validator/max_weight": weights.max().item(),
"validator/min_weight": weights.min().item(),
}, step=self.global_step)


if self.sync_window % self.hparams.windows_per_weights == 0:
# Set weights on chain
self.subtensor.set_weights(
wallet=self.wallet,
Expand All @@ -452,14 +451,7 @@ async def run(self):
)
tplr.logger.info(f'Set weights on chain for window {self.sync_window}')

# Log weight update metrics
self.wandb.log({
"validator/weight_update_window": self.sync_window,
"validator/mean_weight": weights.mean().item(),
"validator/max_weight": weights.max().item(),
"validator/min_weight": weights.min().item(),
"validator/weight_std": weights.std().item(),
}, step=self.global_step)


# Apply the optimizer step
tplr.logger.info("Finish and step.")
Expand Down
2 changes: 1 addition & 1 deletion src/tplr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# mypy: ignore-errors
# type: ignore

__version__ = "0.2.1"
__version__ = "0.2.2"

# Import package.
from .chain import *
Expand Down

0 comments on commit 00b64e5

Please sign in to comment.