Skip to content

Commit

Permalink
directly set use_tf_gamma on Attention instances
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 28, 2023
1 parent ece4306 commit d2dbc21
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion enformer_pytorch/modeling_enformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,10 @@ def from_pretrained(name, use_tf_gamma = None, **kwargs):
enformer = Enformer.from_pretrained(name, **kwargs)

if name == 'EleutherAI/enformer-official-rough':
enformer.use_tf_gamma = default(use_tf_gamma, True)
use_tf_gamma = default(use_tf_gamma, True)

for module in enformer.modules():
if isinstance(module, Attention):
module.use_tf_gamma = use_tf_gamma

return enformer
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'enformer-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.8.2',
version = '0.8.3',
license='MIT',
description = 'Enformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit d2dbc21

Please sign in to comment.