Skip to content

Commit

Permalink
Add dev sync script
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Aug 14, 2024
1 parent 5745ff3 commit 3e6e343
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 4 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"pandas",
"pyspark",
"unidecode",
"watchdog",
"legal-segmenter @ git+https://github.com/lexeme-dev/legal-segmenter@main",
]

Expand Down Expand Up @@ -75,7 +76,8 @@ build-backend = "setuptools.build_meta"
[project.scripts]
rl = "rl.cli.main:cli"
train_llm = "rl.llm.train_llm:main"
merge_lora = "rl.llm.merge_lora:merge_lora"
merge_lora = "rl.llm.merge_lora:main"
devsync = "rl.utils.dev_sync:main"

[tool.setuptools.packages]
find = {}
Expand Down
4 changes: 2 additions & 2 deletions rl/llm/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
required=False,
help="Path to write the merged model to. Will default to the same path as the LoRA model in the merged models dir.",
)
def merge_lora(base_model_id: str, lora_model_id: str, output_path: Path):
def main(base_model_id: str, lora_model_id: str, output_path: Path):
"""Merge a LoRA adapter into a PeftModelForCausalLM."""
output_path = (
_get_output_path(Path(lora_model_id)) if output_path is None else output_path
Expand Down Expand Up @@ -94,4 +94,4 @@ def _get_output_path(input_path: Path) -> Path:


if __name__ == "__main__":
merge_lora()
main()
2 changes: 1 addition & 1 deletion rl/llm/train_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def main(
del trainer, tokenizer, model
if merge_after:
merged_dir = _DEFAULT_MERGED_DIR / (name or output_dir.name)
rl.llm.merge_lora.merge_lora.callback(
rl.llm.merge_lora.main.callback(
base_model_id=base_model_id,
lora_model_id=output_dir,
output_path=merged_dir,
Expand Down
111 changes: 111 additions & 0 deletions rl/utils/dev_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#!/usr/bin/env python3

import subprocess
from pathlib import Path
from threading import Timer

from watchdog.events import FileSystemEventHandler
from watchdog.observers import Observer

import rl.utils.click as click


class DebounceTimer:
def __init__(self, timeout, callback):
self.timeout = timeout
self.callback = callback
self.timer = None

def start(self):
if self.timer:
self.timer.cancel()
self.timer = Timer(self.timeout, self.callback)
self.timer.start()


class SyncHandler(FileSystemEventHandler):
def __init__(
self,
local_dir: Path,
remote_user: str,
remote_host: str,
remote_dir: str,
delay: float,
):
self.local_dir = local_dir
self.remote_user = remote_user
self.remote_host = remote_host
self.remote_dir = remote_dir
self.debounce_timer = DebounceTimer(delay, self.sync_repo)

def on_any_event(self, event):
self.debounce_timer.start()

def sync_repo(self):
click.echo("Syncing repository...")
exclude_list = subprocess.run(
[
"git",
"-C",
str(self.local_dir),
"ls-files",
"--exclude-standard",
"-oi",
"--directory",
"--others",
],
stdout=subprocess.PIPE,
text=True,
check=True,
)
subprocess.run(
[
"rsync",
"-avz",
"--delete",
"--exclude-from=-",
"--exclude=.git",
f"{self.local_dir}/",
f"{self.remote_user}@{self.remote_host}:{self.remote_dir}/",
],
input=exclude_list.stdout.encode(),
check=True,
)
click.echo("Sync completed.")


@click.command()
@click.option(
"-l",
"--local",
"local_dir",
type=click.Path(exists=True),
required=True,
help="Local directory to sync",
)
@click.option("-r", "--remote-dir", required=True, help="Remote directory to sync to")
@click.option("-u", "--user", required=True, help="Remote user")
@click.option("-h", "--host", required=True, help="Remote host")
@click.option("-d", "--delay", default=5.0, help="Debounce delay in seconds")
def main(local_dir: str, remote_dir: str, user: str, host: str, delay: float):
local_path = Path(local_dir).resolve()
handler = SyncHandler(local_path, user, host, remote_dir, delay)

handler.sync_repo() # Initial sync

observer = Observer()
observer.schedule(handler, str(local_path), recursive=True)
observer.start()

click.echo(f"Watching {local_path}. Press Ctrl+C to stop. Press Enter to resync.")
try:
while True:
if input() == "":
handler.sync_repo()
except KeyboardInterrupt:
observer.stop()
observer.join()


if __name__ == "__main__":
main()

0 comments on commit 3e6e343

Please sign in to comment.