Skip to content

Commit

Permalink
yarn embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Oct 1, 2023
1 parent 6273935 commit 4d3e2e0
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
2 changes: 1 addition & 1 deletion zeta/nn/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@
)
from zeta.nn.embeddings.yarn import *
from zeta.nn.embeddings.yarn import YarnEmbedding
from zeta.nn.embeddings.position
from zeta.nn.embeddings.positional_interpolation import PositionalInterpolation
34 changes: 33 additions & 1 deletion zeta/nn/embeddings/positional_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,39 @@
class PositionInterpolationEmbeddings(nn.Module):
"""
PositionInterpolation
Overview
========
Positional embeddings that interpolate between sinusoidal and learned embeddings.
Parameters
==========
dim: int
Dimension of the input embedding.
max_positions: int
Maximum number of positions to embed.
base: int
Base of the sinusoidal embedding.
device: torch.device
Device to store the embeddings on.
Attributes
==========
inv_freq: torch.Tensor
Cached inverse frequencies.
max_seq_len_cached: int
Maximum sequence length cached.
scale: float
Scale of the sinusoidal embedding.
cos_cached: torch.Tensor
Cached cosine values.
sin_cached: torch.Tensor
Cached sine values.
Methods
=======
forward(x, seq_len=None)
Forward pass of the PositionInterpolationEmbeddings.
"""
def __init__(
Expand Down

0 comments on commit 4d3e2e0

Please sign in to comment.