Skip to content

Commit

Permalink
Merge branch 'fix-ppl-test' of https://github.com/nod-ai/shark-ai int…
Browse files Browse the repository at this point in the history
…o fix-ppl-test
  • Loading branch information
archana-ramalingam committed Jan 25, 2025
2 parents 5da8ac1 + 4991a96 commit de38f53
Show file tree
Hide file tree
Showing 17 changed files with 489 additions and 94 deletions.
16 changes: 13 additions & 3 deletions app_tests/integration_tests/llm/shortfin/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Test fixtures and configurations."""

import hashlib
import pytest
from pathlib import Path
import hashlib
from tokenizers import Tokenizer, Encoding

from ..model_management import (
ModelProcessor,
Expand All @@ -25,8 +27,10 @@
),
"llama3.1_8b": ModelConfig(
source=ModelSource.LOCAL,
local_path=Path("/data/llama3.1/8b/llama8b_f16.irpa"),
model_file="llama8b_f16.irpa",
local_path=Path(
"/data/llama3.1/weights/8b/fp16/llama3.1_8b_fp16_instruct.irpa"
),
model_file="llama3.1_8b_fp16_instruct.irpa",
tokenizer_id="NousResearch/Meta-Llama-3.1-8B",
batch_sizes=(1, 4),
device_settings=device_settings.CPU,
Expand Down Expand Up @@ -89,3 +93,9 @@ def server(model_artifacts, request):

process.terminate()
process.wait()


@pytest.fixture(scope="module")
def encoded_prompt(model_artifacts: ModelArtifacts, request) -> list[int]:
tokenizer = Tokenizer.from_file(str(model_artifacts.tokenizer_path))
return tokenizer.encode(request.param).ids
82 changes: 66 additions & 16 deletions app_tests/integration_tests/llm/shortfin/cpu_llm_server_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Main test module for LLM server functionality."""

from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
import pytest
import requests
import uuid
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, Any
import uuid

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,16 +40,10 @@ class TestLLMServer:
pytest.param(
"llama3.1_8b",
{"model": "llama3.1_8b", "prefix_sharing": "none"},
marks=pytest.mark.xfail(
reason="llama3.1_8b irpa file not available on CI machine"
),
),
pytest.param(
"llama3.1_8b",
{"model": "llama3.1_8b", "prefix_sharing": "trie"},
marks=pytest.mark.xfail(
reason="llama3.1_8b irpa file not available on CI machine"
),
),
],
ids=[
Expand Down Expand Up @@ -78,6 +72,58 @@ def test_basic_generation(self, server: tuple[Any, int]) -> None:
message=f"Generation did not match expected pattern.\nExpected to start with: {expected_prefix}\nActual response: {response}",
)

@pytest.mark.parametrize(
"model_artifacts,server,encoded_prompt",
[
(
"open_llama_3b",
{"model": "open_llama_3b", "prefix_sharing": "none"},
"0 1 2 3 4 5 ",
),
(
"open_llama_3b",
{"model": "open_llama_3b", "prefix_sharing": "trie"},
"0 1 2 3 4 5 ",
),
pytest.param(
"llama3.1_8b",
{"model": "llama3.1_8b", "prefix_sharing": "none"},
"0 1 2 3 4 5 ",
),
pytest.param(
"llama3.1_8b",
{"model": "llama3.1_8b", "prefix_sharing": "trie"},
"0 1 2 3 4 5 ",
),
],
ids=[
"open_llama_3b_none_input_ids",
"open_llama_3b_trie_input_ids",
"llama31_8b_none_input_ids",
"llama31_8b_trie_input_ids",
],
indirect=True,
)
def test_basic_generation_input_ids(
self, server: tuple[Any, int], encoded_prompt
) -> None:
"""Tests basic text generation capabilities.
Args:
server: Tuple of (process, port) from server fixture
"""
process, port = server
assert process.poll() is None, "Server process terminated unexpectedly"

response = self._generate(encoded_prompt, port, input_ids=True)
expected_prefix = "6 7 8"
if not response.startswith(expected_prefix):
raise AccuracyValidationException(
expected=f"{expected_prefix}...",
actual=response,
message=f"Generation did not match expected pattern.\nExpected to start with: {expected_prefix}\nActual response: {response}",
)

