Skip to content

Commit

Permalink
[DOCS][FusedDenseGELUDense]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 20, 2023
1 parent bdc229a commit bddc2df
Show file tree
Hide file tree
Showing 6 changed files with 311 additions and 2 deletions.
140 changes: 140 additions & 0 deletions docs/zeta/nn/modules/fused_gelu_dense.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# `FusedDenseGELUDense`

## Overview

The `FusedDenseGELUDense` module is a versatile neural network layer designed for efficient computation of dense layers with GELU (Gaussian Error Linear Unit) activations. This documentation will provide an in-depth understanding of the module's architecture, purpose, parameters, and usage examples.

## Table of Contents

1. [Introduction](#introduction)
2. [Architecture](#architecture)
3. [Purpose](#purpose)
4. [Class Definition](#class-definition)
- [Parameters](#parameters)
- [Internal Layers](#internal-layers)
5. [Functionality and Usage](#functionality-and-usage)
- [Forward Pass](#forward-pass)
6. [Examples](#examples)
- [Basic Usage](#basic-usage)
- [Custom Configuration](#custom-configuration)
- [Quantization with bitsandbytes](#quantization-with-bitsandbytes)
7. [Additional Information](#additional-information)
8. [References](#references)

---

## 1. Introduction <a name="introduction"></a>

The `FusedDenseGELUDense` module combines dense layers with GELU activations in a single neural network layer. This fusion improves computational efficiency and is particularly useful in various deep learning applications.

## 2. Architecture <a name="architecture"></a>

The `FusedDenseGELUDense` layer consists of two dense sub-layers, each followed by a GELU activation function. It takes an input tensor and passes it through these sub-layers to produce the final output.

## 3. Purpose <a name="purpose"></a>

The primary purpose of the `FusedDenseGELUDense` layer is to efficiently compute dense transformations with GELU activations. It is designed for use in neural networks, providing a convenient way to incorporate these operations into deep learning models.

## 4. Class Definition <a name="class-definition"></a>

### Parameters <a name="parameters"></a>

- `dim` (int): Input dimension.
- `dim_out` (int): Output dimension.
- `bias` (bool, optional): Whether to include bias terms. Defaults to True.
- `has_fp16_weights` (bool, optional): Whether to use fp16 weights. Defaults to False.
- `threshold` (float, optional): Threshold for quantization. Defaults to 6.0.

### Internal Layers <a name="internal-layers"></a>

The `FusedDenseGELUDense` layer consists of the following internal layers:

1. `dense1`: The first dense layer.
2. `act`: The GELU activation function.
3. `dense2`: The second dense layer.

## 5. Functionality and Usage <a name="functionality-and-usage"></a>

### Forward Pass <a name="forward-pass"></a>

The `forward` method of the `FusedDenseGELUDense` layer performs the following operations:

1. Applies the first dense layer (`dense1`) to the input tensor.
2. Applies the GELU activation function (`act`) to the result.
3. Applies the second dense layer (`dense2`) to the GELU-activated output.

## 6. Examples <a name="examples"></a>

### Basic Usage <a name="basic-usage"></a>

Here's a basic example of using the `FusedDenseGELUDense` layer:

```python
import torch
from zeta.nn import FusedDenseGELUDense

# Create an instance of FusedDenseGELUDense
model = FusedDenseGELUDense(dim=512, dim_out=1024)

# Generate random input tensor
x = torch.randn(1, 512)

# Forward pass
out = model(x)

# Check the output shape
print(out.shape) # torch.Size([1, 512])
```

### Custom Configuration <a name="custom-configuration"></a>

You can customize the layer by specifying different parameters:

```python
# Create a custom FusedDenseGELUDense layer
custom_model = FusedDenseGELUDense(
dim=256, dim_out=512, bias=False, has_fp16_weights=True, threshold=4.0
)

# Generate random input tensor
x = torch.randn(1, 256)

# Forward pass with the custom configuration
out = custom_model(x)
```

### Quantization with bitsandbytes <a name="quantization-with-bitsandbytes"></a>

You can enable quantization using the `bitsandbytes` library by providing a quantized implementation of the dense layers:

```python
# Install bitsandbytes if not already installed
# pip install bitsandbytes

import torch
from zeta.nn import FusedDenseGELUDense

# Create an instance of FusedDenseGELUDense with quantization
quantized_model = FusedDenseGELUDense(
dim=512, dim_out=1024, has_fp16_weights=True, threshold=4.0
)

# Generate random input tensor
x = torch.randn(1, 512)

# Forward pass with quantization
out = quantized_model(x)
```

## 7. Additional Information <a name="additional-information"></a>

- The `FusedDenseGELUDense` layer efficiently combines dense and GELU activation operations.
- Custom configurations for bias, weight precision, and threshold are supported.
- Quantization can be enabled using the `bitsandbytes` library for further efficiency.

## 8. References <a name="references"></a>

For more information on GELU activations and dense layers in PyTorch, refer to the official PyTorch documentation:

- [GELU Activation Function](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html)
- [Dense Layer](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ nav:
- MultiModalAdapterDenseNetwork: "zeta/nn/modules/mm_adapter.md"
- CustomMLP: "zeta/nn/modules/custom_mlp.md"
- PolymorphicNeuronLayer: "zeta/nn/modules/polymorphic_activation.md"
- FusedDenseGELUDense: "zeta/nn/modules/fused_gelu_dense.md"
- zeta.nn.attention:
- FlashAttention: "zeta/nn/attention/flash_attention.md"
- MultiQueryAttention: "zeta/nn/attention/multiquery.md"
Expand Down
70 changes: 70 additions & 0 deletions tests/nn/modules/test_fused_gelu_dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
import torch
from zeta.nn.modules.fused_gelu_dense import FusedDenseGELUDense

def test_class_init():
model = FusedDenseGELUDense(512, 1024)

assert model.dim == 512
assert model.dim_out == 1024
assert model.bias == True
assert model.has_fp16_weights == False
assert model.threshold == 6.0

def test_class_init_with_args():
model = FusedDenseGELUDense(512, 1024, bias=False, has_fp16_weights=True, threshold=5.0)

assert model.dim == 512
assert model.dim_out == 1024
assert model.bias == False
assert model.has_fp16_weights == True
assert model.threshold == 5.0

def test_forward():
model = FusedDenseGELUDense(512, 1024)
x = torch.randn(1, 512)
out = model(x)

assert out.shape == torch.Size([1, 512])

def test_forward_with_different_input():
model = FusedDenseGELUDense(512, 1024)
x = torch.randn(2, 512)
out = model(x)

assert out.shape == torch.Size([2, 512])

def test_forward_with_different_dim():
model = FusedDenseGELUDense(256, 512)
x = torch.randn(1, 256)
out = model(x)

assert out.shape == torch.Size([1, 256])

def test_forward_with_different_dim_out():
model = FusedDenseGELUDense(512, 2048)
x = torch.randn(1, 512)
out = model(x)

assert out.shape == torch.Size([1, 512])

def test_forward_with_no_bias():
model = FusedDenseGELUDense(512, 1024, bias=False)
x = torch.randn(1, 512)
out = model(x)

assert out.shape == torch.Size([1, 512])

def test_forward_with_fp16_weights():
model = FusedDenseGELUDense(512, 1024, has_fp16_weights=True)
x = torch.randn(1, 512)
out = model(x)

assert out.shape == torch.Size([1, 512])

def test_forward_with_different_threshold():
model = FusedDenseGELUDense(512, 1024, threshold=5.0)
x = torch.randn(1, 512)
out = model(x)

assert out.shape == torch.Size([1, 512])
2 changes: 1 addition & 1 deletion zeta/cloud/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def zetacloud(
task_name: str = None,
cluster_name: str = "[ZetaTrainingRun]",
cluster_name: str = "ZetaTrainingRun",
cloud: Any = AWS(),
gpus: str = None,
filename: str = "train.py",
Expand Down
98 changes: 98 additions & 0 deletions zeta/nn/modules/fused_gelu_dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch
from torch import nn

class FusedDenseGELUDense(nn.Module):
"""FuseFusedDenseGELUDense
Args
dim (int): Input dimension
dim_out (int): Output dimension
bias (bool, optional): Bias. Defaults to True.
has_fp16_weights (bool, optional): Use fp16 weights. Defaults to False.
threshold (float, optional): Threshold for quantization. Defaults to 6.0.
Examples:
>>> x = torch.randn(1, 512)
>>> model = FusedDenseGELUDense(512, 1024)
>>> out = model(x)
>>> out.shape
torch.Size([1, 512])
"""
def __init__(
self,
dim: int,
dim_out: int,
bias: bool = True,
has_fp16_weights: bool = False,
threshold: float = 6.0,
*args,
**kwargs
):
super(FusedDenseGELUDense, self).__init__()
self.dim = dim
self.dim_out = dim_out
self.bias = bias
self.has_fp16_weights = has_fp16_weights
self.threshold = threshold


try:
import bitsandbytes as bnb
# Using bitsandbytes for quantization
self.dense1 = bnb.nn.Linear8bitLt(
dim,
dim_out,
bias=bias,
has_fp16_weights=has_fp16_weights,
threshold=threshold,
*args,
**kwargs
)

# Reverse
self.dense2 = bnb.nn.Linear8bitLt(
dim_out,
dim,
bias=bias,
has_fp16_weights=has_fp16_weights,
threshold=threshold,
*args,
**kwargs
)

except ModuleNotFoundError:
# Using torch.nn.Linear
self.dense1 = nn.Linear(
dim,
dim_out,
bias=bias
*args,
**kwargs
)

# Dense 2
self.dense2 = nn.Linear(
dim_out,
dim,
bias=bias
*args,
**kwargs
)

# Activation
self.act = nn.GELU()

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass
Args:
x (torch.Tensor): x input
Returns:
torch.Tensor: _description_
"""
x = self.dense1(x)
x = self.act(x)
x = self.dense2(x)
return x

2 changes: 1 addition & 1 deletion zeta/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@
"StableAdamWUnfused",
"GradientAscent",
"GradientEquilibrum",
"DecoupledLionW8Bit"
"DecoupledLionW8Bit",
]

0 comments on commit bddc2df

Please sign in to comment.