Skip to content

Commit

Permalink
Add trace_tensor op and refactor tensor tracing
Browse files Browse the repository at this point in the history
We don't have an op that dispatches to the underlying
iree.turbine.ops.iree.trace_tensor.

With this change we can trace into safetensors files when executing
both in eager and with IREE.
Routes all tracing through the new tracing op. The user can set a
desired sink.

Added some functionality to prefix trace keys based on the module
structure.

Something that I could not preserve is the tensor trace counter.
It is a global variable and when exporting we pass multiple times
through the traced function, which makes the counter inconsistent.
The only why to distinglush traced tensors is through the trace keys.

The distinction between tensor tracing and goldens tracing has been
remove at the point of issuing a trace. If the user needs different
treatment they can customize the trace sink.
  • Loading branch information
sogartar committed Jan 20, 2025
1 parent 1f50538 commit 83d36b3
Show file tree
Hide file tree
Showing 14 changed files with 367 additions and 72 deletions.
4 changes: 2 additions & 2 deletions docs/debug_flags.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ false.
to emit information. When disabled, it is a no-op.
* `enable_nan_checks`: Enables certain expensive nan checks that may be
included in the model.
* `save_goldens_path`: When set to a path, any tensor traced via
`trace_tensor(golden=True)` will be added to a safetensors file and output
* `trace_path`: When set to a path, any tensor traced via
`trace_tensor()` will be added to a safetensors file and output
in a deterministic way to the path.
* `use_custom_int_conv_kernel`: Uses custom kernels for integer convolution
arithmetic. This produces the most optimal compiled results but can impede
Expand Down
1 change: 1 addition & 0 deletions sharktank/requirements-tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ protobuf
pytest==8.0.0
pytest-html
pytest-xdist==3.5.0
safetensors>=0.4.5
50 changes: 34 additions & 16 deletions sharktank/sharktank/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,59 @@

from typing import Dict, Optional
from collections import OrderedDict
from collections.abc import Mapping
import torch
import torch.nn as nn

from ..types import InferenceTensor, Theta, AnyTensor
from ..utils import debugging
from .. import ops

__all__ = [
"BaseLayer",
"ThetaLayer",
]


def _set_recursively_submodules_default_trace_tensor_key_prefix(
module: nn.Module, prefix: str = ""
):
if isinstance(module, BaseLayer):
module.trace_tensor_key_prefix = prefix

for name, submodule in module.named_children():
submodule_prefix = f"{prefix}{name}."
_set_recursively_submodules_default_trace_tensor_key_prefix(
submodule, submodule_prefix
)


class BaseLayer(nn.Module):
"""Base class of all of our layers."""

def trace_tensor(
self, key: str, t: torch.Tensor, *, values: bool = True, golden: bool = False
):
debugging.trace_tensor(key, t, values=values, golden=golden)
def __init__(self):
super().__init__()
self._trace_tensor_key_prefix = ""

