From d242b945323b286e0239ed3f3b61cb0fff7f3bb1 Mon Sep 17 00:00:00 2001 From: Vertex MG Team Date: Thu, 3 Oct 2024 23:17:42 -0700 Subject: [PATCH] Update Gemma deployment notebook's prediction samples. PiperOrigin-RevId: 682184683 --- ...el_garden_gemma_deployment_on_vertex.ipynb | 189 +++++++++++------- 1 file changed, 121 insertions(+), 68 deletions(-) diff --git a/notebooks/community/model_garden/model_garden_gemma_deployment_on_vertex.ipynb b/notebooks/community/model_garden/model_garden_gemma_deployment_on_vertex.ipynb index bc6885f7d2..555b1ec540 100644 --- a/notebooks/community/model_garden/model_garden_gemma_deployment_on_vertex.ipynb +++ b/notebooks/community/model_garden/model_garden_gemma_deployment_on_vertex.ipynb @@ -206,7 +206,7 @@ "# @markdown Accept the model agreement to access the models:\n", "# @markdown 1. Open the [Gemma model card](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/335) from [Vertex AI Model Garden](https://cloud.google.com/model-garden).\n", "# @markdown 1. Review and accept the agreement in the pop-up window on the model card page. If you have previously accepted the model agreement, there will not be a pop-up window on the model card page and this step is not needed.\n", - "# @markdown 1. After accepting the agreement of Gemma, a `https://` link containing Gemma pretrained and finetuned models will be shared.\n", + "# @markdown 1. After accepting the agreement of Gemma, a `https://` link containing Gemma pretrained and instruction-tuned models will be shared.\n", "# @markdown 1. Paste the link in the `VERTEX_AI_MODEL_GARDEN_GEMMA` field below.\n", "# @markdown **Note:** This will unzip and copy the Gemma model artifacts to your Cloud Storage bucket, which will take around 1 hour.\n", "\n", @@ -465,78 +465,66 @@ " print(prediction)" ] }, - { - "cell_type": "markdown", - "metadata": { - "id": "f615c03d6638" - }, - "source": [ - "### Build chat applications with Gemma\n", - "\n", - "You can build chat applications with the instruction finetuned Gemma models.\n", - "\n", - "The instruction tuned Gemma models were trained with a specific formatter that annotates instruction tuning examples with extra information, both during training and inference. The annotations (1) indicate roles in a conversation, and (2) delineate tunes in a conversation. Below we show a sample code snippet for formatting the model prompt using the user and model chat templates for a multi-turn conversation. The relevant tokens are:\n", - "- `user`: user turn\n", - "- `model`: model turn\n", - "- ``: beginning of dialogue turn\n", - "- ``: end of dialogue turn\n", - "\n", - "An example set of dialogues is:\n", - "```\n", - "user\n", - "knock knock\n", - "model\n", - "who is there\n", - "user\n", - "LaMDA\n", - "model\n", - "LaMDA who?\n", - "```\n", - "where `\\n` is the turn separator and `model\\n` is the prompt prefix. This means if we would like to prompt the model with a question like, `What is Cramer's Rule?`, we should use:\n", - "```\n", - "user\n", - "What is Cramer's Rule?\n", - "model\n", - "```" - ] - }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", - "id": "e59377392346" + "id": "89ad18133ace" }, "outputs": [], "source": [ - "# Chat templates.\n", - "USER_CHAT_TEMPLATE = \"user\\n{prompt}\\n\"\n", - "MODEL_CHAT_TEMPLATE = \"model\\n{prompt}\\n\"\n", - "\n", - "# Sample formatted prompt.\n", - "prompt = (\n", - " USER_CHAT_TEMPLATE.format(prompt=\"What is a good place for travel in the US?\")\n", - " + MODEL_CHAT_TEMPLATE.format(prompt=\"California.\")\n", - " + USER_CHAT_TEMPLATE.format(prompt=\"What can I do in California?\")\n", - " + \"model\\n\"\n", + "# @title Chat completion\n", + "\n", + "# @markdown You can build chat applications with the instruction-tuned Gemma models.\n", + "\n", + "_region = REGION\n", + "REGION = TPU_DEPLOYMENT_REGION\n", + "\n", + "ENDPOINT_RESOURCE_NAME = \"projects/{}/locations/{}/endpoints/{}\".format(\n", + " PROJECT_ID, REGION, endpoints[\"hexllm_tpu\"].name\n", ")\n", - "print(\"Chat prompt:\\n\", prompt)\n", "\n", - "instances = [\n", - " {\n", - " \"prompt\": prompt,\n", - " \"max_tokens\": 50,\n", - " \"temperature\": 1.0,\n", - " \"top_p\": 1.0,\n", - " \"top_k\": 1,\n", - " },\n", - "]\n", - "response = endpoints[\"hexllm_tpu\"].predict(\n", - " instances=instances, use_dedicated_endpoint=use_dedicated_endpoint\n", + "# @title Chat Completions Inference\n", + "\n", + "# @markdown Once deployment succeeds, you can send requests to the endpoint using the OpenAI SDK.\n", + "\n", + "# @markdown First you will need to install the SDK and some auth-related dependencies.\n", + "\n", + "! pip install -qU openai google-auth requests\n", + "\n", + "# @markdown Next fill out some request parameters:\n", + "\n", + "user_message = \"How is your day going?\" # @param {type: \"string\"}\n", + "# @markdown If you encounter the issue like `ServiceUnavailable: 503 Took too long to respond when processing`, you can reduce the maximum number of output tokens, such as set `max_tokens` as 20.\n", + "max_tokens = 50 # @param {type: \"integer\"}\n", + "temperature = 1.0 # @param {type: \"number\"}\n", + "\n", + "# @markdown Now we can send a request.\n", + "\n", + "import google.auth\n", + "import openai\n", + "\n", + "creds, project = google.auth.default()\n", + "auth_req = google.auth.transport.requests.Request()\n", + "creds.refresh(auth_req)\n", + "\n", + "BASE_URL = (\n", + " f\"https://{REGION}-aiplatform.googleapis.com/v1beta1/{ENDPOINT_RESOURCE_NAME}\"\n", + ")\n", + "client = openai.OpenAI(base_url=BASE_URL, api_key=creds.token)\n", + "\n", + "model_response = client.chat.completions.create(\n", + " model=\"\",\n", + " messages=[{\"role\": \"user\", \"content\": user_message}],\n", + " temperature=temperature,\n", + " max_tokens=max_tokens,\n", ")\n", + "print(model_response)\n", "\n", - "prediction = response.predictions[0]\n", - "print(prediction)" + "REGION = _region\n", + "\n", + "# @markdown Click \"Show Code\" to see more details." ] }, { @@ -557,7 +545,6 @@ { "cell_type": "code", "execution_count": null, - "language": "python", "metadata": { "cellView": "form", "id": "03d504bcd60b" @@ -756,7 +743,7 @@ "id": "RRR11SWykYaX" }, "source": [ - "Once deployment succeeds, you can send requests to the endpoint with text prompts. Sampling parameters supported by vLLM can be found [here](https://github.com/vllm-project/vllm/blob/2e8e49fce3775e7704d413b2f02da6d7c99525c9/vllm/sampling_params.py#L23-L64). Setting `raw_response` to `True` allows you to obtain raw outputs." + "Once deployment succeeds, you can send requests to the endpoint with text prompts. Sampling parameters supported by vLLM can be found [here](https://docs.vllm.ai/en/latest/dev/sampling_params.html)." ] }, { @@ -800,6 +787,7 @@ "temperature = 1.0 # @param {type:\"number\"}\n", "top_p = 1.0 # @param {type:\"number\"}\n", "top_k = 1 # @param {type:\"integer\"}\n", + "# @markdown Set `raw_response` to `True` to obtain the raw model output. Set `raw_response` to `False` to apply additional formatting in the structure of `\"Prompt:\\n{prompt.strip()}\\nOutput:\\n{output}\"`.\n", "raw_response = False # @param {type:\"boolean\"}\n", "\n", "# Overrides parameters for inferences.\n", @@ -822,14 +810,60 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": { - "id": "104fe2c03812" + "cellView": "form", + "id": "4f9eb6b3242b" }, + "outputs": [], "source": [ - "### Apply chat templates\n", + "# @title Chat completion\n", + "\n", + "# @markdown You can build chat applications with the instruction-tuned Gemma models.\n", + "\n", + "ENDPOINT_RESOURCE_NAME = \"projects/{}/locations/{}/endpoints/{}\".format(\n", + " PROJECT_ID, REGION, endpoints[\"vllm_gpu\"].name\n", + ")\n", + "\n", + "# @title Chat Completions Inference\n", + "\n", + "# @markdown Once deployment succeeds, you can send requests to the endpoint using the OpenAI SDK.\n", + "\n", + "# @markdown First you will need to install the SDK and some auth-related dependencies.\n", + "\n", + "! pip install -qU openai google-auth requests\n", + "\n", + "# @markdown Next fill out some request parameters:\n", "\n", - "Chat templates can be applied to model predictions generated by the vLLM endpoint as well. You may use the same code snippets as for the Hex-LLM endpoint. They are not repeated here for brevity." + "user_message = \"How is your day going?\" # @param {type: \"string\"}\n", + "# @markdown If you encounter the issue like `ServiceUnavailable: 503 Took too long to respond when processing`, you can reduce the maximum number of output tokens, such as set `max_tokens` as 20.\n", + "max_tokens = 50 # @param {type: \"integer\"}\n", + "temperature = 1.0 # @param {type: \"number\"}\n", + "\n", + "# @markdown Now we can send a request.\n", + "\n", + "import google.auth\n", + "import openai\n", + "\n", + "creds, project = google.auth.default()\n", + "auth_req = google.auth.transport.requests.Request()\n", + "creds.refresh(auth_req)\n", + "\n", + "BASE_URL = (\n", + " f\"https://{REGION}-aiplatform.googleapis.com/v1beta1/{ENDPOINT_RESOURCE_NAME}\"\n", + ")\n", + "client = openai.OpenAI(base_url=BASE_URL, api_key=creds.token)\n", + "\n", + "model_response = client.chat.completions.create(\n", + " model=\"\",\n", + " messages=[{\"role\": \"user\", \"content\": user_message}],\n", + " temperature=temperature,\n", + " max_tokens=max_tokens,\n", + ")\n", + "print(model_response)\n", + "\n", + "# @markdown Click \"Show Code\" to see more details." ] }, { @@ -872,11 +906,30 @@ "name": "model_garden_gemma_deployment_on_vertex.ipynb", "toc_visible": true }, + "environment": { + "kernel": "python3", + "name": ".m114", + "type": "gcloud", + "uri": "gcr.io/deeplearning-platform-release/:m114" + }, "kernelspec": { "display_name": "Python 3", + "language": "python", "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 4 }