From d4b7d9f90621ecedd49a05f8e0fcd634c0974f25 Mon Sep 17 00:00:00 2001 From: Sang Choe Date: Fri, 7 Jun 2024 09:03:52 -0400 Subject: [PATCH] Update README.md --- README.md | 48 +++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 190fece..f0198da 100644 --- a/README.md +++ b/README.md @@ -85,14 +85,20 @@ run.influence.compute_self_influence(test_log) # Uncertainty estimation ### HuggingFace Integration Our software design allows for the seamless integration with HuggingFace's [Transformer](https://github.com/huggingface/transformers/tree/main), a popular DL framework -that conveniently handles distributed training, data loading, etc. We plan to support more -frameworks (e.g. Lightning) in the future. +that conveniently handles distributed training, data loading, etc. ```python from transformers import Trainer, Seq2SeqTrainer from logix.huggingface import patch_trainer, LogIXArguments -logix_args = LogIXArguments(project, config, lora=True, hessian="raw", save="grad") +# Define LogIX arguments +logix_args = LogIXArguments(project="myproject", + config="config.yaml", + lora=True, + hessian="raw", + save="grad") + +# Patch HF Trainer LogIXTrainer = patch_trainer(Trainer) # Pass LogIXArguments as TrainingArguments @@ -108,6 +114,42 @@ trainer.influence() trainer.self_influence() ``` +### PyTorch Lightning Integration +Similarly, we also support the LogIX + PyTorch Lightning integration. The code example +is provided below. + +```python +from lightning import LightningModule, Trainer +from logix.lightning import patch, LogIXArguments + +class MyLitModule(LightningModule): + ... + +def data_id_extractor(batch): + return tokenizer.batch_decode(batch["input_ids"]) + +# Define LogIX arguments +logix_args = LogIXArguments(project="myproject", + config="config.yaml", + lora=True, + hessian="raw", + save="grad") + +# Patch Lightning Module and Trainer +LogIXModule, LogIXTrainer = patch(MyLitModule, + Trainer, + logix_args=logix_args, + data_id_extractor=data_id_extractor) + +# Use patched Module and Trainer as before +module = LogIXModule(user_args) +trainer = LogIXTrainer(user_args) + +# Instead of trainer.fit(module, train_loader), use +trainer.extract_log(module, train_loader) +trainer.influence(module, train_loader) +``` + Please check out [Examples](/examples) for more detailed examples!