Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 authored Jun 7, 2024
1 parent c928e4b commit d4b7d9f
Showing 1 changed file with 45 additions and 3 deletions.
48 changes: 45 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,20 @@ run.influence.compute_self_influence(test_log) # Uncertainty estimation
### HuggingFace Integration
Our software design allows for the seamless integration with HuggingFace's
[Transformer](https://github.com/huggingface/transformers/tree/main), a popular DL framework
that conveniently handles distributed training, data loading, etc. We plan to support more
frameworks (e.g. Lightning) in the future.
that conveniently handles distributed training, data loading, etc.

```python
from transformers import Trainer, Seq2SeqTrainer
from logix.huggingface import patch_trainer, LogIXArguments

logix_args = LogIXArguments(project, config, lora=True, hessian="raw", save="grad")
# Define LogIX arguments
logix_args = LogIXArguments(project="myproject",
config="config.yaml",
lora=True,
hessian="raw",
save="grad")

# Patch HF Trainer
LogIXTrainer = patch_trainer(Trainer)

# Pass LogIXArguments as TrainingArguments
Expand All @@ -108,6 +114,42 @@ trainer.influence()
trainer.self_influence()
```

### PyTorch Lightning Integration
Similarly, we also support the LogIX + PyTorch Lightning integration. The code example
is provided below.

```python
from lightning import LightningModule, Trainer
from logix.lightning import patch, LogIXArguments

class MyLitModule(LightningModule):
...

def data_id_extractor(batch):
return tokenizer.batch_decode(batch["input_ids"])

# Define LogIX arguments
logix_args = LogIXArguments(project="myproject",
config="config.yaml",
lora=True,
hessian="raw",
save="grad")

# Patch Lightning Module and Trainer
LogIXModule, LogIXTrainer = patch(MyLitModule,
Trainer,
logix_args=logix_args,
data_id_extractor=data_id_extractor)

# Use patched Module and Trainer as before
module = LogIXModule(user_args)
trainer = LogIXTrainer(user_args)

# Instead of trainer.fit(module, train_loader), use
trainer.extract_log(module, train_loader)
trainer.influence(module, train_loader)
```

Please check out [Examples](/examples) for more detailed examples!


Expand Down

0 comments on commit d4b7d9f

Please sign in to comment.