diff --git a/examples/roberta/.gitignore b/examples/roberta/.gitignore new file mode 100644 index 000000000..70df3ea68 --- /dev/null +++ b/examples/roberta/.gitignore @@ -0,0 +1,2 @@ +roberta* +*.log diff --git a/examples/roberta/README.md b/examples/roberta/README.md new file mode 100644 index 000000000..bf3dd5952 --- /dev/null +++ b/examples/roberta/README.md @@ -0,0 +1,54 @@ +# Roberta + +This document explains how to build the [Roberta](https://huggingface.co/docs/transformers/model_doc/roberta) model using TensorRT-LLM. It also describes how to run on a single GPU and two GPUs. + +## Overview + +The TensorRT-LLM Roberta implementation can be found in [`tensorrt_llm/models/roberta/model.py`](../../tensorrt_llm/models/roberta/model.py). The TensorRT-LLM Roberta example +code is located in [`examples/roberta`](./). There are four main files in that folder: + + * [`build.py`](./build.py) to build the [TensorRT](https://developer.nvidia.com/tensorrt) engine(s) needed to run the Roberta model, + * [`run.py`](./run.py) to run the inference on an input text, + +## Build and run Roberta on a single GPU + +In this example, TensorRT-LLM builds TensorRT engine(s) from the [HuggingFace Roberta](https://huggingface.co/docs/transformers/model_doc/roberta) model. +Use the following command to build the TensorRT engine: + +```bash +python3 build.py --dtype=float16 --log_level=verbose + +# Enable the special TensorRT-LLM Roberta Attention plugin (--use_bert_attention_plugin) to increase runtime performance. +python3 build.py --dtype=float16 --log_level=verbose --use_bert_attention_plugin float16 +# Enable half accumulation for attention BMM1 (applied to unfused MHA plugins) +python3 build.py --dtype=float16 --log_level=verbose --use_bert_attention_plugin float16 --enable_qk_half_accum +``` + +The following command can be used to run the Roberta model on a single GPU: + +```bash +python3 run.py +``` + +#### Fused MultiHead Attention (FMHA) + +You can enable the FMHA kernels for Roberta by adding `--enable_context_fmha` to the invocation of `build.py`. Note that it is disabled by default because of possible accuracy issues due to the use of Flash Attention. + +If you find that the default fp16 accumulation (`--enable_context_fmha`) cannot meet the requirement, you can try to enable fp32 accumulation by adding `--enable_context_fmha_fp32_acc`. However, it is expected to see performance drop. + +Note `--enable_context_fmha` / `--enable_context_fmha_fp32_acc` has to be used together with `--use_bert_attention_plugin float16`. + +## Build and run Roberta on two GPUs + +The following two commands can be used to build TensorRT engines to run Roberta on two GPUs. The first command builds one engine for the first GPU. The second command builds another engine for the second GPU. + +```bash +python3 build.py --world_size=2 --rank=0 +python3 build.py --world_size=2 --rank=1 +``` + +The following command can be used to run the inference on 2 GPUs. It uses MPI with `mpirun`. + +```bash +mpirun -n 2 python3 run.py +``` diff --git a/examples/roberta/build.py b/examples/roberta/build.py new file mode 100644 index 000000000..8e8595219 --- /dev/null +++ b/examples/roberta/build.py @@ -0,0 +1,281 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from collections import OrderedDict + +# isort: off +import torch +import tensorrt as trt +# isort: on +from transformers import RobertaConfig, RobertaForQuestionAnswering, RobertaModel, RobertaForSequenceClassification + +import tensorrt_llm +from tensorrt_llm.builder import Builder +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.network import net_guard +from tensorrt_llm.plugin.plugin import ContextFMHAType + +from weight import load_from_hf_roberta, load_from_hf_qa_roberta, load_from_hf_cls_roberta # isort:skip + + +def get_engine_name(model, dtype, tp_size, rank): + return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', + type=int, + default=1, + help='Tensor parallelism size') + parser.add_argument('--rank', type=int, default=0) + parser.add_argument('--dtype', + type=str, + default='float16', + choices=['float16', 'float32']) + parser.add_argument('--timing_cache', type=str, default='model.cache') + parser.add_argument('--log_level', type=str, default='info') + parser.add_argument('--vocab_size', type=int, default=51200) + parser.add_argument('--n_labels', type=int, default=2) + parser.add_argument('--n_layer', type=int, default=24) + parser.add_argument('--n_positions', type=int, default=1024) + parser.add_argument('--n_embd', type=int, default=1024) + parser.add_argument('--n_head', type=int, default=16) + parser.add_argument('--hidden_act', type=str, default='gelu') + parser.add_argument('--max_batch_size', type=int, default=256) + parser.add_argument('--max_input_len', type=int, default=512) + parser.add_argument('--gpus_per_node', type=int, default=8) + parser.add_argument('--output_dir', type=str, default='roberta_outputs') + parser.add_argument('--use_bert_attention_plugin', + nargs='?', + const='float16', + type=str, + default=False, + choices=['float16', 'float32']) + parser.add_argument('--use_gemm_plugin', + nargs='?', + const='float16', + type=str, + default=False, + choices=['float16', 'float32']) + parser.add_argument('--use_layernorm_plugin', + nargs='?', + const='float16', + type=str, + default=False, + choices=['float16', 'float32']) + parser.add_argument('--enable_qk_half_accum', + default=False, + action='store_true') + parser.add_argument('--enable_context_fmha', + default=False, + action='store_true') + parser.add_argument('--enable_context_fmha_fp32_acc', + default=False, + action='store_true') + parser.add_argument( + '--model', + default=tensorrt_llm.models.RobertaModel.__name__, + choices=[ + tensorrt_llm.models.RobertaModel.__name__, + tensorrt_llm.models.RobertaForQuestionAnswering.__name__, + tensorrt_llm.models.RobertaForSequenceClassification.__name__ + ]) + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_arguments() + tensorrt_llm.logger.set_level(args.log_level) + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + bs_range = [1, (args.max_batch_size + 1) // 2, args.max_batch_size] + inlen_range = [1, (args.max_input_len + 1) // 2, args.max_input_len] + torch_dtype = torch.float16 if args.dtype == 'float16' else torch.float32 + trt_dtype = trt.float16 if args.dtype == 'float16' else trt.float32 + + builder = Builder() + builder_config = builder.create_builder_config( + name=args.model, + precision=args.dtype, + timing_cache=args.timing_cache, + tensor_parallel=args.world_size, # TP only + max_batch_size=args.max_batch_size, + max_input_len=args.max_input_len, + ) + # Initialize model + + roberta_config = RobertaConfig( + vocab_size=args.vocab_size, + hidden_size=args.n_embd, + num_hidden_layers=args.n_layer, + num_attention_heads=args.n_head, + intermediate_size=4 * args.n_embd, + hidden_act=args.hidden_act, + max_position_embeddings=args.n_positions, + torch_dtype=torch_dtype, + ) + + output_name = 'hidden_states' + if args.model == tensorrt_llm.models.RobertaModel.__name__: + hf_roberta = RobertaModel(roberta_config, add_pooling_layer=False) + tensorrt_llm_roberta = tensorrt_llm.models.RobertaModel( + num_layers=roberta_config.num_hidden_layers, + num_heads=roberta_config.num_attention_heads, + hidden_size=roberta_config.hidden_size, + vocab_size=roberta_config.vocab_size, + hidden_act=roberta_config.hidden_act, + max_position_embeddings=roberta_config.max_position_embeddings, + type_vocab_size=roberta_config.type_vocab_size, + pad_token_id=roberta_config.pad_token_id, + mapping=Mapping(world_size=args.world_size, + rank=args.rank, + tp_size=args.world_size), # TP only + dtype=trt_dtype) + load_from_hf_roberta( + tensorrt_llm_roberta, + hf_roberta, + roberta_config, + rank=args.rank, + tensor_parallel=args.world_size, + fp16=(args.dtype == 'float16'), + ) + + elif args.model == tensorrt_llm.models.RobertaForQuestionAnswering.__name__: + hf_roberta = RobertaForQuestionAnswering(roberta_config) + tensorrt_llm_roberta = tensorrt_llm.models.RobertaForQuestionAnswering( + num_layers=roberta_config.num_hidden_layers, + num_heads=roberta_config.num_attention_heads, + hidden_size=roberta_config.hidden_size, + vocab_size=roberta_config.vocab_size, + hidden_act=roberta_config.hidden_act, + max_position_embeddings=roberta_config.max_position_embeddings, + type_vocab_size=roberta_config.type_vocab_size, + pad_token_id=roberta_config.pad_token_id, + num_labels=args. + n_labels, # TODO: this might just need to be a constant + mapping=Mapping(world_size=args.world_size, + rank=args.rank, + tp_size=args.world_size), # TP only + dtype=trt_dtype) + load_from_hf_qa_roberta( + tensorrt_llm_roberta, + hf_roberta, + roberta_config, + rank=args.rank, + tensor_parallel=args.world_size, + fp16=(args.dtype == 'float16'), + ) + output_name = 'logits' + elif args.model == tensorrt_llm.models.RobertaForSequenceClassification.__name__: + hf_roberta = RobertaForSequenceClassification(roberta_config).cuda().to( + torch_dtype).eval() + output_name = "logits" + tensorrt_llm_roberta = tensorrt_llm.models.RobertaForSequenceClassification( + num_layers=roberta_config.num_hidden_layers, + num_heads=roberta_config.num_attention_heads, + hidden_size=roberta_config.hidden_size, + vocab_size=roberta_config.vocab_size, + hidden_act=roberta_config.hidden_act, + max_position_embeddings=roberta_config.max_position_embeddings, + type_vocab_size=roberta_config.type_vocab_size, + pad_token_id=roberta_config.pad_token_id, + num_labels= + 2, # just make it a const here, seems to me not worth as a config + mapping=tensorrt_llm.Mapping( + world_size=args.world_size, + rank=args.rank, + tp_size=args.world_size), # TP only + dtype=trt_dtype) + load_from_hf_cls_roberta(tensorrt_llm_roberta, + hf_roberta, + roberta_config, + rank=args.rank, + tensor_parallel=args.world_size, + fp16=(args.dtype == 'float16')) + + else: + assert False, f"Unknown Roberta model {args.model}" + + # Module -> Network + network = builder.create_network() + if args.use_bert_attention_plugin: + network.plugin_config.set_bert_attention_plugin( + dtype=args.use_bert_attention_plugin) + if args.use_gemm_plugin: + network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin) + if args.use_layernorm_plugin: + network.plugin_config.set_layernorm_plugin( + dtype=args.use_layernorm_plugin) + if args.enable_qk_half_accum: + network.plugin_config.enable_qk_half_accum() + assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc) + if args.enable_context_fmha: + network.plugin_config.set_context_fmha(ContextFMHAType.enabled) + if args.enable_context_fmha_fp32_acc: + network.plugin_config.set_context_fmha( + ContextFMHAType.enabled_with_fp32_acc) + if args.world_size > 1: + network.plugin_config.set_nccl_plugin(args.dtype) + with net_guard(network): + # Prepare + network.set_named_parameters(tensorrt_llm_roberta.named_parameters()) + + # Forward + input_ids = tensorrt_llm.Tensor( + name='input_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([('batch_size', [bs_range]), + ('input_len', [inlen_range])]), + ) + + # also called segment_ids + token_type_ids = tensorrt_llm.Tensor( + name='token_type_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([('batch_size', [bs_range]), + ('input_len', [inlen_range])]), + ) + + input_lengths = tensorrt_llm.Tensor(name='input_lengths', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([ + ('batch_size', [bs_range]) + ])) + + # logits for QA BERT, or hidden_state for vanila BERT + output = tensorrt_llm_roberta(input_ids=input_ids, + input_lengths=input_lengths, + token_type_ids=token_type_ids) + + # Mark outputs + output_dtype = trt.float16 if args.dtype == 'float16' else trt.float32 + output.mark_output(output_name, output_dtype) + + # Network -> Engine + engine = builder.build_engine(network, builder_config) + assert engine is not None, 'Failed to build engine.' + engine_file = os.path.join( + args.output_dir, + get_engine_name(args.model, args.dtype, args.world_size, args.rank)) + with open(engine_file, 'wb') as f: + f.write(engine) + builder.save_config(builder_config, + os.path.join(args.output_dir, 'config.json')) diff --git a/examples/roberta/run.py b/examples/roberta/run.py new file mode 100644 index 000000000..3a03dc2ee --- /dev/null +++ b/examples/roberta/run.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os + +# isort: off +import torch +import tensorrt as trt +# isort: on + +import tensorrt_llm +from tensorrt_llm import logger +from tensorrt_llm.models import RobertaForQuestionAnswering, RobertaModel, RobertaForSequenceClassification +from tensorrt_llm.runtime import Session, TensorInfo + +from build import get_engine_name # isort:skip + + +def trt_dtype_to_torch(dtype): + if dtype == trt.float16: + return torch.float16 + elif dtype == trt.float32: + return torch.float32 + elif dtype == trt.int32: + return torch.int32 + else: + raise TypeError("%s is not supported" % dtype) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--log_level', type=str, default='info') + parser.add_argument('--engine_dir', type=str, default='roberta_outputs') + + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_arguments() + + tensorrt_llm.logger.set_level(args.log_level) + + config_path = os.path.join(args.engine_dir, 'config.json') + with open(config_path, 'r') as f: + config = json.load(f) + dtype = config['builder_config']['precision'] + world_size = config['builder_config']['tensor_parallel'] + assert world_size == tensorrt_llm.mpi_world_size(), \ + f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})' + + model_name = config['builder_config']['name'] + runtime_rank = tensorrt_llm.mpi_rank() if world_size > 1 else 0 + + runtime_mapping = tensorrt_llm.Mapping(world_size, + runtime_rank, + tp_size=world_size) + torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node) + + serialize_path = get_engine_name(model_name, dtype, world_size, + runtime_rank) + serialize_path = os.path.join(args.engine_dir, serialize_path) + + stream = torch.cuda.current_stream().cuda_stream + logger.info(f'Loading engine from {serialize_path}') + with open(serialize_path, 'rb') as f: + engine_buffer = f.read() + logger.info(f'Creating session from engine') + session = Session.from_serialized_engine(engine_buffer) + + for i in range(3): + batch_size = (i + 1) * 4 + seq_len = (i + 1) * 32 + input_ids = torch.randint(100, (batch_size, seq_len)).int().cuda() + input_lengths = seq_len * torch.ones( + (batch_size, ), dtype=torch.int32, device='cuda') + token_type_ids = torch.randint(100, (batch_size, seq_len)).int().cuda() + + inputs = { + 'input_ids': input_ids, + 'input_lengths': input_lengths, + 'token_type_ids': token_type_ids + } + output_info = session.infer_shapes([ + TensorInfo('input_ids', trt.DataType.INT32, input_ids.shape), + TensorInfo('input_lengths', trt.DataType.INT32, + input_lengths.shape), + TensorInfo('token_type_ids', trt.DataType.INT32, + token_type_ids.shape), + ]) + outputs = { + t.name: torch.empty(tuple(t.shape), + dtype=trt_dtype_to_torch(t.dtype), + device='cuda') + for t in output_info + } + if (model_name == RobertaModel.__name__): + output_name = 'hidden_states' + elif (model_name == RobertaForQuestionAnswering.__name__): + output_name = 'logits' + elif (model_name == RobertaForSequenceClassification.__name__): + output_name = 'logits' + else: + assert False, f"Unknown BERT model {model_name}" + + assert output_name in outputs, f'{output_name} not found in outputs, check if build.py set the name correctly' + + ok = session.run(inputs, outputs, stream) + assert ok, "Runtime execution failed" + torch.cuda.synchronize() + res = outputs[output_name] diff --git a/examples/roberta/weight.py b/examples/roberta/weight.py new file mode 100644 index 000000000..97cd2c349 --- /dev/null +++ b/examples/roberta/weight.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import torch + + +def extract_layer_idx(name): + ss = name.split('.') + for s in ss: + if s.isdigit(): + return s + return None + + +def split(v, tp_size, idx, dim=0): + if tp_size == 1: + return v + if len(v.shape) == 1: + return np.ascontiguousarray(np.split(v, tp_size)[idx]) + elif len(v.shape) == 2: + return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx]) + return None + + +def load_from_hf_roberta(tensorrt_llm_roberta, + hf_roberta, + hf_roberta_config, + rank=0, + tensor_parallel=1, + fp16=False): + qkv_weight = [[None, None, None] + for _ in range(hf_roberta_config.num_hidden_layers)] + + qkv_bias = [[None, None, None] + for _ in range(hf_roberta_config.num_hidden_layers)] + + torch_dtype = torch.float16 if fp16 else torch.float32 + for k, v in hf_roberta.state_dict().items(): + v = v.to(torch_dtype).cpu().numpy() + if 'embeddings.word_embeddings.weight' in k: + tensorrt_llm_roberta.embedding.vocab_embedding.weight.value = v + elif 'embeddings.position_embeddings.weight' in k: + tensorrt_llm_roberta.embedding.position_embedding.weight.value = v + elif 'embeddings.token_type_embeddings.weight' in k: + tensorrt_llm_roberta.embedding.token_embedding.weight.value = v + elif 'embeddings.LayerNorm.weight' in k: + tensorrt_llm_roberta.embedding.embedding_ln.weight.value = v + elif 'embeddings.LayerNorm.bias' in k: + tensorrt_llm_roberta.embedding.embedding_ln.bias.value = v + else: + layer_idx = extract_layer_idx(k) + if layer_idx is None: + continue + idx = int(layer_idx) + if 'attention.output.dense.weight' in k: + tensorrt_llm_roberta.layers[ + idx].attention.dense.weight.value = split(v, + tensor_parallel, + rank, + dim=1) + elif 'attention.output.dense.bias' in k: + tensorrt_llm_roberta.layers[idx].attention.dense.bias.value = v + elif 'attention.output.LayerNorm.weight' in k: + tensorrt_llm_roberta.layers[idx].input_layernorm.weight.value = v + elif 'attention.output.LayerNorm.bias' in k: + tensorrt_llm_roberta.layers[idx].input_layernorm.bias.value = v + elif 'intermediate.dense.weight' in k: + tensorrt_llm_roberta.layers[idx].mlp.fc.weight.value = split( + v, tensor_parallel, rank) + elif 'intermediate.dense.bias' in k: + tensorrt_llm_roberta.layers[idx].mlp.fc.bias.value = split( + v, tensor_parallel, rank) + elif 'output.dense.weight' in k: + tensorrt_llm_roberta.layers[idx].mlp.proj.weight.value = split( + v, tensor_parallel, rank, dim=1) + elif 'output.dense.bias' in k: + tensorrt_llm_roberta.layers[idx].mlp.proj.bias.value = v + elif 'output.LayerNorm.weight' in k: + tensorrt_llm_roberta.layers[idx].post_layernorm.weight.value = v + elif 'output.LayerNorm.bias' in k: + tensorrt_llm_roberta.layers[idx].post_layernorm.bias.value = v + elif 'attention.self.query.weight' in k: + qkv_weight[idx][0] = v + elif 'attention.self.query.bias' in k: + qkv_bias[idx][0] = v + elif 'attention.self.key.weight' in k: + qkv_weight[idx][1] = v + elif 'attention.self.key.bias' in k: + qkv_bias[idx][1] = v + elif 'attention.self.value.weight' in k: + qkv_weight[idx][2] = v + elif 'attention.self.value.bias' in k: + qkv_bias[idx][2] = v + + for i in range(hf_roberta_config.num_hidden_layers): + tensorrt_llm_roberta.layers[i].attention.qkv.weight.value = split( + np.concatenate(qkv_weight[i]), tensor_parallel, rank) + tensorrt_llm_roberta.layers[i].attention.qkv.bias.value = split( + np.concatenate(qkv_bias[i]), tensor_parallel, rank) + + +def load_from_hf_qa_roberta(tensorrt_llm_qa_roberta, + hf_qa_roberta, + hf_roberta_config, + rank=0, + tensor_parallel=1, + fp16=False): + load_from_hf_roberta(tensorrt_llm_qa_roberta.roberta, hf_qa_roberta, hf_roberta_config, + rank, tensor_parallel, fp16) + states = hf_qa_roberta.state_dict() + + torch_dtype = torch.float16 if fp16 else torch.float32 + + tensorrt_llm_qa_roberta.qa_outputs.weight.value = states[ + 'qa_outputs.weight'].to(torch_dtype).cpu().numpy() + tensorrt_llm_qa_roberta.qa_outputs.bias.value = states['qa_outputs.bias'].to( + torch_dtype).cpu().numpy() + +def load_from_hf_cls_roberta(tensorrt_llm_cls_roberta, + hf_qa_roberta, + hf_roberta_config, + rank=0, + tensor_parallel=1, + fp16=False): + load_from_hf_roberta(tensorrt_llm_cls_roberta.roberta, hf_qa_roberta, hf_roberta_config, + rank, tensor_parallel, fp16) + states = hf_qa_roberta.state_dict() + + torch_dtype = torch.float16 if fp16 else torch.float32 + + tensorrt_llm_cls_roberta.classifier.dense.weight.value = states[ + 'classifier.dense.weight'].to(torch_dtype).cpu().numpy() + tensorrt_llm_cls_roberta.classifier.dense.bias.value = states[ + 'classifier.dense.bias'].to(torch_dtype).cpu().numpy() + + tensorrt_llm_cls_roberta.classifier.out_proj.weight.value = states[ + 'classifier.out_proj.weight'].to(torch_dtype).cpu().numpy() + tensorrt_llm_cls_roberta.classifier.out_proj.bias.value = states[ + 'classifier.out_proj.bias'].to(torch_dtype).cpu().numpy() diff --git a/tensorrt_llm/models/__init__.py b/tensorrt_llm/models/__init__.py index ee886121e..e6c3837ad 100755 --- a/tensorrt_llm/models/__init__.py +++ b/tensorrt_llm/models/__init__.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from .baichuan.model import BaichuanForCausalLM -from .bert.model import BertForQuestionAnswering, BertModel +from .bert.model import BertForQuestionAnswering, BertModel, BertForSequenceClassification +from .roberta.model import RobertaForQuestionAnswering, RobertaModel, RobertaForSequenceClassification from .bloom.model import BloomForCausalLM, BloomModel from .chatglm.model import ChatGLMHeadModel, ChatGLMModel from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder @@ -31,6 +32,10 @@ __all__ = [ 'BertModel', 'BertForQuestionAnswering', + 'BertForSequenceClassification', + 'RobertaModel', + 'RobertaForQuestionAnswering', + 'RobertaForSequenceClassification', 'BloomModel', 'BloomForCausalLM', 'FalconForCausalLM', diff --git a/tensorrt_llm/models/bert/model.py b/tensorrt_llm/models/bert/model.py index 7d24a664c..3e519a0ed 100644 --- a/tensorrt_llm/models/bert/model.py +++ b/tensorrt_llm/models/bert/model.py @@ -12,13 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from email.quoprimime import unquote import math import numpy as np from ..._common import default_net from ...functional import (bert_attention, concat, constant, expand, matmul, - shape, slice, softmax, split) + shape, slice, softmax, split, cast, unsqueeze, select, ACT2FN) from ...layers import MLP, ColumnLinear, Embedding, LayerNorm, Linear, RowLinear from ...mapping import Mapping from ...module import Module, ModuleList @@ -212,7 +213,8 @@ def __init__(self, mapping=Mapping(), dtype=None): super().__init__() - + self.max_position_embeddings = max_position_embeddings + self.dtype = dtype self.embedding = BertEmbedding( vocab_size=vocab_size, hidden_size=hidden_size, @@ -238,9 +240,28 @@ def forward(self, hidden_states=None): hidden_states = self.embedding(input_ids, position_ids, token_type_ids) + # creat extended_attention_mask as https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py + seq_len_2d = concat([1, shape(input_ids, 1)]) + position_ids_buffer = constant( + np.expand_dims( + np.arange(self.max_position_embeddings).astype(np.int32), 0)) + tmp_position_ids = slice(position_ids_buffer, + starts=[0, 0], + sizes=seq_len_2d) + tmp_position_ids = expand(tmp_position_ids, shape(input_ids)) #BxL + tmp_input_lengths = unsqueeze(input_lengths, 1) #Bx1 + tmp_input_lengths = expand(tmp_input_lengths, shape(input_ids)) #BxL + mask = tmp_position_ids < tmp_input_lengths # BxL + mask = cast(mask, 'int32') + extended_attention_mask = unsqueeze(mask, 1) + extended_attention_mask = unsqueeze(extended_attention_mask, 1) # Bx1x1xL + extended_attention_mask = (1 - extended_attention_mask) * -214748364 # a small negative number in int32 range + extended_attention_mask = cast(extended_attention_mask, self.dtype) + for layer in self.layers: hidden_states = layer(hidden_states=hidden_states, - input_lengths=input_lengths) + input_lengths=input_lengths, + attention_mask=extended_attention_mask) return hidden_states @@ -287,3 +308,62 @@ def forward(self, logits = self.qa_outputs(hidden_states) return logits + +class BertPooler(Module): + def __init__(self, hidden_size, dtype): + super().__init__() + self.dense = Linear(hidden_size, hidden_size, dtype=dtype) + self.activation = ACT2FN['tanh'] + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = select(hidden_states, 1, 0) + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertForSequenceClassification(Module): + + def __init__(self, + num_layers, + num_heads, + hidden_size, + vocab_size, + hidden_act, + max_position_embeddings, + type_vocab_size, + num_labels=2, + mapping=Mapping(), + dtype=None): + super().__init__() + self.bert = BertModel(num_layers=num_layers, + num_heads=num_heads, + hidden_size=hidden_size, + vocab_size=vocab_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, + mapping=mapping, + dtype=dtype) + self.num_labels = num_labels + self.pooler = BertPooler(hidden_size=hidden_size, dtype=dtype) + self.classifier = Linear(hidden_size, num_labels, dtype=dtype) + + def forward(self, + input_ids=None, + input_lengths=None, + token_type_ids=None, + position_ids=None, + hidden_states=None): + + hidden_states = self.bert.forward(input_ids=input_ids, + input_lengths=input_lengths, + token_type_ids=token_type_ids, + position_ids=position_ids, + hidden_states=hidden_states) + pooled_output = self.pooler(hidden_states) + logits = self.classifier(pooled_output) + + return logits diff --git a/tensorrt_llm/models/roberta/__init__.py b/tensorrt_llm/models/roberta/__init__.py new file mode 100644 index 000000000..2a36ca922 --- /dev/null +++ b/tensorrt_llm/models/roberta/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tensorrt_llm/models/roberta/model.py b/tensorrt_llm/models/roberta/model.py new file mode 100644 index 000000000..e40306294 --- /dev/null +++ b/tensorrt_llm/models/roberta/model.py @@ -0,0 +1,390 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from email.quoprimime import unquote +import math + +import numpy as np + +from ..._common import default_net +from ...functional import (bert_attention, concat, constant, expand, matmul, + shape, slice, softmax, split, cast, unsqueeze, select, ACT2FN) +from ...layers import MLP, ColumnLinear, Embedding, LayerNorm, Linear, RowLinear +from ...mapping import Mapping +from ...module import Module, ModuleList + + +class RobertaEmbedding(Module): + + def __init__(self, + vocab_size, + hidden_size, + max_position_embeddings, + type_vocab_size, + pad_token_id, + dtype=None): + super().__init__() + self.padding_idx = pad_token_id, + self.vocab_embedding = Embedding(vocab_size, hidden_size, dtype=dtype) + self.position_embedding = Embedding(max_position_embeddings, + hidden_size, + dtype=dtype) + self.token_embedding = Embedding(type_vocab_size, + hidden_size, + dtype=dtype) + self.max_position_embeddings = max_position_embeddings + + self.embedding_ln = LayerNorm(normalized_shape=hidden_size, dtype=dtype) + + def forward(self, input_ids, position_ids=None, token_type_ids=None): + token_type_ids_buffer = constant( + np.expand_dims( + np.zeros(self.max_position_embeddings).astype(np.int32), 0)) + + seq_len_2d = concat([1, shape(input_ids, 1)]) + + if token_type_ids is None: + # slice + token_type_ids = slice(token_type_ids_buffer, + starts=[0, 0], + sizes=seq_len_2d) + token_type_ids = expand(token_type_ids, shape(input_ids)) + + x = self.vocab_embedding(input_ids) + x = x + self.position_embedding(position_ids) + x = x + self.token_embedding(token_type_ids) + x = self.embedding_ln(x) + return x + + +class RobertaAttention(Module): + + def __init__(self, + hidden_size, + num_attention_heads, + max_position_embeddings, + dtype=None, + tp_group=None, + tp_size=1): + super().__init__() + + self.attention_head_size = hidden_size // num_attention_heads + self.num_attention_heads = num_attention_heads // tp_size + self.hidden_size = hidden_size // tp_size + self.max_position_embeddings = max_position_embeddings + self.norm_factor = math.sqrt(self.attention_head_size) + + self.qkv = ColumnLinear(hidden_size, + hidden_size * 3, + dtype=dtype, + tp_group=tp_group, + tp_size=tp_size, + gather_output=False) + self.dense = RowLinear(hidden_size, + hidden_size, + dtype=dtype, + tp_group=tp_group, + tp_size=tp_size) + + def forward(self, hidden_states, attention_mask=None, input_lengths=None): + qkv = self.qkv(hidden_states) + + # attention + if default_net().plugin_config.bert_attention_plugin: + assert input_lengths is not None + context = bert_attention(qkv, input_lengths, + self.num_attention_heads, + self.attention_head_size, 1.0) + else: + + def transpose_for_scores(x): + new_x_shape = concat([ + shape(x, 0), + shape(x, 1), self.num_attention_heads, + self.attention_head_size + ]) + return x.view(new_x_shape).permute([0, 2, 1, 3]) + + query, key, value = split(qkv, self.hidden_size, dim=2) + query = transpose_for_scores(query) + key = transpose_for_scores(key) + value = transpose_for_scores(value) + + key = key.permute([0, 1, 3, 2]) + attention_scores = matmul(query, key) + attention_scores = attention_scores / self.norm_factor + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + attention_probs = softmax(attention_scores, dim=-1) + + context = matmul(attention_probs, value).permute([0, 2, 1, 3]) + context = context.view( + concat([shape(context, 0), + shape(context, 1), self.hidden_size])) + + context = self.dense(context) + + return context + + +class RobertaEncoderLayer(Module): + + def __init__(self, + hidden_size, + num_attention_heads, + max_position_embeddings, + hidden_act='relu', + tp_group=None, + tp_size=1, + dtype=None): + super().__init__() + self.input_layernorm = LayerNorm(normalized_shape=hidden_size, + dtype=dtype) + + self.attention = RobertaAttention(hidden_size, + num_attention_heads, + max_position_embeddings, + tp_group=tp_group, + tp_size=tp_size, + dtype=dtype) + self.mlp = MLP(hidden_size=hidden_size, + ffn_hidden_size=hidden_size * 4, + hidden_act=hidden_act, + tp_group=tp_group, + tp_size=tp_size, + dtype=dtype) + self.post_layernorm = LayerNorm(normalized_shape=hidden_size, + dtype=dtype) + + def forward(self, hidden_states, attention_mask=None, input_lengths=None): + residual = hidden_states + + attention_output = self.attention(hidden_states, + attention_mask=attention_mask, + input_lengths=input_lengths) + + hidden_states = residual + attention_output + + hidden_states = self.input_layernorm(hidden_states) + + residual = hidden_states + + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + hidden_states = self.post_layernorm(hidden_states) + + return hidden_states + + +class RobertaModel(Module): + + def __init__(self, + num_layers, + num_heads, + hidden_size, + vocab_size, + hidden_act, + max_position_embeddings, + type_vocab_size, + pad_token_id, + mapping=Mapping(), + dtype=None): + super().__init__() + self.max_position_embeddings = max_position_embeddings + self.dtype = dtype + self.padding_idx = pad_token_id + + self.embedding = RobertaEmbedding( + vocab_size=vocab_size, + hidden_size=hidden_size, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, + pad_token_id=pad_token_id, + dtype=dtype) + + self.layers = ModuleList([ + RobertaEncoderLayer(hidden_size=hidden_size, + num_attention_heads=num_heads, + max_position_embeddings=max_position_embeddings, + hidden_act=hidden_act, + tp_group=mapping.tp_group, + tp_size=mapping.tp_size, + dtype=dtype) for _ in range(num_layers) + ]) + + def forward(self, + input_ids=None, + input_lengths=None, + token_type_ids=None, + position_ids=None, + hidden_states=None): + + seq_len_2d = concat([1, shape(input_ids, 1)]) + position_ids_buffer = constant( + np.expand_dims( + np.arange(self.max_position_embeddings).astype(np.int32), 0)) + tmp_position_ids = slice(position_ids_buffer, + starts=[0, 0], + sizes=seq_len_2d) + tmp_position_ids = expand(tmp_position_ids, shape(input_ids)) #BxL + + tmp_input_lengths = unsqueeze(input_lengths, 1) #Bx1 + tmp_input_lengths = expand(tmp_input_lengths, shape(input_ids)) #BxL + mask = tmp_position_ids < tmp_input_lengths # BxL + mask = cast(mask, 'int32') + # self.register_network_output('attention_mask', mask) + # create position ids like create_position_ids_from_input_ids in https://github.com/huggingface/transformers/blob/main/src/transformers/models/roberta/modeling_roberta.py + if position_ids is None: + position_ids = (tmp_position_ids + 1) * mask + position_ids = position_ids + self.padding_idx + # self.register_network_output('position_ids', position_ids) + + # creat extended_attention_mask as https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py + extended_attention_mask = unsqueeze(mask, 1) + extended_attention_mask = unsqueeze(extended_attention_mask, 1) # Bx1x1xL + extended_attention_mask = (1 - extended_attention_mask) * -214748364 # a small negative number in int32 range + extended_attention_mask = cast(extended_attention_mask, self.dtype) + + + hidden_states = self.embedding(input_ids, position_ids, token_type_ids) + + for layer in self.layers: + hidden_states = layer(hidden_states=hidden_states, + input_lengths=input_lengths, + attention_mask=extended_attention_mask) + + return hidden_states + + +class RobertaForQuestionAnswering(Module): + + def __init__(self, + num_layers, + num_heads, + hidden_size, + vocab_size, + hidden_act, + max_position_embeddings, + type_vocab_size, + pad_token_id, + num_labels=2, + mapping=Mapping(), + dtype=None): + super().__init__() + self.roberta= RobertaModel(num_layers=num_layers, + num_heads=num_heads, + hidden_size=hidden_size, + vocab_size=vocab_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, + pad_token_id=pad_token_id, + mapping=mapping, + dtype=dtype) + self.num_labels = num_labels + self.qa_outputs = Linear(hidden_size, num_labels, dtype=dtype) + + def forward(self, + input_ids=None, + input_lengths=None, + token_type_ids=None, + position_ids=None, + hidden_states=None): + + hidden_states = self.roberta.forward(input_ids=input_ids, + input_lengths=input_lengths, + token_type_ids=token_type_ids, + position_ids=position_ids, + hidden_states=hidden_states) + + logits = self.qa_outputs(hidden_states) + + return logits + +class RobertaPooler(Module): + def __init__(self, hidden_size, dtype): + super().__init__() + self.dense = Linear(hidden_size, hidden_size, dtype=dtype) + self.activation = ACT2FN['tanh'] + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = select(hidden_states, 1, 0) + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + +class RobertaClassificationHead(Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, hidden_size, dtype, num_labels): + super().__init__() + self.dense = Linear(hidden_size, hidden_size, dtype=dtype) + self.out_proj = Linear(hidden_size, num_labels) + + def forward(self, features, **kwargs): + x = select(features, 1, 0) + x = self.dense(x) + x = ACT2FN['tanh'](x) + x = self.out_proj(x) + return x + +class RobertaForSequenceClassification(Module): + + def __init__(self, + num_layers, + num_heads, + hidden_size, + vocab_size, + hidden_act, + max_position_embeddings, + type_vocab_size, + pad_token_id, + num_labels=2, + mapping=Mapping(), + dtype=None): + super().__init__() + self.roberta= RobertaModel(num_layers=num_layers, + num_heads=num_heads, + hidden_size=hidden_size, + vocab_size=vocab_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, + pad_token_id=pad_token_id, + mapping=mapping, + dtype=dtype) + + self.classifier = RobertaClassificationHead(hidden_size=hidden_size, num_labels=num_labels, dtype=dtype) + + def forward(self, + input_ids=None, + input_lengths=None, + token_type_ids=None, + position_ids=None, + hidden_states=None): + + hidden_states = self.roberta.forward(input_ids=input_ids, + input_lengths=input_lengths, + token_type_ids=token_type_ids, + position_ids=position_ids, + hidden_states=hidden_states) + logits = self.classifier(hidden_states) + + return logits diff --git a/tests/model/test_bert.py b/tests/model/test_bert.py index ddfa6525e..89af12a39 100644 --- a/tests/model/test_bert.py +++ b/tests/model/test_bert.py @@ -25,12 +25,12 @@ import tensorrt as trt # isort: on from parameterized import parameterized -from transformers import BertConfig, BertForQuestionAnswering, BertModel +from transformers import BertConfig, BertForQuestionAnswering, BertModel, BertForSequenceClassification, AutoTokenizer import tensorrt_llm import tensorrt_llm.runtime from tensorrt_llm import Builder -from tensorrt_llm._utils import trt_dtype_to_torch +from tensorrt_llm._utils import trt_dtype_to_torch, str_dtype_to_trt from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType from tensorrt_llm.runtime import TensorInfo @@ -148,18 +148,41 @@ def load_from_hf_qa_bert(tensorrt_llm_qa_bert, tensorrt_llm_qa_bert.qa_outputs.bias.value = states['qa_outputs.bias'].to( torch_dtype).cpu().numpy() +def load_from_hf_cls_bert(tensorrt_llm_qa_bert, + hf_qa_bert, + hf_bert_config, + rank=0, + tensor_parallel=1, + fp16=False): + load_from_hf_bert(tensorrt_llm_qa_bert.bert, hf_qa_bert, hf_bert_config, + rank, tensor_parallel, fp16) + states = hf_qa_bert.state_dict() + + torch_dtype = torch.float16 if fp16 else torch.float32 + + tensorrt_llm_qa_bert.pooler.dense.weight.value = states[ + 'bert.pooler.dense.weight'].to(torch_dtype).cpu().numpy() + tensorrt_llm_qa_bert.pooler.dense.bias.value = states[ + 'bert.pooler.dense.bias'].to(torch_dtype).cpu().numpy() + + tensorrt_llm_qa_bert.classifier.weight.value = states[ + 'classifier.weight'].to(torch_dtype).cpu().numpy() + tensorrt_llm_qa_bert.classifier.bias.value = states[ + 'classifier.bias'].to(torch_dtype).cpu().numpy() + class TestBert(unittest.TestCase): def load_test_cases(): - models = [BertModel.__name__, BertForQuestionAnswering.__name__] + models = [BertForSequenceClassification.__name__, BertModel.__name__, BertForQuestionAnswering.__name__] + model_dirs = ['', 'bert-base-uncased'] # add more tests for read data. test_cases = [] test_cases += product(models, [False], [False], [False], - [ContextFMHAType.disabled], ['float32']) + [ContextFMHAType.disabled], ['float32'], model_dirs) test_cases += product(models, [False], [True], [True], [ ContextFMHAType.disabled, ContextFMHAType.enabled, ContextFMHAType.enabled_with_fp32_acc - ], ['float16']) + ], ['float16'], model_dirs) return test_cases @@ -171,7 +194,7 @@ def custom_name_func(testcase_func, param_num, param): @parameterized.expand(load_test_cases, name_func=custom_name_func) def test_bert(self, model, use_refit, use_plugin, fast_building, - context_fmha_type, dtype): + context_fmha_type, dtype, model_dir): tensorrt_llm.logger.set_level('error') fp16 = (dtype == 'float16') world_size = 1 @@ -224,22 +247,36 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, [bs_range]) ])) # Initialize model - bert_config = BertConfig( - vocab_size=vocab_size, - hidden_size=hidden_size, - num_hidden_layers=num_layers, - num_attention_heads=num_heads, - intermediate_size=4 * hidden_size, - hidden_act=hidden_act, - max_position_embeddings=max_position_embeddings, - torch_dtype=torch_dtype, - ) + if model_dir: + bert_config = BertConfig.from_pretrained(model_dir, torch_dtype=torch_dtype) + vocab_size = bert_config.vocab_size + hidden_size = bert_config.hidden_size + num_layers = bert_config.num_hidden_layers + num_heads = bert_config.num_attention_heads + hidden_size = bert_config.intermediate_size // 4 + hidden_act = bert_config.hidden_act + max_position_embeddings = bert_config.max_position_embeddings + else: + bert_config = BertConfig( + vocab_size=vocab_size, + hidden_size=hidden_size, + num_hidden_layers=num_layers, + num_attention_heads=num_heads, + intermediate_size=4 * hidden_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + torch_dtype=torch_dtype, + ) output_name = "hidden_states" if model == BertModel.__name__: - hf_bert = BertModel( - bert_config, - add_pooling_layer=False).cuda().to(torch_dtype).eval() + if model_dir: + hf_bert = BertModel.from_pretrained( + model_dir).cuda().to(torch_dtype).eval() + else: + hf_bert = BertModel( + bert_config, + add_pooling_layer=False).cuda().to(torch_dtype).eval() tensorrt_llm_bert = tensorrt_llm.models.BertModel( num_layers=num_layers, num_heads=num_heads, @@ -260,8 +297,12 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, tensor_parallel=world_size, fp16=fp16) elif model == BertForQuestionAnswering.__name__: - hf_bert = BertForQuestionAnswering(bert_config).cuda().to( - torch_dtype).eval() + if model_dir: + hf_bert = BertForQuestionAnswering.from_pretrained( + model_dir).cuda().to(torch_dtype).eval() + else: + hf_bert = BertForQuestionAnswering(bert_config).cuda().to( + torch_dtype).eval() output_name = "logits" tensorrt_llm_bert = tensorrt_llm.models.BertForQuestionAnswering( num_layers=num_layers, @@ -284,6 +325,36 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, rank=rank, tensor_parallel=world_size, fp16=fp16) + elif model == BertForSequenceClassification.__name__: + if model_dir: + hf_bert = BertForSequenceClassification.from_pretrained( + model_dir).cuda().to(torch_dtype).eval() + else: + hf_bert = BertForSequenceClassification(bert_config).cuda().to( + torch_dtype).eval() + output_name = "logits" + tensorrt_llm_bert = tensorrt_llm.models.BertForSequenceClassification( + num_layers=num_layers, + num_heads=num_heads, + hidden_size=hidden_size, + vocab_size=vocab_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + type_vocab_size=bert_config.type_vocab_size, + num_labels= + 2, # just make it a const here, seems to me not worth as a config + mapping=tensorrt_llm.Mapping( + world_size=world_size, + rank=rank, + tp_size=world_size), # TP only + dtype=trt_dtype) + load_from_hf_cls_bert(tensorrt_llm_bert, + hf_bert, + bert_config, + rank=rank, + tensor_parallel=world_size, + fp16=fp16) + else: assert False, f"Unknown model {model}" # Prepare @@ -298,6 +369,9 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, output_dtype = trt.float16 if fp16 else trt.float32 output.mark_output(output_name, output_dtype) + for k, v in tensorrt_llm_bert.named_network_outputs(): + network._mark_output(v, k, str_dtype_to_trt(dtype)) + # Build engine engine_buffer = builder.build_engine(network, builder_config) session = tensorrt_llm.runtime.Session.from_serialized_engine( @@ -307,9 +381,25 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, # Inference # The dtype of input_ids should be queried from the engine, # for testing purpose, int32 is fine for now. - input_ids = torch.randint(100, (batch_size, input_len)).int().cuda() - input_lengths = input_len * torch.ones( - (batch_size, ), dtype=torch.int32, device='cuda') + attention_mask = None + if model_dir: + hf_tokenizer = AutoTokenizer.from_pretrained(model_dir) + input_strings = ['Hello world!' for _ in range(batch_size)] + input_ids_with_padding = hf_tokenizer( + input_strings, padding='max_length', max_length=input_len) + input_ids_without_padding = hf_tokenizer( + input_strings) + input_ids = torch.tensor(input_ids_with_padding['input_ids']).int().cuda() + input_lengths = [len(x) for x in input_ids_without_padding['input_ids']] + input_lengths = torch.tensor( + input_lengths, device=input_ids.device, dtype=torch.int32) + attention_mask = torch.tensor( + input_ids_with_padding['attention_mask'], + device=input_ids.device, dtype=torch.int32) + else: + input_ids = torch.randint(bert_config.vocab_size, (batch_size, input_len)).int().cuda() + input_lengths = input_len * torch.ones( + (batch_size, ), dtype=torch.int32, device='cuda') output_info = session.infer_shapes([ TensorInfo('input_ids', trt.DataType.INT32, @@ -335,11 +425,23 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, res = outputs[output_name] with torch.no_grad(): - hf_outputs = hf_bert.forward(input_ids) + if model_dir: + hf_outputs = hf_bert.forward( + input_ids=input_ids, + attention_mask=attention_mask) + else: + hf_outputs = hf_bert.forward(input_ids) torch.cuda.synchronize() if model == BertModel.__name__: ref = hf_outputs.last_hidden_state + if use_plugin and model_dir: + # when we use_plugin and have real-data model_dir and input + # We do not need to care about the output of padding positions: + attention_mask_tmp = attention_mask.unsqueeze(-1) + ref = ref * attention_mask_tmp + res = res * attention_mask_tmp + np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy(), atol=1e-2, @@ -351,6 +453,13 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, ref_start_logits = hf_outputs.start_logits ref_end_logits = hf_outputs.end_logits + if use_plugin and model_dir: + # when we use_plugin and have real-data model_dir and input + # We do not need to care about the output of padding positions: + ref_start_logits = ref_start_logits * attention_mask + ref_end_logits = ref_end_logits * attention_mask + res_start_logits = res_start_logits * attention_mask + res_end_logits = res_end_logits * attention_mask np.testing.assert_allclose(ref_start_logits.cpu().numpy(), res_start_logits.cpu().numpy(), @@ -358,6 +467,12 @@ def test_bert(self, model, use_refit, use_plugin, fast_building, np.testing.assert_allclose(ref_end_logits.cpu().numpy(), res_end_logits.cpu().numpy(), atol=1.5e-2) + elif model == BertForSequenceClassification.__name__: + ref = hf_outputs.logits + np.testing.assert_allclose(ref.cpu().numpy(), + res.cpu().numpy(), + atol=1e-2, + rtol=1e-2) if __name__ == '__main__': diff --git a/tests/model/test_roberta.py b/tests/model/test_roberta.py new file mode 100644 index 000000000..629ab6f0f --- /dev/null +++ b/tests/model/test_roberta.py @@ -0,0 +1,487 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest +from collections import OrderedDict +from itertools import product + +import numpy as np +import parameterized + +# isort: off +import torch +import tensorrt as trt +# isort: on +from parameterized import parameterized +from transformers import RobertaConfig, RobertaForQuestionAnswering, RobertaModel, RobertaForSequenceClassification, AutoTokenizer + +import tensorrt_llm +from tensorrt_llm.models import roberta +import tensorrt_llm.runtime +from tensorrt_llm import Builder +from tensorrt_llm._utils import trt_dtype_to_torch, str_dtype_to_trt +from tensorrt_llm.network import net_guard +from tensorrt_llm.plugin.plugin import ContextFMHAType +from tensorrt_llm.runtime import TensorInfo + + +def extract_layer_idx(name): + ss = name.split('.') + for s in ss: + if s.isdigit(): + return s + return None + + +def split(v, tp_size, idx, dim=0): + if tp_size == 1: + return v + if len(v.shape) == 1: + return np.ascontiguousarray(np.split(v, tp_size)[idx]) + elif len(v.shape) == 2: + return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx]) + return None + + +def load_from_hf_roberta(tensorrt_llm_roberta, + hf_roberta, + hf_roberta_config, + rank=0, + tensor_parallel=1, + fp16=False): + qkv_weight = [[None, None, None] + for _ in range(hf_roberta_config.num_hidden_layers)] + + qkv_bias = [[None, None, None] + for _ in range(hf_roberta_config.num_hidden_layers)] + + torch_dtype = torch.float16 if fp16 else torch.float32 + for k, v in hf_roberta.state_dict().items(): + v = v.to(torch_dtype).cpu().numpy() + if 'embeddings.word_embeddings.weight' in k: + tensorrt_llm_roberta.embedding.vocab_embedding.weight.value = v + elif 'embeddings.position_embeddings.weight' in k: + tensorrt_llm_roberta.embedding.position_embedding.weight.value = v + elif 'embeddings.token_type_embeddings.weight' in k: + tensorrt_llm_roberta.embedding.token_embedding.weight.value = v + elif 'embeddings.LayerNorm.weight' in k: + tensorrt_llm_roberta.embedding.embedding_ln.weight.value = v + elif 'embeddings.LayerNorm.bias' in k: + tensorrt_llm_roberta.embedding.embedding_ln.bias.value = v + else: + layer_idx = extract_layer_idx(k) + if layer_idx is None: + continue + idx = int(layer_idx) + if 'attention.output.dense.weight' in k: + tensorrt_llm_roberta.layers[ + idx].attention.dense.weight.value = split(v, + tensor_parallel, + rank, + dim=1) + elif 'attention.output.dense.bias' in k: + tensorrt_llm_roberta.layers[idx].attention.dense.bias.value = v + elif 'attention.output.LayerNorm.weight' in k: + tensorrt_llm_roberta.layers[idx].input_layernorm.weight.value = v + elif 'attention.output.LayerNorm.bias' in k: + tensorrt_llm_roberta.layers[idx].input_layernorm.bias.value = v + elif 'intermediate.dense.weight' in k: + tensorrt_llm_roberta.layers[idx].mlp.fc.weight.value = split( + v, tensor_parallel, rank) + elif 'intermediate.dense.bias' in k: + tensorrt_llm_roberta.layers[idx].mlp.fc.bias.value = split( + v, tensor_parallel, rank) + elif 'output.dense.weight' in k: + tensorrt_llm_roberta.layers[idx].mlp.proj.weight.value = split( + v, tensor_parallel, rank, dim=1) + elif 'output.dense.bias' in k: + tensorrt_llm_roberta.layers[idx].mlp.proj.bias.value = v + elif 'output.LayerNorm.weight' in k: + tensorrt_llm_roberta.layers[idx].post_layernorm.weight.value = v + elif 'output.LayerNorm.bias' in k: + tensorrt_llm_roberta.layers[idx].post_layernorm.bias.value = v + elif 'attention.self.query.weight' in k: + qkv_weight[idx][0] = v + elif 'attention.self.query.bias' in k: + qkv_bias[idx][0] = v + elif 'attention.self.key.weight' in k: + qkv_weight[idx][1] = v + elif 'attention.self.key.bias' in k: + qkv_bias[idx][1] = v + elif 'attention.self.value.weight' in k: + qkv_weight[idx][2] = v + elif 'attention.self.value.bias' in k: + qkv_bias[idx][2] = v + + for i in range(hf_roberta_config.num_hidden_layers): + tensorrt_llm_roberta.layers[i].attention.qkv.weight.value = split( + np.concatenate(qkv_weight[i]), tensor_parallel, rank) + tensorrt_llm_roberta.layers[i].attention.qkv.bias.value = split( + np.concatenate(qkv_bias[i]), tensor_parallel, rank) + + +def load_from_hf_qa_roberta(tensorrt_llm_qa_roberta, + hf_qa_roberta, + hf_roberta_config, + rank=0, + tensor_parallel=1, + fp16=False): + load_from_hf_roberta(tensorrt_llm_qa_roberta.roberta, hf_qa_roberta, hf_roberta_config, + rank, tensor_parallel, fp16) + states = hf_qa_roberta.state_dict() + + torch_dtype = torch.float16 if fp16 else torch.float32 + + tensorrt_llm_qa_roberta.qa_outputs.weight.value = states[ + 'qa_outputs.weight'].to(torch_dtype).cpu().numpy() + tensorrt_llm_qa_roberta.qa_outputs.bias.value = states['qa_outputs.bias'].to( + torch_dtype).cpu().numpy() + +def load_from_hf_cls_roberta(tensorrt_llm_cls_roberta, + hf_qa_roberta, + hf_roberta_config, + rank=0, + tensor_parallel=1, + fp16=False): + load_from_hf_roberta(tensorrt_llm_cls_roberta.roberta, hf_qa_roberta, hf_roberta_config, + rank, tensor_parallel, fp16) + states = hf_qa_roberta.state_dict() + + torch_dtype = torch.float16 if fp16 else torch.float32 + + tensorrt_llm_cls_roberta.classifier.dense.weight.value = states[ + 'classifier.dense.weight'].to(torch_dtype).cpu().numpy() + tensorrt_llm_cls_roberta.classifier.dense.bias.value = states[ + 'classifier.dense.bias'].to(torch_dtype).cpu().numpy() + + tensorrt_llm_cls_roberta.classifier.out_proj.weight.value = states[ + 'classifier.out_proj.weight'].to(torch_dtype).cpu().numpy() + tensorrt_llm_cls_roberta.classifier.out_proj.bias.value = states[ + 'classifier.out_proj.bias'].to(torch_dtype).cpu().numpy() + + +class TestRoberta(unittest.TestCase): + + def load_test_cases(): + models = [RobertaForSequenceClassification.__name__, RobertaModel.__name__, RobertaForQuestionAnswering.__name__] + model_dirs = ['', 'roberta-base'] # add more tests for read data. + test_cases = [] + test_cases += product(models, [False], [False], [False], + [ContextFMHAType.disabled], ['float32'], model_dirs) + test_cases += product(models, [False], [True], [True], [ + ContextFMHAType.disabled, ContextFMHAType.enabled, + ContextFMHAType.enabled_with_fp32_acc + ], ['float16'], model_dirs) + + return test_cases + + def custom_name_func(testcase_func, param_num, param): + return "%s_%s" % ( + testcase_func.__name__, + parameterized.to_safe_name("_".join(str(x) for x in param.args)), + ) + + @parameterized.expand(load_test_cases, name_func=custom_name_func) + def test_roberta(self, model, use_refit, use_plugin, fast_building, + context_fmha_type, dtype, model_dir): + tensorrt_llm.logger.set_level('error') + fp16 = (dtype == 'float16') + world_size = 1 + rank = 0 + batch_size = 8 + input_len = 128 + vocab_size = 51200 + num_layers = 12 + num_heads = 12 + hidden_act = 'gelu' + max_position_embeddings = 512 + hidden_size = 768 + bs_range = [1, (batch_size + 1) // 2, batch_size] + inlen_range = [1, (input_len + 1) // 2, input_len] + torch_dtype = torch.float16 if fp16 else torch.float32 + trt_dtype = trt.float16 if fp16 else trt.float32 + timing_cache = 'model.cache' + + torch.manual_seed(0) + + builder = Builder() + with tempfile.TemporaryDirectory() as tmpdirname: + builder_config = builder.create_builder_config( + name=model, + precision='float16' if fp16 else 'float32', + timing_cache=timing_cache, + tensor_parallel=world_size, # TP only + use_refit=use_refit) + network = builder.create_network() + if use_plugin: + network.plugin_config.set_bert_attention_plugin(dtype) + if fast_building: + network.plugin_config.set_gemm_plugin(dtype) + network.plugin_config.set_context_fmha(context_fmha_type) + with net_guard(network): + # Prepare inputs + # TODO: could class be better than dict for profiles? + input_ids = tensorrt_llm.Tensor(name='input_ids', + dtype=trt.int32, + shape=[-1, -1], + dim_range=OrderedDict([ + ('batch_size', [bs_range]), + ('input_len', [inlen_range]) + ])) + input_lengths = tensorrt_llm.Tensor(name='input_lengths', + dtype=trt.int32, + shape=[-1], + dim_range=OrderedDict([ + ('batch_size', + [bs_range]) + ])) + # Initialize model + if model_dir: + roberta_config = RobertaConfig.from_pretrained(model_dir, torch_dtype=torch_dtype) + vocab_size = roberta_config.vocab_size + hidden_size = roberta_config.hidden_size + num_layers = roberta_config.num_hidden_layers + num_heads = roberta_config.num_attention_heads + hidden_size = roberta_config.intermediate_size // 4 + hidden_act = roberta_config.hidden_act + max_position_embeddings = roberta_config.max_position_embeddings + else: + roberta_config = RobertaConfig( + vocab_size=vocab_size, + hidden_size=hidden_size, + num_hidden_layers=num_layers, + num_attention_heads=num_heads, + intermediate_size=4 * hidden_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + torch_dtype=torch_dtype, + ) + + output_name = "hidden_states" + if model == RobertaModel.__name__: + if model_dir: + hf_roberta = RobertaModel.from_pretrained( + model_dir).cuda().to(torch_dtype).eval() + else: + hf_roberta = RobertaModel( + roberta_config, + add_pooling_layer=False).cuda().to(torch_dtype).eval() + tensorrt_llm_roberta = tensorrt_llm.models.RobertaModel( + num_layers=num_layers, + num_heads=num_heads, + hidden_size=hidden_size, + vocab_size=vocab_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + type_vocab_size=roberta_config.type_vocab_size, + pad_token_id=roberta_config.pad_token_id, + mapping=tensorrt_llm.Mapping( + world_size=world_size, + rank=rank, + tp_size=world_size), # TP only + dtype=trt_dtype) + load_from_hf_roberta(tensorrt_llm_roberta, + hf_roberta, + roberta_config, + rank=rank, + tensor_parallel=world_size, + fp16=fp16) + elif model == RobertaForQuestionAnswering.__name__: + if model_dir: + hf_roberta = RobertaForQuestionAnswering.from_pretrained( + model_dir).cuda().to(torch_dtype).eval() + else: + hf_roberta = RobertaForQuestionAnswering(roberta_config).cuda().to( + torch_dtype).eval() + output_name = "logits" + tensorrt_llm_roberta = tensorrt_llm.models.RobertaForQuestionAnswering( + num_layers=num_layers, + num_heads=num_heads, + hidden_size=hidden_size, + vocab_size=vocab_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + type_vocab_size=roberta_config.type_vocab_size, + pad_token_id=roberta_config.pad_token_id, + num_labels= + 2, # just make it a const here, seems to me not worth as a config + mapping=tensorrt_llm.Mapping( + world_size=world_size, + rank=rank, + tp_size=world_size), # TP only + dtype=trt_dtype) + load_from_hf_qa_roberta(tensorrt_llm_roberta, + hf_roberta, + roberta_config, + rank=rank, + tensor_parallel=world_size, + fp16=fp16) + elif model == RobertaForSequenceClassification.__name__: + if model_dir: + hf_roberta = RobertaForSequenceClassification.from_pretrained( + model_dir).cuda().to(torch_dtype).eval() + else: + hf_roberta = RobertaForSequenceClassification(roberta_config).cuda().to( + torch_dtype).eval() + output_name = "logits" + tensorrt_llm_roberta = tensorrt_llm.models.RobertaForSequenceClassification( + num_layers=num_layers, + num_heads=num_heads, + hidden_size=hidden_size, + vocab_size=vocab_size, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + type_vocab_size=roberta_config.type_vocab_size, + pad_token_id=roberta_config.pad_token_id, + num_labels= + 2, # just make it a const here, seems to me not worth as a config + mapping=tensorrt_llm.Mapping( + world_size=world_size, + rank=rank, + tp_size=world_size), # TP only + dtype=trt_dtype) + load_from_hf_cls_roberta(tensorrt_llm_roberta, + hf_roberta, + roberta_config, + rank=rank, + tensor_parallel=world_size, + fp16=fp16) + + else: + assert False, f"Unknown model {model}" + # Prepare + network.set_named_parameters( + tensorrt_llm_roberta.named_parameters()) + + # Forward + output = tensorrt_llm_roberta(input_ids=input_ids, + input_lengths=input_lengths) + + # Mark outputs + output_dtype = trt.float16 if fp16 else trt.float32 + output.mark_output(output_name, output_dtype) + + for k, v in tensorrt_llm_roberta.named_network_outputs(): + if '_ids' in k or 'attention_mask' in k: + network._mark_output(v, k, str_dtype_to_trt('int32')) + else: + network._mark_output(v, k, str_dtype_to_trt(dtype)) + + # Build engine + engine_buffer = builder.build_engine(network, builder_config) + session = tensorrt_llm.runtime.Session.from_serialized_engine( + engine_buffer) + stream = torch.cuda.current_stream().cuda_stream + + # Inference + # The dtype of input_ids should be queried from the engine, + # for testing purpose, int32 is fine for now. + attention_mask = None + if model_dir: + hf_tokenizer = AutoTokenizer.from_pretrained(model_dir) + input_strings = ['Hello world!' for _ in range(batch_size)] + input_ids_with_padding = hf_tokenizer( + input_strings, padding='max_length', max_length=input_len) + input_ids_without_padding = hf_tokenizer( + input_strings) + input_ids = torch.tensor(input_ids_with_padding['input_ids']).int().cuda() + input_lengths = [len(x) for x in input_ids_without_padding['input_ids']] + input_lengths = torch.tensor( + input_lengths, device=input_ids.device, dtype=torch.int32) + attention_mask = torch.tensor( + input_ids_with_padding['attention_mask'], + device=input_ids.device, dtype=torch.int32) + else: + # since 1 is default padding toekn id, we need to be careful about use 1. + input_ids = torch.randint(2, roberta_config.vocab_size, (batch_size, input_len)).int().cuda() + input_lengths = input_len * torch.ones( + (batch_size, ), dtype=torch.int32, device='cuda') + + output_info = session.infer_shapes([ + TensorInfo('input_ids', trt.DataType.INT32, + (batch_size, input_len)), + TensorInfo('input_lengths', trt.DataType.INT32, (batch_size, )) + ]) + session._print_engine_info() + + outputs = { + t.name: torch.empty(tuple(t.shape), + dtype=trt_dtype_to_torch(t.dtype), + device='cuda') + for t in output_info + } + assert output_name in outputs, f'{output_name} not found in outputs' + session.run(inputs={ + 'input_ids': input_ids, + 'input_lengths': input_lengths + }, + outputs=outputs, + stream=stream) + torch.cuda.synchronize() + res = outputs[output_name] + + with torch.no_grad(): + if model_dir: + hf_outputs = hf_roberta.forward( + input_ids=input_ids, + attention_mask=attention_mask) + else: + hf_outputs = hf_roberta.forward(input_ids) + torch.cuda.synchronize() + + if model == RobertaModel.__name__: + ref = hf_outputs.last_hidden_state + if use_plugin and model_dir: + # when we use_plugin and have real-data model_dir and input + # We do not need to care about the output of padding positions: + attention_mask_tmp = attention_mask.unsqueeze(-1) + ref = ref * attention_mask_tmp + res = res * attention_mask_tmp + + np.testing.assert_allclose(ref.cpu().numpy(), + res.cpu().numpy(), + atol=1e-2, + rtol=1e-2) + elif model == RobertaForQuestionAnswering.__name__: + res_start_logits, res_end_logits = torch.split(res, 1, -1) + res_start_logits = res_start_logits.squeeze() + res_end_logits = res_end_logits.squeeze() + + ref_start_logits = hf_outputs.start_logits + ref_end_logits = hf_outputs.end_logits + if use_plugin and model_dir: + # when we use_plugin and have real-data model_dir and input + # We do not need to care about the output of padding positions: + ref_start_logits = ref_start_logits * attention_mask + ref_end_logits = ref_end_logits * attention_mask + res_start_logits = res_start_logits * attention_mask + res_end_logits = res_end_logits * attention_mask + + np.testing.assert_allclose(ref_start_logits.cpu().numpy(), + res_start_logits.cpu().numpy(), + atol=1.5e-2) + np.testing.assert_allclose(ref_end_logits.cpu().numpy(), + res_end_logits.cpu().numpy(), + atol=1.5e-2) + elif model == RobertaForSequenceClassification.__name__: + ref = hf_outputs.logits + np.testing.assert_allclose(ref.cpu().numpy(), + res.cpu().numpy(), + atol=1e-2, + rtol=1e-2) + + +if __name__ == '__main__': + unittest.main()