Skip to content

Commit

Permalink
#4059: Add PADDED, BINARY modes to embeddings for optimized cache loo…
Browse files Browse the repository at this point in the history
…kups on repeated tokens

Update RM embeddings to handle new input format and embed types. Remove split_weights arg
  • Loading branch information
tt-aho committed Nov 30, 2023
1 parent ade5668 commit 52e9593
Show file tree
Hide file tree
Showing 14 changed files with 481 additions and 308 deletions.
51 changes: 34 additions & 17 deletions models/demos/metal_BERT_large_11/tt/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List, Optional, Tuple, Union
import torch
import tt_lib as ttl
from tt_lib.utils import pad_weight
from models.utility_functions import torch2tt_tensor


class TtEmbeddings:
Expand All @@ -15,6 +15,7 @@ def __init__(self, hugging_face_reference_model, device, model_config, tt_cache_
config = hugging_face_reference_model.config
state_dict = hugging_face_reference_model.state_dict()
self.embedding_dim = config.hidden_size
self.pad_token = config.pad_token_id

base_address = "bert.embeddings"
if tt_cache_path is not None:
Expand Down Expand Up @@ -49,30 +50,45 @@ def __init__(self, hugging_face_reference_model, device, model_config, tt_cache_
)
).to(device, self.model_config["EMBEDDINGS_LAYERNORM_BETA_MEMCFG"])
else:
self.word_embeddings_weight = ttl.tensor.Tensor(
pad_weight(state_dict[f"{base_address}.word_embeddings.weight"]),
self.word_embeddings_weight = torch2tt_tensor(
state_dict[f"{base_address}.word_embeddings.weight"],
device,
ttl.tensor.Layout.ROW_MAJOR,
model_config["INPUT_EMBEDDINGS_WEIGHTS_MEMCFG"],
model_config["INPUT_EMBEDDINGS_WEIGHTS_DTYPE"],
).to(device, model_config["INPUT_EMBEDDINGS_WEIGHTS_MEMCFG"])
)

self.position_embeddings_weight = ttl.tensor.Tensor(
pad_weight(state_dict[f"{base_address}.position_embeddings.weight"]),
self.position_embeddings_weight = torch2tt_tensor(
state_dict[f"{base_address}.position_embeddings.weight"],
device,
ttl.tensor.Layout.ROW_MAJOR,
model_config["INPUT_EMBEDDINGS_WEIGHTS_MEMCFG"],
model_config["INPUT_EMBEDDINGS_WEIGHTS_DTYPE"],
).to(device, model_config["INPUT_EMBEDDINGS_WEIGHTS_MEMCFG"])
)

self.token_type_embeddings_weight = ttl.tensor.Tensor(
pad_weight(state_dict[f"{base_address}.token_type_embeddings.weight"]),
self.token_type_embeddings_weight = torch2tt_tensor(
state_dict[f"{base_address}.token_type_embeddings.weight"],
device,
ttl.tensor.Layout.ROW_MAJOR,
model_config["INPUT_EMBEDDINGS_WEIGHTS_MEMCFG"],
model_config["INPUT_EMBEDDINGS_WEIGHTS_DTYPE"],
).to(device, model_config["INPUT_EMBEDDINGS_WEIGHTS_MEMCFG"])
)

self.layerNorm_gamma = ttl.tensor.Tensor(
self.layerNorm_gamma = torch2tt_tensor(
state_dict[f"{base_address}.LayerNorm.weight"].reshape([1, 1, -1, 32]),
device,
ttl.tensor.Layout.ROW_MAJOR,
model_config["EMBEDDINGS_LAYERNORM_GAMMA_MEMCFG"],
model_config["EMBEDDINGS_LAYERNORM_GAMMA_DTYPE"],
).to(device, model_config["EMBEDDINGS_LAYERNORM_GAMMA_MEMCFG"])
)

self.layerNorm_beta = ttl.tensor.Tensor(
self.layerNorm_beta = torch2tt_tensor(
state_dict[f"{base_address}.LayerNorm.bias"].reshape([1, 1, -1, 32]),
device,
ttl.tensor.Layout.ROW_MAJOR,
model_config["EMBEDDINGS_LAYERNORM_BETA_MEMCFG"],
model_config["EMBEDDINGS_LAYERNORM_BETA_DTYPE"],
).to(device, model_config["EMBEDDINGS_LAYERNORM_BETA_MEMCFG"])
)

