-
-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Kye
committed
Aug 22, 2023
1 parent
4c28a6a
commit c15a4bc
Showing
4 changed files
with
435 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# XPOS Module Documentation | ||
------------------------- | ||
|
||
### Architecture | ||
|
||
The XPOS module is a part of a neural network model and is implemented as a subclass of `torch.nn.Module`. It consists of several functions and a class that work together to apply rotary positional embeddings to an input tensor. | ||
|
||
### Purpose | ||
|
||
The purpose of the XPOS module is to incorporate positional information into the input tensor of a neural network model. It achieves this by generating fixed positional embeddings and applying them to the input tensor using rotary positional encoding techniques. This allows the model to capture the sequential order and relative positions of the input elements, which can be beneficial for tasks such as natural language processing and time series analysis. | ||
|
||
### Functions and Methods | ||
|
||
1. `fixed_pos_embedding(x)`: Generates fixed positional embeddings for the input tensor. | ||
|
||
- Args: | ||
- `x` (torch.Tensor): Input tensor of shape `(seq_len, dim)`. | ||
- Returns: | ||
- `sin` (torch.Tensor): Sine positional embeddings of shape `(seq_len, dim)`. | ||
- `cos` (torch.Tensor): Cosine positional embeddings of shape `(seq_len, dim)`. | ||
2. `rotate_every_two(x)`: Rearranges the elements of the input tensor by rotating every two elements. | ||
|
||
- Args: | ||
- `x` (torch.Tensor): Input tensor of shape `(batch_size, seq_len, dim)`. | ||
- Returns: | ||
- `x` (torch.Tensor): Rearranged tensor of shape `(batch_size, seq_len, dim)`. | ||
3. `duplicate_interleave(m)`: Duplicates a matrix while interleaving the copy. | ||
|
||
- Args: | ||
- `m` (torch.Tensor): Input matrix. | ||
- Returns: | ||
- `m` (torch.Tensor): Duplicated and interleaved matrix. | ||
4. `apply_rotary_pos_emb(x, sin, cos, scale=1)`: Applies rotary positional embeddings to the input tensor. | ||
|
||
- Args: | ||
- `x` (torch.Tensor): Input tensor of shape `(batch_size, seq_len, dim)`. | ||
- `sin` (torch.Tensor): Sine positional embeddings of shape `(seq_len, dim)`. | ||
- `cos` (torch.Tensor): Cosine positional embeddings of shape `(seq_len, dim)`. | ||
- `scale` (float): Scaling factor for the positional embeddings (default: 1). | ||
- Returns: | ||
- `x` (torch.Tensor): Tensor with applied rotary positional embeddings. | ||
5. `XPOS(head_dim, scale_base=512)`: XPOS module class. | ||
|
||
- Args: | ||
- `head_dim` (int): Dimensionality of the input tensor. | ||
- `scale_base` (int): Base value for scaling the positional embeddings (default: 512). | ||
- Methods: | ||
- `forward(x, offset=0, downscale=False)`: Forward pass of the XPOS module. | ||
- Args: | ||
- `x` (torch.Tensor): Input tensor of shape `(batch_size, seq_len, dim)`. | ||
- `offset` (int): Offset value for positional embeddings (default: 0). | ||
- `downscale` (bool): Boolean indicating whether to downscale the positional embeddings (default: False). | ||
- Returns: | ||
- `x` (torch.Tensor): Tensor with applied rotary positional embeddings. | ||
|
||
### Usage Examples | ||
|
||
1. Applying XPOS module to an input tensor: | ||
|
||
``` | ||
import torch | ||
from xpos import XPOS | ||
# Create an instance of the XPOS module | ||
xpos = XPOS(head_dim=256) | ||
# Generate a random input tensor | ||
x = torch.randn(1, 10, 256) | ||
# Apply the XPOS module to the input tensor | ||
output = xpos(x) | ||
``` | ||
|
||
Copy code | ||
|
||
2. Applying XPOS module with offset and downscaling: | ||
|
||
``` | ||
import torch | ||
from xpos import XPOS | ||
# Create an instance of the XPOS module | ||
xpos = XPOS(head_dim=512) | ||
# Generate a random input tensor | ||
x = torch.randn(1, 20, 512) | ||
# Apply the XPOS module to the input tensor with offset and downscaling | ||
output = xpos(x, offset=2, downscale=True) | ||
``` | ||
|
||
Copy code | ||
|
||
3. Using the individual functions of the XPOS module: | ||
|
||
``` | ||
import torch | ||
from xpos import fixed_pos_embedding, apply_rotary_pos_emb | ||
# Generate fixed positional embeddings | ||
scale = torch.randn(10, 256) | ||
sin, cos = fixed_pos_embedding(scale) | ||
# Apply rotary positional embeddings to an input tensor | ||
x = torch.randn(1, 10, 256) | ||
output = apply_rotary_pos_emb(x, sin, cos, scale=0.5) | ||
``` | ||
|
||
Copy code | ||
|
||
Note: The above examples assume that the `xpos.py` file |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
## Documentation | ||
|
||
### Overview | ||
|
||
The provided module comprises utility functions and classes to streamline specific operations with Python data structures and PyTorch models. The main aspects of the module are: | ||
|
||
- Checking the existence of a value. | ||
- Implementing custom call behavior through classes. | ||
- Custom decorators for function calls. | ||
- Dictionary manipulation. | ||
- Initialization of PyTorch layer parameters. | ||
|
||
### Functions and Classes | ||
|
||
1. **exists(val: Any) -> bool**: | ||
Checks if the provided value is not `None`. | ||
|
||
2. **default(val: Any, d: Any) -> Any**: | ||
Returns the value if it's not `None`; otherwise, it returns a default value. | ||
|
||
3. **once(fn: Callable) -> Callable**: | ||
A decorator ensuring that the function is only called once. | ||
|
||
4. **eval_decorator(fn: Callable) -> Callable**: | ||
A decorator for `torch.nn.Module` methods to switch the module to `eval` mode during the function call and revert to its original mode afterwards. | ||
|
||
5. **cast_tuple(val: Any, depth: int) -> Tuple**: | ||
Casts a value to a tuple with a specific depth. | ||
|
||
6. **maybe(fn: Callable) -> Callable**: | ||
A decorator that calls the function only if its first argument exists. | ||
|
||
7. **always**: | ||
A class that always returns the specified value when called. | ||
|
||
8. **not_equals** and **equals**: | ||
Classes that, when instantiated with a value, check if another value is (not) equal to the specified value. | ||
|
||
9. **init_zero_(layer: nn.Module) -> None**: | ||
Initializes the weights and biases of a torch layer to zero. | ||
|
||
10. **pick_and_pop(keys: List[str], d: Dict) -> Dict**: | ||
Extracts values from a dictionary based on provided keys. | ||
|
||
11. **group_dict_by_key(cond: Callable, d: Dict) -> Tuple[Dict, Dict]**: | ||
Groups dictionary keys based on a given condition. | ||
|
||
12. **string_begins_with(prefix: str, str: str) -> bool**: | ||
Checks if a string starts with a specific prefix. | ||
|
||
13. **group_by_key_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]**: | ||
Groups dictionary items by keys starting with a specific prefix. | ||
|
||
14. **groupby_prefix_and_trim(prefix: str, d: Dict) -> Tuple[Dict, Dict]**: | ||
Similar to `group_by_key_prefix` but also removes the prefix from keys. | ||
|
||
### Usage Examples | ||
|
||
1. **Using the `once` decorator**: | ||
|
||
```python | ||
from zeta import once | ||
|
||
@once | ||
def greet(): | ||
print("Hello, World!") | ||
|
||
greet() # prints "Hello, World!" | ||
greet() # Does nothing on the second call | ||
``` | ||
|
||
2. **Using the `eval_decorator` with PyTorch**: | ||
|
||
```python | ||
import torch.nn as nn | ||
from zeta import eval_decorator | ||
|
||
class SimpleModel(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.layer = nn.Linear(10, 10) | ||
|
||
@eval_decorator | ||
def predict(self, x): | ||
return self.layer(x) | ||
|
||
model = SimpleModel() | ||
input_tensor = torch.randn(1, 10) | ||
output = model.predict(input_tensor) # Automatically switches to eval mode and back | ||
``` | ||
|
||
3. **Dictionary Manipulation with Prefix Functions**: | ||
|
||
```python | ||
from zeta import group_by_key_prefix | ||
|
||
sample_dict = { | ||
"user_name": "John", | ||
"user_age": 25, | ||
"order_id": 12345, | ||
"order_date": "2023-01-01" | ||
} | ||
|
||
user_data, order_data = group_by_key_prefix("user_", sample_dict) | ||
print(user_data) # {'user_name': 'John', 'user_age': 25} | ||
print(order_data) # {'order_id': 12345, 'order_date': '2023-01-01'} | ||
``` | ||
|
||
This module is a collection of general-purpose utility functions and classes, making many common operations more concise. It's beneficial when working with PyTorch models and various data manipulation tasks. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.