diff --git a/.github/workflows/tpu-tgi-release.yml b/.github/workflows/tpu-tgi-release.yml index 91e5f101..68657691 100644 --- a/.github/workflows/tpu-tgi-release.yml +++ b/.github/workflows/tpu-tgi-release.yml @@ -74,7 +74,7 @@ jobs: labels: ${{ steps.meta.outputs.labels }} build-args: | VERSION=${{ steps.version.outputs.version }} - TGI_VERSION=v2.0.3 + TGI_VERSION=v2.2.0 - name: Generate artifact attestation for TGI @@ -95,7 +95,7 @@ jobs: labels: ${{ steps.meta-ie.outputs.labels }} build-args: | VERSION=${{ steps.version.outputs.version }} - TGI_VERSION=v2.0.3 + TGI_VERSION=v2.2.0 target: inference-endpoint diff --git a/Makefile b/Makefile index b7723d34..46b9084a 100644 --- a/Makefile +++ b/Makefile @@ -19,7 +19,7 @@ REAL_CLONE_URL = $(if $(CLONE_URL),$(CLONE_URL),$(DEFAULT_CLONE_URL)) .PHONY: build_dist style style_check clean -TGI_VERSION ?= v2.0.3 +TGI_VERSION ?= v2.2.0 rwildcard=$(wildcard $1) $(foreach d,$1,$(call rwildcard,$(addsuffix /$(notdir $d),$(wildcard $(dir $d)*)))) diff --git a/examples/language-modeling/gemma_tuning.ipynb b/examples/language-modeling/gemma_tuning.ipynb index fc7f4717..eebb8478 100644 --- a/examples/language-modeling/gemma_tuning.ipynb +++ b/examples/language-modeling/gemma_tuning.ipynb @@ -35,12 +35,12 @@ "languageId": "shellscript" } }, - "outputs": [], "source": [ "gcloud compute tpus tpu-vm ssh $TPU_NAME \\\n", " --zone=$ZONE \\\n", " -- -L 8888:localhost:8888" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -59,7 +59,6 @@ "languageId": "shellscript" } }, - "outputs": [], "source": [ "git clone https://github.com/huggingface/optimum-tpu.git\n", "# Install Optimum tpu\n", @@ -73,7 +72,8 @@ "# Change directory and launch Jupyter notebook\n", "cd optimum-tpu/examples/language-modeling\n", "jupyter notebook --port 8888" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -94,10 +94,10 @@ "execution_count": null, "id": "37bccce7-1ce4-4470-9e81-c15b120ef294", "metadata": {}, - "outputs": [], "source": [ "!huggingface-cli login --token YOUR_HF_TOKEN" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -115,13 +115,13 @@ "execution_count": null, "id": "6d3c7bc2", "metadata": {}, - "outputs": [], "source": [ "from optimum.tpu import fsdp_v2\n", "\n", "\n", "fsdp_v2.use_fsdp_v2()" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -140,13 +140,13 @@ "execution_count": null, "id": "f0196b5d", "metadata": {}, - "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "\n", "dataset = load_dataset(\"databricks/databricks-dolly-15k\", split=\"train\")" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -161,10 +161,10 @@ "execution_count": null, "id": "12409299", "metadata": {}, - "outputs": [], "source": [ "dataset[321]" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -179,13 +179,13 @@ "execution_count": null, "id": "9c24e0b1", "metadata": {}, - "outputs": [], "source": [ "{'instruction': 'When was the 8088 processor released?',\n", " 'context': 'The 8086 (also called iAPX 86) is a 16-bit microprocessor chip designed by Intel between early 1976 and June 8, 1978, when it was released. The Intel 8088, released July 1, 1979, is a slightly modified chip with an external 8-bit data bus (allowing the use of cheaper and fewer supporting ICs),[note 1] and is notable as the processor used in the original IBM PC design.',\n", " 'response': 'The Intel 8088 processor was released July 1, 1979.',\n", " 'category': 'information_extraction'}" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -200,7 +200,6 @@ "execution_count": null, "id": "f1497e0f", "metadata": {}, - "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "\n", @@ -218,7 +217,8 @@ " prompt += tokenizer.eos_token\n", " sample[\"prompt\"] = prompt\n", " return sample" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -233,10 +233,10 @@ "execution_count": null, "id": "16b44a9b", "metadata": {}, - "outputs": [], "source": [ "data = dataset.map(preprocess_function, remove_columns=list(dataset.features))" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -253,14 +253,14 @@ "execution_count": null, "id": "f18472ce", "metadata": {}, - "outputs": [], "source": [ "import torch\n", "from transformers import AutoModelForCausalLM\n", "\n", "\n", "model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=False, torch_dtype=torch.bfloat16)" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -275,7 +275,6 @@ "execution_count": null, "id": "4a01f651", "metadata": {}, - "outputs": [], "source": [ "from peft import LoraConfig\n", "\n", @@ -286,7 +285,8 @@ " target_modules=[\"k_proj\", \"v_proj\"],\n", " task_type=\"CAUSAL_LM\",\n", ")" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -301,7 +301,6 @@ "execution_count": null, "id": "780f1033", "metadata": {}, - "outputs": [], "source": [ "from transformers import TrainingArguments\n", "from trl import SFTTrainer\n", @@ -329,7 +328,8 @@ " max_seq_length=1024,\n", " packing=True,\n", ")" - ] + ], + "outputs": [] }, { "cell_type": "markdown", @@ -344,10 +344,10 @@ "execution_count": null, "id": "4c437a81", "metadata": {}, - "outputs": [], "source": [ "trainer.train()" - ] + ], + "outputs": [] }, { "cell_type": "markdown", diff --git a/optimum/tpu/modeling_llama.py b/optimum/tpu/modeling_llama.py index c935bdce..53ca53ab 100644 --- a/optimum/tpu/modeling_llama.py +++ b/optimum/tpu/modeling_llama.py @@ -340,7 +340,7 @@ def _init_rope(self): base=self.rope_theta, ) else: - scaling_type = self.config.rope_scaling["type"] + scaling_type = self.config.rope_scaling.get("rope_type", self.config.rope_scaling.get("type")) scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "linear": self.rotary_emb = LlamaLinearScalingRotaryEmbedding( @@ -349,7 +349,7 @@ def _init_rope(self): scaling_factor=scaling_factor, base=self.rope_theta, ) - elif scaling_type == "dynamic": + elif scaling_type in ["dynamic", "llama3"]: self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, diff --git a/pyproject.toml b/pyproject.toml index 30c1ef5b..f0d863d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ keywords = [ ] dependencies = [ - "transformers == 4.41.1", + "transformers == 4.43.3", "torch == 2.4.0", "torch-xla[tpu] == 2.4.0", "loguru == 0.6.0", diff --git a/text-generation-inference/docker/Dockerfile b/text-generation-inference/docker/Dockerfile index 5632449c..68273709 100644 --- a/text-generation-inference/docker/Dockerfile +++ b/text-generation-inference/docker/Dockerfile @@ -8,7 +8,7 @@ RUN tar -C /tgi -xf /tgi/sources.tar.gz --strip-components=1 # Build cargo components (adapted from TGI original Dockerfile) # Note that the build image is aligned on the same Linux version as the base image (Debian bookworm/ Ubuntu 22.04) -FROM lukemathwalker/cargo-chef:latest-rust-1.77-bookworm AS chef +FROM lukemathwalker/cargo-chef:latest-rust-1.79-bookworm AS chef WORKDIR /usr/src ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse @@ -95,7 +95,7 @@ RUN apt-get update -y \ RUN pip install --upgrade pip # Install HuggingFace packages -ARG TRANSFORMERS_VERSION='4.41.1' +ARG TRANSFORMERS_VERSION='4.43.3' ARG ACCELERATE_VERSION='0.27.2' ARG SAFETENSORS_VERSION='0.4.2' diff --git a/text-generation-inference/server/Makefile b/text-generation-inference/server/Makefile index b26225d3..904bf128 100644 --- a/text-generation-inference/server/Makefile +++ b/text-generation-inference/server/Makefile @@ -2,7 +2,7 @@ pkg_name := text_generation_server BUILDDIR ?= $(CURDIR)/build VERSION ?= 0.0.1 -TGI_VERSION ?= v2.0.3 +TGI_VERSION ?= v2.2.0 mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST))) mkfile_dir := $(dir $(mkfile_path)) pkg_dir := $(BUILDDIR)/$(pkg_name) diff --git a/text-generation-inference/server/pyproject.toml b/text-generation-inference/server/pyproject.toml index 5a3d4070..7f146b0b 100644 --- a/text-generation-inference/server/pyproject.toml +++ b/text-generation-inference/server/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ 'grpc-interceptor == 0.15.2', 'typer == 0.6.1', 'safetensors == 0.4.2', - 'transformers == 4.41.1', + 'transformers == 4.43.3', 'loguru == 0.6.0', "sentencepiece == 0.2.0", "numpy<2.0", diff --git a/text-generation-inference/server/text_generation_server/cli.py b/text-generation-inference/server/text_generation_server/cli.py index 54cf8b20..4c654578 100644 --- a/text-generation-inference/server/text_generation_server/cli.py +++ b/text-generation-inference/server/text_generation_server/cli.py @@ -18,6 +18,8 @@ def serve( uds_path: str = "/tmp/text-generation-server", logger_level: str = "INFO", json_output: bool = False, + otlp_service_name: str = "text-generation-inference.server", + max_input_tokens: Optional[int] = None, ): """This is the main entry-point for the server CLI. @@ -54,6 +56,10 @@ def serve( if trust_remote_code is not None: logger.warning("'trust_remote_code' argument is not supported and will be ignored.") + if otlp_service_name is not None: + logger.warning("'otlp_service_name' argument is not supported and will be ignored.") + if max_input_tokens is not None: + logger.warning("'max_input_tokens' argument is not supported and will be ignored.") # Import here after the logger is added to log potential import exceptions from optimum.tpu.model import fetch_model diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index 4fdcdd19..e8d6d85c 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -164,6 +164,8 @@ def assign(self, batch_id: int, request: Request, generation_config: GenerationC self._generation_config.typical_p = request.parameters.typical_p self._generation_config.do_sample = request.parameters.do_sample self._generation_config.repetition_penalty = request.parameters.repetition_penalty + # Workaround to avoid bug in token_utils in transformers. + self._generation_config._eos_token_tensor = getattr(self._generation_config, "_eos_token_tensor", None) self._truncate = request.truncate self.seed = request.parameters.seed # TODO: watermark diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index 59291bf7..9b8e2cc4 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -133,6 +133,8 @@ def assign(self, batch_id: int, request: Request, generation_config: GenerationC self._generation_config.typical_p = request.parameters.typical_p self._generation_config.do_sample = request.parameters.do_sample self._generation_config.repetition_penalty = request.parameters.repetition_penalty + # Workaround to avoid bug in token_utils in transformers. + self._generation_config._eos_token_tensor = getattr(self._generation_config, "_eos_token_tensor", None) self._truncate = request.truncate self.seed = request.parameters.seed # TODO: watermark diff --git a/text-generation-inference/tests/test_decode.py b/text-generation-inference/tests/test_decode.py index 1f431a3e..732ed081 100644 --- a/text-generation-inference/tests/test_decode.py +++ b/text-generation-inference/tests/test_decode.py @@ -39,6 +39,11 @@ def test_decode_single(params): @pytest.mark.slow @pytest.mark.parametrize("params", [ + DecodeTestParams( + model_id="meta-llama/Meta-Llama-3.1-8B", + sequence_length=256, + expected_text=" Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind,", + ), DecodeTestParams( model_id="meta-llama/Meta-Llama-3-8B", sequence_length=256, @@ -55,7 +60,7 @@ def test_decode_single(params): expected_text=" Winston Smith, his chin nuzzled into his breast in an effort to escape the v", ), ], - ids=["Meta-Llama-3-8B", "gemma-7b", "Mistral-7B-v0.3"], + ids=["Meta-Llama-3.1-8B", "Meta-Llama-3-8B", "gemma-7b", "Mistral-7B-v0.3"], ) def test_decode_single_slow(params): _test_decode_single(params)