Skip to content

Commit

Permalink
Merge pull request #6 from pomonam/examples
Browse files Browse the repository at this point in the history
Final Release Code Refactor
  • Loading branch information
pomonam authored Mar 22, 2024
2 parents da1646c + 1e9cbc3 commit 09e38a8
Show file tree
Hide file tree
Showing 57 changed files with 3,823 additions and 1,407 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,12 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

# MacOS
.DS_Store

# Checkpoints and influence outputs
checkpoints/
analyses/
data/
*.pth
*.pth
*.pt
62 changes: 31 additions & 31 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Kronfluence: Technical Documentation & FAQs

For a detailed description of the methodology, see the [**paper**](https://arxiv.org/abs/2308.03296) *Studying Large Language Model Generalization with Influence Functions*.

## Requirements

Kronfluence has been tested on the following versions of [PyTorch](https://pytorch.org/):
Expand All @@ -10,7 +12,7 @@ Kronfluence has been tested on the following versions of [PyTorch](https://pytor

Kronfluence supports:
- Computing influence functions on selected PyTorch modules. At the moment, we support `nn.Linear` and `nn.Conv2d`;
- Computing influence functions with several strategies: `identity`, `diagonal`, `KFAC`, and `EKFAC`;
- Computing influence functions with several Hessian approximation strategies: `identity`, `diagonal`, `KFAC`, and `EKFAC`;
- Computing pairwise and self-influence scores.

> [!NOTE]
Expand Down Expand Up @@ -42,8 +44,8 @@ query_dataset = prepare_query_dataset()

**Define a Task.**
To compute influence scores, you need to define a [`Task`](https://github.com/pomonam/kronfluence/blob/main/kronfluence/task.py) class.
This class encapsulates information about the trained model and how influence scores will be computed:
(1) how to compute the training loss; (2) how to compute the measurement (f(θ) in the [paper](https://arxiv.org/abs/2308.03296));
This class contains information about the trained model and how influence scores will be computed:
(1) how to compute the training loss; (2) how to compute the measurable quantity (f(θ) in the [paper](https://arxiv.org/abs/2308.03296); see Equation 5);
(3) which modules to use for influence function computations; and (4) whether the model used [attention mask](https://huggingface.co/docs/transformers/en/glossary#attention-mask).

```python
Expand All @@ -59,13 +61,15 @@ class YourTask(Task):
model: nn.Module,
sample: bool = False,
) -> torch.Tensor:
# This will be used for computing the training gradient.
# TODO: Complete this method.

def compute_measurement(
self,
batch: Any,
model: nn.Module,
) -> torch.Tensor:
# This will be used for computing the measurable quantity.
# TODO: Complete this method.

def tracked_modules(self) -> Optional[List[str]]:
Expand Down Expand Up @@ -129,7 +133,7 @@ scores = analyzer.compute_pairwise_scores(
...
```

You can organize all factors and scores for the specific model with `factor_name` and `score_name`.
You can organize all factors and scores for the specific model with `factors_name` and `scores_name`.

### FAQs

Expand All @@ -146,15 +150,15 @@ inspect `model.named_modules()` to determine what modules to use. You can specif

> [!NOTE]
> If the embedding layer for transformers are defined with `nn.Linear`, you must write
> `task.tracked_modules` to avoid influence computations embedding matrices.
> `task.tracked_modules` to avoid influence computations embedding matrices (it is too expensive).
**How should I implement Task.compute_train_loss?**
Implement the loss function used to train the model. Note that the function should return
the summed loss (over batches and tokens) and should not include regularizations.

**How should I implement Task.compute_measurement?**
It depends on the analysis you would like to perform. Influence functions approximate the [local effect of downweighting/upweighting
a data point on the query's measurable quantity](https://arxiv.org/abs/2209.05364). You can use the loss, [margin](https://arxiv.org/abs/2303.14186) (for classification),
It depends on the analysis you would like to perform. Influence functions approximate the [effect of downweighting/upweighting
a training data point on the query's measurable quantity](https://arxiv.org/abs/2209.05364). You can use the loss, [margin](https://arxiv.org/abs/2303.14186) (for classification),
or [conditional log-likelihood](https://arxiv.org/abs/2308.03296) (for language modeling).

**I encounter TrackedModuleNotFoundError when using DDP or FSDP.**
Expand Down Expand Up @@ -188,7 +192,7 @@ import torch
from kronfluence.arguments import FactorArguments

factor_args = FactorArguments(
strategy="ekfac", # Choose from "identity", "diagonal", "KFAC", or "EKFAC".
strategy="ekfac", # Choose from "identity", "diagonal", "kfac", or "ekfac".
use_empirical_fisher=False,
immediate_gradient_removal=False,
ignore_bias=False,
Expand Down Expand Up @@ -217,8 +221,8 @@ analyzer.fit_all_factors(factors_name="initial_factor", dataset=train_dataset, f
```

You can change:
- `strategy`: Selects the preconditioning strategy (`identity`, `diagonal`, `KFAC`, or `EKFAC`).
- `use_empirical_fisher`: Determines whether to approximate the [empirical Fisher](https://arxiv.org/abs/1905.12558) (using actual labels from batch)
- `strategy`: Selects the Hessian approximation strategy (`identity`, `diagonal`, `KFAC`, or `EKFAC`).
- `use_empirical_fisher`: Determines whether to use the [empirical Fisher](https://arxiv.org/abs/1905.12558) (using actual labels from batch)
instead of the true Fisher (using sampled labels from model's predictions). It is recommended to be `False`.
- `immediate_gradient_removal`: Specifies whether to instantly set `param.grad = None` within module hooks. Generally,
recommended to be `False`, as it requires installing additional hooks. This should not affect the fitted factors, but
Expand All @@ -227,7 +231,7 @@ can potentially reduce peak memory.

### Fitting Covariance Matrices

`KFAC` and `EKFAC` require computing the activation and pseudo-gradient covariance matrices.
`KFAC` and `EKFAC` require computing the uncentered activation and pre-activation pseudo-gradient covariance matrices.
To fit covariance matrices, you can use `analyzer.fit_covariance_matrices`.
```python
# Fitting covariance matrices.
Expand All @@ -236,23 +240,21 @@ analyzer.fit_covariance_matrices(factors_name="initial_factor", dataset=train_da
covariance_matrices = analyzer.load_covariance_matrices(factors_name="initial_factor")
```

You can tune:
This step corresponds to Equation 16 in the paper. You can tune:
- `covariance_max_examples`: Controls the maximum number of data points for fitting covariance matrices. Setting it to `None`,
Kronfluence computes covariance matrices for all data points.
- `covariance_data_partition_size`: Number of data partitions to use for computing covariance matrices.
For example, when `covariance_data_partition_size = 2`, the dataset is split into 2 chunks and covariance matrices
are separately computed for each chunk. These chunked covariance matrices are later aggregated. This is useful with GPU preemption as intermediate
covariance matrices will be saved in disk. It can be also helpful when launching multiple parallel jobs, where each GPU
can compute covariance matrices on some partitioned data (You can specify `target_data_partitions` in the parameter).
This should not affect the quality of the fitted factors.
- `covariance_module_partition_size`: Number of module partitions to use for computing covariance matrices.
For example, when `covariance_module_partition_size = 2`, the module is split into 2 chunks and covariance matrices
are separately computed for each chunk. This is useful when the available GPU memory is limited (e.g., the total
covariance matrices cannot fit into memory). However, this will do multiple iterations over the dataset and can be slow.
This should not affect the quality of the fitted factors.
- `activation_covariance_dtype`: `dtype` for computing activation covariance matrices. You can also use `torch.bfloat16`
or `torch.float16`.
- `gradient_covariance_dtype`: `dtype` for computing activation covariance matrices. You can also use `torch.bfloat16`
- `gradient_covariance_dtype`: `dtype` for computing pre-activation pseudo-gradient covariance matrices. You can also use `torch.bfloat16`
or `torch.float16`.

**Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors.
Expand All @@ -264,7 +266,7 @@ or `torch.float16`.

### Performing Eigendecomposition

After computing the covariance matrices, `KFAC` and `EKFAC` require performing Eigendecomposition.
After computing the covariance matrices, `KFAC` and `EKFAC` require performing eigendecomposition.

```python
# Performing Eigendecomposition.
Expand All @@ -273,13 +275,13 @@ analyzer.perform_eigendecomposition(factors_name="initial_factor", factor_args=f
eigen_factors = analyzer.load_eigendecomposition(factors_name="initial_factor")
```

You can tune:
- `eigendecomposition_dtype`: `dtype` for performing Eigendecomposition. You can also use `torch.float32`,
This corresponds to Equation 18 in the paper. You can tune:
- `eigendecomposition_dtype`: `dtype` for performing eigendecomposition. You can also use `torch.float32`,
but `torch.float64` is recommended.

### Fitting Lambda Matrices

`EKFAC` and `diagonal` require computing the Lambda matrices for all modules.
`EKFAC` and `diagonal` require computing the Lambda (eigenvalue) matrices for all modules.

```python
# Fitting Lambda matrices.
Expand All @@ -288,7 +290,7 @@ analyzer.fit_lambda_matrices(factors_name="initial_factor", dataset=train_datase
lambda_matrices = analyzer.load_lambda_matrices(factors_name="initial_factor")
```

You can tune:
This corresponds to Equation 20 in the paper. You can tune:
- `lambda_max_examples`: Controls the maximum number of data points for fitting Lambda matrices.
- `lambda_data_partition_size`: Number of data partitions to use for computing Lambda matrices.
- `lambda_module_partition_size`: Number of module partitions to use for computing Lambda matrices.
Expand All @@ -297,7 +299,7 @@ You can set `cached_activation_cpu_offload=True` to cache these activations in C
- `lambda_iterative_aggregate`: Whether to compute the Lambda matrices with for-loop instead of batched matrix multiplications.
This is helpful for reducing peak memory, as it avoids holding multiple copies of tensors with the same shape as the per-sample-gradient.
- `lambda_dtype`: `dtype` for computing Lambda matrices. You can also use `torch.bfloat16`
or `torch.float16`, but `torch.float32` is generally recommended.
or `torch.float16`.


**Dealing with OOMs.** Here are some steps to fix Out of Memory (OOM) errors.
Expand All @@ -312,7 +314,7 @@ or `torch.float16`, but `torch.float32` is generally recommended.
**I get different factors each time I run the code.**
This is expected as we sample labels from the model's prediction when computing covariance and Lambda matrices.
Using `use_empirical_fisher=True` could make the process more deterministic. Moreover, different hardware might compute
different eigenvectors when performing Eigendecomposition.
different eigenvectors when performing eigendecomposition.

**How should I select the batch size?**
You can use the largest possible batch size that does not result in OOM. Typically, the batch size for fitting Lambda
Expand Down Expand Up @@ -347,16 +349,14 @@ score_args = ScoreArguments(

- `damping`: A damping factor for the damped matrix-vector product. Uses a heuristic based on mean eigenvalues
(0.1 x mean eigenvalues) if None.
- `immediate_gradient_removal`: Whether to immediately remove `param.grad` within a hook. This should be set to
`False` in most cases.
- `immediate_gradient_removal`: Whether to immediately remove `param.grad` within a hook.
- `data_partition_size`: Number of data partitions for computing influence scores.
- `module_partition_size`: Number of module partitions for computing influence scores.
- `per_module_score`: Whether to return a per-module influence scores. Instead of summing over influences across
all modules, this will keep track of intermediate module-wise scores.

- `query_gradient_rank`: The rank for the query batching. If `None`, no query batching will be used.
- `query_gradient_svd_dtype`: `dtype` for performing singular value decomposition (SVD) for query batch. You can use `torch.float32`,
but `torch.float64` is recommended.
- `query_gradient_rank`: The rank for the query batching (low-rank approximation to the query gradient; see Section 3.2.2). If `None`, no query batching will be used.
- `query_gradient_svd_dtype`: `dtype` for performing singular value decomposition (SVD) for query batch. You can also use `torch.float32`.

- `cached_activation_cpu_offload`: Whether to offload cached activations to CPU.
- `score_dtype`: `dtype` for computing influence scores. You can use `torch.bfloat16` or `torch.float16`.
Expand All @@ -366,15 +366,15 @@ but `torch.float32` is recommended.

### Computing Influence Scores

To compute pairwise influence scores, you can run:
To compute pairwise influence scores (Equation 5 in the paper), you can run:
```python
# Computing pairwise influence scores.
analyzer.compute_pairwise_scores(scores_name="pairwise", factors_name="ekfac", score_args=score_args)
# Loading pairwise influence scores.
scores = analyzer.load_pairwise_scores(scores_name="pairwise")
```

To compute self-influence scores, you can run:
To compute self-influence scores (see Section 5.4 from [paper](https://arxiv.org/pdf/1703.04730.pdf)), you can run:
```python
# Computing pairwise influence scores.
analyzer.compute_self_scores(scores_name="self", factors_name="ekfac", score_args=score_args)
Expand All @@ -388,14 +388,14 @@ scores = analyzer.load_self_scores(scores_name="self")
3. Try using lower precision for `per_sample_gradient_dtype` and `score_dtype`.
4. Try setting `immediate_gradient_removal=True`.
5. Try setting `query_gradient_rank > 1`. The recommended values are `16`, `32`, `64`, `128`, and `256`. Note that query
batching is only supported for computing pairwise influence scores.
batching is only supported for computing pairwise influence scores, not self-infleucen scores.
6. Try setting `module_partition_size > 1`.

### FAQs

**Influence scores are very large in magnitude.**
Ideally, influence scores need to be divided by the total number of training data points. However, the code does
not normalize the scores. If you would like, you can divide the scores with the total number of data points used to
not normalize the scores. If you would like, you can divide the scores with the total number of data points (or tokens) used to
train the model.

## References
Expand Down
28 changes: 16 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
<a href="#"><img width="380" img src=".assets/kronfluence.svg" alt="Kronfluence"/></a>
</p>


<p align="center">
<a href="https://pypi.org/project/kronfluence">
<img alt="License" src="https://img.shields.io/pypi/v/kronfluence.svg?style=flat-square">
</a>
<a href="https://github.com/pomonam/kronfluence/blob/main/LICENSE">
<img alt="License" src="https://img.shields.io/badge/License-Apache_2.0-blue.svg">
</a>
Expand All @@ -19,7 +23,7 @@

---

> **Kronfluence** is a repository designed to compute [influence functions](https://arxiv.org/abs/1703.04730) using [Kronecker-factored Approximate Curvature (KFAC)](https://arxiv.org/abs/1503.05671) or [Eigenvalue-corrected KFAC (EKFAC)](https://arxiv.org/abs/1806.03884).
> **Kronfluence** is a research repository designed to compute [influence functions](https://arxiv.org/abs/1703.04730) using [Kronecker-factored Approximate Curvature (KFAC)](https://arxiv.org/abs/1503.05671) or [Eigenvalue-corrected KFAC (EKFAC)](https://arxiv.org/abs/1806.03884).
For a detailed description of the methodology, see the [**paper**](https://arxiv.org/abs/2308.03296) *Studying Large Language Model Generalization with Influence Functions*.

---
Expand Down Expand Up @@ -53,10 +57,10 @@ pip install -e .
Kronfluence supports influence computations on `nn.Linear` and `nn.Conv2d` modules. See the [**Technical Documentation**](https://github.com/pomonam/kronfluence/blob/main/DOCUMENTATION.md)
page for a comprehensive guide.

### Examples
### Learn More

The [examples](https://github.com/pomonam/kronfluence/tree/main/examples) folder contains several examples on how to use Kronfluence.
We plan to add more language model examples. **TL;DR** You need to prepare the trained model and datasets, and pass them into the `Analyzer`.
The [examples](https://github.com/pomonam/kronfluence/tree/main/examples) folder contains several examples.
We plan to add more examples in the future. **TL;DR** You need to prepare the trained model and datasets, and pass them into `Analyzer`.

```python
import torch
Expand Down Expand Up @@ -90,27 +94,27 @@ eval_dataset = torchvision.datasets.MNIST(
train=True,
)

# Initialize the task with relevant loss and measurement.
# Define the task.
task = MnistTask()

# Prepare the model for influence computation with the specified task.
# Prepare the model for influence computation.
model = prepare_model(model=model, task=task)
analyzer = Analyzer(analysis_name="mnist", model=model, task=task)

# Fit all EKFAC factors for the given model on the training dataset.
analyzer.fit_all_factors(factors_name="ekfac", dataset=train_dataset)
# Fit all EKFAC factors for the given model.
analyzer.fit_all_factors(factors_name="my_factors", dataset=train_dataset)

# Compute all pairwise influence scores using the computed factors.
# Compute all pairwise influence scores with the computed factors.
analyzer.compute_pairwise_scores(
scores_name="pairwise_scores",
factors_name="ekfac",
scores_name="my_scores",
factors_name="my_factors",
query_dataset=eval_dataset,
train_dataset=train_dataset,
per_device_query_batch_size=1024,
)

# Load the scores with dimension `len(eval_dataset) x len(train_dataset)`.
scores = analyzer.load_pairwise_scores(scores_name="pairwise_scoeres")
scores = analyzer.load_pairwise_scores(scores_name="my_scores")
```

## Contributing
Expand Down
3 changes: 0 additions & 3 deletions examples/_test_requirements.txt

This file was deleted.

57 changes: 57 additions & 0 deletions examples/cifar/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# CIFAR-10 & ResNet-9 Example

This directory contains scripts for training ResNet-9 on CIFAR-10. The pipeline is motivated from
[TRAK repository](https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb).

## Training

To train ResNet-9 on CIFAR-10 dataset, run the following command:
```bash
python train.py --dataset_dir ./data \
--checkpoint_dir ./checkpoints \
--train_batch_size 512 \
--eval_batch_size 1024 \
--learning_rate 0.4 \
--weight_decay 0.0001 \
--num_train_epochs 25 \
--seed 1004
```

## Computing Pairwise Influence Scores

To obtain pairwise influence scores on 2000 query data points using `ekfac`, run the following command:
```bash
python analyze.py --query_batch_size 1000 \
--dataset_dir ./data \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac
```
You can also use `identity`, `diagonal`, and `kfac`. On A100 (80GB), it takes roughly 1.5 minutes to compute the
pairwise scores (including computing EKFAC factors).

## Mislabeled Data Detection

We can use self-influence scores (see Section 5.4 for the [paper](https://arxiv.org/pdf/1703.04730.pdf)) to detect mislabeled examples.
First, train the model with 10% of training examples mislabeled by running the following command:
```bash
python train.py --dataset_dir ./data \
--corrupt_percentage 0.1 \
--checkpoint_dir ./checkpoints \
--train_batch_size 512 \
--eval_batch_size 1024 \
--learning_rate 0.4 \
--weight_decay 0.0001 \
--num_train_epochs 25 \
--seed 1004
```

Then, compute self-influence scores with the following command:
```bash
python detect_mislabeled_dataset.py --dataset_dir ./data \
--corrupt_percentage 0.1 \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac
```

On A100 (80GB), it takes roughly 1.5 minutes to compute the self-influence scores.
We can detect around 82% of mislabeled data points by inspecting 10% of the dataset (96% by inspecting 20%).
Loading

0 comments on commit 09e38a8

Please sign in to comment.