From 18715a4f05de65278225c5ff055189b5d69215da Mon Sep 17 00:00:00 2001 From: erenup Date: Fri, 29 Dec 2023 14:17:35 -0500 Subject: [PATCH 1/4] add bertforsequenceclassification; add attention mask in bert; add more test cases; add real data input test cases --- tensorrt_llm/models/__init__.py | 2 +- tensorrt_llm/models/bert/model.py | 86 +++++++++++++++- tests/model/test_bert.py | 165 +++++++++++++++++++++++++----- 3 files changed, 224 insertions(+), 29 deletions(-) diff --git a/tensorrt_llm/models/__init__.py b/tensorrt_llm/models/__init__.py index ee886121e..68dd7ea7f 100755 --- a/tensorrt_llm/models/__init__.py +++ b/tensorrt_llm/models/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from .baichuan.model import BaichuanForCausalLM -from .bert.model import BertForQuestionAnswering, BertModel +from .bert.model import BertForQuestionAnswering, BertModel, BertForSequenceClassification from .bloom.model import BloomForCausalLM, BloomModel from .chatglm.model import ChatGLMHeadModel, ChatGLMModel from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder diff --git a/tensorrt_llm/models/bert/model.py b/tensorrt_llm/models/bert/model.py index 7d24a664c..3e519a0ed 100644 --- a/tensorrt_llm/models/bert/model.py +++ b/tensorrt_llm/models/bert/model.py @@ -12,13 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from email.quoprimime import unquote import math import numpy as np from ..._common import default_net from ...functional import (bert_attention, concat, constant, expand, matmul, - shape, slice, softmax, split) + shape, slice, softmax, split, cast, unsqueeze, select, ACT2FN) from ...layers import MLP, ColumnLinear, Embedding, LayerNorm, Linear, RowLinear from ...mapping import Mapping from ...module import Module, ModuleList @@ -212,7 +213,8 @@ def __init__(self, mapping=Mapping(), dtype=None): super().__init__() - + self.max_position_embeddings = max_position_embeddings + self.dtype = dtype self.embedding = BertEmbedding( vocab_size=vocab_size, hidden_size=hidden_size, @@ -238,9 +240,28 @@ def forward(self, hidden_states=None): hidden_states = self.embedding(input_ids, position_ids, token_type_ids) + # creat extended_attention_mask as https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py + seq_len_2d = concat([1, shape(input_ids, 1)]) + position_ids_buffer = constant( + np.expand_dims( + np.arange(self.max_position_embeddings).astype(np.int32), 0)) + tmp_position_ids = slice(position_ids_buffer, + starts=[0, 0], + sizes=seq_len_2d) + tmp_position_ids = expand(tmp_position_ids, shape(input_ids)) #BxL + tmp_input_lengths = unsqueeze(input_lengths, 1) #Bx1 + tmp_input_lengths = expand(tmp_input_lengths, shape(input_ids)) #BxL + mask = tmp_position_ids < tmp_input_lengths # BxL + mask = cast(mask, 'int32') + extended_attention_mask = unsqueeze(mask, 1) + extended_attention_mask = unsqueeze(extended_attention_mask, 1) # Bx1x1xL + extended_attention_mask = (1 - extended_attention_mask) * -214748364 # a small negative number in int32 range + extended_attention_mask = cast(extended_attention_mask, self.dtype) + for layer in self.layers: hidden_states = layer(hidden_states=hidden_states, - input_lengths=input_lengths) + input_lengths=input_lengths, + attention_mask=extended_attention_mask) return hidden_states @@ -287,3 +308,62 @@ def forward(self, logits = self.qa_outputs(hidden_states) return logits + +class BertPooler(Module): + def __init__(self, hidden_size, dtype): + super().__init__() + self.dense = Linear(hidden_size, hidden_size, dtype=dtype) + self.activation = ACT2FN['tanh'] + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = select(hidden_states, 1, 0) + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertForSequenceClassification(Module): + + def __init__(self, + num_layers, + num_heads, + hidden_size, + vocab_size, + hidden_act, + max_position_embeddings, + type_vocab_size, + num_labels=2, + mapping=Mapping(), + dtype=None): + super().__init__() + self.bert = BertModel(num_layers=num_layers, + num_heads=num_heads, + hidden_size=hidden_size, + vocab_size=vocab_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, + mapping=mapping, + dtype=dtype) + self.num_labels = num_labels + self.pooler = BertPooler(hidden_size=hidden_size, dtype=dtype) + self.classifier = Linear(hidden_size, num_labels, dtype=dtype) + + def forward(self, + input_ids=None, + input_lengths=None, + token_type_ids=None, + position_ids=None, + hidden_states=None): + + hidden_states = self.bert.forward(input_ids=input_ids, + input_lengths=input_lengths, + token_type_ids=token_type_ids, + position_ids=position_ids, + hidden_states=hidden_states) + pooled_output = self.pooler(hidden_states) + logits = self.classifier(pooled_output) + + return logits diff --git a/tests/model/test_bert.py b/tests/model/test_bert.py index ddfa6525e..29f1c4d04 100644 --- a/tests/model/test_bert.py +++ b/tests/model/test_bert.py @@ -25,12 +25,12 @@ import tensorrt as trt # isort: on from parameterized import parameterized -from transformers import BertConfig, BertForQuestionAnswering, BertModel +from transformers import BertConfig, BertForQuestionAnswering, BertModel, BertForSequenceClassification, AutoTokenizer import tensorrt_llm import tensorrt_llm.runtime from tensorrt_llm import Builder -from tensorrt_llm._utils import trt_dtype_to_torch +from tensorrt_llm._utils import trt_dtype_to_torch, str_dtype_to_trt from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType from tensorrt_llm.runtime import TensorInfo @@ -148,18 +148,41 @@ def load_from_hf_qa_bert(tensorrt_llm_qa_bert, tensorrt_llm_qa_bert.qa_outputs.bias.value = states['qa_outputs.bias'].to( torch_dtype).cpu().numpy() +def load_from_hf_cls_bert(tensorrt_llm_qa_bert, + hf_qa_bert, + hf_bert_config, + rank=0, + tensor_parallel=1, + fp16=False): + load_from_hf_bert(tensorrt_llm_qa_bert.bert, hf_qa_bert, hf_bert_config, + rank, tensor_parallel, fp16) + states = hf_qa_bert.state_dict() + + torch_dtype = torch.float16 if fp16 else torch.float32 + + tensorrt_llm_qa_bert.pooler.dense.weight.value = states[ + 'bert.pooler.dense.weight'].to(torch_dtype).cpu().numpy() + tensorrt_llm_qa_bert.pooler.dense.bias.value = states[ + 'bert.pooler.dense.bias'].to(torch_dtype).cpu().numpy() + + tensorrt_llm_qa_bert.classifier.weight.value = states[ + 'classifier.weight'].to(torch_dtype).cpu().numpy() + tensorrt_llm_qa_bert.classifier.bias.value = states[ + 'classifier.bias'].to(torch_dtype).cpu().numpy() + class TestBert(unittest.TestCase): def load_test_cases(): - models = [BertModel.__name__, BertForQuestionAnswering.__name__] + models = [BertForSequenceClassification.__name__, BertModel.__name__, BertForQuestionAnswering.__name__] + model_dirs = ['', 'bert-base-uncased'] # add more tests for read data. test_cases = [] test_cases += product(models, [False], [False], [False], - [ContextFMHAType.disabled], ['float32']) + [ContextFMHAType.disabled], ['float32'], model_dirs) test_cases += product(models, [False], [True], [True], [ ContextFMHAType.disabled, ContextFMHAType.enabled, ContextFMHAType.enabled_with_fp32_acc - ], ['float16']) + ], ['float16'], model_dirs) return test_cases @@ -171,7 +194,7 @@ def custom_name_func(testcase_func, param_num, param): @parameterized.expand(load_test_cases, name_func=custom_name_func) def test_bert(self, model, use_refit, use_plugin, fast_building, - context_fmha_type, dtype): + context_fmha_type, dtype, model_dir): tensorrt_llm.logger.set_level('error') fp16 = (dtype == 'float16') world_size = 1 @@ -224,22 +247,36 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, [bs_range]) ])) # Initialize model - bert_config = BertConfig( - vocab_size=vocab_size, - hidden_size=hidden_size, - num_hidden_layers=num_layers, - num_attention_heads=num_heads, - intermediate_size=4 * hidden_size, - hidden_act=hidden_act, - max_position_embeddings=max_position_embeddings, - torch_dtype=torch_dtype, - ) + if model_dir: + bert_config = BertConfig.from_pretrained(model_dir, torch_dtype=torch_dtype) + vocab_size = bert_config.vocab_size + hidden_size = bert_config.hidden_size + num_layers = bert_config.num_hidden_layers + num_heads = bert_config.num_attention_heads + hidden_size = bert_config.intermediate_size // 4 + hidden_act = bert_config.hidden_act + max_position_embeddings = bert_config.max_position_embeddings + else: + bert_config = BertConfig( + vocab_size=vocab_size, + hidden_size=hidden_size, + num_hidden_layers=num_layers, + num_attention_heads=num_heads, + intermediate_size=4 * hidden_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + torch_dtype=torch_dtype, + ) output_name = "hidden_states" if model == BertModel.__name__: - hf_bert = BertModel( - bert_config, - add_pooling_layer=False).cuda().to(torch_dtype).eval() + if model_dir: + hf_bert = BertModel.from_pretrained( + model_dir).cuda().to(torch_dtype).eval() + else: + hf_bert = BertModel( + bert_config, + add_pooling_layer=False).cuda().to(torch_dtype).eval() tensorrt_llm_bert = tensorrt_llm.models.BertModel( num_layers=num_layers, num_heads=num_heads, @@ -260,8 +297,12 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, tensor_parallel=world_size, fp16=fp16) elif model == BertForQuestionAnswering.__name__: - hf_bert = BertForQuestionAnswering(bert_config).cuda().to( - torch_dtype).eval() + if model_dir: + hf_bert = BertForQuestionAnswering.from_pretrained( + model_dir).cuda().to(torch_dtype).eval() + else: + hf_bert = BertForQuestionAnswering(bert_config).cuda().to( + torch_dtype).eval() output_name = "logits" tensorrt_llm_bert = tensorrt_llm.models.BertForQuestionAnswering( num_layers=num_layers, @@ -284,6 +325,36 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, rank=rank, tensor_parallel=world_size, fp16=fp16) + elif model == BertForSequenceClassification.__name__: + if model_dir: + hf_bert = BertForSequenceClassification.from_pretrained( + model_dir).cuda().to(torch_dtype).eval() + else: + hf_bert = BertForSequenceClassification(bert_config).cuda().to( + torch_dtype).eval() + output_name = "logits" + tensorrt_llm_bert = tensorrt_llm.models.BertForSequenceClassification( + num_layers=num_layers, + num_heads=num_heads, + hidden_size=hidden_size, + vocab_size=vocab_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + type_vocab_size=bert_config.type_vocab_size, + num_labels= + 2, # just make it a const here, seems to me not worth as a config + mapping=tensorrt_llm.Mapping( + world_size=world_size, + rank=rank, + tp_size=world_size), # TP only + dtype=trt_dtype) + load_from_hf_cls_bert(tensorrt_llm_bert, + hf_bert, + bert_config, + rank=rank, + tensor_parallel=world_size, + fp16=fp16) + else: assert False, f"Unknown model {model}" # Prepare @@ -298,6 +369,9 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, output_dtype = trt.float16 if fp16 else trt.float32 output.mark_output(output_name, output_dtype) + for k, v in tensorrt_llm_bert.named_network_outputs(): + network._mark_output(v, k, str_dtype_to_trt(dtype)) + # Build engine engine_buffer = builder.build_engine(network, builder_config) session = tensorrt_llm.runtime.Session.from_serialized_engine( @@ -307,9 +381,25 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, # Inference # The dtype of input_ids should be queried from the engine, # for testing purpose, int32 is fine for now. - input_ids = torch.randint(100, (batch_size, input_len)).int().cuda() - input_lengths = input_len * torch.ones( - (batch_size, ), dtype=torch.int32, device='cuda') + attention_mask = None + if model_dir: + hf_tokenizer = AutoTokenizer.from_pretrained(model_dir) + input_strings = ['Hello world!' for _ in range(batch_size)] + input_ids_with_padding = hf_tokenizer( + input_strings, padding='max_length', max_length=input_len) + input_ids_without_padding = hf_tokenizer( + input_strings) + input_ids = torch.tensor(input_ids_with_padding['input_ids']).int().cuda() + input_lengths = [len(x) for x in input_ids_without_padding['input_ids']] + input_lengths = torch.tensor( + input_lengths, device=input_ids.device, dtype=torch.int32) + attention_mask = torch.tensor( + input_ids_with_padding['attention_mask'], + device=input_ids.device, dtype=torch.int32) + else: + input_ids = torch.randint(100, (batch_size, input_len)).int().cuda() + input_lengths = input_len * torch.ones( + (batch_size, ), dtype=torch.int32, device='cuda') output_info = session.infer_shapes([ TensorInfo('input_ids', trt.DataType.INT32, @@ -335,11 +425,23 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, res = outputs[output_name] with torch.no_grad(): - hf_outputs = hf_bert.forward(input_ids) + if model_dir: + hf_outputs = hf_bert.forward( + input_ids=input_ids, + attention_mask=attention_mask) + else: + hf_outputs = hf_bert.forward(input_ids) torch.cuda.synchronize() if model == BertModel.__name__: ref = hf_outputs.last_hidden_state + if use_plugin and model_dir: + # when we use_plugin and have real-data model_dir and input + # We do not need to care about the output of padding positions: + attention_mask_tmp = attention_mask.unsqueeze(-1) + ref = ref * attention_mask_tmp + res = res * attention_mask_tmp + np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy(), atol=1e-2, @@ -351,6 +453,13 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, ref_start_logits = hf_outputs.start_logits ref_end_logits = hf_outputs.end_logits + if use_plugin and model_dir: + # when we use_plugin and have real-data model_dir and input + # We do not need to care about the output of padding positions: + ref_start_logits = ref_start_logits * attention_mask + ref_end_logits = ref_end_logits * attention_mask + res_start_logits = res_start_logits * attention_mask + res_end_logits = res_end_logits * attention_mask np.testing.assert_allclose(ref_start_logits.cpu().numpy(), res_start_logits.cpu().numpy(), @@ -358,6 +467,12 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, np.testing.assert_allclose(ref_end_logits.cpu().numpy(), res_end_logits.cpu().numpy(), atol=1.5e-2) + elif model == BertForSequenceClassification.__name__: + ref = hf_outputs.logits + np.testing.assert_allclose(ref.cpu().numpy(), + res.cpu().numpy(), + atol=1e-2, + rtol=1e-2) if __name__ == '__main__': From 6c2f91f2419dceb0e1e55791a65884bc456f5838 Mon Sep 17 00:00:00 2001 From: erenup Date: Fri, 29 Dec 2023 16:18:20 -0500 Subject: [PATCH 2/4] copy from bert to create roberta --- examples/roberta/.gitignore | 2 + examples/roberta/README.md | 54 ++ examples/roberta/base_benchmark/config.json | 22 + .../config.json | 22 + examples/roberta/build.py | 251 +++++++++ examples/roberta/large_benchmark/config.json | 22 + .../config.json | 22 + examples/roberta/run.py | 121 +++++ examples/roberta/weight.py | 129 +++++ tensorrt_llm/models/roberta/__init__.py | 14 + tensorrt_llm/models/roberta/model.py | 369 ++++++++++++++ tests/model/test_roberta.py | 479 ++++++++++++++++++ 12 files changed, 1507 insertions(+) create mode 100644 examples/roberta/.gitignore create mode 100644 examples/roberta/README.md create mode 100644 examples/roberta/base_benchmark/config.json create mode 100644 examples/roberta/base_with_attention_plugin_benchmark/config.json create mode 100644 examples/roberta/build.py create mode 100644 examples/roberta/large_benchmark/config.json create mode 100644 examples/roberta/large_with_attention_plugin_benchmark/config.json create mode 100644 examples/roberta/run.py create mode 100644 examples/roberta/weight.py create mode 100644 tensorrt_llm/models/roberta/__init__.py create mode 100644 tensorrt_llm/models/roberta/model.py create mode 100644 tests/model/test_roberta.py diff --git a/examples/roberta/.gitignore b/examples/roberta/.gitignore new file mode 100644 index 000000000..4b9ff316e --- /dev/null +++ b/examples/roberta/.gitignore @@ -0,0 +1,2 @@ +bert* +*.log diff --git a/examples/roberta/README.md b/examples/roberta/README.md new file mode 100644 index 000000000..7a9ff87db --- /dev/null +++ b/examples/roberta/README.md @@ -0,0 +1,54 @@ +# BERT + +This document explains how to build the [BERT](https://huggingface.co/docs/transformers/model_doc/bert) model using TensorRT-LLM. It also describes how to run on a single GPU and two GPUs. + +## Overview + +The TensorRT-LLM BERT implementation can be found in [`tensorrt_llm/models/bert/model.py`](../../tensorrt_llm/models/bert/model.py). The TensorRT-LLM BERT example +code is located in [`examples/bert`](./). There are four main files in that folder: + + * [`build.py`](./build.py) to build the [TensorRT](https://developer.nvidia.com/tensorrt) engine(s) needed to run the BERT model, + * [`run.py`](./run.py) to run the inference on an input text, + +## Build and run BERT on a single GPU + +In this example, TensorRT-LLM builds TensorRT engine(s) from the [HuggingFace BERT](https://huggingface.co/docs/transformers/model_doc/bert) model. +Use the following command to build the TensorRT engine: + +```bash +python3 build.py --dtype=float16 --log_level=verbose + +# Enable the special TensorRT-LLM BERT Attention plugin (--use_bert_attention_plugin) to increase runtime performance. +python3 build.py --dtype=float16 --log_level=verbose --use_bert_attention_plugin float16 +# Enable half accumulation for attention BMM1 (applied to unfused MHA plugins) +python3 build.py --dtype=float16 --log_level=verbose --use_bert_attention_plugin float16 --enable_qk_half_accum +``` + +The following command can be used to run the BERT model on a single GPU: + +```bash +python3 run.py +``` + +#### Fused MultiHead Attention (FMHA) + +You can enable the FMHA kernels for BERT by adding `--enable_context_fmha` to the invocation of `build.py`. Note that it is disabled by default because of possible accuracy issues due to the use of Flash Attention. + +If you find that the default fp16 accumulation (`--enable_context_fmha`) cannot meet the requirement, you can try to enable fp32 accumulation by adding `--enable_context_fmha_fp32_acc`. However, it is expected to see performance drop. + +Note `--enable_context_fmha` / `--enable_context_fmha_fp32_acc` has to be used together with `--use_bert_attention_plugin float16`. + +## Build and run BERT on two GPUs + +The following two commands can be used to build TensorRT engines to run BERT on two GPUs. The first command builds one engine for the first GPU. The second command builds another engine for the second GPU. + +```bash +python3 build.py --world_size=2 --rank=0 +python3 build.py --world_size=2 --rank=1 +``` + +The following command can be used to run the inference on 2 GPUs. It uses MPI with `mpirun`. + +```bash +mpirun -n 2 python3 run.py +``` diff --git a/examples/roberta/base_benchmark/config.json b/examples/roberta/base_benchmark/config.json new file mode 100644 index 000000000..025f0383e --- /dev/null +++ b/examples/roberta/base_benchmark/config.json @@ -0,0 +1,22 @@ +{ + "builder_config": { + "max_batch_size": 256, + "max_input_len": 512, + "name": "bert", + "precision": "float16", + "tensor_parallel": 1, + "use_refit": false + }, + "plugin_config": { + "bert_attention_plugin": "float16", + "context_fmha_enabled": true, + "gemm_plugin": "float16", + "gpt_attention_plugin": false, + "identity_plugin": false, + "layernorm_plugin": false, + "layernorm_quantization_plugin": false, + "nccl_plugin": false, + "smooth_quant_gemm_plugin": false, + "weight_only_quant_matmul_plugin": false + } +} diff --git a/examples/roberta/base_with_attention_plugin_benchmark/config.json b/examples/roberta/base_with_attention_plugin_benchmark/config.json new file mode 100644 index 000000000..025f0383e --- /dev/null +++ b/examples/roberta/base_with_attention_plugin_benchmark/config.json @@ -0,0 +1,22 @@ +{ + "builder_config": { + "max_batch_size": 256, + "max_input_len": 512, + "name": "bert", + "precision": "float16", + "tensor_parallel": 1, + "use_refit": false + }, + "plugin_config": { + "bert_attention_plugin": "float16", + "context_fmha_enabled": true, + "gemm_plugin": "float16", + "gpt_attention_plugin": false, + "identity_plugin": false, + "layernorm_plugin": false, + "layernorm_quantization_plugin": false, + "nccl_plugin": false, + "smooth_quant_gemm_plugin": false, + "weight_only_quant_matmul_plugin": false + } +} diff --git a/examples/roberta/build.py b/examples/roberta/build.py new file mode 100644 index 000000000..090fe33ef --- /dev/null +++ b/examples/roberta/build.py @@ -0,0 +1,251 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from collections import OrderedDict + +# isort: off +import torch +import tensorrt as trt +# isort: on +from transformers import BertConfig, BertForQuestionAnswering, BertModel + +import tensorrt_llm +from tensorrt_llm.builder import Builder +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.network import net_guard +from tensorrt_llm.plugin.plugin import ContextFMHAType + +from weight import load_from_hf_bert, load_from_hf_qa_bert # isort:skip + + +def get_engine_name(model, dtype, tp_size, rank): + return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', + type=int, + default=1, + help='Tensor parallelism size') + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--dtype', + type=str, + default='float16', + choices=['float16', 'float32']) + parser.add_argument('--timing_cache', type=str, default='model.cache') + parser.add_argument('--log_level', type=str, default='info') + parser.add_argument('--vocab_size', type=int, default=51200) + parser.add_argument('--n_labels', type=int, default=2) + parser.add_argument('--n_layer', type=int, default=24) + parser.add_argument('--n_positions', type=int, default=1024) + parser.add_argument('--n_embd', type=int, default=1024) + parser.add_argument('--n_head', type=int, default=16) + parser.add_argument('--hidden_act', type=str, default='gelu') + parser.add_argument('--max_batch_size', type=int, default=256) + parser.add_argument('--max_input_len', type=int, default=512) + parser.add_argument('--gpus_per_node', type=int, default=8) + parser.add_argument('--output_dir', type=str, default='bert_outputs') + parser.add_argument('--use_bert_attention_plugin', + nargs='?', + const='float16', + type=str, + default=False, + choices=['float16', 'float32']) + parser.add_argument('--use_gemm_plugin', + nargs='?', + const='float16', + type=str, + default=False, + choices=['float16', 'float32']) + parser.add_argument('--use_layernorm_plugin', + nargs='?', + const='float16', + type=str, + default=False, + choices=['float16', 'float32']) + parser.add_argument('--enable_qk_half_accum', + default=False, + action='store_true') + parser.add_argument('--enable_context_fmha', + default=False, + action='store_true') + parser.add_argument('--enable_context_fmha_fp32_acc', + default=False, + action='store_true') + parser.add_argument( + '--model', + default=tensorrt_llm.models.BertModel.__name__, + choices=[ + tensorrt_llm.models.BertModel.__name__, + tensorrt_llm.models.BertForQuestionAnswering.__name__ + ]) + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_arguments() + tensorrt_llm.logger.set_level(args.log_level) + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + bs_range = [1, (args.max_batch_size + 1) // 2, args.max_batch_size] + inlen_range = [1, (args.max_input_len + 1) // 2, args.max_input_len] + torch_dtype = torch.float16 if args.dtype == 'float16' else torch.float32 + trt_dtype = trt.float16 if args.dtype == 'float16' else trt.float32 + + builder = Builder() + builder_config = builder.create_builder_config( + name=args.model, + precision=args.dtype, + timing_cache=args.timing_cache, + tensor_parallel=args.world_size, # TP only + max_batch_size=args.max_batch_size, + max_input_len=args.max_input_len, + ) + # Initialize model + + bert_config = BertConfig( + vocab_size=args.vocab_size, + hidden_size=args.n_embd, + num_hidden_layers=args.n_layer, + num_attention_heads=args.n_head, + intermediate_size=4 * args.n_embd, + hidden_act=args.hidden_act, + max_position_embeddings=args.n_positions, + torch_dtype=torch_dtype, + ) + + output_name = 'hidden_states' + if args.model == tensorrt_llm.models.BertModel.__name__: + hf_bert = BertModel(bert_config, add_pooling_layer=False) + tensorrt_llm_bert = tensorrt_llm.models.BertModel( + num_layers=bert_config.num_hidden_layers, + num_heads=bert_config.num_attention_heads, + hidden_size=bert_config.hidden_size, + vocab_size=bert_config.vocab_size, + hidden_act=bert_config.hidden_act, + max_position_embeddings=bert_config.max_position_embeddings, + type_vocab_size=bert_config.type_vocab_size, + mapping=Mapping(world_size=args.world_size, + rank=args.rank, + tp_size=args.world_size), # TP only + dtype=trt_dtype) + load_from_hf_bert( + tensorrt_llm_bert, + hf_bert, + bert_config, + rank=args.rank, + tensor_parallel=args.world_size, + fp16=(args.dtype == 'float16'), + ) + + elif args.model == tensorrt_llm.models.BertForQuestionAnswering.__name__: + hf_bert = BertForQuestionAnswering(bert_config) + tensorrt_llm_bert = tensorrt_llm.models.BertForQuestionAnswering( + num_layers=bert_config.num_hidden_layers, + num_heads=bert_config.num_attention_heads, + hidden_size=bert_config.hidden_size, + vocab_size=bert_config.vocab_size, + hidden_act=bert_config.hidden_act, + max_position_embeddings=bert_config.max_position_embeddings, + type_vocab_size=bert_config.type_vocab_size, + num_labels=args. + n_labels, # TODO: this might just need to be a constant + mapping=Mapping(world_size=args.world_size, + rank=args.rank, + tp_size=args.world_size), # TP only + dtype=trt_dtype) + load_from_hf_qa_bert( + tensorrt_llm_bert, + hf_bert, + bert_config, + rank=args.rank, + tensor_parallel=args.world_size, + fp16=(args.dtype == 'float16'), + ) + output_name = 'logits' + else: + assert False, f"Unknown BERT model {args.model}" + + # Module -> Network + network = builder.create_network() + if args.use_bert_attention_plugin: + network.plugin_config.set_bert_attention_plugin( + dtype=args.use_bert_attention_plugin) + if args.use_gemm_plugin: + network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin) + if args.use_layernorm_plugin: + network.plugin_config.set_layernorm_plugin( + dtype=args.use_layernorm_plugin) + if args.enable_qk_half_accum: + network.plugin_config.enable_qk_half_accum() + assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc) + if args.enable_context_fmha: + network.plugin_config.set_context_fmha(ContextFMHAType.enabled) + if args.enable_context_fmha_fp32_acc: + network.plugin_config.set_context_fmha( + ContextFMHAType.enabled_with_fp32_acc) + if args.world_size > 1: + network.plugin_config.set_nccl_plugin(args.dtype) + with net_guard(network): + # Prepare + network.set_named_parameters(tensorrt_llm_bert.named_parameters()) + + # Forward + input_ids = tensorrt_llm.Tensor( + name='input_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([('batch_size', [bs_range]), + ('input_len', [inlen_range])]), + ) + + # also called segment_ids + token_type_ids = tensorrt_llm.Tensor( + name='token_type_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([('batch_size', [bs_range]), + ('input_len', [inlen_range])]), + ) + + input_lengths = tensorrt_llm.Tensor(name='input_lengths', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([ + ('batch_size', [bs_range]) + ])) + + # logits for QA BERT, or hidden_state for vanila BERT + output = tensorrt_llm_bert(input_ids=input_ids, + input_lengths=input_lengths, + token_type_ids=token_type_ids) + + # Mark outputs + output_dtype = trt.float16 if args.dtype == 'float16' else trt.float32 + output.mark_output(output_name, output_dtype) + + # Network -> Engine + engine = builder.build_engine(network, builder_config) + assert engine is not None, 'Failed to build engine.' + engine_file = os.path.join( + args.output_dir, + get_engine_name(args.model, args.dtype, args.world_size, args.rank)) + with open(engine_file, 'wb') as f: + f.write(engine) + builder.save_config(builder_config, + os.path.join(args.output_dir, 'config.json')) diff --git a/examples/roberta/large_benchmark/config.json b/examples/roberta/large_benchmark/config.json new file mode 100644 index 000000000..c720c7ac6 --- /dev/null +++ b/examples/roberta/large_benchmark/config.json @@ -0,0 +1,22 @@ +{ + "builder_config": { + "max_batch_size": 256, + "max_input_len": 512, + "name": "bert", + "precision": "float16", + "tensor_parallel": 1, + "use_refit": false + }, + "plugin_config": { + "bert_attention_plugin": false, + "context_fmha_enabled": false, + "gemm_plugin": false, + "gpt_attention_plugin": false, + "identity_plugin": false, + "layernorm_plugin": false, + "layernorm_quantization_plugin": false, + "nccl_plugin": false, + "smooth_quant_gemm_plugin": false, + "weight_only_quant_matmul_plugin": false + } +} diff --git a/examples/roberta/large_with_attention_plugin_benchmark/config.json b/examples/roberta/large_with_attention_plugin_benchmark/config.json new file mode 100644 index 000000000..025f0383e --- /dev/null +++ b/examples/roberta/large_with_attention_plugin_benchmark/config.json @@ -0,0 +1,22 @@ +{ + "builder_config": { + "max_batch_size": 256, + "max_input_len": 512, + "name": "bert", + "precision": "float16", + "tensor_parallel": 1, + "use_refit": false + }, + "plugin_config": { + "bert_attention_plugin": "float16", + "context_fmha_enabled": true, + "gemm_plugin": "float16", + "gpt_attention_plugin": false, + "identity_plugin": false, + "layernorm_plugin": false, + "layernorm_quantization_plugin": false, + "nccl_plugin": false, + "smooth_quant_gemm_plugin": false, + "weight_only_quant_matmul_plugin": false + } +} diff --git a/examples/roberta/run.py b/examples/roberta/run.py new file mode 100644 index 000000000..470d26527 --- /dev/null +++ b/examples/roberta/run.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os + +# isort: off +import torch +import tensorrt as trt +# isort: on + +import tensorrt_llm +from tensorrt_llm import logger +from tensorrt_llm.models import BertForQuestionAnswering, BertModel +from tensorrt_llm.runtime import Session, TensorInfo + +from build import get_engine_name # isort:skip + + +def trt_dtype_to_torch(dtype): + if dtype == trt.float16: + return torch.float16 + elif dtype == trt.float32: + return torch.float32 + elif dtype == trt.int32: + return torch.int32 + else: + raise TypeError("%s is not supported" % dtype) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--log_level', type=str, default='info') + parser.add_argument('--engine_dir', type=str, default='bert_outputs') + + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_arguments() + + tensorrt_llm.logger.set_level(args.log_level) + + config_path = os.path.join(args.engine_dir, 'config.json') + with open(config_path, 'r') as f: + config = json.load(f) + dtype = config['builder_config']['precision'] + world_size = config['builder_config']['tensor_parallel'] + assert world_size == tensorrt_llm.mpi_world_size(), \ + f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})' + + model_name = config['builder_config']['name'] + runtime_rank = tensorrt_llm.mpi_rank() if world_size > 1 else 0 + + runtime_mapping = tensorrt_llm.Mapping(world_size, + runtime_rank, + tp_size=world_size) + torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node) + + serialize_path = get_engine_name(model_name, dtype, world_size, + runtime_rank) + serialize_path = os.path.join(args.engine_dir, serialize_path) + + stream = torch.cuda.current_stream().cuda_stream + logger.info(f'Loading engine from {serialize_path}') + with open(serialize_path, 'rb') as f: + engine_buffer = f.read() + logger.info(f'Creating session from engine') + session = Session.from_serialized_engine(engine_buffer) + + for i in range(3): + batch_size = (i + 1) * 4 + seq_len = (i + 1) * 32 + input_ids = torch.randint(100, (batch_size, seq_len)).int().cuda() + input_lengths = seq_len * torch.ones( + (batch_size, ), dtype=torch.int32, device='cuda') + token_type_ids = torch.randint(100, (batch_size, seq_len)).int().cuda() + + inputs = { + 'input_ids': input_ids, + 'input_lengths': input_lengths, + 'token_type_ids': token_type_ids + } + output_info = session.infer_shapes([ + TensorInfo('input_ids', trt.DataType.INT32, input_ids.shape), + TensorInfo('input_lengths', trt.DataType.INT32, + input_lengths.shape), + TensorInfo('token_type_ids', trt.DataType.INT32, + token_type_ids.shape), + ]) + outputs = { + t.name: torch.empty(tuple(t.shape), + dtype=trt_dtype_to_torch(t.dtype), + device='cuda') + for t in output_info + } + if (model_name == BertModel.__name__): + output_name = 'hidden_states' + elif (model_name == BertForQuestionAnswering.__name__): + output_name = 'logits' + else: + assert False, f"Unknown BERT model {model_name}" + + assert output_name in outputs, f'{output_name} not found in outputs, check if build.py set the name correctly' + + ok = session.run(inputs, outputs, stream) + assert ok, "Runtime execution failed" + torch.cuda.synchronize() + res = outputs[output_name] diff --git a/examples/roberta/weight.py b/examples/roberta/weight.py new file mode 100644 index 000000000..78920ac41 --- /dev/null +++ b/examples/roberta/weight.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import torch + + +def extract_layer_idx(name): + ss = name.split('.') + for s in ss: + if s.isdigit(): + return s + return None + + +def split(v, tp_size, idx, dim=0): + if tp_size == 1: + return v + if len(v.shape) == 1: + return np.ascontiguousarray(np.split(v, tp_size)[idx].copy()) + elif len(v.shape) == 2: + return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx].copy()) + return None + + +def load_from_hf_bert(tensorrt_llm_bert, + hf_bert, + hf_bert_config, + rank=0, + tensor_parallel=1, + fp16=False): + qkv_weight = [[None, None, None] + for _ in range(hf_bert_config.num_hidden_layers)] + + qkv_bias = [[None, None, None] + for _ in range(hf_bert_config.num_hidden_layers)] + + for k, v in hf_bert.state_dict().items(): + torch_dtype = torch.float16 if fp16 else torch.float32 + v = v.to(torch_dtype).cpu().numpy() + if 'embeddings.word_embeddings.weight' in k: + tensorrt_llm_bert.embedding.vocab_embedding.weight.value = v + elif 'embeddings.position_embeddings.weight' in k: + tensorrt_llm_bert.embedding.position_embedding.weight.value = v + elif 'embeddings.token_type_embeddings.weight' in k: + tensorrt_llm_bert.embedding.token_embedding.weight.value = v + elif 'embeddings.LayerNorm.weight' in k: + tensorrt_llm_bert.embedding.embedding_ln.weight.value = v + elif 'embeddings.LayerNorm.bias' in k: + tensorrt_llm_bert.embedding.embedding_ln.bias.value = v + else: + layer_idx = extract_layer_idx(k) + if layer_idx is None: + continue + idx = int(layer_idx) + if 'attention.output.dense.weight' in k: + tensorrt_llm_bert.layers[ + idx].attention.dense.weight.value = split(v, + tensor_parallel, + rank, + dim=1) + elif 'attention.output.dense.bias' in k: + tensorrt_llm_bert.layers[idx].attention.dense.bias.value = v + elif 'attention.output.LayerNorm.weight' in k: + tensorrt_llm_bert.layers[idx].input_layernorm.weight.value = v + elif 'attention.output.LayerNorm.bias' in k: + tensorrt_llm_bert.layers[idx].input_layernorm.bias.value = v + elif 'intermediate.dense.weight' in k: + tensorrt_llm_bert.layers[idx].mlp.fc.weight.value = split( + v, tensor_parallel, rank) + elif 'intermediate.dense.bias' in k: + tensorrt_llm_bert.layers[idx].mlp.fc.bias.value = split( + v, tensor_parallel, rank) + elif 'output.dense.weight' in k: + tensorrt_llm_bert.layers[idx].mlp.proj.weight.value = split( + v, tensor_parallel, rank, dim=1) + elif 'output.dense.bias' in k: + tensorrt_llm_bert.layers[idx].mlp.proj.bias.value = v + elif 'output.LayerNorm.weight' in k: + tensorrt_llm_bert.layers[idx].post_layernorm.weight.value = v + elif 'output.LayerNorm.bias' in k: + tensorrt_llm_bert.layers[idx].post_layernorm.bias.value = v + elif 'attention.self.query.weight' in k: + qkv_weight[idx][0] = v + elif 'attention.self.query.bias' in k: + qkv_bias[idx][0] = v + elif 'attention.self.key.weight' in k: + qkv_weight[idx][1] = v + elif 'attention.self.key.bias' in k: + qkv_bias[idx][1] = v + elif 'attention.self.value.weight' in k: + qkv_weight[idx][2] = v + elif 'attention.self.value.bias' in k: + qkv_bias[idx][2] = v + + for i in range(hf_bert_config.num_hidden_layers): + tensorrt_llm_bert.layers[i].attention.qkv.weight.value = split( + np.concatenate(qkv_weight[i]), tensor_parallel, rank) + tensorrt_llm_bert.layers[i].attention.qkv.bias.value = split( + np.concatenate(qkv_bias[i]), tensor_parallel, rank) + + +def load_from_hf_qa_bert(tensorrt_llm_qa_bert, + hf_qa_bert, + hf_bert_config, + rank=0, + tensor_parallel=1, + fp16=False): + load_from_hf_bert(tensorrt_llm_qa_bert.bert, hf_qa_bert, hf_bert_config, + rank, tensor_parallel, fp16) + states = hf_qa_bert.state_dict() + + torch_dtype = torch.float16 if fp16 else torch.float32 + + tensorrt_llm_qa_bert.qa_outputs.weight.value = states[ + 'qa_outputs.weight'].to(torch_dtype).cpu().numpy() + tensorrt_llm_qa_bert.qa_outputs.bias.value = states['qa_outputs.bias'].to( + torch_dtype).cpu().numpy() diff --git a/tensorrt_llm/models/roberta/__init__.py b/tensorrt_llm/models/roberta/__init__.py new file mode 100644 index 000000000..2a36ca922 --- /dev/null +++ b/tensorrt_llm/models/roberta/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tensorrt_llm/models/roberta/model.py b/tensorrt_llm/models/roberta/model.py new file mode 100644 index 000000000..3e519a0ed --- /dev/null +++ b/tensorrt_llm/models/roberta/model.py @@ -0,0 +1,369 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from email.quoprimime import unquote +import math + +import numpy as np + +from ..._common import default_net +from ...functional import (bert_attention, concat, constant, expand, matmul, + shape, slice, softmax, split, cast, unsqueeze, select, ACT2FN) +from ...layers import MLP, ColumnLinear, Embedding, LayerNorm, Linear, RowLinear +from ...mapping import Mapping +from ...module import Module, ModuleList + + +class BertEmbedding(Module): + + def __init__(self, + vocab_size, + hidden_size, + max_position_embeddings, + type_vocab_size, + dtype=None): + super().__init__() + self.vocab_embedding = Embedding(vocab_size, hidden_size, dtype=dtype) + self.position_embedding = Embedding(max_position_embeddings, + hidden_size, + dtype=dtype) + self.token_embedding = Embedding(type_vocab_size, + hidden_size, + dtype=dtype) + self.max_position_embeddings = max_position_embeddings + + self.embedding_ln = LayerNorm(normalized_shape=hidden_size, dtype=dtype) + + def forward(self, input_ids, position_ids=None, token_type_ids=None): + position_ids_buffer = constant( + np.expand_dims( + np.arange(self.max_position_embeddings).astype(np.int32), 0)) + + token_type_ids_buffer = constant( + np.expand_dims( + np.zeros(self.max_position_embeddings).astype(np.int32), 0)) + + seq_len_2d = concat([1, shape(input_ids, 1)]) + + if position_ids is None: + # slice + position_ids = slice(position_ids_buffer, + starts=[0, 0], + sizes=seq_len_2d) + position_ids = expand(position_ids, shape(input_ids)) + + if token_type_ids is None: + # slice + token_type_ids = slice(token_type_ids_buffer, + starts=[0, 0], + sizes=seq_len_2d) + token_type_ids = expand(token_type_ids, shape(input_ids)) + + x = self.vocab_embedding(input_ids) + x = x + self.position_embedding(position_ids) + x = x + self.token_embedding(token_type_ids) + x = self.embedding_ln(x) + return x + + +class BertAttention(Module): + + def __init__(self, + hidden_size, + num_attention_heads, + max_position_embeddings, + dtype=None, + tp_group=None, + tp_size=1): + super().__init__() + + self.attention_head_size = hidden_size // num_attention_heads + self.num_attention_heads = num_attention_heads // tp_size + self.hidden_size = hidden_size // tp_size + self.max_position_embeddings = max_position_embeddings + self.norm_factor = math.sqrt(self.attention_head_size) + + self.qkv = ColumnLinear(hidden_size, + hidden_size * 3, + dtype=dtype, + tp_group=tp_group, + tp_size=tp_size, + gather_output=False) + self.dense = RowLinear(hidden_size, + hidden_size, + dtype=dtype, + tp_group=tp_group, + tp_size=tp_size) + + def forward(self, hidden_states, attention_mask=None, input_lengths=None): + qkv = self.qkv(hidden_states) + + # attention + if default_net().plugin_config.bert_attention_plugin: + assert input_lengths is not None + context = bert_attention(qkv, input_lengths, + self.num_attention_heads, + self.attention_head_size, 1.0) + else: + + def transpose_for_scores(x): + new_x_shape = concat([ + shape(x, 0), + shape(x, 1), self.num_attention_heads, + self.attention_head_size + ]) + return x.view(new_x_shape).permute([0, 2, 1, 3]) + + query, key, value = split(qkv, self.hidden_size, dim=2) + query = transpose_for_scores(query) + key = transpose_for_scores(key) + value = transpose_for_scores(value) + + key = key.permute([0, 1, 3, 2]) + attention_scores = matmul(query, key) + attention_scores = attention_scores / self.norm_factor + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + attention_probs = softmax(attention_scores, dim=-1) + + context = matmul(attention_probs, value).permute([0, 2, 1, 3]) + context = context.view( + concat([shape(context, 0), + shape(context, 1), self.hidden_size])) + + context = self.dense(context) + + return context + + +class BertEncoderLayer(Module): + + def __init__(self, + hidden_size, + num_attention_heads, + max_position_embeddings, + hidden_act='relu', + tp_group=None, + tp_size=1, + dtype=None): + super().__init__() + self.input_layernorm = LayerNorm(normalized_shape=hidden_size, + dtype=dtype) + + self.attention = BertAttention(hidden_size, + num_attention_heads, + max_position_embeddings, + tp_group=tp_group, + tp_size=tp_size, + dtype=dtype) + self.mlp = MLP(hidden_size=hidden_size, + ffn_hidden_size=hidden_size * 4, + hidden_act=hidden_act, + tp_group=tp_group, + tp_size=tp_size, + dtype=dtype) + self.post_layernorm = LayerNorm(normalized_shape=hidden_size, + dtype=dtype) + + def forward(self, hidden_states, attention_mask=None, input_lengths=None): + residual = hidden_states + + attention_output = self.attention(hidden_states, + attention_mask=attention_mask, + input_lengths=input_lengths) + + hidden_states = residual + attention_output + + hidden_states = self.input_layernorm(hidden_states) + + residual = hidden_states + + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + hidden_states = self.post_layernorm(hidden_states) + + return hidden_states + + +class BertModel(Module): + + def __init__(self, + num_layers, + num_heads, + hidden_size, + vocab_size, + hidden_act, + max_position_embeddings, + type_vocab_size, + mapping=Mapping(), + dtype=None): + super().__init__() + self.max_position_embeddings = max_position_embeddings + self.dtype = dtype + self.embedding = BertEmbedding( + vocab_size=vocab_size, + hidden_size=hidden_size, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, + dtype=dtype) + + self.layers = ModuleList([ + BertEncoderLayer(hidden_size=hidden_size, + num_attention_heads=num_heads, + max_position_embeddings=max_position_embeddings, + hidden_act=hidden_act, + tp_group=mapping.tp_group, + tp_size=mapping.tp_size, + dtype=dtype) for _ in range(num_layers) + ]) + + def forward(self, + input_ids=None, + input_lengths=None, + token_type_ids=None, + position_ids=None, + hidden_states=None): + hidden_states = self.embedding(input_ids, position_ids, token_type_ids) + + # creat extended_attention_mask as https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py + seq_len_2d = concat([1, shape(input_ids, 1)]) + position_ids_buffer = constant( + np.expand_dims( + np.arange(self.max_position_embeddings).astype(np.int32), 0)) + tmp_position_ids = slice(position_ids_buffer, + starts=[0, 0], + sizes=seq_len_2d) + tmp_position_ids = expand(tmp_position_ids, shape(input_ids)) #BxL + tmp_input_lengths = unsqueeze(input_lengths, 1) #Bx1 + tmp_input_lengths = expand(tmp_input_lengths, shape(input_ids)) #BxL + mask = tmp_position_ids < tmp_input_lengths # BxL + mask = cast(mask, 'int32') + extended_attention_mask = unsqueeze(mask, 1) + extended_attention_mask = unsqueeze(extended_attention_mask, 1) # Bx1x1xL + extended_attention_mask = (1 - extended_attention_mask) * -214748364 # a small negative number in int32 range + extended_attention_mask = cast(extended_attention_mask, self.dtype) + + for layer in self.layers: + hidden_states = layer(hidden_states=hidden_states, + input_lengths=input_lengths, + attention_mask=extended_attention_mask) + + return hidden_states + + +class BertForQuestionAnswering(Module): + + def __init__(self, + num_layers, + num_heads, + hidden_size, + vocab_size, + hidden_act, + max_position_embeddings, + type_vocab_size, + num_labels=2, + mapping=Mapping(), + dtype=None): + super().__init__() + self.bert = BertModel(num_layers=num_layers, + num_heads=num_heads, + hidden_size=hidden_size, + vocab_size=vocab_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, + mapping=mapping, + dtype=dtype) + self.num_labels = num_labels + self.qa_outputs = Linear(hidden_size, num_labels, dtype=dtype) + + def forward(self, + input_ids=None, + input_lengths=None, + token_type_ids=None, + position_ids=None, + hidden_states=None): + + hidden_states = self.bert.forward(input_ids=input_ids, + input_lengths=input_lengths, + token_type_ids=token_type_ids, + position_ids=position_ids, + hidden_states=hidden_states) + + logits = self.qa_outputs(hidden_states) + + return logits + +class BertPooler(Module): + def __init__(self, hidden_size, dtype): + super().__init__() + self.dense = Linear(hidden_size, hidden_size, dtype=dtype) + self.activation = ACT2FN['tanh'] + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = select(hidden_states, 1, 0) + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertForSequenceClassification(Module): + + def __init__(self, + num_layers, + num_heads, + hidden_size, + vocab_size, + hidden_act, + max_position_embeddings, + type_vocab_size, + num_labels=2, + mapping=Mapping(), + dtype=None): + super().__init__() + self.bert = BertModel(num_layers=num_layers, + num_heads=num_heads, + hidden_size=hidden_size, + vocab_size=vocab_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, + mapping=mapping, + dtype=dtype) + self.num_labels = num_labels + self.pooler = BertPooler(hidden_size=hidden_size, dtype=dtype) + self.classifier = Linear(hidden_size, num_labels, dtype=dtype) + + def forward(self, + input_ids=None, + input_lengths=None, + token_type_ids=None, + position_ids=None, + hidden_states=None): + + hidden_states = self.bert.forward(input_ids=input_ids, + input_lengths=input_lengths, + token_type_ids=token_type_ids, + position_ids=position_ids, + hidden_states=hidden_states) + pooled_output = self.pooler(hidden_states) + logits = self.classifier(pooled_output) + + return logits diff --git a/tests/model/test_roberta.py b/tests/model/test_roberta.py new file mode 100644 index 000000000..29f1c4d04 --- /dev/null +++ b/tests/model/test_roberta.py @@ -0,0 +1,479 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest +from collections import OrderedDict +from itertools import product + +import numpy as np +import parameterized + +# isort: off +import torch +import tensorrt as trt +# isort: on +from parameterized import parameterized +from transformers import BertConfig, BertForQuestionAnswering, BertModel, BertForSequenceClassification, AutoTokenizer + +import tensorrt_llm +import tensorrt_llm.runtime +from tensorrt_llm import Builder +from tensorrt_llm._utils import trt_dtype_to_torch, str_dtype_to_trt +from tensorrt_llm.network import net_guard +from tensorrt_llm.plugin.plugin import ContextFMHAType +from tensorrt_llm.runtime import TensorInfo + + +def extract_layer_idx(name): + ss = name.split('.') + for s in ss: + if s.isdigit(): + return s + return None + + +def split(v, tp_size, idx, dim=0): + if tp_size == 1: + return v + if len(v.shape) == 1: + return np.ascontiguousarray(np.split(v, tp_size)[idx]) + elif len(v.shape) == 2: + return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx]) + return None + + +def load_from_hf_bert(tensorrt_llm_bert, + hf_bert, + hf_bert_config, + rank=0, + tensor_parallel=1, + fp16=False): + qkv_weight = [[None, None, None] + for _ in range(hf_bert_config.num_hidden_layers)] + + qkv_bias = [[None, None, None] + for _ in range(hf_bert_config.num_hidden_layers)] + + torch_dtype = torch.float16 if fp16 else torch.float32 + for k, v in hf_bert.state_dict().items(): + v = v.to(torch_dtype).cpu().numpy() + if 'embeddings.word_embeddings.weight' in k: + tensorrt_llm_bert.embedding.vocab_embedding.weight.value = v + elif 'embeddings.position_embeddings.weight' in k: + tensorrt_llm_bert.embedding.position_embedding.weight.value = v + elif 'embeddings.token_type_embeddings.weight' in k: + tensorrt_llm_bert.embedding.token_embedding.weight.value = v + elif 'embeddings.LayerNorm.weight' in k: + tensorrt_llm_bert.embedding.embedding_ln.weight.value = v + elif 'embeddings.LayerNorm.bias' in k: + tensorrt_llm_bert.embedding.embedding_ln.bias.value = v + else: + layer_idx = extract_layer_idx(k) + if layer_idx is None: + continue + idx = int(layer_idx) + if 'attention.output.dense.weight' in k: + tensorrt_llm_bert.layers[ + idx].attention.dense.weight.value = split(v, + tensor_parallel, + rank, + dim=1) + elif 'attention.output.dense.bias' in k: + tensorrt_llm_bert.layers[idx].attention.dense.bias.value = v + elif 'attention.output.LayerNorm.weight' in k: + tensorrt_llm_bert.layers[idx].input_layernorm.weight.value = v + elif 'attention.output.LayerNorm.bias' in k: + tensorrt_llm_bert.layers[idx].input_layernorm.bias.value = v + elif 'intermediate.dense.weight' in k: + tensorrt_llm_bert.layers[idx].mlp.fc.weight.value = split( + v, tensor_parallel, rank) + elif 'intermediate.dense.bias' in k: + tensorrt_llm_bert.layers[idx].mlp.fc.bias.value = split( + v, tensor_parallel, rank) + elif 'output.dense.weight' in k: + tensorrt_llm_bert.layers[idx].mlp.proj.weight.value = split( + v, tensor_parallel, rank, dim=1) + elif 'output.dense.bias' in k: + tensorrt_llm_bert.layers[idx].mlp.proj.bias.value = v + elif 'output.LayerNorm.weight' in k: + tensorrt_llm_bert.layers[idx].post_layernorm.weight.value = v + elif 'output.LayerNorm.bias' in k: + tensorrt_llm_bert.layers[idx].post_layernorm.bias.value = v + elif 'attention.self.query.weight' in k: + qkv_weight[idx][0] = v + elif 'attention.self.query.bias' in k: + qkv_bias[idx][0] = v + elif 'attention.self.key.weight' in k: + qkv_weight[idx][1] = v + elif 'attention.self.key.bias' in k: + qkv_bias[idx][1] = v + elif 'attention.self.value.weight' in k: + qkv_weight[idx][2] = v + elif 'attention.self.value.bias' in k: + qkv_bias[idx][2] = v + + for i in range(hf_bert_config.num_hidden_layers): + tensorrt_llm_bert.layers[i].attention.qkv.weight.value = split( + np.concatenate(qkv_weight[i]), tensor_parallel, rank) + tensorrt_llm_bert.layers[i].attention.qkv.bias.value = split( + np.concatenate(qkv_bias[i]), tensor_parallel, rank) + + +def load_from_hf_qa_bert(tensorrt_llm_qa_bert, + hf_qa_bert, + hf_bert_config, + rank=0, + tensor_parallel=1, + fp16=False): + load_from_hf_bert(tensorrt_llm_qa_bert.bert, hf_qa_bert, hf_bert_config, + rank, tensor_parallel, fp16) + states = hf_qa_bert.state_dict() + + torch_dtype = torch.float16 if fp16 else torch.float32 + + tensorrt_llm_qa_bert.qa_outputs.weight.value = states[ + 'qa_outputs.weight'].to(torch_dtype).cpu().numpy() + tensorrt_llm_qa_bert.qa_outputs.bias.value = states['qa_outputs.bias'].to( + torch_dtype).cpu().numpy() + +def load_from_hf_cls_bert(tensorrt_llm_qa_bert, + hf_qa_bert, + hf_bert_config, + rank=0, + tensor_parallel=1, + fp16=False): + load_from_hf_bert(tensorrt_llm_qa_bert.bert, hf_qa_bert, hf_bert_config, + rank, tensor_parallel, fp16) + states = hf_qa_bert.state_dict() + + torch_dtype = torch.float16 if fp16 else torch.float32 + + tensorrt_llm_qa_bert.pooler.dense.weight.value = states[ + 'bert.pooler.dense.weight'].to(torch_dtype).cpu().numpy() + tensorrt_llm_qa_bert.pooler.dense.bias.value = states[ + 'bert.pooler.dense.bias'].to(torch_dtype).cpu().numpy() + + tensorrt_llm_qa_bert.classifier.weight.value = states[ + 'classifier.weight'].to(torch_dtype).cpu().numpy() + tensorrt_llm_qa_bert.classifier.bias.value = states[ + 'classifier.bias'].to(torch_dtype).cpu().numpy() + + +class TestBert(unittest.TestCase): + + def load_test_cases(): + models = [BertForSequenceClassification.__name__, BertModel.__name__, BertForQuestionAnswering.__name__] + model_dirs = ['', 'bert-base-uncased'] # add more tests for read data. + test_cases = [] + test_cases += product(models, [False], [False], [False], + [ContextFMHAType.disabled], ['float32'], model_dirs) + test_cases += product(models, [False], [True], [True], [ + ContextFMHAType.disabled, ContextFMHAType.enabled, + ContextFMHAType.enabled_with_fp32_acc + ], ['float16'], model_dirs) + + return test_cases + + def custom_name_func(testcase_func, param_num, param): + return "%s_%s" % ( + testcase_func.__name__, + parameterized.to_safe_name("_".join(str(x) for x in param.args)), + ) + + @parameterized.expand(load_test_cases, name_func=custom_name_func) + def test_bert(self, model, use_refit, use_plugin, fast_building, + context_fmha_type, dtype, model_dir): + tensorrt_llm.logger.set_level('error') + fp16 = (dtype == 'float16') + world_size = 1 + rank = 0 + batch_size = 8 + input_len = 128 + vocab_size = 51200 + num_layers = 12 + num_heads = 12 + hidden_act = 'gelu' + max_position_embeddings = 512 + hidden_size = 768 + bs_range = [1, (batch_size + 1) // 2, batch_size] + inlen_range = [1, (input_len + 1) // 2, input_len] + torch_dtype = torch.float16 if fp16 else torch.float32 + trt_dtype = trt.float16 if fp16 else trt.float32 + timing_cache = 'model.cache' + + torch.manual_seed(0) + + builder = Builder() + with tempfile.TemporaryDirectory() as tmpdirname: + builder_config = builder.create_builder_config( + name=model, + precision='float16' if fp16 else 'float32', + timing_cache=timing_cache, + tensor_parallel=world_size, # TP only + use_refit=use_refit) + network = builder.create_network() + if use_plugin: + network.plugin_config.set_bert_attention_plugin(dtype) + if fast_building: + network.plugin_config.set_gemm_plugin(dtype) + network.plugin_config.set_context_fmha(context_fmha_type) + with net_guard(network): + # Prepare inputs + # TODO: could class be better than dict for profiles? + input_ids = tensorrt_llm.Tensor(name='input_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([ + ('batch_size', [bs_range]), + ('input_len', [inlen_range]) + ])) + input_lengths = tensorrt_llm.Tensor(name='input_lengths', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([ + ('batch_size', + [bs_range]) + ])) + # Initialize model + if model_dir: + bert_config = BertConfig.from_pretrained(model_dir, torch_dtype=torch_dtype) + vocab_size = bert_config.vocab_size + hidden_size = bert_config.hidden_size + num_layers = bert_config.num_hidden_layers + num_heads = bert_config.num_attention_heads + hidden_size = bert_config.intermediate_size // 4 + hidden_act = bert_config.hidden_act + max_position_embeddings = bert_config.max_position_embeddings + else: + bert_config = BertConfig( + vocab_size=vocab_size, + hidden_size=hidden_size, + num_hidden_layers=num_layers, + num_attention_heads=num_heads, + intermediate_size=4 * hidden_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + torch_dtype=torch_dtype, + ) + + output_name = "hidden_states" + if model == BertModel.__name__: + if model_dir: + hf_bert = BertModel.from_pretrained( + model_dir).cuda().to(torch_dtype).eval() + else: + hf_bert = BertModel( + bert_config, + add_pooling_layer=False).cuda().to(torch_dtype).eval() + tensorrt_llm_bert = tensorrt_llm.models.BertModel( + num_layers=num_layers, + num_heads=num_heads, + hidden_size=hidden_size, + vocab_size=vocab_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + type_vocab_size=bert_config.type_vocab_size, + mapping=tensorrt_llm.Mapping( + world_size=world_size, + rank=rank, + tp_size=world_size), # TP only + dtype=trt_dtype) + load_from_hf_bert(tensorrt_llm_bert, + hf_bert, + bert_config, + rank=rank, + tensor_parallel=world_size, + fp16=fp16) + elif model == BertForQuestionAnswering.__name__: + if model_dir: + hf_bert = BertForQuestionAnswering.from_pretrained( + model_dir).cuda().to(torch_dtype).eval() + else: + hf_bert = BertForQuestionAnswering(bert_config).cuda().to( + torch_dtype).eval() + output_name = "logits" + tensorrt_llm_bert = tensorrt_llm.models.BertForQuestionAnswering( + num_layers=num_layers, + num_heads=num_heads, + hidden_size=hidden_size, + vocab_size=vocab_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + type_vocab_size=bert_config.type_vocab_size, + num_labels= + 2, # just make it a const here, seems to me not worth as a config + mapping=tensorrt_llm.Mapping( + world_size=world_size, + rank=rank, + tp_size=world_size), # TP only + dtype=trt_dtype) + load_from_hf_qa_bert(tensorrt_llm_bert, + hf_bert, + bert_config, + rank=rank, + tensor_parallel=world_size, + fp16=fp16) + elif model == BertForSequenceClassification.__name__: + if model_dir: + hf_bert = BertForSequenceClassification.from_pretrained( + model_dir).cuda().to(torch_dtype).eval() + else: + hf_bert = BertForSequenceClassification(bert_config).cuda().to( + torch_dtype).eval() + output_name = "logits" + tensorrt_llm_bert = tensorrt_llm.models.BertForSequenceClassification( + num_layers=num_layers, + num_heads=num_heads, + hidden_size=hidden_size, + vocab_size=vocab_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + type_vocab_size=bert_config.type_vocab_size, + num_labels= + 2, # just make it a const here, seems to me not worth as a config + mapping=tensorrt_llm.Mapping( + world_size=world_size, + rank=rank, + tp_size=world_size), # TP only + dtype=trt_dtype) + load_from_hf_cls_bert(tensorrt_llm_bert, + hf_bert, + bert_config, + rank=rank, + tensor_parallel=world_size, + fp16=fp16) + + else: + assert False, f"Unknown model {model}" + # Prepare + network.set_named_parameters( + tensorrt_llm_bert.named_parameters()) + + # Forward + output = tensorrt_llm_bert(input_ids=input_ids, + input_lengths=input_lengths) + + # Mark outputs + output_dtype = trt.float16 if fp16 else trt.float32 + output.mark_output(output_name, output_dtype) + + for k, v in tensorrt_llm_bert.named_network_outputs(): + network._mark_output(v, k, str_dtype_to_trt(dtype)) + + # Build engine + engine_buffer = builder.build_engine(network, builder_config) + session = tensorrt_llm.runtime.Session.from_serialized_engine( + engine_buffer) + stream = torch.cuda.current_stream().cuda_stream + + # Inference + # The dtype of input_ids should be queried from the engine, + # for testing purpose, int32 is fine for now. + attention_mask = None + if model_dir: + hf_tokenizer = AutoTokenizer.from_pretrained(model_dir) + input_strings = ['Hello world!' for _ in range(batch_size)] + input_ids_with_padding = hf_tokenizer( + input_strings, padding='max_length', max_length=input_len) + input_ids_without_padding = hf_tokenizer( + input_strings) + input_ids = torch.tensor(input_ids_with_padding['input_ids']).int().cuda() + input_lengths = [len(x) for x in input_ids_without_padding['input_ids']] + input_lengths = torch.tensor( + input_lengths, device=input_ids.device, dtype=torch.int32) + attention_mask = torch.tensor( + input_ids_with_padding['attention_mask'], + device=input_ids.device, dtype=torch.int32) + else: + input_ids = torch.randint(100, (batch_size, input_len)).int().cuda() + input_lengths = input_len * torch.ones( + (batch_size, ), dtype=torch.int32, device='cuda') + + output_info = session.infer_shapes([ + TensorInfo('input_ids', trt.DataType.INT32, + (batch_size, input_len)), + TensorInfo('input_lengths', trt.DataType.INT32, (batch_size, )) + ]) + session._print_engine_info() + + outputs = { + t.name: torch.empty(tuple(t.shape), + dtype=trt_dtype_to_torch(t.dtype), + device='cuda') + for t in output_info + } + assert output_name in outputs, f'{output_name} not found in outputs' + session.run(inputs={ + 'input_ids': input_ids, + 'input_lengths': input_lengths + }, + outputs=outputs, + stream=stream) + torch.cuda.synchronize() + res = outputs[output_name] + + with torch.no_grad(): + if model_dir: + hf_outputs = hf_bert.forward( + input_ids=input_ids, + attention_mask=attention_mask) + else: + hf_outputs = hf_bert.forward(input_ids) + torch.cuda.synchronize() + + if model == BertModel.__name__: + ref = hf_outputs.last_hidden_state + if use_plugin and model_dir: + # when we use_plugin and have real-data model_dir and input + # We do not need to care about the output of padding positions: + attention_mask_tmp = attention_mask.unsqueeze(-1) + ref = ref * attention_mask_tmp + res = res * attention_mask_tmp + + np.testing.assert_allclose(ref.cpu().numpy(), + res.cpu().numpy(), + atol=1e-2, + rtol=1e-2) + elif model == BertForQuestionAnswering.__name__: + res_start_logits, res_end_logits = torch.split(res, 1, -1) + res_start_logits = res_start_logits.squeeze() + res_end_logits = res_end_logits.squeeze() + + ref_start_logits = hf_outputs.start_logits + ref_end_logits = hf_outputs.end_logits + if use_plugin and model_dir: + # when we use_plugin and have real-data model_dir and input + # We do not need to care about the output of padding positions: + ref_start_logits = ref_start_logits * attention_mask + ref_end_logits = ref_end_logits * attention_mask + res_start_logits = res_start_logits * attention_mask + res_end_logits = res_end_logits * attention_mask + + np.testing.assert_allclose(ref_start_logits.cpu().numpy(), + res_start_logits.cpu().numpy(), + atol=1.5e-2) + np.testing.assert_allclose(ref_end_logits.cpu().numpy(), + res_end_logits.cpu().numpy(), + atol=1.5e-2) + elif model == BertForSequenceClassification.__name__: + ref = hf_outputs.logits + np.testing.assert_allclose(ref.cpu().numpy(), + res.cpu().numpy(), + atol=1e-2, + rtol=1e-2) + + +if __name__ == '__main__': + unittest.main() From e7ccf0056f322d081a0a4eb988ccfe44949788c3 Mon Sep 17 00:00:00 2001 From: erenup Date: Fri, 29 Dec 2023 18:04:40 -0500 Subject: [PATCH 3/4] support roberta --- examples/roberta/.gitignore | 2 +- examples/roberta/README.md | 24 +-- examples/roberta/base_benchmark/config.json | 22 -- .../config.json | 22 -- examples/roberta/build.py | 105 ++++++---- examples/roberta/large_benchmark/config.json | 22 -- .../config.json | 22 -- examples/roberta/run.py | 10 +- examples/roberta/weight.py | 90 +++++--- tensorrt_llm/models/__init__.py | 5 + tensorrt_llm/models/roberta/model.py | 85 +++++--- tests/model/test_bert.py | 2 +- tests/model/test_roberta.py | 194 +++++++++--------- 13 files changed, 302 insertions(+), 303 deletions(-) delete mode 100644 examples/roberta/base_benchmark/config.json delete mode 100644 examples/roberta/base_with_attention_plugin_benchmark/config.json delete mode 100644 examples/roberta/large_benchmark/config.json delete mode 100644 examples/roberta/large_with_attention_plugin_benchmark/config.json diff --git a/examples/roberta/.gitignore b/examples/roberta/.gitignore index 4b9ff316e..70df3ea68 100644 --- a/examples/roberta/.gitignore +++ b/examples/roberta/.gitignore @@ -1,2 +1,2 @@ -bert* +roberta* *.log diff --git a/examples/roberta/README.md b/examples/roberta/README.md index 7a9ff87db..bf3dd5952 100644 --- a/examples/roberta/README.md +++ b/examples/roberta/README.md @@ -1,30 +1,30 @@ -# BERT +# Roberta -This document explains how to build the [BERT](https://huggingface.co/docs/transformers/model_doc/bert) model using TensorRT-LLM. It also describes how to run on a single GPU and two GPUs. +This document explains how to build the [Roberta](https://huggingface.co/docs/transformers/model_doc/roberta) model using TensorRT-LLM. It also describes how to run on a single GPU and two GPUs. ## Overview -The TensorRT-LLM BERT implementation can be found in [`tensorrt_llm/models/bert/model.py`](../../tensorrt_llm/models/bert/model.py). The TensorRT-LLM BERT example -code is located in [`examples/bert`](./). There are four main files in that folder: +The TensorRT-LLM Roberta implementation can be found in [`tensorrt_llm/models/roberta/model.py`](../../tensorrt_llm/models/roberta/model.py). The TensorRT-LLM Roberta example +code is located in [`examples/roberta`](./). There are four main files in that folder: - * [`build.py`](./build.py) to build the [TensorRT](https://developer.nvidia.com/tensorrt) engine(s) needed to run the BERT model, + * [`build.py`](./build.py) to build the [TensorRT](https://developer.nvidia.com/tensorrt) engine(s) needed to run the Roberta model, * [`run.py`](./run.py) to run the inference on an input text, -## Build and run BERT on a single GPU +## Build and run Roberta on a single GPU -In this example, TensorRT-LLM builds TensorRT engine(s) from the [HuggingFace BERT](https://huggingface.co/docs/transformers/model_doc/bert) model. +In this example, TensorRT-LLM builds TensorRT engine(s) from the [HuggingFace Roberta](https://huggingface.co/docs/transformers/model_doc/roberta) model. Use the following command to build the TensorRT engine: ```bash python3 build.py --dtype=float16 --log_level=verbose -# Enable the special TensorRT-LLM BERT Attention plugin (--use_bert_attention_plugin) to increase runtime performance. +# Enable the special TensorRT-LLM Roberta Attention plugin (--use_bert_attention_plugin) to increase runtime performance. python3 build.py --dtype=float16 --log_level=verbose --use_bert_attention_plugin float16 # Enable half accumulation for attention BMM1 (applied to unfused MHA plugins) python3 build.py --dtype=float16 --log_level=verbose --use_bert_attention_plugin float16 --enable_qk_half_accum ``` -The following command can be used to run the BERT model on a single GPU: +The following command can be used to run the Roberta model on a single GPU: ```bash python3 run.py @@ -32,15 +32,15 @@ python3 run.py #### Fused MultiHead Attention (FMHA) -You can enable the FMHA kernels for BERT by adding `--enable_context_fmha` to the invocation of `build.py`. Note that it is disabled by default because of possible accuracy issues due to the use of Flash Attention. +You can enable the FMHA kernels for Roberta by adding `--enable_context_fmha` to the invocation of `build.py`. Note that it is disabled by default because of possible accuracy issues due to the use of Flash Attention. If you find that the default fp16 accumulation (`--enable_context_fmha`) cannot meet the requirement, you can try to enable fp32 accumulation by adding `--enable_context_fmha_fp32_acc`. However, it is expected to see performance drop. Note `--enable_context_fmha` / `--enable_context_fmha_fp32_acc` has to be used together with `--use_bert_attention_plugin float16`. -## Build and run BERT on two GPUs +## Build and run Roberta on two GPUs -The following two commands can be used to build TensorRT engines to run BERT on two GPUs. The first command builds one engine for the first GPU. The second command builds another engine for the second GPU. +The following two commands can be used to build TensorRT engines to run Roberta on two GPUs. The first command builds one engine for the first GPU. The second command builds another engine for the second GPU. ```bash python3 build.py --world_size=2 --rank=0 diff --git a/examples/roberta/base_benchmark/config.json b/examples/roberta/base_benchmark/config.json deleted file mode 100644 index 025f0383e..000000000 --- a/examples/roberta/base_benchmark/config.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "builder_config": { - "max_batch_size": 256, - "max_input_len": 512, - "name": "bert", - "precision": "float16", - "tensor_parallel": 1, - "use_refit": false - }, - "plugin_config": { - "bert_attention_plugin": "float16", - "context_fmha_enabled": true, - "gemm_plugin": "float16", - "gpt_attention_plugin": false, - "identity_plugin": false, - "layernorm_plugin": false, - "layernorm_quantization_plugin": false, - "nccl_plugin": false, - "smooth_quant_gemm_plugin": false, - "weight_only_quant_matmul_plugin": false - } -} diff --git a/examples/roberta/base_with_attention_plugin_benchmark/config.json b/examples/roberta/base_with_attention_plugin_benchmark/config.json deleted file mode 100644 index 025f0383e..000000000 --- a/examples/roberta/base_with_attention_plugin_benchmark/config.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "builder_config": { - "max_batch_size": 256, - "max_input_len": 512, - "name": "bert", - "precision": "float16", - "tensor_parallel": 1, - "use_refit": false - }, - "plugin_config": { - "bert_attention_plugin": "float16", - "context_fmha_enabled": true, - "gemm_plugin": "float16", - "gpt_attention_plugin": false, - "identity_plugin": false, - "layernorm_plugin": false, - "layernorm_quantization_plugin": false, - "nccl_plugin": false, - "smooth_quant_gemm_plugin": false, - "weight_only_quant_matmul_plugin": false - } -} diff --git a/examples/roberta/build.py b/examples/roberta/build.py index 090fe33ef..dd3107d05 100644 --- a/examples/roberta/build.py +++ b/examples/roberta/build.py @@ -20,7 +20,7 @@ import torch import tensorrt as trt # isort: on -from transformers import BertConfig, BertForQuestionAnswering, BertModel +from transformers import RobertaConfig, RobertaForQuestionAnswering, RobertaModel, RobertaForSequenceClassification import tensorrt_llm from tensorrt_llm.builder import Builder @@ -28,7 +28,7 @@ from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType -from weight import load_from_hf_bert, load_from_hf_qa_bert # isort:skip +from weight import load_from_hf_roberta, load_from_hf_qa_roberta, load_from_hf_cls_roberta # isort:skip def get_engine_name(model, dtype, tp_size, rank): @@ -58,7 +58,7 @@ def parse_arguments(): parser.add_argument('--max_batch_size', type=int, default=256) parser.add_argument('--max_input_len', type=int, default=512) parser.add_argument('--gpus_per_node', type=int, default=8) - parser.add_argument('--output_dir', type=str, default='bert_outputs') + parser.add_argument('--output_dir', type=str, default='roberta_outputs') parser.add_argument('--use_bert_attention_plugin', nargs='?', const='float16', @@ -88,10 +88,10 @@ def parse_arguments(): action='store_true') parser.add_argument( '--model', - default=tensorrt_llm.models.BertModel.__name__, + default=tensorrt_llm.models.RobertaModel.__name__, choices=[ - tensorrt_llm.models.BertModel.__name__, - tensorrt_llm.models.BertForQuestionAnswering.__name__ + tensorrt_llm.models.RobertaModel.__name__, + tensorrt_llm.models.RobertaForQuestionAnswering.__name__ ]) return parser.parse_args() @@ -118,7 +118,7 @@ def parse_arguments(): ) # Initialize model - bert_config = BertConfig( + roberta_config = RobertaConfig( vocab_size=args.vocab_size, hidden_size=args.n_embd, num_hidden_layers=args.n_layer, @@ -130,56 +130,85 @@ def parse_arguments(): ) output_name = 'hidden_states' - if args.model == tensorrt_llm.models.BertModel.__name__: - hf_bert = BertModel(bert_config, add_pooling_layer=False) - tensorrt_llm_bert = tensorrt_llm.models.BertModel( - num_layers=bert_config.num_hidden_layers, - num_heads=bert_config.num_attention_heads, - hidden_size=bert_config.hidden_size, - vocab_size=bert_config.vocab_size, - hidden_act=bert_config.hidden_act, - max_position_embeddings=bert_config.max_position_embeddings, - type_vocab_size=bert_config.type_vocab_size, + if args.model == tensorrt_llm.models.RobertaModel.__name__: + hf_roberta = RobertaModel(roberta_config, add_pooling_layer=False) + tensorrt_llm_roberta = tensorrt_llm.models.RobertaModel( + num_layers=roberta_config.num_hidden_layers, + num_heads=roberta_config.num_attention_heads, + hidden_size=roberta_config.hidden_size, + vocab_size=roberta_config.vocab_size, + hidden_act=roberta_config.hidden_act, + max_position_embeddings=roberta_config.max_position_embeddings, + type_vocab_size=roberta_config.type_vocab_size, + pad_token_id=roberta_config.pad_token_id, mapping=Mapping(world_size=args.world_size, rank=args.rank, tp_size=args.world_size), # TP only dtype=trt_dtype) - load_from_hf_bert( - tensorrt_llm_bert, - hf_bert, - bert_config, + load_from_hf_roberta( + tensorrt_llm_roberta, + hf_roberta, + roberta_config, rank=args.rank, tensor_parallel=args.world_size, fp16=(args.dtype == 'float16'), ) - elif args.model == tensorrt_llm.models.BertForQuestionAnswering.__name__: - hf_bert = BertForQuestionAnswering(bert_config) - tensorrt_llm_bert = tensorrt_llm.models.BertForQuestionAnswering( - num_layers=bert_config.num_hidden_layers, - num_heads=bert_config.num_attention_heads, - hidden_size=bert_config.hidden_size, - vocab_size=bert_config.vocab_size, - hidden_act=bert_config.hidden_act, - max_position_embeddings=bert_config.max_position_embeddings, - type_vocab_size=bert_config.type_vocab_size, + elif args.model == tensorrt_llm.models.RobertaForQuestionAnswering.__name__: + hf_roberta = RobertaForQuestionAnswering(roberta_config) + tensorrt_llm_roberta = tensorrt_llm.models.RobertaForQuestionAnswering( + num_layers=roberta_config.num_hidden_layers, + num_heads=roberta_config.num_attention_heads, + hidden_size=roberta_config.hidden_size, + vocab_size=roberta_config.vocab_size, + hidden_act=roberta_config.hidden_act, + max_position_embeddings=roberta_config.max_position_embeddings, + type_vocab_size=roberta_config.type_vocab_size, + pad_token_id=roberta_config.pad_token_id, num_labels=args. n_labels, # TODO: this might just need to be a constant mapping=Mapping(world_size=args.world_size, rank=args.rank, tp_size=args.world_size), # TP only dtype=trt_dtype) - load_from_hf_qa_bert( - tensorrt_llm_bert, - hf_bert, - bert_config, + load_from_hf_qa_roberta( + tensorrt_llm_roberta, + hf_roberta, + roberta_config, rank=args.rank, tensor_parallel=args.world_size, fp16=(args.dtype == 'float16'), ) output_name = 'logits' + elif args.model == tensorrt_llm.models.RobertaForSequenceClassification.__name__: + hf_roberta = RobertaForSequenceClassification(roberta_config).cuda().to( + torch_dtype).eval() + output_name = "logits" + tensorrt_llm_roberta = tensorrt_llm.models.RobertaForSequenceClassification( + num_layers=roberta_config.num_hidden_layers, + num_heads=roberta_config.num_attention_heads, + hidden_size=roberta_config.hidden_size, + vocab_size=roberta_config.vocab_size, + hidden_act=roberta_config.hidden_act, + max_position_embeddings=roberta_config.max_position_embeddings, + type_vocab_size=roberta_config.type_vocab_size, + pad_token_id=roberta_config.pad_token_id, + num_labels= + 2, # just make it a const here, seems to me not worth as a config + mapping=tensorrt_llm.Mapping( + world_size=args.world_size, + rank=args.rank, + tp_size=args.world_size), # TP only + dtype=trt_dtype) + load_from_hf_cls_roberta(tensorrt_llm_roberta, + hf_roberta, + roberta_config, + rank=args.rank, + tensor_parallel=args.world_size, + fp16=(args.dtype == 'float16')) + else: - assert False, f"Unknown BERT model {args.model}" + assert False, f"Unknown Roberta model {args.model}" # Module -> Network network = builder.create_network() @@ -203,7 +232,7 @@ def parse_arguments(): network.plugin_config.set_nccl_plugin(args.dtype) with net_guard(network): # Prepare - network.set_named_parameters(tensorrt_llm_bert.named_parameters()) + network.set_named_parameters(tensorrt_llm_roberta.named_parameters()) # Forward input_ids = tensorrt_llm.Tensor( @@ -231,7 +260,7 @@ def parse_arguments(): ])) # logits for QA BERT, or hidden_state for vanila BERT - output = tensorrt_llm_bert(input_ids=input_ids, + output = tensorrt_llm_roberta(input_ids=input_ids, input_lengths=input_lengths, token_type_ids=token_type_ids) diff --git a/examples/roberta/large_benchmark/config.json b/examples/roberta/large_benchmark/config.json deleted file mode 100644 index c720c7ac6..000000000 --- a/examples/roberta/large_benchmark/config.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "builder_config": { - "max_batch_size": 256, - "max_input_len": 512, - "name": "bert", - "precision": "float16", - "tensor_parallel": 1, - "use_refit": false - }, - "plugin_config": { - "bert_attention_plugin": false, - "context_fmha_enabled": false, - "gemm_plugin": false, - "gpt_attention_plugin": false, - "identity_plugin": false, - "layernorm_plugin": false, - "layernorm_quantization_plugin": false, - "nccl_plugin": false, - "smooth_quant_gemm_plugin": false, - "weight_only_quant_matmul_plugin": false - } -} diff --git a/examples/roberta/large_with_attention_plugin_benchmark/config.json b/examples/roberta/large_with_attention_plugin_benchmark/config.json deleted file mode 100644 index 025f0383e..000000000 --- a/examples/roberta/large_with_attention_plugin_benchmark/config.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "builder_config": { - "max_batch_size": 256, - "max_input_len": 512, - "name": "bert", - "precision": "float16", - "tensor_parallel": 1, - "use_refit": false - }, - "plugin_config": { - "bert_attention_plugin": "float16", - "context_fmha_enabled": true, - "gemm_plugin": "float16", - "gpt_attention_plugin": false, - "identity_plugin": false, - "layernorm_plugin": false, - "layernorm_quantization_plugin": false, - "nccl_plugin": false, - "smooth_quant_gemm_plugin": false, - "weight_only_quant_matmul_plugin": false - } -} diff --git a/examples/roberta/run.py b/examples/roberta/run.py index 470d26527..3a03dc2ee 100644 --- a/examples/roberta/run.py +++ b/examples/roberta/run.py @@ -23,7 +23,7 @@ import tensorrt_llm from tensorrt_llm import logger -from tensorrt_llm.models import BertForQuestionAnswering, BertModel +from tensorrt_llm.models import RobertaForQuestionAnswering, RobertaModel, RobertaForSequenceClassification from tensorrt_llm.runtime import Session, TensorInfo from build import get_engine_name # isort:skip @@ -43,7 +43,7 @@ def trt_dtype_to_torch(dtype): def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--log_level', type=str, default='info') - parser.add_argument('--engine_dir', type=str, default='bert_outputs') + parser.add_argument('--engine_dir', type=str, default='roberta_outputs') return parser.parse_args() @@ -106,9 +106,11 @@ def parse_arguments(): device='cuda') for t in output_info } - if (model_name == BertModel.__name__): + if (model_name == RobertaModel.__name__): output_name = 'hidden_states' - elif (model_name == BertForQuestionAnswering.__name__): + elif (model_name == RobertaForQuestionAnswering.__name__): + output_name = 'logits' + elif (model_name == RobertaForSequenceClassification.__name__): output_name = 'logits' else: assert False, f"Unknown BERT model {model_name}" diff --git a/examples/roberta/weight.py b/examples/roberta/weight.py index 78920ac41..97cd2c349 100644 --- a/examples/roberta/weight.py +++ b/examples/roberta/weight.py @@ -28,69 +28,69 @@ def split(v, tp_size, idx, dim=0): if tp_size == 1: return v if len(v.shape) == 1: - return np.ascontiguousarray(np.split(v, tp_size)[idx].copy()) + return np.ascontiguousarray(np.split(v, tp_size)[idx]) elif len(v.shape) == 2: - return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx].copy()) + return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx]) return None -def load_from_hf_bert(tensorrt_llm_bert, - hf_bert, - hf_bert_config, +def load_from_hf_roberta(tensorrt_llm_roberta, + hf_roberta, + hf_roberta_config, rank=0, tensor_parallel=1, fp16=False): qkv_weight = [[None, None, None] - for _ in range(hf_bert_config.num_hidden_layers)] + for _ in range(hf_roberta_config.num_hidden_layers)] qkv_bias = [[None, None, None] - for _ in range(hf_bert_config.num_hidden_layers)] + for _ in range(hf_roberta_config.num_hidden_layers)] - for k, v in hf_bert.state_dict().items(): - torch_dtype = torch.float16 if fp16 else torch.float32 + torch_dtype = torch.float16 if fp16 else torch.float32 + for k, v in hf_roberta.state_dict().items(): v = v.to(torch_dtype).cpu().numpy() if 'embeddings.word_embeddings.weight' in k: - tensorrt_llm_bert.embedding.vocab_embedding.weight.value = v + tensorrt_llm_roberta.embedding.vocab_embedding.weight.value = v elif 'embeddings.position_embeddings.weight' in k: - tensorrt_llm_bert.embedding.position_embedding.weight.value = v + tensorrt_llm_roberta.embedding.position_embedding.weight.value = v elif 'embeddings.token_type_embeddings.weight' in k: - tensorrt_llm_bert.embedding.token_embedding.weight.value = v + tensorrt_llm_roberta.embedding.token_embedding.weight.value = v elif 'embeddings.LayerNorm.weight' in k: - tensorrt_llm_bert.embedding.embedding_ln.weight.value = v + tensorrt_llm_roberta.embedding.embedding_ln.weight.value = v elif 'embeddings.LayerNorm.bias' in k: - tensorrt_llm_bert.embedding.embedding_ln.bias.value = v + tensorrt_llm_roberta.embedding.embedding_ln.bias.value = v else: layer_idx = extract_layer_idx(k) if layer_idx is None: continue idx = int(layer_idx) if 'attention.output.dense.weight' in k: - tensorrt_llm_bert.layers[ + tensorrt_llm_roberta.layers[ idx].attention.dense.weight.value = split(v, tensor_parallel, rank, dim=1) elif 'attention.output.dense.bias' in k: - tensorrt_llm_bert.layers[idx].attention.dense.bias.value = v + tensorrt_llm_roberta.layers[idx].attention.dense.bias.value = v elif 'attention.output.LayerNorm.weight' in k: - tensorrt_llm_bert.layers[idx].input_layernorm.weight.value = v + tensorrt_llm_roberta.layers[idx].input_layernorm.weight.value = v elif 'attention.output.LayerNorm.bias' in k: - tensorrt_llm_bert.layers[idx].input_layernorm.bias.value = v + tensorrt_llm_roberta.layers[idx].input_layernorm.bias.value = v elif 'intermediate.dense.weight' in k: - tensorrt_llm_bert.layers[idx].mlp.fc.weight.value = split( + tensorrt_llm_roberta.layers[idx].mlp.fc.weight.value = split( v, tensor_parallel, rank) elif 'intermediate.dense.bias' in k: - tensorrt_llm_bert.layers[idx].mlp.fc.bias.value = split( + tensorrt_llm_roberta.layers[idx].mlp.fc.bias.value = split( v, tensor_parallel, rank) elif 'output.dense.weight' in k: - tensorrt_llm_bert.layers[idx].mlp.proj.weight.value = split( + tensorrt_llm_roberta.layers[idx].mlp.proj.weight.value = split( v, tensor_parallel, rank, dim=1) elif 'output.dense.bias' in k: - tensorrt_llm_bert.layers[idx].mlp.proj.bias.value = v + tensorrt_llm_roberta.layers[idx].mlp.proj.bias.value = v elif 'output.LayerNorm.weight' in k: - tensorrt_llm_bert.layers[idx].post_layernorm.weight.value = v + tensorrt_llm_roberta.layers[idx].post_layernorm.weight.value = v elif 'output.LayerNorm.bias' in k: - tensorrt_llm_bert.layers[idx].post_layernorm.bias.value = v + tensorrt_llm_roberta.layers[idx].post_layernorm.bias.value = v elif 'attention.self.query.weight' in k: qkv_weight[idx][0] = v elif 'attention.self.query.bias' in k: @@ -104,26 +104,48 @@ def load_from_hf_bert(tensorrt_llm_bert, elif 'attention.self.value.bias' in k: qkv_bias[idx][2] = v - for i in range(hf_bert_config.num_hidden_layers): - tensorrt_llm_bert.layers[i].attention.qkv.weight.value = split( + for i in range(hf_roberta_config.num_hidden_layers): + tensorrt_llm_roberta.layers[i].attention.qkv.weight.value = split( np.concatenate(qkv_weight[i]), tensor_parallel, rank) - tensorrt_llm_bert.layers[i].attention.qkv.bias.value = split( + tensorrt_llm_roberta.layers[i].attention.qkv.bias.value = split( np.concatenate(qkv_bias[i]), tensor_parallel, rank) -def load_from_hf_qa_bert(tensorrt_llm_qa_bert, - hf_qa_bert, - hf_bert_config, +def load_from_hf_qa_roberta(tensorrt_llm_qa_roberta, + hf_qa_roberta, + hf_roberta_config, rank=0, tensor_parallel=1, fp16=False): - load_from_hf_bert(tensorrt_llm_qa_bert.bert, hf_qa_bert, hf_bert_config, + load_from_hf_roberta(tensorrt_llm_qa_roberta.roberta, hf_qa_roberta, hf_roberta_config, rank, tensor_parallel, fp16) - states = hf_qa_bert.state_dict() + states = hf_qa_roberta.state_dict() torch_dtype = torch.float16 if fp16 else torch.float32 - tensorrt_llm_qa_bert.qa_outputs.weight.value = states[ + tensorrt_llm_qa_roberta.qa_outputs.weight.value = states[ 'qa_outputs.weight'].to(torch_dtype).cpu().numpy() - tensorrt_llm_qa_bert.qa_outputs.bias.value = states['qa_outputs.bias'].to( + tensorrt_llm_qa_roberta.qa_outputs.bias.value = states['qa_outputs.bias'].to( torch_dtype).cpu().numpy() + +def load_from_hf_cls_roberta(tensorrt_llm_cls_roberta, + hf_qa_roberta, + hf_roberta_config, + rank=0, + tensor_parallel=1, + fp16=False): + load_from_hf_roberta(tensorrt_llm_cls_roberta.roberta, hf_qa_roberta, hf_roberta_config, + rank, tensor_parallel, fp16) + states = hf_qa_roberta.state_dict() + + torch_dtype = torch.float16 if fp16 else torch.float32 + + tensorrt_llm_cls_roberta.classifier.dense.weight.value = states[ + 'classifier.dense.weight'].to(torch_dtype).cpu().numpy() + tensorrt_llm_cls_roberta.classifier.dense.bias.value = states[ + 'classifier.dense.bias'].to(torch_dtype).cpu().numpy() + + tensorrt_llm_cls_roberta.classifier.out_proj.weight.value = states[ + 'classifier.out_proj.weight'].to(torch_dtype).cpu().numpy() + tensorrt_llm_cls_roberta.classifier.out_proj.bias.value = states[ + 'classifier.out_proj.bias'].to(torch_dtype).cpu().numpy() diff --git a/tensorrt_llm/models/__init__.py b/tensorrt_llm/models/__init__.py index 68dd7ea7f..e6c3837ad 100755 --- a/tensorrt_llm/models/__init__.py +++ b/tensorrt_llm/models/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. from .baichuan.model import BaichuanForCausalLM from .bert.model import BertForQuestionAnswering, BertModel, BertForSequenceClassification +from .roberta.model import RobertaForQuestionAnswering, RobertaModel, RobertaForSequenceClassification from .bloom.model import BloomForCausalLM, BloomModel from .chatglm.model import ChatGLMHeadModel, ChatGLMModel from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder @@ -31,6 +32,10 @@ __all__ = [ 'BertModel', 'BertForQuestionAnswering', + 'BertForSequenceClassification', + 'RobertaModel', + 'RobertaForQuestionAnswering', + 'RobertaForSequenceClassification', 'BloomModel', 'BloomForCausalLM', 'FalconForCausalLM', diff --git a/tensorrt_llm/models/roberta/model.py b/tensorrt_llm/models/roberta/model.py index 3e519a0ed..e40306294 100644 --- a/tensorrt_llm/models/roberta/model.py +++ b/tensorrt_llm/models/roberta/model.py @@ -25,15 +25,17 @@ from ...module import Module, ModuleList -class BertEmbedding(Module): +class RobertaEmbedding(Module): def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, + pad_token_id, dtype=None): super().__init__() + self.padding_idx = pad_token_id, self.vocab_embedding = Embedding(vocab_size, hidden_size, dtype=dtype) self.position_embedding = Embedding(max_position_embeddings, hidden_size, @@ -46,23 +48,12 @@ def __init__(self, self.embedding_ln = LayerNorm(normalized_shape=hidden_size, dtype=dtype) def forward(self, input_ids, position_ids=None, token_type_ids=None): - position_ids_buffer = constant( - np.expand_dims( - np.arange(self.max_position_embeddings).astype(np.int32), 0)) - token_type_ids_buffer = constant( np.expand_dims( np.zeros(self.max_position_embeddings).astype(np.int32), 0)) seq_len_2d = concat([1, shape(input_ids, 1)]) - if position_ids is None: - # slice - position_ids = slice(position_ids_buffer, - starts=[0, 0], - sizes=seq_len_2d) - position_ids = expand(position_ids, shape(input_ids)) - if token_type_ids is None: # slice token_type_ids = slice(token_type_ids_buffer, @@ -77,7 +68,7 @@ def forward(self, input_ids, position_ids=None, token_type_ids=None): return x -class BertAttention(Module): +class RobertaAttention(Module): def __init__(self, hidden_size, @@ -149,7 +140,7 @@ def transpose_for_scores(x): return context -class BertEncoderLayer(Module): +class RobertaEncoderLayer(Module): def __init__(self, hidden_size, @@ -163,7 +154,7 @@ def __init__(self, self.input_layernorm = LayerNorm(normalized_shape=hidden_size, dtype=dtype) - self.attention = BertAttention(hidden_size, + self.attention = RobertaAttention(hidden_size, num_attention_heads, max_position_embeddings, tp_group=tp_group, @@ -200,7 +191,7 @@ def forward(self, hidden_states, attention_mask=None, input_lengths=None): return hidden_states -class BertModel(Module): +class RobertaModel(Module): def __init__(self, num_layers, @@ -210,20 +201,24 @@ def __init__(self, hidden_act, max_position_embeddings, type_vocab_size, + pad_token_id, mapping=Mapping(), dtype=None): super().__init__() self.max_position_embeddings = max_position_embeddings self.dtype = dtype - self.embedding = BertEmbedding( + self.padding_idx = pad_token_id + + self.embedding = RobertaEmbedding( vocab_size=vocab_size, hidden_size=hidden_size, max_position_embeddings=max_position_embeddings, type_vocab_size=type_vocab_size, + pad_token_id=pad_token_id, dtype=dtype) self.layers = ModuleList([ - BertEncoderLayer(hidden_size=hidden_size, + RobertaEncoderLayer(hidden_size=hidden_size, num_attention_heads=num_heads, max_position_embeddings=max_position_embeddings, hidden_act=hidden_act, @@ -238,9 +233,7 @@ def forward(self, token_type_ids=None, position_ids=None, hidden_states=None): - hidden_states = self.embedding(input_ids, position_ids, token_type_ids) - # creat extended_attention_mask as https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py seq_len_2d = concat([1, shape(input_ids, 1)]) position_ids_buffer = constant( np.expand_dims( @@ -249,15 +242,27 @@ def forward(self, starts=[0, 0], sizes=seq_len_2d) tmp_position_ids = expand(tmp_position_ids, shape(input_ids)) #BxL + tmp_input_lengths = unsqueeze(input_lengths, 1) #Bx1 tmp_input_lengths = expand(tmp_input_lengths, shape(input_ids)) #BxL mask = tmp_position_ids < tmp_input_lengths # BxL mask = cast(mask, 'int32') + # self.register_network_output('attention_mask', mask) + # create position ids like create_position_ids_from_input_ids in https://github.com/huggingface/transformers/blob/main/src/transformers/models/roberta/modeling_roberta.py + if position_ids is None: + position_ids = (tmp_position_ids + 1) * mask + position_ids = position_ids + self.padding_idx + # self.register_network_output('position_ids', position_ids) + + # creat extended_attention_mask as https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py extended_attention_mask = unsqueeze(mask, 1) extended_attention_mask = unsqueeze(extended_attention_mask, 1) # Bx1x1xL extended_attention_mask = (1 - extended_attention_mask) * -214748364 # a small negative number in int32 range extended_attention_mask = cast(extended_attention_mask, self.dtype) + + hidden_states = self.embedding(input_ids, position_ids, token_type_ids) + for layer in self.layers: hidden_states = layer(hidden_states=hidden_states, input_lengths=input_lengths, @@ -266,7 +271,7 @@ def forward(self, return hidden_states -class BertForQuestionAnswering(Module): +class RobertaForQuestionAnswering(Module): def __init__(self, num_layers, @@ -276,17 +281,19 @@ def __init__(self, hidden_act, max_position_embeddings, type_vocab_size, + pad_token_id, num_labels=2, mapping=Mapping(), dtype=None): super().__init__() - self.bert = BertModel(num_layers=num_layers, + self.roberta= RobertaModel(num_layers=num_layers, num_heads=num_heads, hidden_size=hidden_size, vocab_size=vocab_size, hidden_act=hidden_act, max_position_embeddings=max_position_embeddings, type_vocab_size=type_vocab_size, + pad_token_id=pad_token_id, mapping=mapping, dtype=dtype) self.num_labels = num_labels @@ -299,7 +306,7 @@ def forward(self, position_ids=None, hidden_states=None): - hidden_states = self.bert.forward(input_ids=input_ids, + hidden_states = self.roberta.forward(input_ids=input_ids, input_lengths=input_lengths, token_type_ids=token_type_ids, position_ids=position_ids, @@ -309,7 +316,7 @@ def forward(self, return logits -class BertPooler(Module): +class RobertaPooler(Module): def __init__(self, hidden_size, dtype): super().__init__() self.dense = Linear(hidden_size, hidden_size, dtype=dtype) @@ -323,8 +330,22 @@ def forward(self, hidden_states): pooled_output = self.activation(pooled_output) return pooled_output +class RobertaClassificationHead(Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, hidden_size, dtype, num_labels): + super().__init__() + self.dense = Linear(hidden_size, hidden_size, dtype=dtype) + self.out_proj = Linear(hidden_size, num_labels) + + def forward(self, features, **kwargs): + x = select(features, 1, 0) + x = self.dense(x) + x = ACT2FN['tanh'](x) + x = self.out_proj(x) + return x -class BertForSequenceClassification(Module): +class RobertaForSequenceClassification(Module): def __init__(self, num_layers, @@ -334,22 +355,23 @@ def __init__(self, hidden_act, max_position_embeddings, type_vocab_size, + pad_token_id, num_labels=2, mapping=Mapping(), dtype=None): super().__init__() - self.bert = BertModel(num_layers=num_layers, + self.roberta= RobertaModel(num_layers=num_layers, num_heads=num_heads, hidden_size=hidden_size, vocab_size=vocab_size, hidden_act=hidden_act, max_position_embeddings=max_position_embeddings, type_vocab_size=type_vocab_size, + pad_token_id=pad_token_id, mapping=mapping, dtype=dtype) - self.num_labels = num_labels - self.pooler = BertPooler(hidden_size=hidden_size, dtype=dtype) - self.classifier = Linear(hidden_size, num_labels, dtype=dtype) + + self.classifier = RobertaClassificationHead(hidden_size=hidden_size, num_labels=num_labels, dtype=dtype) def forward(self, input_ids=None, @@ -358,12 +380,11 @@ def forward(self, position_ids=None, hidden_states=None): - hidden_states = self.bert.forward(input_ids=input_ids, + hidden_states = self.roberta.forward(input_ids=input_ids, input_lengths=input_lengths, token_type_ids=token_type_ids, position_ids=position_ids, hidden_states=hidden_states) - pooled_output = self.pooler(hidden_states) - logits = self.classifier(pooled_output) + logits = self.classifier(hidden_states) return logits diff --git a/tests/model/test_bert.py b/tests/model/test_bert.py index 29f1c4d04..89af12a39 100644 --- a/tests/model/test_bert.py +++ b/tests/model/test_bert.py @@ -397,7 +397,7 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, input_ids_with_padding['attention_mask'], device=input_ids.device, dtype=torch.int32) else: - input_ids = torch.randint(100, (batch_size, input_len)).int().cuda() + input_ids = torch.randint(bert_config.vocab_size, (batch_size, input_len)).int().cuda() input_lengths = input_len * torch.ones( (batch_size, ), dtype=torch.int32, device='cuda') diff --git a/tests/model/test_roberta.py b/tests/model/test_roberta.py index 29f1c4d04..629ab6f0f 100644 --- a/tests/model/test_roberta.py +++ b/tests/model/test_roberta.py @@ -25,9 +25,10 @@ import tensorrt as trt # isort: on from parameterized import parameterized -from transformers import BertConfig, BertForQuestionAnswering, BertModel, BertForSequenceClassification, AutoTokenizer +from transformers import RobertaConfig, RobertaForQuestionAnswering, RobertaModel, RobertaForSequenceClassification, AutoTokenizer import tensorrt_llm +from tensorrt_llm.models import roberta import tensorrt_llm.runtime from tensorrt_llm import Builder from tensorrt_llm._utils import trt_dtype_to_torch, str_dtype_to_trt @@ -54,63 +55,63 @@ def split(v, tp_size, idx, dim=0): return None -def load_from_hf_bert(tensorrt_llm_bert, - hf_bert, - hf_bert_config, +def load_from_hf_roberta(tensorrt_llm_roberta, + hf_roberta, + hf_roberta_config, rank=0, tensor_parallel=1, fp16=False): qkv_weight = [[None, None, None] - for _ in range(hf_bert_config.num_hidden_layers)] + for _ in range(hf_roberta_config.num_hidden_layers)] qkv_bias = [[None, None, None] - for _ in range(hf_bert_config.num_hidden_layers)] + for _ in range(hf_roberta_config.num_hidden_layers)] torch_dtype = torch.float16 if fp16 else torch.float32 - for k, v in hf_bert.state_dict().items(): + for k, v in hf_roberta.state_dict().items(): v = v.to(torch_dtype).cpu().numpy() if 'embeddings.word_embeddings.weight' in k: - tensorrt_llm_bert.embedding.vocab_embedding.weight.value = v + tensorrt_llm_roberta.embedding.vocab_embedding.weight.value = v elif 'embeddings.position_embeddings.weight' in k: - tensorrt_llm_bert.embedding.position_embedding.weight.value = v + tensorrt_llm_roberta.embedding.position_embedding.weight.value = v elif 'embeddings.token_type_embeddings.weight' in k: - tensorrt_llm_bert.embedding.token_embedding.weight.value = v + tensorrt_llm_roberta.embedding.token_embedding.weight.value = v elif 'embeddings.LayerNorm.weight' in k: - tensorrt_llm_bert.embedding.embedding_ln.weight.value = v + tensorrt_llm_roberta.embedding.embedding_ln.weight.value = v elif 'embeddings.LayerNorm.bias' in k: - tensorrt_llm_bert.embedding.embedding_ln.bias.value = v + tensorrt_llm_roberta.embedding.embedding_ln.bias.value = v else: layer_idx = extract_layer_idx(k) if layer_idx is None: continue idx = int(layer_idx) if 'attention.output.dense.weight' in k: - tensorrt_llm_bert.layers[ + tensorrt_llm_roberta.layers[ idx].attention.dense.weight.value = split(v, tensor_parallel, rank, dim=1) elif 'attention.output.dense.bias' in k: - tensorrt_llm_bert.layers[idx].attention.dense.bias.value = v + tensorrt_llm_roberta.layers[idx].attention.dense.bias.value = v elif 'attention.output.LayerNorm.weight' in k: - tensorrt_llm_bert.layers[idx].input_layernorm.weight.value = v + tensorrt_llm_roberta.layers[idx].input_layernorm.weight.value = v elif 'attention.output.LayerNorm.bias' in k: - tensorrt_llm_bert.layers[idx].input_layernorm.bias.value = v + tensorrt_llm_roberta.layers[idx].input_layernorm.bias.value = v elif 'intermediate.dense.weight' in k: - tensorrt_llm_bert.layers[idx].mlp.fc.weight.value = split( + tensorrt_llm_roberta.layers[idx].mlp.fc.weight.value = split( v, tensor_parallel, rank) elif 'intermediate.dense.bias' in k: - tensorrt_llm_bert.layers[idx].mlp.fc.bias.value = split( + tensorrt_llm_roberta.layers[idx].mlp.fc.bias.value = split( v, tensor_parallel, rank) elif 'output.dense.weight' in k: - tensorrt_llm_bert.layers[idx].mlp.proj.weight.value = split( + tensorrt_llm_roberta.layers[idx].mlp.proj.weight.value = split( v, tensor_parallel, rank, dim=1) elif 'output.dense.bias' in k: - tensorrt_llm_bert.layers[idx].mlp.proj.bias.value = v + tensorrt_llm_roberta.layers[idx].mlp.proj.bias.value = v elif 'output.LayerNorm.weight' in k: - tensorrt_llm_bert.layers[idx].post_layernorm.weight.value = v + tensorrt_llm_roberta.layers[idx].post_layernorm.weight.value = v elif 'output.LayerNorm.bias' in k: - tensorrt_llm_bert.layers[idx].post_layernorm.bias.value = v + tensorrt_llm_roberta.layers[idx].post_layernorm.bias.value = v elif 'attention.self.query.weight' in k: qkv_weight[idx][0] = v elif 'attention.self.query.bias' in k: @@ -124,58 +125,58 @@ def load_from_hf_bert(tensorrt_llm_bert, elif 'attention.self.value.bias' in k: qkv_bias[idx][2] = v - for i in range(hf_bert_config.num_hidden_layers): - tensorrt_llm_bert.layers[i].attention.qkv.weight.value = split( + for i in range(hf_roberta_config.num_hidden_layers): + tensorrt_llm_roberta.layers[i].attention.qkv.weight.value = split( np.concatenate(qkv_weight[i]), tensor_parallel, rank) - tensorrt_llm_bert.layers[i].attention.qkv.bias.value = split( + tensorrt_llm_roberta.layers[i].attention.qkv.bias.value = split( np.concatenate(qkv_bias[i]), tensor_parallel, rank) -def load_from_hf_qa_bert(tensorrt_llm_qa_bert, - hf_qa_bert, - hf_bert_config, +def load_from_hf_qa_roberta(tensorrt_llm_qa_roberta, + hf_qa_roberta, + hf_roberta_config, rank=0, tensor_parallel=1, fp16=False): - load_from_hf_bert(tensorrt_llm_qa_bert.bert, hf_qa_bert, hf_bert_config, + load_from_hf_roberta(tensorrt_llm_qa_roberta.roberta, hf_qa_roberta, hf_roberta_config, rank, tensor_parallel, fp16) - states = hf_qa_bert.state_dict() + states = hf_qa_roberta.state_dict() torch_dtype = torch.float16 if fp16 else torch.float32 - tensorrt_llm_qa_bert.qa_outputs.weight.value = states[ + tensorrt_llm_qa_roberta.qa_outputs.weight.value = states[ 'qa_outputs.weight'].to(torch_dtype).cpu().numpy() - tensorrt_llm_qa_bert.qa_outputs.bias.value = states['qa_outputs.bias'].to( + tensorrt_llm_qa_roberta.qa_outputs.bias.value = states['qa_outputs.bias'].to( torch_dtype).cpu().numpy() -def load_from_hf_cls_bert(tensorrt_llm_qa_bert, - hf_qa_bert, - hf_bert_config, +def load_from_hf_cls_roberta(tensorrt_llm_cls_roberta, + hf_qa_roberta, + hf_roberta_config, rank=0, tensor_parallel=1, fp16=False): - load_from_hf_bert(tensorrt_llm_qa_bert.bert, hf_qa_bert, hf_bert_config, + load_from_hf_roberta(tensorrt_llm_cls_roberta.roberta, hf_qa_roberta, hf_roberta_config, rank, tensor_parallel, fp16) - states = hf_qa_bert.state_dict() + states = hf_qa_roberta.state_dict() torch_dtype = torch.float16 if fp16 else torch.float32 - tensorrt_llm_qa_bert.pooler.dense.weight.value = states[ - 'bert.pooler.dense.weight'].to(torch_dtype).cpu().numpy() - tensorrt_llm_qa_bert.pooler.dense.bias.value = states[ - 'bert.pooler.dense.bias'].to(torch_dtype).cpu().numpy() + tensorrt_llm_cls_roberta.classifier.dense.weight.value = states[ + 'classifier.dense.weight'].to(torch_dtype).cpu().numpy() + tensorrt_llm_cls_roberta.classifier.dense.bias.value = states[ + 'classifier.dense.bias'].to(torch_dtype).cpu().numpy() - tensorrt_llm_qa_bert.classifier.weight.value = states[ - 'classifier.weight'].to(torch_dtype).cpu().numpy() - tensorrt_llm_qa_bert.classifier.bias.value = states[ - 'classifier.bias'].to(torch_dtype).cpu().numpy() + tensorrt_llm_cls_roberta.classifier.out_proj.weight.value = states[ + 'classifier.out_proj.weight'].to(torch_dtype).cpu().numpy() + tensorrt_llm_cls_roberta.classifier.out_proj.bias.value = states[ + 'classifier.out_proj.bias'].to(torch_dtype).cpu().numpy() -class TestBert(unittest.TestCase): +class TestRoberta(unittest.TestCase): def load_test_cases(): - models = [BertForSequenceClassification.__name__, BertModel.__name__, BertForQuestionAnswering.__name__] - model_dirs = ['', 'bert-base-uncased'] # add more tests for read data. + models = [RobertaForSequenceClassification.__name__, RobertaModel.__name__, RobertaForQuestionAnswering.__name__] + model_dirs = ['', 'roberta-base'] # add more tests for read data. test_cases = [] test_cases += product(models, [False], [False], [False], [ContextFMHAType.disabled], ['float32'], model_dirs) @@ -193,7 +194,7 @@ def custom_name_func(testcase_func, param_num, param): ) @parameterized.expand(load_test_cases, name_func=custom_name_func) - def test_bert(self, model, use_refit, use_plugin, fast_building, + def test_roberta(self, model, use_refit, use_plugin, fast_building, context_fmha_type, dtype, model_dir): tensorrt_llm.logger.set_level('error') fp16 = (dtype == 'float16') @@ -248,16 +249,16 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, ])) # Initialize model if model_dir: - bert_config = BertConfig.from_pretrained(model_dir, torch_dtype=torch_dtype) - vocab_size = bert_config.vocab_size - hidden_size = bert_config.hidden_size - num_layers = bert_config.num_hidden_layers - num_heads = bert_config.num_attention_heads - hidden_size = bert_config.intermediate_size // 4 - hidden_act = bert_config.hidden_act - max_position_embeddings = bert_config.max_position_embeddings + roberta_config = RobertaConfig.from_pretrained(model_dir, torch_dtype=torch_dtype) + vocab_size = roberta_config.vocab_size + hidden_size = roberta_config.hidden_size + num_layers = roberta_config.num_hidden_layers + num_heads = roberta_config.num_attention_heads + hidden_size = roberta_config.intermediate_size // 4 + hidden_act = roberta_config.hidden_act + max_position_embeddings = roberta_config.max_position_embeddings else: - bert_config = BertConfig( + roberta_config = RobertaConfig( vocab_size=vocab_size, hidden_size=hidden_size, num_hidden_layers=num_layers, @@ -269,49 +270,51 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, ) output_name = "hidden_states" - if model == BertModel.__name__: + if model == RobertaModel.__name__: if model_dir: - hf_bert = BertModel.from_pretrained( + hf_roberta = RobertaModel.from_pretrained( model_dir).cuda().to(torch_dtype).eval() else: - hf_bert = BertModel( - bert_config, + hf_roberta = RobertaModel( + roberta_config, add_pooling_layer=False).cuda().to(torch_dtype).eval() - tensorrt_llm_bert = tensorrt_llm.models.BertModel( + tensorrt_llm_roberta = tensorrt_llm.models.RobertaModel( num_layers=num_layers, num_heads=num_heads, hidden_size=hidden_size, vocab_size=vocab_size, hidden_act=hidden_act, max_position_embeddings=max_position_embeddings, - type_vocab_size=bert_config.type_vocab_size, + type_vocab_size=roberta_config.type_vocab_size, + pad_token_id=roberta_config.pad_token_id, mapping=tensorrt_llm.Mapping( world_size=world_size, rank=rank, tp_size=world_size), # TP only dtype=trt_dtype) - load_from_hf_bert(tensorrt_llm_bert, - hf_bert, - bert_config, + load_from_hf_roberta(tensorrt_llm_roberta, + hf_roberta, + roberta_config, rank=rank, tensor_parallel=world_size, fp16=fp16) - elif model == BertForQuestionAnswering.__name__: + elif model == RobertaForQuestionAnswering.__name__: if model_dir: - hf_bert = BertForQuestionAnswering.from_pretrained( + hf_roberta = RobertaForQuestionAnswering.from_pretrained( model_dir).cuda().to(torch_dtype).eval() else: - hf_bert = BertForQuestionAnswering(bert_config).cuda().to( + hf_roberta = RobertaForQuestionAnswering(roberta_config).cuda().to( torch_dtype).eval() output_name = "logits" - tensorrt_llm_bert = tensorrt_llm.models.BertForQuestionAnswering( + tensorrt_llm_roberta = tensorrt_llm.models.RobertaForQuestionAnswering( num_layers=num_layers, num_heads=num_heads, hidden_size=hidden_size, vocab_size=vocab_size, hidden_act=hidden_act, max_position_embeddings=max_position_embeddings, - type_vocab_size=bert_config.type_vocab_size, + type_vocab_size=roberta_config.type_vocab_size, + pad_token_id=roberta_config.pad_token_id, num_labels= 2, # just make it a const here, seems to me not worth as a config mapping=tensorrt_llm.Mapping( @@ -319,28 +322,29 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, rank=rank, tp_size=world_size), # TP only dtype=trt_dtype) - load_from_hf_qa_bert(tensorrt_llm_bert, - hf_bert, - bert_config, + load_from_hf_qa_roberta(tensorrt_llm_roberta, + hf_roberta, + roberta_config, rank=rank, tensor_parallel=world_size, fp16=fp16) - elif model == BertForSequenceClassification.__name__: + elif model == RobertaForSequenceClassification.__name__: if model_dir: - hf_bert = BertForSequenceClassification.from_pretrained( + hf_roberta = RobertaForSequenceClassification.from_pretrained( model_dir).cuda().to(torch_dtype).eval() else: - hf_bert = BertForSequenceClassification(bert_config).cuda().to( + hf_roberta = RobertaForSequenceClassification(roberta_config).cuda().to( torch_dtype).eval() output_name = "logits" - tensorrt_llm_bert = tensorrt_llm.models.BertForSequenceClassification( + tensorrt_llm_roberta = tensorrt_llm.models.RobertaForSequenceClassification( num_layers=num_layers, num_heads=num_heads, hidden_size=hidden_size, vocab_size=vocab_size, hidden_act=hidden_act, max_position_embeddings=max_position_embeddings, - type_vocab_size=bert_config.type_vocab_size, + type_vocab_size=roberta_config.type_vocab_size, + pad_token_id=roberta_config.pad_token_id, num_labels= 2, # just make it a const here, seems to me not worth as a config mapping=tensorrt_llm.Mapping( @@ -348,9 +352,9 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, rank=rank, tp_size=world_size), # TP only dtype=trt_dtype) - load_from_hf_cls_bert(tensorrt_llm_bert, - hf_bert, - bert_config, + load_from_hf_cls_roberta(tensorrt_llm_roberta, + hf_roberta, + roberta_config, rank=rank, tensor_parallel=world_size, fp16=fp16) @@ -359,18 +363,21 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, assert False, f"Unknown model {model}" # Prepare network.set_named_parameters( - tensorrt_llm_bert.named_parameters()) + tensorrt_llm_roberta.named_parameters()) # Forward - output = tensorrt_llm_bert(input_ids=input_ids, + output = tensorrt_llm_roberta(input_ids=input_ids, input_lengths=input_lengths) # Mark outputs output_dtype = trt.float16 if fp16 else trt.float32 output.mark_output(output_name, output_dtype) - for k, v in tensorrt_llm_bert.named_network_outputs(): - network._mark_output(v, k, str_dtype_to_trt(dtype)) + for k, v in tensorrt_llm_roberta.named_network_outputs(): + if '_ids' in k or 'attention_mask' in k: + network._mark_output(v, k, str_dtype_to_trt('int32')) + else: + network._mark_output(v, k, str_dtype_to_trt(dtype)) # Build engine engine_buffer = builder.build_engine(network, builder_config) @@ -397,7 +404,8 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, input_ids_with_padding['attention_mask'], device=input_ids.device, dtype=torch.int32) else: - input_ids = torch.randint(100, (batch_size, input_len)).int().cuda() + # since 1 is default padding toekn id, we need to be careful about use 1. + input_ids = torch.randint(2, roberta_config.vocab_size, (batch_size, input_len)).int().cuda() input_lengths = input_len * torch.ones( (batch_size, ), dtype=torch.int32, device='cuda') @@ -426,14 +434,14 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, with torch.no_grad(): if model_dir: - hf_outputs = hf_bert.forward( + hf_outputs = hf_roberta.forward( input_ids=input_ids, attention_mask=attention_mask) else: - hf_outputs = hf_bert.forward(input_ids) + hf_outputs = hf_roberta.forward(input_ids) torch.cuda.synchronize() - if model == BertModel.__name__: + if model == RobertaModel.__name__: ref = hf_outputs.last_hidden_state if use_plugin and model_dir: # when we use_plugin and have real-data model_dir and input @@ -446,7 +454,7 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, res.cpu().numpy(), atol=1e-2, rtol=1e-2) - elif model == BertForQuestionAnswering.__name__: + elif model == RobertaForQuestionAnswering.__name__: res_start_logits, res_end_logits = torch.split(res, 1, -1) res_start_logits = res_start_logits.squeeze() res_end_logits = res_end_logits.squeeze() @@ -467,7 +475,7 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, np.testing.assert_allclose(ref_end_logits.cpu().numpy(), res_end_logits.cpu().numpy(), atol=1.5e-2) - elif model == BertForSequenceClassification.__name__: + elif model == RobertaForSequenceClassification.__name__: ref = hf_outputs.logits np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy(), From 2446406c40fcaa54290921bb5dd7cf4af5117b60 Mon Sep 17 00:00:00 2001 From: erenup Date: Sun, 31 Dec 2023 11:59:24 +0000 Subject: [PATCH 4/4] fix args in build.py --- examples/roberta/build.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/roberta/build.py b/examples/roberta/build.py index dd3107d05..8e8595219 100644 --- a/examples/roberta/build.py +++ b/examples/roberta/build.py @@ -91,7 +91,8 @@ def parse_arguments(): default=tensorrt_llm.models.RobertaModel.__name__, choices=[ tensorrt_llm.models.RobertaModel.__name__, - tensorrt_llm.models.RobertaForQuestionAnswering.__name__ + tensorrt_llm.models.RobertaForQuestionAnswering.__name__, + tensorrt_llm.models.RobertaForSequenceClassification.__name__ ]) return parser.parse_args()