diff --git a/pyproject.toml b/pyproject.toml index a84d327..9a6431e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "pandas", "pyspark", "unidecode", + "watchdog", "legal-segmenter @ git+https://github.com/lexeme-dev/legal-segmenter@main", ] @@ -75,7 +76,8 @@ build-backend = "setuptools.build_meta" [project.scripts] rl = "rl.cli.main:cli" train_llm = "rl.llm.train_llm:main" -merge_lora = "rl.llm.merge_lora:merge_lora" +merge_lora = "rl.llm.merge_lora:main" +devsync = "rl.utils.dev_sync:main" [tool.setuptools.packages] find = {} diff --git a/rl/llm/merge_lora.py b/rl/llm/merge_lora.py index c1f7f1c..4a62e9e 100644 --- a/rl/llm/merge_lora.py +++ b/rl/llm/merge_lora.py @@ -30,7 +30,7 @@ required=False, help="Path to write the merged model to. Will default to the same path as the LoRA model in the merged models dir.", ) -def merge_lora(base_model_id: str, lora_model_id: str, output_path: Path): +def main(base_model_id: str, lora_model_id: str, output_path: Path): """Merge a LoRA adapter into a PeftModelForCausalLM.""" output_path = ( _get_output_path(Path(lora_model_id)) if output_path is None else output_path @@ -94,4 +94,4 @@ def _get_output_path(input_path: Path) -> Path: if __name__ == "__main__": - merge_lora() + main() diff --git a/rl/llm/train_llm.py b/rl/llm/train_llm.py index 118d09f..24623b3 100644 --- a/rl/llm/train_llm.py +++ b/rl/llm/train_llm.py @@ -226,7 +226,7 @@ def main( del trainer, tokenizer, model if merge_after: merged_dir = _DEFAULT_MERGED_DIR / (name or output_dir.name) - rl.llm.merge_lora.merge_lora.callback( + rl.llm.merge_lora.main.callback( base_model_id=base_model_id, lora_model_id=output_dir, output_path=merged_dir, diff --git a/rl/utils/dev_sync.py b/rl/utils/dev_sync.py new file mode 100644 index 0000000..39cae6c --- /dev/null +++ b/rl/utils/dev_sync.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 + +import subprocess +from pathlib import Path +from threading import Timer + +from watchdog.events import FileSystemEventHandler +from watchdog.observers import Observer + +import rl.utils.click as click + + +class DebounceTimer: + def __init__(self, timeout, callback): + self.timeout = timeout + self.callback = callback + self.timer = None + + def start(self): + if self.timer: + self.timer.cancel() + self.timer = Timer(self.timeout, self.callback) + self.timer.start() + + +class SyncHandler(FileSystemEventHandler): + def __init__( + self, + local_dir: Path, + remote_user: str, + remote_host: str, + remote_dir: str, + delay: float, + ): + self.local_dir = local_dir + self.remote_user = remote_user + self.remote_host = remote_host + self.remote_dir = remote_dir + self.debounce_timer = DebounceTimer(delay, self.sync_repo) + + def on_any_event(self, event): + self.debounce_timer.start() + + def sync_repo(self): + click.echo("Syncing repository...") + exclude_list = subprocess.run( + [ + "git", + "-C", + str(self.local_dir), + "ls-files", + "--exclude-standard", + "-oi", + "--directory", + "--others", + ], + stdout=subprocess.PIPE, + text=True, + check=True, + ) + subprocess.run( + [ + "rsync", + "-avz", + "--delete", + "--exclude-from=-", + "--exclude=.git", + f"{self.local_dir}/", + f"{self.remote_user}@{self.remote_host}:{self.remote_dir}/", + ], + input=exclude_list.stdout.encode(), + check=True, + ) + click.echo("Sync completed.") + + +@click.command() +@click.option( + "-l", + "--local", + "local_dir", + type=click.Path(exists=True), + required=True, + help="Local directory to sync", +) +@click.option("-r", "--remote-dir", required=True, help="Remote directory to sync to") +@click.option("-u", "--user", required=True, help="Remote user") +@click.option("-h", "--host", required=True, help="Remote host") +@click.option("-d", "--delay", default=5.0, help="Debounce delay in seconds") +def main(local_dir: str, remote_dir: str, user: str, host: str, delay: float): + local_path = Path(local_dir).resolve() + handler = SyncHandler(local_path, user, host, remote_dir, delay) + + handler.sync_repo() # Initial sync + + observer = Observer() + observer.schedule(handler, str(local_path), recursive=True) + observer.start() + + click.echo(f"Watching {local_path}. Press Ctrl+C to stop. Press Enter to resync.") + try: + while True: + if input() == "": + handler.sync_repo() + except KeyboardInterrupt: + observer.stop() + observer.join() + + +if __name__ == "__main__": + main()