Skip to content

Commit

Permalink
adopt the fix proposed by @johahi for fixing the small numerical disc…
Browse files Browse the repository at this point in the history
…repancy in the pretrained model between tensorflow and pytorch #31
  • Loading branch information
lucidrains committed Sep 27, 2023
1 parent 05a3654 commit a99626a
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 32 deletions.
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1 @@
recursive-include enformer_pytorch *.yml
recursive-include enformer_pytorch *.pt
28 changes: 16 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,18 @@ Deepmind has released the weights for their tensorflow sonnet Enformer model! I

Update: <a href="https://github.com/jstjohn">John St. John</a> did some work and found that the `enformer-official-rough` model hits the reported marks in the paper - human pearson R of `0.625` for validation, and `0.65` for test.

Update: As of version 0.8.0, if one were to use the `from_pretrained` function to load the pretrained model, it should automatically use precomputed gamma positions to address a difference between tensorflow and pytorch `xlogy`. This should resolve the numerical discrepancy above. If you were to further finetune and not be using the `from_pretrained` function, please make sure to set `use_tf_gamma = True` when using `.from_hparams` to instantiate the `Enformer`

```bash
$ pip install enformer-pytorch>=0.5
````

Loading the model

```python
from enformer_pytorch import Enformer
from enformer_pytorch import from_pretrained
enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough')
enformer = from_pretrained('EleutherAI/enformer-official-rough')
```

Quick sanity check on a single human validation point
Expand All @@ -143,19 +145,19 @@ This is all made possible thanks to HuggingFace's [custom model](https://hugging
You can also load, with overriding of the `target_length` parameter, if you are working with shorter sequence lengths
```python
from enformer_pytorch import Enformer
from enformer_pytorch import from_pretrained
model = Enformer.from_pretrained('EleutherAI/enformer-official-rough', target_length = 128, dropout_rate = 0.1)
model = from_pretrained('EleutherAI/enformer-official-rough', target_length = 128, dropout_rate = 0.1)
# do your fine-tuning
```
To save on memory during fine-tuning a large Enformer model
```python
from enformer_pytorch import Enformer
from enformer_pytorch import from_pretrained
enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough', use_checkpointing = True)
enformer = from_pretrained('EleutherAI/enformer-official-rough', use_checkpointing = True)
# finetune enformer on a limited budget
```
Expand All @@ -168,10 +170,10 @@ Fine-tuning on new tracks
```python
import torch
from enformer_pytorch import Enformer
from enformer_pytorch import from_pretrained
from enformer_pytorch.finetune import HeadAdapterWrapper
enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough')
enformer = from_pretrained('EleutherAI/enformer-official-rough')
model = HeadAdapterWrapper(
enformer = enformer,
Expand All @@ -190,10 +192,10 @@ Finetuning on contextual data (cell type, transcription factor, etc)
```python
import torch
from enformer_pytorch import Enformer
from enformer_pytorch import from_pretrained
from enformer_pytorch.finetune import ContextAdapterWrapper

enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough')
enformer = from_pretrained('EleutherAI/enformer-official-rough')

model = ContextAdapterWrapper(
enformer = enformer,
Expand All @@ -218,10 +220,10 @@ Finally, there is also a way to use attention aggregation from a set of context
```python
import torch
from enformer_pytorch import Enformer
from enformer_pytorch import from_pretrained
from enformer_pytorch.finetune import ContextAttentionAdapterWrapper

enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough')
enformer = from_pretrained('EleutherAI/enformer-official-rough')

model = ContextAttentionAdapterWrapper(
enformer = enformer,
Expand Down Expand Up @@ -315,6 +317,8 @@ seq, rand_shift_val, rc_bool = ds[0] # (196608,), (1,), (1,)
Special thanks goes out to <a href="https://www.eleuther.ai/">EleutherAI</a> for providing the resources to retrain the model, during a time when the official model from Deepmind had not been released yet.
Thanks also goes out to <a href="johahi">@johahi</a> for finding out that there are numerical differences between the torch and tensorflow implementations of `xlogy`. He provided a fix for this difference, which is adopted in this repository in `v0.8.0`
## Todo
- [x] script to load weights from trained tensorflow enformer model to pytorch model
Expand Down
2 changes: 1 addition & 1 deletion enformer_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from enformer_pytorch.config_enformer import EnformerConfig
from enformer_pytorch.modeling_enformer import Enformer, SEQUENCE_LENGTH, AttentionPool
from enformer_pytorch.modeling_enformer import Enformer, from_pretrained, SEQUENCE_LENGTH, AttentionPool
from enformer_pytorch.data import seq_indices_to_one_hot, str_to_one_hot, GenomeIntervalDataset, FastaInterval
4 changes: 3 additions & 1 deletion enformer_pytorch/config_enformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
use_convnext = False,
num_downsamples = 7, # genetic sequence is downsampled 2 ** 7 == 128x in default Enformer - can be changed for higher resolution
dim_divisible_by = 128,
use_tf_gamma = False,
**kwargs,
):
self.dim = dim
Expand All @@ -32,5 +33,6 @@ def __init__(
self.use_checkpointing = use_checkpointing
self.num_downsamples = num_downsamples
self.dim_divisible_by = dim_divisible_by

self.use_tf_gamma = use_tf_gamma

super().__init__(**kwargs)
58 changes: 44 additions & 14 deletions enformer_pytorch/modeling_enformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import math
from pathlib import Path

import torch
from torch import nn, einsum
import torch.nn.functional as F
Expand All @@ -18,6 +20,13 @@
SEQUENCE_LENGTH = 196_608
TARGET_LENGTH = 896

# gamma positions from tensorflow
# addressing a difference between xlogy results from tensorflow and pytorch
# solution came from @johahi

DIR = Path(__file__).parents[0]
TF_GAMMAS = torch.load(str(DIR / "precomputed"/ "tf_gammas.pt"))

# helpers

def exists(val):
Expand All @@ -26,6 +35,12 @@ def exists(val):
def default(val, d):
return val if exists(val) else d

def always(val):
def inner(*args, **kwargs):
print(val.shape)
return val
return inner

def map_values(fn, d):
return {key: fn(values) for key, values in d.items()}

Expand Down Expand Up @@ -75,30 +90,24 @@ def get_positional_features_gamma(positions, features, seq_len, stddev = None, s
if not exists(start_mean):
start_mean = seq_len / features

# turns out xlogy between tensorflow and torch differs because of the log - thanks to phd student @johahi for finding this!
# do everything in float64 here for precision

dtype = positions.dtype
positions = positions.double()
mean = torch.linspace(start_mean, seq_len, features, device = positions.device, dtype = torch.float64)
mean = torch.linspace(start_mean, seq_len, features, device = positions.device)

mean = mean[None, ...]
concentration = (mean / stddev) ** 2
rate = mean / stddev ** 2

probabilities = gamma_pdf(positions.abs()[..., None], concentration, rate)
probabilities = gamma_pdf(positions.float().abs()[..., None], concentration, rate)
probabilities = probabilities + eps
outputs = probabilities / torch.amax(probabilities, dim = -1, keepdim = True)
return outputs

return outputs.to(dtype)

def get_positional_embed(seq_len, feature_size, device):
def get_positional_embed(seq_len, feature_size, device, use_tf_gamma):
distances = torch.arange(-seq_len + 1, seq_len, device = device)

feature_functions = [
get_positional_features_exponential,
get_positional_features_central_mask,
get_positional_features_gamma
get_positional_features_gamma if not use_tf_gamma else always(TF_GAMMAS.to(device))
]

num_components = len(feature_functions) * 2
Expand Down Expand Up @@ -213,7 +222,8 @@ def __init__(
dim_key = 64,
dim_value = 64,
dropout = 0.,
pos_dropout = 0.
pos_dropout = 0.,
use_tf_gamma = False
):
super().__init__()
self.scale = dim_key ** -0.5
Expand All @@ -240,6 +250,10 @@ def __init__(
self.pos_dropout = nn.Dropout(pos_dropout)
self.attn_dropout = nn.Dropout(dropout)

# whether to use tf gamma

self.use_tf_gamma = use_tf_gamma

def forward(self, x):
n, h, device = x.shape[-2], self.heads, x.device

Expand All @@ -253,7 +267,7 @@ def forward(self, x):

content_logits = einsum('b h i d, b h j d -> b h i j', q + self.rel_content_bias, k)

positions = get_positional_embed(n, self.num_rel_pos_features, device)
positions = get_positional_embed(n, self.num_rel_pos_features, device, use_tf_gamma = self.use_tf_gamma)
positions = self.pos_dropout(positions)
rel_k = self.to_rel_k(positions)

Expand Down Expand Up @@ -308,6 +322,11 @@ def __init__(self, config):

self.conv_tower = nn.Sequential(*conv_layers)

# whether to use tensorflow gamma positions

use_tf_gamma = config.use_tf_gamma
self.use_tf_gamma = use_tf_gamma

# transformer

transformer = []
Expand All @@ -322,7 +341,8 @@ def __init__(self, config):
dim_value = config.dim // config.heads,
dropout = config.attn_dropout,
pos_dropout = config.pos_dropout,
num_rel_pos_features = config.dim // config.heads
num_rel_pos_features = config.dim // config.heads,
use_tf_gamma = use_tf_gamma
),
nn.Dropout(config.dropout_rate)
)),
Expand Down Expand Up @@ -454,3 +474,13 @@ def forward(
return out, x

return out

# from pretrained function

def from_pretrained(name, use_tf_gamma = None, **kwargs):
enformer = Enformer.from_pretrained(name, **kwargs)

if name == 'EleutherAI/enformer-official-rough':
enformer.use_tf_gamma = default(use_tf_gamma, True)

return enformer
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'enformer-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.7.7',
version = '0.8.0',
license='MIT',
description = 'Enformer - Pytorch',
author = 'Phil Wang',
Expand Down
4 changes: 2 additions & 2 deletions test_pretrained.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from enformer_pytorch import Enformer
from enformer_pytorch import from_pretrained

enformer = Enformer.from_pretrained('EleutherAI/enformer-official-rough').cuda()
enformer = from_pretrained('EleutherAI/enformer-official-rough').cuda()
enformer.eval()

data = torch.load('./data/test-sample.pt')
Expand Down

0 comments on commit a99626a

Please sign in to comment.