Skip to content

Commit

Permalink
Merge pull request #25 from tplr-ai/fix/checkpointing_2
Browse files Browse the repository at this point in the history
Fix/validator_sync
  • Loading branch information
distributedstatemachine authored Jan 6, 2025
2 parents 1c18cd3 + 4d138c2 commit 943c24a
Show file tree
Hide file tree
Showing 9 changed files with 386 additions and 292 deletions.
8 changes: 4 additions & 4 deletions hparams.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
"pages_per_window": 5,
"batch_size": 6,
"learning_rate": 4e-4,
"blocks_per_window": 3,
"blocks_per_window": 4,
"windows_per_sync": 100,
"windows_per_weights": 10,
"windows_per_weights": 100,
"momentum_decay": 0.999,
"topk_compression": 32,
"target_chunk": 64,
Expand All @@ -25,7 +25,7 @@
"alpha_f": 0.1,
"t_max": 20000,
"validator_offset": 4,
"checkpoint_frequency": 50,
"checkpoint_frequency": 1000,
"topk_peers": 20,
"minimum_peers": 5
"minimum_peers": 5
}
7 changes: 4 additions & 3 deletions neurons/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(self):
self.scheduler = SequentialLR(
self.optimizer,
schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[10],
milestones=[250],
)

# Init compression
Expand Down Expand Up @@ -177,8 +177,9 @@ def __init__(self):
# Main training loop.
async def run(self):
# Try to load latest checkpoint
validator_uid, stake = self.comms.get_highest_stake_validator()
if stake > 0:
validator_uid = self.metagraph.S.argmax().item()
tplr.logger.info(f"Found validator with highest stake: {validator_uid}")
if validator_uid is not None:
try:
# Calculate the most recent window that should have a checkpoint
expected_checkpoint_window = (self.current_window // self.hparams.checkpoint_frequency) * self.hparams.checkpoint_frequency
Expand Down
481 changes: 267 additions & 214 deletions neurons/validator.py

Large diffs are not rendered by default.

43 changes: 38 additions & 5 deletions scripts/start.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,42 @@
#!/bin/bash
# Load environment variables
source .env_test
source .env

# pm2 delete all
pm2 start neurons/miner.py --interpreter python3 --name TM0 -- --wallet.name ${WALLET_NAME} --wallet.hotkey ${WALLET_HOTKEY}_M1 --device ${CUDA_DEVICE} --subtensor.network ${NETWORK} --debug --netuid ${NETUID} --project ${1:-default}
# Stop any existing processes
pm2 delete all

pm2 start neurons/miner.py --interpreter python3 --name TM1 -- --wallet.name ${WALLET_NAME} --wallet.hotkey ${WALLET_HOTKEY}_M2 --device ${CUDA_DEVICE} --subtensor.network ${NETWORK} --use_wandb --debug --netuid ${NETUID} --project ${1:-default}
# # Generate random suffix for project name
# RANDOM_SUFFIX=$(cat /dev/urandom | tr -dc 'a-z0-9' | fold -w 4 | head -n 1)
# PROJECT_NAME="test_${RANDOM_SUFFIX}"

# Start miners and validator with matching configurations
pm2 start neurons/miner.py --interpreter python3 --name TM1 -- \
--wallet.name Bistro \
--wallet.hotkey M111 \
--device cuda:3 \
--subtensor.network test \
--debug \
--netuid 268 \
--use_wandb
# --project "${PROJECT_NAME}"

pm2 start neurons/miner.py --interpreter python3 --name TM2 -- \
--wallet.name Bistro \
--wallet.hotkey M222 \
--device cuda:1 \
--subtensor.network test \
--debug \
--netuid 268 \
--use_wandb
# --project "${PROJECT_NAME}"

pm2 start neurons/validator.py --interpreter python3 --name TV1 -- \
--wallet.name Bistro \
--wallet.hotkey V11 \
--device cuda:2 \
--subtensor.network test \
--debug \
--netuid 268 \
# --use_wandb \
# --project "${PROJECT_NAME}"

pm2 start neurons/validator.py --interpreter python3 --name TV1 -- --wallet.name ${WALLET_NAME} --wallet.hotkey ${WALLET_HOTKEY} --device ${CUDA_DEVICE} --subtensor.network ${NETWORK} --use_wandb --debug --netuid ${NETUID} --project ${1:-default}
2 changes: 1 addition & 1 deletion src/tplr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# mypy: ignore-errors
# type: ignore

__version__ = "0.2.5"
__version__ = "0.2.6"

# Import package.
from .chain import *
Expand Down
2 changes: 1 addition & 1 deletion src/tplr/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def handler(event, _u, _s):
except Exception:
time.sleep(1)

async def commit(self, wallet: "bt.wallet", bucket: Bucket) -> None:
def commit(self, wallet: "bt.wallet", bucket: Bucket) -> None:
"""Commits bucket configuration to the chain.
Args:
Expand Down
33 changes: 6 additions & 27 deletions src/tplr/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import asyncio
import aiofiles
import tempfile
import numpy as np
import bittensor as bt
from typing import List, Dict, Optional, Tuple
from types import SimpleNamespace
Expand Down Expand Up @@ -325,8 +324,8 @@ async def get(
) -> Optional[Tuple[dict, int]]:
"""GET operation: Retrieve state_dict and global_step."""
filename = f"{key}-{window}-{uid}-v{__version__}.pt"
full_key = f"{uid}/{window}/{filename}"
tplr.logger.debug(f"GET {full_key} -->")
# full_key = f"{uid}/{window}/{filename}"
tplr.logger.debug(f"GET {filename} -->")

try:
if local:
Expand Down Expand Up @@ -354,6 +353,7 @@ async def get(

# Get the peer's bucket from commitments
peer_bucket = self.commitments.get(int(uid))
tplr.logger.debug(f"getting {key} from peer : {peer_bucket}")
if not peer_bucket:
tplr.logger.debug(f"No bucket found for UID {uid}")
return None
Expand Down Expand Up @@ -413,16 +413,16 @@ async def get(
return state_dict, global_step
except Exception as e:
tplr.logger.debug(
f"Error loading data from {full_key}: {e}"
f"Error loading data from {filename}: {e}"
)
return None

except Exception as e:
tplr.logger.debug(f"GET error {full_key}: {e}")
tplr.logger.debug(f"GET error {filename}: {e}")
return None

finally:
tplr.logger.debug(f"GET {full_key} <--")
tplr.logger.debug(f"GET {filename} <--")

async def get_with_retry(
self,
Expand Down Expand Up @@ -718,24 +718,3 @@ async def cleanup_old_checkpoints(self, keep_last: int = 3):

except Exception as e:
tplr.logger.error(f"Error cleaning up old checkpoints: {e}")

def get_highest_stake_validator(self) -> Tuple[Optional[int], float]:
"""Returns the UID and stake of the neuron with the highest stake."""
stakes = self.metagraph.S

# Convert numpy array to torch tensor if needed
if isinstance(stakes, np.ndarray):
stakes = torch.from_numpy(stakes)

# Check if any stakes are non-zero
if torch.all(stakes == 0):
return None, 0.0

highest_stake_uid = torch.argmax(stakes).item()
stake = stakes[highest_stake_uid].item()

# Validate the stake is actually non-zero
if stake == 0:
return None, 0.0

return highest_stake_uid, stake
2 changes: 1 addition & 1 deletion src/tplr/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def P(window: int, duration: float) -> str:
rich_tracebacks=True,
highlighter=NullHighlighter(),
show_level=False,
show_time=False,
show_time=True,
show_path=False,
)
],
Expand Down
100 changes: 64 additions & 36 deletions src/tplr/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,46 +11,31 @@
def initialize_wandb(
run_prefix: str, uid: str, config: any, group: str, job_type: str
) -> Run:
"""Initialize WandB run with persistence and resumption capabilities.
Args:
run_prefix (str): Prefix for the run name (e.g., 'V' for validator, 'M' for miner)
uid (str): Unique identifier for the run
config (any): Configuration object containing project and other settings
group (str): Group name for organizing runs
job_type (str): Type of job (e.g., 'validation', 'training')
Returns:
Run: Initialized WandB run object
"""
# Ensure the wandb directory exists
"""Initialize WandB run with version tracking for unified workspace management."""
wandb_dir = os.path.join(os.getcwd(), "wandb")
os.makedirs(wandb_dir, exist_ok=True)

# Define the run ID file path inside the wandb directory
run_id_file = os.path.join(
wandb_dir, f"wandb_run_id_{run_prefix}{uid}_{__version__}.txt"
)
# Modified run ID file to not include version
run_id_file = os.path.join(wandb_dir, f"wandb_run_id_{run_prefix}{uid}.txt")

# Check for existing run and verify it still exists in wandb
# Check for existing run
run_id = None
if os.path.exists(run_id_file):
with open(run_id_file, "r") as f:
run_id = f.read().strip()

# Verify if run still exists in wandb
try:
api = wandb.Api()
api.run(f"tplr/{config.project}-v{__version__}/{run_id}")
api.run(f"tplr/{config.project}/{run_id}")
logger.info(f"Found existing run ID: {run_id}")
except Exception:
logger.info(f"Previous run {run_id} not found in WandB, starting new run")
run_id = None
os.remove(run_id_file)

# Initialize WandB
# Initialize WandB with version as a tag
run = wandb.init(
project=f"{config.project}-v{__version__}",
project=config.project,
entity="tplr",
id=run_id,
resume="must" if run_id else "never",
Expand All @@ -59,30 +44,73 @@ def initialize_wandb(
group=group,
job_type=job_type,
dir=wandb_dir,
tags=[f"v{__version__}"],
settings=wandb.Settings(
init_timeout=300,
_disable_stats=True,
),
)

# Special handling for evaluator
if run_prefix == "E":
tasks = config.tasks.split(",")
for task in tasks:
metric_name = f"eval/{task}"
wandb.define_metric(
name=metric_name, step_metric="global_step", plot=True, summary="max"
)
# Add version history to run config
if "version_history" not in run.config:
run.config.update({"version_history": [__version__]}, allow_val_change=True)
elif __version__ not in run.config.version_history:
version_history = run.config.version_history + [__version__]
run.config.update({"version_history": version_history}, allow_val_change=True)

# Keep current version in config
run.config.update({"current_version": __version__}, allow_val_change=True)

# Track the last step seen for each version
version_steps = {}

# Get the current global step from WandB if resuming
if run_id:
try:
api = wandb.Api()
run_data = api.run(f"tplr/{config.project}/{run_id}")
history = run_data.scan_history()
global_step = max((row.get("_step", 0) for row in history), default=0)
version_steps["global"] = global_step
except Exception:
version_steps["global"] = 0
else:
version_steps["global"] = 0

# Create a wrapper for wandb.log that automatically adds version
original_log = run.log

def log_with_version(metrics, **kwargs):
# Increment global step
version_steps["global"] += 1
current_step = version_steps["global"]

# Initialize version step if needed
if __version__ not in version_steps:
version_steps[__version__] = current_step

# Use version-specific step counter
versioned_metrics = {}
for k, v in metrics.items():
# Add metric under current version
versioned_metrics[f"v{__version__}/{k}"] = v
# Also log under "latest/{k}" with version-specific step
versioned_metrics[f"latest/{k}"] = v

# Add version-specific step counter
versioned_metrics[f"v{__version__}/step"] = current_step

# Always use the global step for logging
kwargs["step"] = current_step

# Log metrics
original_log(versioned_metrics, **kwargs)

run.log = log_with_version

# Save run ID for future resumption
if not run_id:
with open(run_id_file, "w") as f:
f.write(run.id)

return run


# TODO: Add error handling for network issues
# TODO: Add retry mechanism for wandb initialization
# TODO: Add cleanup mechanism for old run ID files
# TODO: Add support for custom wandb settings

0 comments on commit 943c24a

Please sign in to comment.