Skip to content

Commit

Permalink
Update Gemma deployment notebook's prediction samples.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 682184683
  • Loading branch information
vertex-mg-bot authored and copybara-github committed Oct 4, 2024
1 parent 6a574ea commit d242b94
Showing 1 changed file with 121 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@
"# @markdown Accept the model agreement to access the models:\n",
"# @markdown 1. Open the [Gemma model card](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/335) from [Vertex AI Model Garden](https://cloud.google.com/model-garden).\n",
"# @markdown 1. Review and accept the agreement in the pop-up window on the model card page. If you have previously accepted the model agreement, there will not be a pop-up window on the model card page and this step is not needed.\n",
"# @markdown 1. After accepting the agreement of Gemma, a `https://` link containing Gemma pretrained and finetuned models will be shared.\n",
"# @markdown 1. After accepting the agreement of Gemma, a `https://` link containing Gemma pretrained and instruction-tuned models will be shared.\n",
"# @markdown 1. Paste the link in the `VERTEX_AI_MODEL_GARDEN_GEMMA` field below.\n",
"# @markdown **Note:** This will unzip and copy the Gemma model artifacts to your Cloud Storage bucket, which will take around 1 hour.\n",
"\n",
Expand Down Expand Up @@ -465,78 +465,66 @@
" print(prediction)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f615c03d6638"
},
"source": [
"### Build chat applications with Gemma\n",
"\n",
"You can build chat applications with the instruction finetuned Gemma models.\n",
"\n",
"The instruction tuned Gemma models were trained with a specific formatter that annotates instruction tuning examples with extra information, both during training and inference. The annotations (1) indicate roles in a conversation, and (2) delineate tunes in a conversation. Below we show a sample code snippet for formatting the model prompt using the user and model chat templates for a multi-turn conversation. The relevant tokens are:\n",
"- `user`: user turn\n",
"- `model`: model turn\n",
"- `<start_of_turn>`: beginning of dialogue turn\n",
"- `<end_of_turn>`: end of dialogue turn\n",
"\n",
"An example set of dialogues is:\n",
"```\n",
"<start_of_turn>user\n",
"knock knock<end_of_turn>\n",
"<start_of_turn>model\n",
"who is there<end_of_turn>\n",
"<start_of_turn>user\n",
"LaMDA<end_of_turn>\n",
"<start_of_turn>model\n",
"LaMDA who?<end_of_turn>\n",
"```\n",
"where `<end_of_turn>\\n` is the turn separator and `<start_of_turn>model\\n` is the prompt prefix. This means if we would like to prompt the model with a question like, `What is Cramer's Rule?`, we should use:\n",
"```\n",
"<start_of_turn>user\n",
"What is Cramer's Rule?<end_of_turn>\n",
"<start_of_turn>model\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "e59377392346"
"id": "89ad18133ace"
},
"outputs": [],
"source": [
"# Chat templates.\n",
"USER_CHAT_TEMPLATE = \"<start_of_turn>user\\n{prompt}<end_of_turn>\\n\"\n",
"MODEL_CHAT_TEMPLATE = \"<start_of_turn>model\\n{prompt}<end_of_turn>\\n\"\n",
"\n",
"# Sample formatted prompt.\n",
"prompt = (\n",
" USER_CHAT_TEMPLATE.format(prompt=\"What is a good place for travel in the US?\")\n",
" + MODEL_CHAT_TEMPLATE.format(prompt=\"California.\")\n",
" + USER_CHAT_TEMPLATE.format(prompt=\"What can I do in California?\")\n",
" + \"<start_of_turn>model\\n\"\n",
"# @title Chat completion\n",
"\n",
"# @markdown You can build chat applications with the instruction-tuned Gemma models.\n",
"\n",
"_region = REGION\n",
"REGION = TPU_DEPLOYMENT_REGION\n",
"\n",
"ENDPOINT_RESOURCE_NAME = \"projects/{}/locations/{}/endpoints/{}\".format(\n",
" PROJECT_ID, REGION, endpoints[\"hexllm_tpu\"].name\n",
")\n",
"print(\"Chat prompt:\\n\", prompt)\n",
"\n",
"instances = [\n",
" {\n",
" \"prompt\": prompt,\n",
" \"max_tokens\": 50,\n",
" \"temperature\": 1.0,\n",
" \"top_p\": 1.0,\n",
" \"top_k\": 1,\n",
" },\n",
"]\n",
"response = endpoints[\"hexllm_tpu\"].predict(\n",
" instances=instances, use_dedicated_endpoint=use_dedicated_endpoint\n",
"# @title Chat Completions Inference\n",
"\n",
"# @markdown Once deployment succeeds, you can send requests to the endpoint using the OpenAI SDK.\n",
"\n",
"# @markdown First you will need to install the SDK and some auth-related dependencies.\n",
"\n",
"! pip install -qU openai google-auth requests\n",
"\n",
"# @markdown Next fill out some request parameters:\n",
"\n",
"user_message = \"How is your day going?\" # @param {type: \"string\"}\n",
"# @markdown If you encounter the issue like `ServiceUnavailable: 503 Took too long to respond when processing`, you can reduce the maximum number of output tokens, such as set `max_tokens` as 20.\n",
"max_tokens = 50 # @param {type: \"integer\"}\n",
"temperature = 1.0 # @param {type: \"number\"}\n",
"\n",
"# @markdown Now we can send a request.\n",
"\n",
"import google.auth\n",
"import openai\n",
"\n",
"creds, project = google.auth.default()\n",
"auth_req = google.auth.transport.requests.Request()\n",
"creds.refresh(auth_req)\n",
"\n",
"BASE_URL = (\n",
" f\"https://{REGION}-aiplatform.googleapis.com/v1beta1/{ENDPOINT_RESOURCE_NAME}\"\n",
")\n",
"client = openai.OpenAI(base_url=BASE_URL, api_key=creds.token)\n",
"\n",
"model_response = client.chat.completions.create(\n",
" model=\"\",\n",
" messages=[{\"role\": \"user\", \"content\": user_message}],\n",
" temperature=temperature,\n",
" max_tokens=max_tokens,\n",
")\n",
"print(model_response)\n",
"\n",
"prediction = response.predictions[0]\n",
"print(prediction)"
"REGION = _region\n",
"\n",
"# @markdown Click \"Show Code\" to see more details."
]
},
{
Expand All @@ -557,7 +545,6 @@
{
"cell_type": "code",
"execution_count": null,
"language": "python",
"metadata": {
"cellView": "form",
"id": "03d504bcd60b"
Expand Down Expand Up @@ -756,7 +743,7 @@
"id": "RRR11SWykYaX"
},
"source": [
"Once deployment succeeds, you can send requests to the endpoint with text prompts. Sampling parameters supported by vLLM can be found [here](https://github.com/vllm-project/vllm/blob/2e8e49fce3775e7704d413b2f02da6d7c99525c9/vllm/sampling_params.py#L23-L64). Setting `raw_response` to `True` allows you to obtain raw outputs."
"Once deployment succeeds, you can send requests to the endpoint with text prompts. Sampling parameters supported by vLLM can be found [here](https://docs.vllm.ai/en/latest/dev/sampling_params.html)."
]
},
{
Expand Down Expand Up @@ -800,6 +787,7 @@
"temperature = 1.0 # @param {type:\"number\"}\n",
"top_p = 1.0 # @param {type:\"number\"}\n",
"top_k = 1 # @param {type:\"integer\"}\n",
"# @markdown Set `raw_response` to `True` to obtain the raw model output. Set `raw_response` to `False` to apply additional formatting in the structure of `\"Prompt:\\n{prompt.strip()}\\nOutput:\\n{output}\"`.\n",
"raw_response = False # @param {type:\"boolean\"}\n",
"\n",
"# Overrides parameters for inferences.\n",
Expand All @@ -822,14 +810,60 @@
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "104fe2c03812"
"cellView": "form",
"id": "4f9eb6b3242b"
},
"outputs": [],
"source": [
"### Apply chat templates\n",
"# @title Chat completion\n",
"\n",
"# @markdown You can build chat applications with the instruction-tuned Gemma models.\n",
"\n",
"ENDPOINT_RESOURCE_NAME = \"projects/{}/locations/{}/endpoints/{}\".format(\n",
" PROJECT_ID, REGION, endpoints[\"vllm_gpu\"].name\n",
")\n",
"\n",
"# @title Chat Completions Inference\n",
"\n",
"# @markdown Once deployment succeeds, you can send requests to the endpoint using the OpenAI SDK.\n",
"\n",
"# @markdown First you will need to install the SDK and some auth-related dependencies.\n",
"\n",
"! pip install -qU openai google-auth requests\n",
"\n",
"# @markdown Next fill out some request parameters:\n",
"\n",
"Chat templates can be applied to model predictions generated by the vLLM endpoint as well. You may use the same code snippets as for the Hex-LLM endpoint. They are not repeated here for brevity."
"user_message = \"How is your day going?\" # @param {type: \"string\"}\n",
"# @markdown If you encounter the issue like `ServiceUnavailable: 503 Took too long to respond when processing`, you can reduce the maximum number of output tokens, such as set `max_tokens` as 20.\n",
"max_tokens = 50 # @param {type: \"integer\"}\n",
"temperature = 1.0 # @param {type: \"number\"}\n",
"\n",
"# @markdown Now we can send a request.\n",
"\n",
"import google.auth\n",
"import openai\n",
"\n",
"creds, project = google.auth.default()\n",
"auth_req = google.auth.transport.requests.Request()\n",
"creds.refresh(auth_req)\n",
"\n",
"BASE_URL = (\n",
" f\"https://{REGION}-aiplatform.googleapis.com/v1beta1/{ENDPOINT_RESOURCE_NAME}\"\n",
")\n",
"client = openai.OpenAI(base_url=BASE_URL, api_key=creds.token)\n",
"\n",
"model_response = client.chat.completions.create(\n",
" model=\"\",\n",
" messages=[{\"role\": \"user\", \"content\": user_message}],\n",
" temperature=temperature,\n",
" max_tokens=max_tokens,\n",
")\n",
"print(model_response)\n",
"\n",
"# @markdown Click \"Show Code\" to see more details."
]
},
{
Expand Down Expand Up @@ -872,11 +906,30 @@
"name": "model_garden_gemma_deployment_on_vertex.ipynb",
"toc_visible": true
},
"environment": {
"kernel": "python3",
"name": ".m114",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/:m114"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 0
"nbformat_minor": 4
}

0 comments on commit d242b94

Please sign in to comment.