@pytest.mark.parametrize(
"model_artifacts,server",
[
Expand Down Expand Up @@ -121,7 +167,7 @@ def test_concurrent_generation(
message=f"Concurrent generation did not match expected pattern.\nExpected to start with: {expected_prefix}\nActual response: {response}",
)

def _generate(self, prompt: str, port: int) -> str:
def _generate(self, prompt: str | list[int], port: int, input_ids=False) -> str:
"""Helper method to make generation request to server.
Args:
Expand All @@ -135,15 +181,19 @@ def _generate(self, prompt: str, port: int) -> str:
requests.exceptions.RequestException: If request fails
AccuracyValidationException: If response format is invalid
"""
payload = {
"sampling_params": {"max_completion_tokens": 15, "temperature": 0.7},
"rid": uuid.uuid4().hex,
"stream": False,
}
if input_ids:
payload["input_ids"] = prompt
else:
payload["text"] = prompt
response = requests.post(
f"http://localhost:{port}/generate",
headers={"Content-Type": "application/json"},
json={
"text": prompt,
"sampling_params": {"max_completion_tokens": 15, "temperature": 0.7},
"rid": uuid.uuid4().hex,
"stream": False,
},
json=payload,
timeout=30, # Add reasonable timeout
)
response.raise_for_status()
Expand Down
4 changes: 2 additions & 2 deletions docs/debug_flags.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ false.
to emit information. When disabled, it is a no-op.
* `enable_nan_checks`: Enables certain expensive nan checks that may be
included in the model.
* `save_goldens_path`: When set to a path, any tensor traced via
`trace_tensor(golden=True)` will be added to a safetensors file and output
* `trace_path`: When set to a path, any tensor traced via
`trace_tensor()` will be added to a safetensors file and output
in a deterministic way to the path.
* `use_custom_int_conv_kernel`: Uses custom kernels for integer convolution
arithmetic. This produces the most optimal compiled results but can impede
Expand Down
1 change: 1 addition & 0 deletions sharktank/requirements-tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ protobuf
pytest==8.0.0
pytest-html
pytest-xdist==3.5.0
safetensors>=0.4.5
85 changes: 69 additions & 16 deletions sharktank/sharktank/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,94 @@

from typing import Dict, Optional
from collections import OrderedDict
from collections.abc import Mapping
import torch
import torch.nn as nn

from ..types import InferenceTensor, Theta, AnyTensor
from ..utils import debugging
from .. import ops

__all__ = [
"BaseLayer",
"ThetaLayer",
]


def _set_recursively_submodules_default_trace_tensor_key_prefix(
module: nn.Module, prefix: str = ""
):
if isinstance(module, BaseLayer):
module.trace_tensor_key_prefix = prefix

for name, submodule in module.named_children():
submodule_prefix = f"{prefix}{name}."
_set_recursively_submodules_default_trace_tensor_key_prefix(
submodule, submodule_prefix
)


class BaseLayer(nn.Module):
"""Base class of all of our layers."""

def trace_tensor(
self, key: str, t: torch.Tensor, *, values: bool = True, golden: bool = False
):
debugging.trace_tensor(key, t, values=values, golden=golden)
def __init__(self):
super().__init__()
self._trace_tensor_key_prefix = ""

def set_recursively_submodules_default_trace_tensor_key_prefix(self):
"""All submodules get a trace key prefix that reflects their nesting with
respect to the parent module.
Example:
```
class A(BaseLayer):
def __init__(self):
...
self.b = ...
class B(BaseLayer):
def __init__(self):
...
self.c = ...
class C(BaseLayer):
def forward(self, x):
self.trace_tensor("x", x)
def trace_tensors(
a = A()
a.set_recursively_submodules_default_trace_tensor_key_prefix()
```
This will result in trace key prefixes
a -> ""
a.b -> "b."
a.b.c -> "b.c."
The trace_tensor method call in C.forward will result in a trace with key
"b.c.x".
"""
_set_recursively_submodules_default_trace_tensor_key_prefix(
self, self.trace_tensor_key_prefix
)

@property
def trace_tensor_key_prefix(self) -> str:
"""When tracing with self.trace_tensor all keys will be prefixed by this
string.
The default prefix is the empty string."""
return self._trace_tensor_key_prefix

@trace_tensor_key_prefix.setter
def trace_tensor_key_prefix(self, value: str):
self._trace_tensor_key_prefix = value

def trace_tensor(
self,
key: str,
tensors: Dict[str, torch.Tensor],
*,
values: bool = True,
golden: bool = False,
tensors: Dict[str, torch.Tensor] | list[torch.Tensor] | torch.Tensor,
):
debugging.trace_tensors(key, tensors, values=values, golden=golden)

def trace_golden(self, key: str, t: torch.Tensor):
debugging.trace_tensor(key, t, golden=True)

def trace_goldens(self, key: str, tensors: Dict[str, torch.Tensor]):
debugging.trace_tensors(key, tensors, golden=True)
debugging.trace_tensor(f"{self.trace_tensor_key_prefix}{key}", tensors)

def assert_not_nan(self, *ts: torch.Tensor):
"""Checks whether tensors have nan values in them.
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
self.assert_not_nan(attn_weights)

# Apply attention mask.
self.trace_tensor("attn_weights", attn_weights, values=False)
self.trace_tensor("attn_weights", attn_weights)
if attention_mask is not None:
# self.trace_tensor("attn_mask", attention_mask)
attn_weights = attn_weights + attention_mask
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/models/flux/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def forward(self, x: AnyTensor) -> AnyTensor:
return self.out_layer(x)


class EmbedND(torch.nn.Module):
class EmbedND(BaseLayer):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__()
self.dim = dim
Expand Down
5 changes: 0 additions & 5 deletions sharktank/sharktank/models/flux/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,6 @@ def make_random_theta(config: FluxParams, dtype: torch.dtype):
in_channels2 = 128
hidden_size = config.hidden_size
mlp_ratio = config.mlp_ratio
mlp_hidden_size = int((mlp_ratio - 1) * hidden_size)
mlp_hidden_size2 = int(mlp_ratio * hidden_size)
mlp_hidden_size3 = int(2 * (mlp_ratio - 1) * hidden_size)
mlp_hidden_size4 = int((mlp_ratio + 1) * hidden_size)
mlp_hidden_size5 = int((2 * mlp_ratio - 1) * hidden_size)
context_in_dim = config.context_in_dim
time_dim = 256
vec_dim = config.vec_in_dim
Expand Down
14 changes: 7 additions & 7 deletions sharktank/sharktank/models/punet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def forward(
# TODO: Verify on the fly upsampling is not needed (num_upsamplers != 0).
act_dtype = sample.dtype
bs, *_ = sample.shape
self.trace_goldens(
self.trace_tensor(
"inputs",
{
"sample": sample,
Expand All @@ -150,11 +150,11 @@ def forward(
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1).to(emb.dtype)
aug_embed = self.add_embedding(add_embeds)
emb = emb + aug_embed
self.trace_golden("emb", emb)
self.trace_tensor("emb", emb)

# 2. Pre-process.
sample = self.conv_in(sample)
self.trace_golden("preprocess", sample)
self.trace_tensor("preprocess", sample)

# 3. Down.
down_block_res_samples = (sample,)
Expand All @@ -167,7 +167,7 @@ def forward(
encoder_attention_mask=None,
)
down_block_res_samples += res_samples
self.trace_golden(f"down_block_{i}", sample)
self.trace_tensor(f"down_block_{i}", sample)

# 4. Mid.
sample, _ = self.mid_block(
Expand All @@ -177,7 +177,7 @@ def forward(
attention_mask=None,
encoder_attention_mask=None,
)
self.trace_golden("mid_block", sample)
self.trace_tensor("mid_block", sample)

# 5. Up.
for i, up_block in enumerate(self.up_blocks):
Expand All @@ -193,14 +193,14 @@ def forward(
attention_mask=None,
encoder_attention_mask=None,
)
self.trace_golden(f"up_block_{i}", sample)
self.trace_tensor(f"up_block_{i}", sample)

# 6. Post-process.
if self.conv_norm_out:
sample = self.conv_norm_out(sample)
sample = ops.elementwise(self.conv_act, sample)
sample = self.conv_out(sample)
self.trace_golden(f"output", sample)
self.trace_tensor(f"output", sample)
return sample

def _create_down_block(
Expand Down
6 changes: 3 additions & 3 deletions sharktank/sharktank/models/vae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def forward(
sample ('torch.Tensor') input latents of shape (batch_size, num_channels, height, width)
"""
self.trace_goldens(
self.trace_tensor(
"inputs",
{
"sample": sample,
Expand All @@ -89,10 +89,10 @@ def forward(
sample = self.post_quant_conv(sample)

sample = self.conv_in(sample)
self.trace_golden("conv_in", sample)
self.trace_tensor("conv_in", sample)
# TODO add training and gradient checkpointing support
sample = self.mid_block(sample, latent_embeds)
self.trace_golden("mid_block", sample)
self.trace_tensor("mid_block", sample)

sample = sample.to(self.upscale_dtype)
for up_block in self.up_blocks:
Expand Down
Loading

0 comments on commit de38f53

Please sign in to comment.