def trace_tensors(
def set_recursively_submodules_default_trace_tensor_key_prefix(self):
_set_recursively_submodules_default_trace_tensor_key_prefix(
self, self.trace_tensor_key_prefix
)

@property
def trace_tensor_key_prefix(self) -> str:
return self._trace_tensor_key_prefix

@trace_tensor_key_prefix.setter
def trace_tensor_key_prefix(self, value: str):
self._trace_tensor_key_prefix = value

def trace_tensor(
self,
key: str,
tensors: Dict[str, torch.Tensor],
*,
values: bool = True,
golden: bool = False,
tensors: Dict[str, torch.Tensor] | list[torch.Tensor] | torch.Tensor,
):
debugging.trace_tensors(key, tensors, values=values, golden=golden)

def trace_golden(self, key: str, t: torch.Tensor):
debugging.trace_tensor(key, t, golden=True)

def trace_goldens(self, key: str, tensors: Dict[str, torch.Tensor]):
debugging.trace_tensors(key, tensors, golden=True)
debugging.trace_tensor(f"{self.trace_tensor_key_prefix}{key}", tensors)

def assert_not_nan(self, *ts: torch.Tensor):
"""Checks whether tensors have nan values in them.
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/paged_llama_attention_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
self.assert_not_nan(attn_weights)

# Apply attention mask.
self.trace_tensor("attn_weights", attn_weights, values=False)
self.trace_tensor("attn_weights", attn_weights)
if attention_mask is not None:
# self.trace_tensor("attn_mask", attention_mask)
attn_weights = attn_weights + attention_mask
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/models/flux/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def forward(self, x: AnyTensor) -> AnyTensor:
return self.out_layer(x)


class EmbedND(torch.nn.Module):
class EmbedND(BaseLayer):
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super().__init__()
self.dim = dim
Expand Down
5 changes: 0 additions & 5 deletions sharktank/sharktank/models/flux/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,6 @@ def make_random_theta(config: FluxParams, dtype: torch.dtype):
in_channels2 = 128
hidden_size = config.hidden_size
mlp_ratio = config.mlp_ratio
mlp_hidden_size = int((mlp_ratio - 1) * hidden_size)
mlp_hidden_size2 = int(mlp_ratio * hidden_size)
mlp_hidden_size3 = int(2 * (mlp_ratio - 1) * hidden_size)
mlp_hidden_size4 = int((mlp_ratio + 1) * hidden_size)
mlp_hidden_size5 = int((2 * mlp_ratio - 1) * hidden_size)
context_in_dim = config.context_in_dim
time_dim = 256
vec_dim = config.vec_in_dim
Expand Down
14 changes: 7 additions & 7 deletions sharktank/sharktank/models/punet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def forward(
# TODO: Verify on the fly upsampling is not needed (num_upsamplers != 0).
act_dtype = sample.dtype
bs, *_ = sample.shape
self.trace_goldens(
self.trace_tensor(
"inputs",
{
"sample": sample,
Expand All @@ -150,11 +150,11 @@ def forward(
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1).to(emb.dtype)
aug_embed = self.add_embedding(add_embeds)
emb = emb + aug_embed
self.trace_golden("emb", emb)
self.trace_tensor("emb", emb)

# 2. Pre-process.
sample = self.conv_in(sample)
self.trace_golden("preprocess", sample)
self.trace_tensor("preprocess", sample)

# 3. Down.
down_block_res_samples = (sample,)
Expand All @@ -167,7 +167,7 @@ def forward(
encoder_attention_mask=None,
)
down_block_res_samples += res_samples
self.trace_golden(f"down_block_{i}", sample)
self.trace_tensor(f"down_block_{i}", sample)

# 4. Mid.
sample, _ = self.mid_block(
Expand All @@ -177,7 +177,7 @@ def forward(
attention_mask=None,
encoder_attention_mask=None,
)
self.trace_golden("mid_block", sample)
self.trace_tensor("mid_block", sample)

# 5. Up.
for i, up_block in enumerate(self.up_blocks):
Expand All @@ -193,14 +193,14 @@ def forward(
attention_mask=None,
encoder_attention_mask=None,
)
self.trace_golden(f"up_block_{i}", sample)
self.trace_tensor(f"up_block_{i}", sample)

# 6. Post-process.
if self.conv_norm_out:
sample = self.conv_norm_out(sample)
sample = ops.elementwise(self.conv_act, sample)
sample = self.conv_out(sample)
self.trace_golden(f"output", sample)
self.trace_tensor(f"output", sample)
return sample

def _create_down_block(
Expand Down
6 changes: 3 additions & 3 deletions sharktank/sharktank/models/vae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def forward(
sample ('torch.Tensor') input latents of shape (batch_size, num_channels, height, width)
"""
self.trace_goldens(
self.trace_tensor(
"inputs",
{
"sample": sample,
Expand All @@ -89,10 +89,10 @@ def forward(
sample = self.post_quant_conv(sample)

sample = self.conv_in(sample)
self.trace_golden("conv_in", sample)
self.trace_tensor("conv_in", sample)
# TODO add training and gradient checkpointing support
sample = self.mid_block(sample, latent_embeds)
self.trace_golden("mid_block", sample)
self.trace_tensor("mid_block", sample)

sample = sample.to(self.upscale_dtype)
for up_block in self.up_blocks:
Expand Down
7 changes: 7 additions & 0 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,13 @@ def to_default(tensor: Tensor, *args, **kwargs):
return unbox_tensor(tensor).to(*args, **kwargs)


@trace_tensor.override(AllOfExprsVariadic(IsOfType(Tensor, InferenceTensor)))
def trace_tensor(key: str, *tensors: tuple[AnyTensor]):
if len(tensors) != 1:
raise ValueError("Tracing more than one tensor at a time is not supported.")
iree.turbine.ops.iree.trace_tensor(key, unshard(tensors[0]))


@transfer_to_logical_device.override(Tensor)
def transfer_to_logical_device_default(tensor: Tensor, ordinal: int):
return iree.turbine.ops.iree.transfer_to_logical_device(
Expand Down
18 changes: 18 additions & 0 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"softmax",
"squeeze",
"to",
"trace_tensor",
"transfer_to_logical_device",
"transpose",
"unflatten",
Expand Down Expand Up @@ -1025,6 +1026,23 @@ def _to_trampoline(d: SignatureDispatcher, tensor: AnyTensor, *args, **kwargs):
d.fail(dispatch_args)


@overridable
def trace_tensor(key: str, *tensors: tuple[AnyTensor]):
...


@trace_tensor.trampoline
def _transfer_to_logical_device_trampoline(
d: SignatureDispatcher, key: str, *tensors: tuple[AnyTensor]
):
for override in d.find_overrides(tensors):
result = override(key, *tensors)
if result is not NotImplemented:
return override, result
else:
d.fail(tensors)


@overridable
def transfer_to_logical_device(tensor: AnyTensor, ordinal: int) -> AnyTensor:
"""Transfer the tensor to a device with ordinal `ordinal`."""
Expand Down
83 changes: 52 additions & 31 deletions sharktank/sharktank/utils/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Tools for debugging models."""
from typing import Dict, Optional

from typing import Callable, Dict, Optional, Tuple
from collections.abc import Mapping
from dataclasses import dataclass
import re
import os
from pathlib import Path
from typing import Sequence
import iree.turbine.support.debugging

import torch

Expand All @@ -29,8 +29,7 @@
class DebugFlags:
enable_tensor_trace: bool = False
enable_nan_checks: bool = False
save_goldens_path: Optional[Path] = None
golden_sequence_value: int = 0
trace_path: Optional[Path] = None

# Feature flags.
# Enables use of custom IREE kernels in lieu of PyTorch general
Expand All @@ -52,8 +51,8 @@ def set(self, part: str):
self.enable_tensor_trace = logical_sense
elif name == "enable_nan_checks":
self.enable_nan_checks = logical_sense
elif name == "save_goldens_path":
self.save_goldens_path = Path(value)
elif name == "trace_path":
self.trace_path = Path(value)
elif name == "use_custom_iree_kernels":
self.use_custom_iree_kernels = logical_sense
else:
Expand Down Expand Up @@ -84,39 +83,61 @@ def parse_from_env() -> "DebugFlags":


def trace_tensor(
key: str, t: torch.Tensor, *, values: bool = True, golden: bool = False
key: str, tensors: Dict[str, torch.Tensor] | list[torch.Tensor] | torch.Tensor
):
trace_tensors(key, {"default": t}, values=values, golden=golden)
if not flags.enable_tensor_trace:
return

if isinstance(tensors, Mapping):
sub_keys = list(tensors.keys())
sub_keys.sort()

def trace_tensors(
key: str,
tensors: Dict[str, torch.Tensor],
*,
values: bool = True,
golden: bool = False,
):
if golden:
if flags.save_goldens_path:
_save_goldens(key, tensors)
return
if not flags.enable_tensor_trace:
for sub_key in sub_keys:
trace_tensor(f"{key}.{sub_key}", tensors[sub_key])
return
for name, t in tensors.items():
if t is not None:
values_repr = repr(t) if values else "...elided..."
print(f"::: TRACE {key}:{name}({list(t.shape), t.dtype}) =\n{values_repr}")

if isinstance(tensors, torch.Tensor):
tensors = (tensors,)

from .. import ops

ops.trace_tensor(key, *tensors)


TraceKey = str
TraceTensors = Callable[[TraceKey, *Tuple[torch.Tensor, ...]], None]


def set_trace_tensor_callback(callback: TraceTensors):
iree.turbine.support.debugging.trace_tensor_callback = callback


def get_trace_tensor_callback() -> Optional[TraceTensors]:
return iree.turbine.support.debugging.trace_tensor_callback


def null_trace_tensor_callback(key: str, *tensors: Tuple[torch.Tensor]):
return


def trace_tensor_to_safetensors_callback(key: str, *tensors: Tuple[torch.Tensor]):
if len(tensors) == 1:
tensors_in_dict = {"": t for t in tensors}
else:
tensors_in_dict = {f"{i}": t for i, t in enumerate(tensors)}
trace_tensors_to_safetensors(key, tensors_in_dict)


set_trace_tensor_callback(trace_tensor_to_safetensors_callback)


def _save_goldens(key: str, tensors: Dict[str, torch.Tensor]):
next_sequence = flags.golden_sequence_value
flags.golden_sequence_value += 1
def trace_tensors_to_safetensors(key: str, tensors: Dict[str, torch.Tensor]):
# Sanitize as path.
key = re.sub("[" + re.escape(r"""#~!@$%^&*()[]{}:;"'""") + "]", "", key)
from safetensors.torch import save_file

path: Path = flags.save_goldens_path / f"{next_sequence:04d}_{key}.safetensors"
path: Path = flags.trace_path / f"{key}.safetensors"
path.parent.mkdir(parents=True, exist_ok=True)
print(f"::: SAVE GOLDEN {path}")
print(f"::: TRACE TENSOR(S) {path}")
non_none_tensors = {k: v.contiguous() for k, v in tensors.items() if v is not None}
save_file(non_none_tensors, path)
save_file(non_none_tensors, filename=path)
Loading

0 comments on commit 83d36b3

Please sign in to comment.