Skip to content

Commit

Permalink
[0.7.1] add structured generation
Browse files Browse the repository at this point in the history
  • Loading branch information
yashbonde committed Jan 5, 2025
1 parent 73823b9 commit aa6ab55
Show file tree
Hide file tree
Showing 11 changed files with 374 additions and 50 deletions.
111 changes: 111 additions & 0 deletions cookbooks/structured_generation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from tuneapi import tu, tt, ta\n",
"from dataclasses import dataclass\n",
"from pydantic import BaseModel\n",
"from typing import List, Optional, Dict, Any"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class MedicalRecord(BaseModel):\n",
" date: str\n",
" diagnosis: str\n",
" treatment: str\n",
"\n",
"class Dog(BaseModel):\n",
" name: str\n",
" breed: str\n",
" records: Optional[List[MedicalRecord]] = None\n",
"\n",
"class Dogs(BaseModel):\n",
" dogs: List[Dog]\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dog: Buddy, Breed: Golden Retriever\n",
" Date: 2023-10-26, Diagnosis: Mild ear infection, Treatment: Ear drops\n",
"\n",
"Dog: Luna, Breed: Beagle\n",
" Date: 2023-10-25, Diagnosis: Routine check-up, Treatment: No treatment needed\n",
" Date: 2023-10-28, Diagnosis: Upset tummy, Treatment: Bland diet and probiotics\n",
"\n",
"Dog: Rocky, Breed: Terrier Mix\n",
" Date: 2023-10-29, Diagnosis: Cut on paw, Treatment: Cleaned and antibiotic ointment\n",
"\n",
"Dog: Daisy, Breed: Poodle\n",
" No medical records on file.\n",
"\n"
]
}
],
"source": [
"# As of this moment we have tested it with the following LLMs:\n",
"\n",
"# model = ta.Openai()\n",
"model = ta.Gemini()\n",
"\n",
"out: Dogs = model.chat(tt.Thread(\n",
" tt.human(\"\"\"\n",
" At the Sunny Paws Animal Clinic, we keep detailed records of all our furry patients. Today, we saw a few dogs.\n",
" There was 'Buddy,' a golden retriever, who visited on '2023-10-26' and was diagnosed with a 'mild ear infection,'\n",
" which we treated with 'ear drops.' Then, there was 'Luna,' a playful beagle, who came in on '2023-10-25' for a\n",
" 'routine check-up,' and no treatment was needed, but we also had her back on '2023-10-28' with a 'upset tummy'\n",
" which we treated with 'bland diet and probiotics.' Finally, a third dog named 'Rocky', a small terrier mix,\n",
" showed up on '2023-10-29' with a small 'cut on his paw,' we cleaned it and used an 'antibiotic ointment'. We\n",
" also have 'Daisy,' a fluffy poodle, who doesn't have any medical records yet, thankfully!\n",
" \"\"\"),\n",
" schema=Dogs,\n",
"))\n",
"\n",
"for dog in out.dogs:\n",
" print(f\"Dog: {dog.name}, Breed: {dog.breed}\")\n",
" if dog.records:\n",
" for record in dog.records:\n",
" print(f\" Date: {record.date}, Diagnosis: {record.diagnosis}, Treatment: {record.treatment}\")\n",
" else:\n",
" print(\" No medical records on file.\")\n",
" print()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
50 changes: 50 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,56 @@ minor versions.

All relevant steps to be taken will be mentioned here.

0.7.1
-----

- Add structured genration support for Gemini and OpenAI APIs. You can jsut pass ``schema`` to ``Thread``. ``model.chat``
will take care of it automatically. Here's an example:

.. code-block:: python
from tuneapi import tt, ta
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
class MedicalRecord(BaseModel):
date: str
diagnosis: str
treatment: str
class Dog(BaseModel):
name: str
breed: str
records: Optional[List[MedicalRecord]] = None
class Dogs(BaseModel):
dogs: List[Dog]
model = ta.Gemini()
out: Dogs = model.chat(tt.Thread(
tt.human("""
At the Sunny Paws Animal Clinic, we keep detailed records of all our furry patients. Today, we saw a few dogs.
There was 'Buddy,' a golden retriever, who visited on '2023-10-26' and was diagnosed with a 'mild ear infection,'
which we treated with 'ear drops.' Then, there was 'Luna,' a playful beagle, who came in on '2023-10-25' for a
'routine check-up,' and no treatment was needed, but we also had her back on '2023-10-28' with a 'upset tummy'
which we treated with 'bland diet and probiotics.' Finally, a third dog named 'Rocky', a small terrier mix,
showed up on '2023-10-29' with a small 'cut on his paw,' we cleaned it and used an 'antibiotic ointment'. We
also have 'Daisy,' a fluffy poodle, who doesn't have any medical records yet, thankfully!
"""),
schema=Dogs,
))
for dog in out.dogs:
print(f"Dog: {dog.name}, Breed: {dog.breed}")
if dog.records:
for record in dog.records:
print(f" Date: {record.date}, Diagnosis: {record.diagnosis}, Treatment: {record.treatment}")
else:
print(" No medical records on file.")
print()
- Add ``pydantic`` as a dependency in the package.

0.7.0
-----

Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
project = "tuneapi"
copyright = "2024, Frello Technologies"
author = "Frello Technologies"
release = "0.5.13"
release = "0.7.1"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tuneapi"
version = "0.7.0"
version = "0.7.1"
description = "Tune AI APIs."
authors = ["Frello Technology Private Limited <[email protected]>"]
license = "MIT"
Expand All @@ -18,6 +18,7 @@ snowflake_id = "1.0.2"
nutree = "0.8.0"
pillow = "^10.2.0"
httpx = "^0.28.1"
pydantic = "^2.6.4"
protobuf = { version = "^5.27.3", optional = true }
boto3 = { version = "1.29.6", optional = true }

Expand Down
2 changes: 1 addition & 1 deletion tuneapi/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
from tuneapi.apis.model_groq import Groq
from tuneapi.apis.model_mistral import Mistral
from tuneapi.apis.model_gemini import Gemini
from tuneapi.apis.turbo import distributed_chat
from tuneapi.apis.turbo import distributed_chat, distributed_chat_async
123 changes: 113 additions & 10 deletions tuneapi/apis/model_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

import httpx
import requests
from typing import Optional, Any, Dict, List
from pydantic import BaseModel
from typing import get_args, get_origin, List, Optional, Dict, Any, Union

import tuneapi.utils as tu
import tuneapi.types as tt
Expand Down Expand Up @@ -110,6 +111,106 @@ def _process_header(self):
"Content-Type": "application/json",
}

