Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mean-Pooling Layer for Bert for packed/nested/remove_input_padding tensor t.region->getDataType() == DataType::kINT32 failed. #2554

Open
michaelfeil opened this issue Dec 10, 2024 · 0 comments

Comments

@michaelfeil
Copy link

michaelfeil commented Dec 10, 2024

I am trying to implement a Bert-Style pooling using the Python functional.py layer.

Relevant code:

if not default_net().plugin_config.remove_input_padding:

Relevant PyTorch code would be e.g.

outputs: [B, Tokens, Hidden_Dim]

# mean over pool
outputs = torch.sum(
            outputs * inputs["attention_mask"][:, :, None], dim=1) / torch.sum(inputs["attention_mask"], dim=1, keepdim=True)

With nested:

nested_outputs.to_padded(fill=0.0)

How would I implement this layer using tensorrt_llm abstractions?

from tensorrt_llm.functional import (
    ACT2FN,
    Tensor,
    concat,
    constant,
    cumsum,
    expand,
    index_select,
    select,
    shape,
    slice,
    unsqueeze,
    mean,
)
from tensorrt_llm.functional import sum as trt_sum

class PoolingLayer(Module):
    """Custom Function, performs average over dimension 1 of a tensor"""

    def __init__(self, hidden_size: int):
        super().__init__()        
        self.hidden_size = hidden_size

    def forward(self, hidden_states: "Tensor", input_lengths: "Tensor", remove_input_padding: bool, attention_mask: Optional["Tensor"], max_input_length: Optional[int]):
        """
        if remove_input_padding:
            hidden_states: [num_tokens, hidden_size]
            input_lengths: [batch_size]
            attention_mask: None
        if not remove_input_padding:
            hidden_states: [batch_size, num_tokens, hidden_size]
            input_lengths: [batch_size]
            attention_mask: [batch_size, num_tokens, hidden_size] for 1=participate, 0=not participate
        
        """
        if not remove_input_padding:
            # if not remove_input_padding, the attention_mask is set and [batch_size, num_tokens, hidden_size] for 1=participate, 0=not participate
            # dimension is [batch_size, num_tokens, hidden_size]
            sum_hidden = trt_sum(hidden_states * attention_mask, dim=1) 
            # sum_hidden [batch_size, hidden_size]            
            mean_hidden = sum_hidden / trt_sum(attention_mask, 1)
            return mean_hidden
        else:
            # state is [num_tokens, hidden_size]
            # e.g. input_lengths = [8, 5, 6]
            # hidden_states = [8+5+6=19, hidden_size]
            # mean_embeddings = [len(input_lengths), hidden_size]
            # mean_embeddings = [mean([hidden_states[0:8], dim=1), mean([hidden_states[8:13], dim=1), mean([hidden_states[13:19], dim=1)]
            # cumsum_indices = [0, 8, 13, 19]
            
            offsets = cumsum(input_lengths, 0)
            dtype_in = hidden_states.dtype
            
            # hidden_cumsum = cumsum(hidden_states, 0) # [num_tokens, hidden_size], CORRECT but does not compile
            hidden_cumsum = hidden_states # [num_tokens, hidden_size] -> INCORRECT, but does compile

            cumsum_states_with_zero = concat(
                [
                    constant(np.ascontiguousarray(np.zeros((1, self.hidden_size), dtype=trt_dtype_to_np(dtype_in)))), # [1, hidden_size]
                    hidden_cumsum # [num_tokens, hidden_size]
                ],
                dim=0 
            ) # size = [num_tokens+1, hidden_size]           
            
            start_indices = offsets - input_lengths  # [batch_size]
            end_indices = offsets  # [batch_size]
            
            # go from a [num_tokens+1, hidden_size] tensor to a 2x [batch_size, hidden_size] tensor with the cumsum
            sum_start = index_select(cumsum_states_with_zero, 0, start_indices)
            sum_end = index_select(cumsum_states_with_zero, 0, end_indices)
             
            s_hidden = (sum_end - sum_start) # [batch_size, hidden_size]
            
            m_hidden = s_hidden / unsqueeze(input_lengths, 1).cast("float32") # [batch_size, hidden_size]
            
            return m_hidden.cast(dtype_in)

Easier reproducer:

            first_token_indices = cumsum(
                concat(
                    [
                        0,
                        slice(
                            input_lengths,
                            starts=[0],
                            sizes=(
                                shape(input_lengths)
                                - constant(np.array([1], dtype=np.int32))
                            ),
                        ),
                    ]
                ),
                dim=0,
            )
            # hidden_states2 = cumsum(hidden_states, 0) # breaks
            hidden_states2 = hidden_states # works
            first_token_tensor = index_select(hidden_states2, 0, first_token_indices)
[12/11/2024-23:39:05] [TRT] [I] Compiler backend is used during engine build.
[12/11/2024-23:39:10] [TRT] [E] Error Code: 9: Skipping tactic 0x0000000000000000 due to exception Assertion t.region->getDataType() == DataType::kINT32 failed. 
[12/11/2024-23:39:10] [TRT] [E] IBuilder::buildSerializedNetwork: Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[BertModel/layers/11/post_layernorm/value_L145/get_tensor_L136/get_constant_tensor_L127/_create_constant_tensor_L111/constant_L1150/CONSTANT_0 + BertModel/layers/11/post_layernorm/layer_norm_L5214/broadcast_helper_L2800/expand_dims_like_L1905/expand_dims_L1737/view_L1645/SHUFFLE_0...BertModel/norm_L625/__truediv___L357/elementwise_binary_L2855/ELEMENTWISE_DIV_0]}.)
also tested things like
            # offsets = cumsum(input_lengths, 0)
            # mean_v = concat(
            #     [
            #         mean(slice(hidden_states, starts=[slice(offsets, [i], sizes=[0])], sizes=(input_lengths[i],)), dim=0)
            #         for i in arange(0, shape(input_lengths, 0), "int32")
            #     ]
            # )
            # return mean_v
@michaelfeil michaelfeil changed the title Mean-Pooling Layer for Bert for varlen/nested/remove_input_padding tensor Mean-Pooling Layer for Bert for packed/nested/remove_input_padding tensor Dec 10, 2024
@michaelfeil michaelfeil changed the title Mean-Pooling Layer for Bert for packed/nested/remove_input_padding tensor Mean-Pooling Layer for Bert for packed/nested/remove_input_padding tensor t.region->getDataType() == DataType::kINT32 failed. Dec 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant