Skip to content

Commit

Permalink
xpos docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Aug 22, 2023
1 parent 4c28a6a commit c15a4bc
Show file tree
Hide file tree
Showing 4 changed files with 435 additions and 14 deletions.
111 changes: 111 additions & 0 deletions docs/zeta/nn/modules/xpos.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# XPOS Module Documentation
-------------------------

### Architecture

The XPOS module is a part of a neural network model and is implemented as a subclass of `torch.nn.Module`. It consists of several functions and a class that work together to apply rotary positional embeddings to an input tensor.

### Purpose

The purpose of the XPOS module is to incorporate positional information into the input tensor of a neural network model. It achieves this by generating fixed positional embeddings and applying them to the input tensor using rotary positional encoding techniques. This allows the model to capture the sequential order and relative positions of the input elements, which can be beneficial for tasks such as natural language processing and time series analysis.

### Functions and Methods

1. `fixed_pos_embedding(x)`: Generates fixed positional embeddings for the input tensor.

- Args:
- `x` (torch.Tensor): Input tensor of shape `(seq_len, dim)`.
- Returns:
- `sin` (torch.Tensor): Sine positional embeddings of shape `(seq_len, dim)`.
- `cos` (torch.Tensor): Cosine positional embeddings of shape `(seq_len, dim)`.
2. `rotate_every_two(x)`: Rearranges the elements of the input tensor by rotating every two elements.

- Args:
- `x` (torch.Tensor): Input tensor of shape `(batch_size, seq_len, dim)`.
- Returns:
- `x` (torch.Tensor): Rearranged tensor of shape `(batch_size, seq_len, dim)`.
3. `duplicate_interleave(m)`: Duplicates a matrix while interleaving the copy.

- Args:
- `m` (torch.Tensor): Input matrix.
- Returns:
- `m` (torch.Tensor): Duplicated and interleaved matrix.
4. `apply_rotary_pos_emb(x, sin, cos, scale=1)`: Applies rotary positional embeddings to the input tensor.

- Args:
- `x` (torch.Tensor): Input tensor of shape `(batch_size, seq_len, dim)`.
- `sin` (torch.Tensor): Sine positional embeddings of shape `(seq_len, dim)`.
- `cos` (torch.Tensor): Cosine positional embeddings of shape `(seq_len, dim)`.
- `scale` (float): Scaling factor for the positional embeddings (default: 1).
- Returns:
- `x` (torch.Tensor): Tensor with applied rotary positional embeddings.
5. `XPOS(head_dim, scale_base=512)`: XPOS module class.

- Args:
- `head_dim` (int): Dimensionality of the input tensor.
- `scale_base` (int): Base value for scaling the positional embeddings (default: 512).
- Methods:
- `forward(x, offset=0, downscale=False)`: Forward pass of the XPOS module.
- Args:
- `x` (torch.Tensor): Input tensor of shape `(batch_size, seq_len, dim)`.
- `offset` (int): Offset value for positional embeddings (default: 0).
- `downscale` (bool): Boolean indicating whether to downscale the positional embeddings (default: False).
- Returns:
- `x` (torch.Tensor): Tensor with applied rotary positional embeddings.

### Usage Examples

1. Applying XPOS module to an input tensor:

```
import torch
from xpos import XPOS
# Create an instance of the XPOS module
xpos = XPOS(head_dim=256)
# Generate a random input tensor
x = torch.randn(1, 10, 256)
# Apply the XPOS module to the input tensor
output = xpos(x)
```

Copy code

2. Applying XPOS module with offset and downscaling:

```
import torch
from xpos import XPOS
# Create an instance of the XPOS module
xpos = XPOS(head_dim=512)
# Generate a random input tensor
x = torch.randn(1, 20, 512)
# Apply the XPOS module to the input tensor with offset and downscaling
output = xpos(x, offset=2, downscale=True)
```

Copy code

3. Using the individual functions of the XPOS module:

```
import torch
from xpos import fixed_pos_embedding, apply_rotary_pos_emb
# Generate fixed positional embeddings
scale = torch.randn(10, 256)
sin, cos = fixed_pos_embedding(scale)
# Apply rotary positional embeddings to an input tensor
x = torch.randn(1, 10, 256)
output = apply_rotary_pos_emb(x, sin, cos, scale=0.5)
```

Copy code

Note: The above examples assume that the `xpos.py` file
109 changes: 109 additions & 0 deletions docs/zeta/nn/utils/helpers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
## Documentation

### Overview

The provided module comprises utility functions and classes to streamline specific operations with Python data structures and PyTorch models. The main aspects of the module are:

- Checking the existence of a value.
- Implementing custom call behavior through classes.
- Custom decorators for function calls.
- Dictionary manipulation.
- Initialization of PyTorch layer parameters.

### Functions and Classes

1. **exists(val: Any) -> bool**:
Checks if the provided value is not `None`.

2. **default(val: Any, d: Any) -> Any**:
Returns the value if it's not `None`; otherwise, it returns a default value.

3. **once(fn: Callable) -> Callable**:
A decorator ensuring that the function is only called once.

4. **eval_decorator(fn: Callable) -> Callable**:
A decorator for `torch.nn.Module` methods to switch the module to `eval` mode during the function call and revert to its original mode afterwards.

