Skip to content

Commit

Permalink
Add attention mask as an option
Browse files Browse the repository at this point in the history
  • Loading branch information
archana-ramalingam committed Jan 24, 2025
1 parent 3251afe commit 5dd3fe6
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 204 deletions.
10 changes: 9 additions & 1 deletion sharktank/sharktank/evaluate/perplexity_iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
tensor_parallelism_size,
attention_kernel,
block_seq_stride,
use_attention_mask,
):
self.torch_device = torch_device
self.iree_device = iree_device
Expand All @@ -80,6 +81,7 @@ def __init__(
self.attention_dtype = torch.float16
self.tensor_parallelism_size = tensor_parallelism_size
self.attention_kernel = attention_kernel
self.use_attention_mask = use_attention_mask

def timeit(func):
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -139,6 +141,7 @@ def compile_model(self, weight_path_str):
attention_kernel=self.attention_kernel,
tensor_parallelism_size=self.tensor_parallelism_size,
block_seq_stride=self.block_seq_stride,
use_attention_mask=self.use_attention_mask,
)
vmfb_path = export_artifacts.get_artifacts()
return vmfb_path
Expand Down Expand Up @@ -398,6 +401,7 @@ def run_perplexity(
attention_kernel,
num_prompts,
block_seq_stride,
use_attention_mask,
):
start = time.time()
perplexity = Perplexity(
Expand All @@ -409,6 +413,7 @@ def run_perplexity(
tensor_parallelism_size=tensor_parallelism_size,
attention_kernel=attention_kernel,
block_seq_stride=block_seq_stride,
use_attention_mask=use_attention_mask,
)

perplexity.get_prompts(num_prompts=num_prompts)
Expand All @@ -434,7 +439,7 @@ def main(argv):
"--attention-kernel",
type=str,
default="decomposed",
choices=["decomposed", "torch_sdpa"],
choices=["decomposed", "torch"],
)
parser.add_argument(
"--block-seq-stride",
Expand Down Expand Up @@ -478,6 +483,8 @@ def main(argv):
weight_path = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)

use_attention_mask = True

ppl = run_perplexity(
weight_path=weight_path,
weight_path_str=str(args.irpa_file),
Expand All @@ -491,6 +498,7 @@ def main(argv):
attention_kernel=args.attention_kernel,
num_prompts=args.num_prompts,
block_seq_stride=args.block_seq_stride,
use_attention_mask=use_attention_mask,
)

logger.info(f"\n{json.dumps(ppl, indent=2)}")
Expand Down
15 changes: 14 additions & 1 deletion sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def main():
help="Enables strictness during export",
action="store_true",
)
parser.add_argument(
"--use-attention-mask",
help="Generates attention mask during export",
action="store_true",
)

cli.add_quantization_options(parser)
cli.add_model_options(parser)
Expand Down Expand Up @@ -220,16 +225,24 @@ def _(model, tokens, seq_lens, seq_block_ids, cs):
else:
cache_tensors = cs

attention_mask = None
if args.use_attention_mask:
sl = tokens.shape[1]
input_mask = model.input_mask(seq_lens, sl)
attention_mask = model.attention_mask(input_mask)

if llama_config.tensor_parallelism_size != 1:
shard_count = llama_config.tensor_parallelism_size

tokens = ops.replicate(tokens, count=shard_count)
if attention_mask:
attention_mask = ops.replicate(attention_mask, count=shard_count)
seq_block_ids = ops.replicate(seq_block_ids, count=shard_count)
cache_tensors = repack_cache(cs, cache_shard_dim)

logits = model.prefill(
tokens,
attention_mask=None, # We rely on causal attention
attention_mask=attention_mask,
seq_block_ids=seq_block_ids,
cache_state=cache_tensors,
)
Expand Down
4 changes: 4 additions & 0 deletions sharktank/sharktank/utils/export_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
tensor_parallelism_size: int,
block_seq_stride: int,
iree_hal_target_device: str,
use_attention_mask: bool = False,
):
self.sharktank_dir = str(
Path(os.path.dirname(os.path.abspath(__file__))).parent.parent.parent
Expand All @@ -104,6 +105,7 @@ def __init__(
self.attention_kernel = attention_kernel
self.tensor_parallelism_size = tensor_parallelism_size
self.block_seq_stride = block_seq_stride
self.use_attention_mask = use_attention_mask

def timeit(func):
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -187,6 +189,8 @@ def export_to_mlir(
if self.attention_kernel in ["decomposed", "torch"]:
export_args.append("--attention-kernel")
export_args.append(self.attention_kernel)
if self.use_attention_mask:
export_args.append("--use-attention-mask")

cwd = self.sharktank_dir
cmd = subprocess.list2cmdline(export_args)
Expand Down
Loading

0 comments on commit 5dd3fe6

Please sign in to comment.