Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use tilized operators for Mistral AI speedup on inference #3812

Closed
muthutt opened this issue Nov 15, 2023 · 8 comments
Closed

Use tilized operators for Mistral AI speedup on inference #3812

muthutt opened this issue Nov 15, 2023 · 8 comments
Assignees
Labels
mistral Mistral AI bringup models Models that run in tt-metal

Comments

@muthutt
Copy link
Contributor

muthutt commented Nov 15, 2023

  • use tilized operators for Mistral AI migrate from row-major
  • tilized tensors are faster on GS/WH TT Tensix cores
  • use the tt_lib.tensor.tilize_with_val_padding when necessary

Reference: https://github.com/tenstorrent-metal/tt-metal/tree/a3740def58c3b8672b7e3279261506ae70b97810/models/demos/resnet/tt
/metalResnetBlock50.py

@muthutt muthutt added the mistral Mistral AI bringup label Nov 15, 2023
@muthutt muthutt changed the title use tilized operators for Mistral AI migrate from row-major use tilized operators for Mistral AI speedup on inference Nov 15, 2023
@boris-drazic boris-drazic added the models Models that run in tt-metal label Nov 16, 2023
@saichandax saichandax changed the title use tilized operators for Mistral AI speedup on inference Use tilized operators for Mistral AI speedup on inference Nov 28, 2023
@Sudharsan-V
Copy link
Contributor

The tensors in the mistral model are converted to the TILE layout to use the tilized operators and the commit is updated for the same.
Corresponding PR: #4029

@boris-drazic
Copy link
Contributor

Further performance improvements for tiles...

Our goal is to only have tensors in tile layout from input to output.
There are some OPs in the model that will not work with tilized tensors, but they seem to be confined to rotary embedding and cache updates.
Lets have tiles in every other place.

Starting in models/experimental/mistral/tt/mistral_transformer.py when we loop over layers

h = torch_to_tt_tensor_rm(h, self.device, put_on_device=False)
for layer in self.layers:
      h = layer(h, freqs_cis, positions, mask)

Put h in tile layout here instead of row major (dimension with sequence length will need to be padded to 32) and then h will go in layer as tile and come out as tile.

Then in models/experimental/mistral/tt/mistral_transformer_block.py x is tilized and padded when sent to RMSNorm

r = self.attention.forward(self.attention_norm(x), freqs_cis, positions, mask)

Update how RMSNorm works so that is just applies tt_lib.tensor.rmsnorm to input tile tensor. Since we have full rows of data followed by full rows of padding RMSNorm should work fine.
All rows with data have no padding in them, so for each row we will divided with correct number inside tt_lib.tensor.rmsnorm and get correct results.

Next up is models/experimental/mistral/tt/mistral_attention.py
First thing here is Linears that shuld be fine with zero padding

xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

then lets skip all padding and reshaping until rotary embedding.
At this point convert to torch and in torch slice out padding so that we have xq shape [batch, sequence, 32, 128] and xk shape [batch, sequence, 8, 128].

Then after key and value are computed with key, value = self.repeat_kv
we have shape [batch, sequence, 32, 128] for query, key, and value.
Can we just pad second dim from sequence to 32 with zeroes and convert to tile layout and then do transpose and computation of score all in tile with no more conversions?

@Sudharsan-V
Copy link
Contributor

Sudharsan-V commented Nov 30, 2023

The commit has been revised based on the provided suggestions.
In models/experimental/mistral/tt/mistral_transformer.py, the input 'h' has been transformed into a TILE layout, ensuring that it enters and exits the loop in the TILE layout.

In models/experimental/mistral/tt/mistral_transformer_block.py, the performance of the rms_norm is good when the input is in TILE Layout.

In models/experimental/mistral/tt/mistral_attention.py, I have removed the padding entirely up to the rotary embedding. However, I couldn't eliminate reshape completely. The linear operation produces an output xq with shape[1, 1, 32, 4096], xk with shape [1, 1, 32, 1024] and xv with shape [1, 1, 32, 1024] which needs to be transformed into [1, 11, 32, 128] (xq), [1, 11, 8, 128] (xk) and [1, 11, 8, 128] (xv) for feeding into the rotary embedding. Therefore, after the linear operation, I converted it to a torch tensor, sliced the tensor, and reshaped it to achieve the desired shape.

Corresponding PR: #4029

@Sudharsan-V
Copy link
Contributor

Sudharsan-V commented Dec 7, 2023

The commit for the mistral model is updated by optimizing the rotary_embed method.
Previously, the conversion of freqs_cis (PyTorch complex tensor) to tt_lib complex tensor was happening
#transformer_block + ( max_tokens * #transformer_block) times.

Prefill stage: #transformer_block
Decode stage: max_tokens * #transformer_block
Total: #transformer_block + ( max_tokens * #transformer_block)

In the optimized version of the mistral model, the conversion of freqs_cis (PyTorch complex tensor) to tt_lib complex tensor will happen 1 + max_tokens times.

Prefill stage:  1
Decode stage: max_tokens 
Total: 1 + max_tokens 

Note: #transformer_block = 32(Number of transformer_block)

Corresponding PR: #4029

@muthutt
Copy link
Contributor Author

muthutt commented Dec 7, 2023

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    return torch.polar(torch.ones_like(freqs), freqs)  # complex64

@vigneshkeerthivasanx you can also attempt to move this code to on device since we have tt_lib.tensor.arange,
tt_lib.tensor, and implement torch.polar as shown below,

def tt_lib_polar(abs,angle):
      s = tt_lib.tensor.sin(angle)
      c = tt_lib.tensor.cos(angle)
      r = tt_lib.tensor.mul(abs,c)
      i = tt_lib.tensor.mul(abs,s)
      return tt_lib.tensor.complex_tensor(r,i)

this should make the tensors on device and hold them there saving some time for moving the data back/forth

@Sudharsan-V
Copy link
Contributor

@muthutt , the commit is updated by incorporating the above-mentioned comments.

After modifying the method precompute_freqs_cis, there is a very slight degradation in the pcc.
Previously the pcc was 0.9928005641025528
After modifying the method precompute_freqs_cis, the pcc is 0.9921967026489718.

Even though the variation in the pcc looks insignificant, this modification has greatly affected the result of the gs-demo.

Input Prompt: ['A man is sitting on a roof ']
Output Prompt before modifying the `precompute_freqs_cis`: 'A man is sitting on a roof 100 meters above the ground.\n\nA man is sitting on'
Output Prompt after modification: 'A man is sitting on a roof 100 feet of a a.10.10.1'

Note: I was not able to use tt_lib's power directly, so I have used the following logic:
pow(exponent, base) = exp(exponent * log(base))

I have used tt_lib.tensor.exp, tt_lib.tensor.mul ad tt_lib.tensor.log accordigly

Corresponding PR: #4029

@muthutt
Copy link
Contributor Author

muthutt commented Dec 11, 2023 via email

@Sudharsan-V
Copy link
Contributor

Sudharsan-V commented Dec 12, 2023

The commit is reverted to preserve the sensible output.
Corresponding PR: #4029
Issue for TT-Scatter op: #4294

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mistral Mistral AI bringup models Models that run in tt-metal
Projects
None yet
Development

No branches or pull requests

4 participants