5. **cast_tuple(val: Any, depth: int) -> Tuple**:
Casts a value to a tuple with a specific depth.

6. **maybe(fn: Callable) -> Callable**:
A decorator that calls the function only if its first argument exists.

7. **always**:
A class that always returns the specified value when called.

8. **not_equals** and **equals**:
Classes that, when instantiated with a value, check if another value is (not) equal to the specified value.

9. **init_zero_(layer: nn.Module) -> None**:
Initializes the weights and biases of a torch layer to zero.

10. **pick_and_pop(keys: List[str], d: Dict) -> Dict**:
Extracts values from a dictionary based on provided keys.

11. **group_dict_by_key(cond: Callable, d: Dict) -> Tuple[Dict, Dict]**:
Groups dictionary keys based on a given condition.

12. **string_begins_with(prefix: str, str: str) -> bool**:
Checks if a string starts with a specific prefix.

13. **group_by_key_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]**:
Groups dictionary items by keys starting with a specific prefix.

14. **groupby_prefix_and_trim(prefix: str, d: Dict) -> Tuple[Dict, Dict]**:
Similar to `group_by_key_prefix` but also removes the prefix from keys.

### Usage Examples

1. **Using the `once` decorator**:

```python
from zeta import once

@once
def greet():
print("Hello, World!")

greet() # prints "Hello, World!"
greet() # Does nothing on the second call
```

2. **Using the `eval_decorator` with PyTorch**:

```python
import torch.nn as nn
from zeta import eval_decorator

class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Linear(10, 10)

@eval_decorator
def predict(self, x):
return self.layer(x)

model = SimpleModel()
input_tensor = torch.randn(1, 10)
output = model.predict(input_tensor) # Automatically switches to eval mode and back
```

3. **Dictionary Manipulation with Prefix Functions**:

```python
from zeta import group_by_key_prefix

sample_dict = {
"user_name": "John",
"user_age": 25,
"order_id": 12345,
"order_date": "2023-01-01"
}

user_data, order_data = group_by_key_prefix("user_", sample_dict)
print(user_data) # {'user_name': 'John', 'user_age': 25}
print(order_data) # {'order_id': 12345, 'order_date': '2023-01-01'}
```

This module is a collection of general-purpose utility functions and classes, making many common operations more concise. It's beneficial when working with PyTorch models and various data manipulation tasks.
65 changes: 54 additions & 11 deletions zeta/nn/modules/xpos_relative_position.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
# Copyright (c) 2022 Agora
# Licensed under The MIT License [see LICENSE for details]

import torch
import torch.nn as nn

def fixed_pos_embedding(x):
"""
Generates fixed positional embeddings for the input tensor.
Args:
- x: Input tensor of shape (seq_len, dim)
Returns:
- sin: Sine positional embeddings of shape (seq_len, dim)
- cos: Cosine positional embeddings of shape (seq_len, dim)
"""
seq_len, dim = x.shape
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim) / dim))
sinusoid_inp = (
Expand All @@ -13,27 +20,52 @@ def fixed_pos_embedding(x):
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)

def rotate_every_two(x):
"""
Rearranges the elements of the input tensor by rotating every two elements.
Args:
- x: Input tensor of shape (batch_size, seq_len, dim)
Returns:
- x: Rearranged tensor of shape (batch_size, seq_len, dim)
"""
x1 = x[:, :, ::2]
x2 = x[:, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
return x.flatten(-2)

def duplicate_interleave(m):
"""
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
Duplicates a matrix while interleaving the copy.
Args:
- m: Input matrix
Returns:
- m: Duplicated and interleaved matrix
"""
dim0 = m.shape[0]
m = m.view(-1, 1) # flatten the matrix
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
m = m.view(-1, 1)
m = m.repeat(1, 2)
m = m.view(dim0, -1)
return m

def apply_rotary_pos_emb(x, sin, cos, scale=1):
"""
Applies rotary positional embeddings to the input tensor.
Args:
- x: Input tensor of shape (batch_size, seq_len, dim)
- sin: Sine positional embeddings of shape (seq_len, dim)
- cos: Cosine positional embeddings of shape (seq_len, dim)
- scale: Scaling factor for the positional embeddings
Returns:
- x: Tensor with applied rotary positional embeddings
"""
sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos))
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
return (x * cos) + (rotate_every_two(x) * sin)


class XPOS(nn.Module):
def __init__(
self, head_dim, scale_base=512
Expand All @@ -46,6 +78,17 @@ def __init__(
)

def forward(self, x, offset=0, downscale=False):
"""
Forward pass of the XPOS module.
Args:
- x: Input tensor of shape (batch_size, seq_len, dim)
- offset: Offset value for positional embeddings
- downscale: Boolean indicating whether to downscale the positional embeddings
Returns:
- x: Tensor with applied rotary positional embeddings
"""
length = x.shape[1]
min_pos = -(length + offset) // 2
max_pos = length + offset + min_pos
Expand All @@ -61,4 +104,4 @@ def forward(self, x, offset=0, downscale=False):
scale = 1 / scale

x = apply_rotary_pos_emb(x, sin, cos, scale)
return x
return x
Loading

0 comments on commit c15a4bc

Please sign in to comment.