Skip to content

Commit

Permalink
bump version
Browse files Browse the repository at this point in the history
  • Loading branch information
distributedstatemachine committed Jan 3, 2025
1 parent 9d5e0db commit 0aaa921
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 51 deletions.
6 changes: 3 additions & 3 deletions neurons/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ def __init__(self):
# Main training loop.
async def run(self):
# Try to load latest checkpoint
validator_uid, stake = self.comms.get_highest_stake_validator()
if stake > 0:
validator_uid = self.metagraph.S.argmax()[0]
if validator_uid is not None:
try:
# Calculate the most recent window that should have a checkpoint
expected_checkpoint_window = (self.current_window // self.hparams.checkpoint_frequency) * self.hparams.checkpoint_frequency
Expand Down Expand Up @@ -357,7 +357,7 @@ async def run(self):
uids=self.peers,
window=step_window,
key='gradient',
timeout=20,
timeout=5,
device=self.config.device,
local=False,
stale_retention=10,
Expand Down
48 changes: 1 addition & 47 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,52 +179,6 @@ def __init__(self):
)

async def run(self):
# Try to load the latest checkpoint
validator_uid, stake = self.comms.get_highest_stake_validator()
if stake > 0:
try:
# Calculate the most recent window that should have a checkpoint
expected_checkpoint_window = (self.current_window // self.hparams.checkpoint_frequency) * self.hparams.checkpoint_frequency

# Try last few windows in case of missed checkpoints
for window in range(expected_checkpoint_window, max(0, expected_checkpoint_window - 3 * self.hparams.checkpoint_frequency), -self.hparams.checkpoint_frequency):
result = await self.comms.get(
uid=str(validator_uid),
window=window,
key='checkpoint',
timeout=240,
local=False,
stale_retention=10
)
if result is None:
tplr.logger.debug(f"No checkpoint found for window {window}")
continue

checkpoint_data, global_step = result
try:
# Load state dicts from dictionary
self.model.load_state_dict(checkpoint_data['model_state_dict'])
self.optimizer.load_state_dict(checkpoint_data['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint_data['scheduler_state_dict'])
self.momentum = checkpoint_data['momentum']
self.global_step = checkpoint_data['global_step']

# Update optimizer and scheduler steps to match
self.optimizer._step_count = self.global_step
self.scheduler.last_epoch = self.global_step

tplr.logger.info(f"Loaded checkpoint from validator {validator_uid} at window {window}, global_step={self.global_step}")
break # Successfully loaded checkpoint, exit loop
except KeyError as e:
tplr.logger.error(f"Invalid checkpoint format: missing key {e}")
except Exception as e:
tplr.logger.error(f"Failed to load checkpoint: {e}")
else:
tplr.logger.info("No valid checkpoints found in recent windows")
except Exception as e:
tplr.logger.warning(f"Failed to load checkpoint: {e}")
else:
tplr.logger.info("No active validators found, starting from scratch")
# Start block listener
self.loop = asyncio.get_running_loop()
self.listener = threading.Thread(
Expand Down Expand Up @@ -255,7 +209,7 @@ async def run(self):
uids=self.peers,
window=step_window,
key='gradient',
timeout=20,
timeout=5,
device=self.config.device,
local=False,
stale_retention=10,
Expand Down
2 changes: 1 addition & 1 deletion src/tplr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# mypy: ignore-errors
# type: ignore

__version__ = "0.2.4"
__version__ = "0.2.5"

# Import package.
from .chain import *
Expand Down

0 comments on commit 0aaa921

Please sign in to comment.