self.layerNorm_eps = config.layer_norm_eps

Expand Down Expand Up @@ -120,17 +136,18 @@ def __call__(
inputs_embeds = ttl.tensor.embeddings(
input_ids,
self.word_embeddings_weight,
split_weights=False,
tilized=True,
embeddings_type=ttl.tensor.EmbeddingsType.PADDED,
pad_token=self.pad_token,
output_mem_config=self.model_config["OUTPUT_EMBEDDINGS_MEMCFG"],
)
input_ids.deallocate()

token_type_embeddings = ttl.tensor.embeddings(
token_type_ids,
self.token_type_embeddings_weight,
split_weights=False,
tilized=True,
embeddings_type=ttl.tensor.EmbeddingsType.BINARY,
output_mem_config=self.model_config["OUTPUT_EMBEDDINGS_MEMCFG"],
)
token_type_ids.deallocate()
Expand All @@ -146,8 +163,8 @@ def __call__(
position_embeddings_tt_tensor = ttl.tensor.embeddings(
position_ids,
self.position_embeddings_weight,
split_weights=False,
tilized=True,
embeddings_type=ttl.tensor.EmbeddingsType.GENERIC,
output_mem_config=self.model_config["OUTPUT_EMBEDDINGS_MEMCFG"],
)
# Deallocate inputs_embeds and token_type_embeddings here to avoid having to move final output
Expand Down
18 changes: 5 additions & 13 deletions tests/tt_eager/python_api_testing/unit_testing/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,21 @@

# SPDX-License-Identifier: Apache-2.0

import math
from pathlib import Path
import sys
import time
import os

import pytest
import torch

import tt_lib as ttl

from tt_lib.utils import is_close
from models.utility_functions import is_wormhole_b0, skip_for_wormhole_b0
from models.utility_functions import skip_for_wormhole_b0
from loguru import logger
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import (
comp_equal,
)


def run_embeddings_tests(
batch_size, num_embeddings, embedding_dim, num_rows, dtype, in0_mem_config, out_mem_config, device, fused=False
batch_size, num_embeddings, embedding_dim, num_rows, dtype, in0_mem_config, out_mem_config, device, tilized=False
):
torch.manual_seed(1234)

Expand All @@ -38,9 +33,9 @@ def run_embeddings_tests(
input_tensor = tensor.Tensor(input_rows_torch, ttl.tensor.DataType.UINT32).to(dev, in0_mem_config)
weights_tensor = tensor.Tensor(weights_torch, dtype).to(dev, in0_mem_config)

ttz = tensor.embeddings(input_tensor, weights_tensor, False, fused, out_mem_config)
ttz = tensor.embeddings(input_tensor, weights_tensor, tilized, output_mem_config=out_mem_config)

if fused:
if tilized:
tt_data = ttz.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()
else:
tt_data = ttz.cpu().to_torch()
Expand All @@ -57,9 +52,6 @@ def run_embeddings_tests(
assert passing_pcc


import pytest


@skip_for_wormhole_b0()
@pytest.mark.parametrize(
"out_mem_config",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ void kernel_main() {
uint32_t num_sticks_per_block = get_arg_val<uint32_t>(3);
uint32_t start_id = get_arg_val<uint32_t>(4);

constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(0) == 1;
constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(0);

constexpr bool src0_is_dram = get_compile_time_arg_val(1) == 1;
#define src_stick_size_is_pow2 get_compile_time_arg_val(2) == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ void kernel_main() {
uint32_t num_sticks_per_block = get_arg_val<uint32_t>(3);
uint32_t start_id = get_arg_val<uint32_t>(4);

constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(0) == 1;
constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(0);

constexpr bool dst0_is_dram = get_compile_time_arg_val(1) == 1;
#define dst_stick_size_is_pow2 get_compile_time_arg_val(2) == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ void kernel_main() {
uint32_t num_sticks = get_arg_val<uint32_t>(2);
uint32_t start_id = get_arg_val<uint32_t>(3);

constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(0) == 1;
constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(0);

constexpr bool dst0_is_dram = get_compile_time_arg_val(1) == 1;
#define dst_stick_size_is_pow2 get_compile_time_arg_val(2) == 1
Expand Down
Loading

0 comments on commit 52e9593

Please sign in to comment.