@staticmethod
def get_structured_schema(model: type[BaseModel]) -> Dict[str, Any]:
"""
Converts a Pydantic BaseModel to a JSON schema compatible with Gemini API,
including `anyOf` for optional or union types and handling nested structures correctly.
Args:
model: The Pydantic BaseModel class to convert.
Returns:
A dictionary representing the JSON schema.
"""

def _process_field(
field_name: str, field_type: Any, field_description: str = None
) -> dict:
"""Helper function to process a single field."""
schema = {}
origin = get_origin(field_type)
args = get_args(field_type)

if origin is list:
schema["type"] = "array"
if args:
item_schema = _process_field_type(args[0])
schema["items"] = item_schema
if "type" not in item_schema and "anyOf" not in item_schema:
schema["items"]["type"] = "object" # default item type for list
else:
schema["items"] = {}
elif origin is Optional:
if args:
inner_schema = _process_field_type(args[0])
schema["anyOf"] = [inner_schema, {"type": "null"}]
else:
schema = {"type": "null"}
elif origin is dict:
schema["type"] = "object"
if len(args) == 2:
schema["additionalProperties"] = _process_field_type(args[1])
else:
schema = _process_field_type(field_type)

if field_description:
schema["description"] = field_description
return schema

def _process_field_type(field_type: Any) -> dict:
"""Helper function to process the type of a field."""

origin = get_origin(field_type)
args = get_args(field_type)

if field_type is str:
return {"type": "string"}
elif field_type is int:
return {"type": "integer"}
elif field_type is float:
return {"type": "number"}
elif field_type is bool:
return {"type": "boolean"}
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
return Gemini.get_structured_schema(
field_type
) # Recursive call for nested models
elif origin is list:
schema = {"type": "array"}
if args:
item_schema = _process_field_type(args[0])
schema["items"] = item_schema
if "type" not in item_schema and "anyOf" not in item_schema:
schema["items"]["type"] = "object"
return schema
elif origin is Optional:
return _process_field_type(args[0])
elif origin is dict:
schema = {"type": "object"}
if len(args) == 2:
schema["additionalProperties"] = _process_field_type(args[1])
return schema
elif origin is Union:
return _process_field_type(args[0])
else:
return {"type": "string"} # default any object to string

schema = {"type": "object", "properties": {}, "required": []}

for field_name, field in model.model_fields.items():
field_description = field.description
if field.is_required():
schema["required"].append(field_name)

schema["properties"][field_name] = _process_field(
field_name, field.annotation, field_description
)

if model.__doc__:
schema["description"] = model.__doc__.strip()
return schema

def chat(
self,
chats: tt.Thread | str,
Expand Down Expand Up @@ -139,11 +240,13 @@ def chat(
output = x
else:
output += x
except Exception as e:
if not x:
raise e
else:
raise ValueError(x)
except requests.HTTPError as e:
print(e.response.text)
raise e

if chats.schema:
output = chats.schema(**tu.from_json(output))
return output
return output

def stream_chat(
Expand Down Expand Up @@ -198,11 +301,11 @@ def stream_chat(
"stopSequences": [],
}

if chats.gen_schema:
if chats.schema:
generation_config.update(
{
"response_mime_type": "application/json",
"response_schema": chats.gen_schema,
"response_schema": self.get_structured_schema(chats.schema),
}
)
data["generationConfig"] = generation_config
Expand Down Expand Up @@ -376,11 +479,11 @@ async def stream_chat_async(
"stopSequences": [],
}

if chats.gen_schema:
if chats.schema:
generation_config.update(
{
"response_mime_type": "application/json",
"response_schema": chats.gen_schema,
"response_schema": chats.schema,
}
)
data["generationConfig"] = generation_config
Expand Down
1 change: 1 addition & 0 deletions tuneapi/apis/model_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tuneapi.utils as tu
import tuneapi.types as tt
from tuneapi.apis.turbo import distributed_chat
from tuneapi.apis.model_openai import Openai as _Openai


class Mistral(tt.ModelInterface):
Expand Down
Loading

0 comments on commit aa6ab55

Please sign in to comment.