Skip to content

Commit

Permalink
feat(model): Integrate GPT-4 text generation in Rago (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
esloch authored Oct 30, 2024
1 parent e50b64e commit 42078a3
Show file tree
Hide file tree
Showing 13 changed files with 503 additions and 13 deletions.
2 changes: 2 additions & 0 deletions .env.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
HF_TOKEN=${HF_TOKEN}
OPENAI_API_KEY=${OPENAI_API_KEY}
1 change: 1 addition & 0 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ jobs:

env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

steps:
- uses: actions/checkout@v4
Expand Down
344 changes: 343 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ torchvision = [
{version = ">=0.20.0", markers="extra=='gpu' and extra!='cpu'"},
]
langdetect = ">=1"
openai = "^1.52.2"

[tool.poetry.extras]
cpu = ["torch", "torchvision"]
Expand Down
2 changes: 2 additions & 0 deletions src/rago/augmented/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

from rago.augmented.base import AugmentedBase
from rago.augmented.hugging_face import HuggingFaceAug
from rago.augmented.openai_aug import OpenAIAug

__all__ = [
'AugmentedBase',
'HuggingFaceAug',
'OpenAIAug',
]
35 changes: 35 additions & 0 deletions src/rago/augmented/openai_aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""OpenAIAug class for query augmentation using OpenAI API."""

from __future__ import annotations

import openai

from typeguard import typechecked

from rago.augmented.base import AugmentedBase


@typechecked
class OpenAIAug(AugmentedBase):
"""OpenAIAug class for query augmentation using OpenAI API."""

def __init__(self, model_name: str = 'gpt-4', k: int = 1) -> None:
"""Initialize the OpenAIAug class."""
self.model_name = model_name
self.k = k

def search(
self, query: str, documents: list[str], k: int = 1
) -> list[str]:
"""Augment the query by expanding or rephrasing it using OpenAI."""
prompt = f"Retrieval: '{query}'\nContext: {' '.join(documents)}"

response = openai.Completion.create( # type: ignore[attr-defined]
model=self.model_name,
messages=[{'role': 'user', 'content': prompt}],
max_tokens=50,
temperature=0.5,
)

augmented_query = response.choices[0]['message']['content'].strip()
return [augmented_query] * self.k
2 changes: 2 additions & 0 deletions src/rago/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from rago.generation.base import GenerationBase
from rago.generation.hugging_face import HuggingFaceGen
from rago.generation.llama import LlamaGen
from rago.generation.openai_gpt import OpenAIGPTGen

__all__ = [
'GenerationBase',
'HuggingFaceGen',
'LlamaGen',
'OpenAIGPTGen',
]
8 changes: 4 additions & 4 deletions src/rago/generation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class GenerationBase:
"""Generic Generation class."""

apikey: str = ''
api_key: str = ''
device_name: str = 'cpu'
device: torch.device
model: Any
Expand All @@ -26,7 +26,7 @@ class GenerationBase:
def __init__(
self,
model_name: str = '',
apikey: str = '',
api_key: str = '',
temperature: float = 0.5,
output_max_length: int = 500,
device: str = 'auto',
Expand All @@ -37,13 +37,13 @@ def __init__(
----------
model_name : str
The name of the model to use.
apikey : str
api_key : str
temperature : float
output_max_length : int
Maximum length of the generated output.
device: str (default=auto)
"""
self.apikey = apikey
self.api_key = api_key
self.model_name = model_name
self.output_max_length = output_max_length
self.temperature = temperature
Expand Down
4 changes: 2 additions & 2 deletions src/rago/generation/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class HuggingFaceGen(GenerationBase):
def __init__(
self,
model_name: str = 't5-small',
apikey: str = '',
api_key: str = '',
temperature: float = 0.5,
output_max_length: int = 500,
device: str = 'auto',
Expand All @@ -28,7 +28,7 @@ def __init__(

super().__init__(
model_name=model_name,
apikey=apikey,
api_key=api_key,
temperature=temperature,
output_max_length=output_max_length,
device=device,
Expand Down
8 changes: 4 additions & 4 deletions src/rago/generation/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class LlamaGen(GenerationBase):
def __init__(
self,
model_name: str = 'meta-llama/Llama-3.2-1B',
apikey: str = '',
api_key: str = '',
temperature: float = 0.5,
output_max_length: int = 500,
device: str = 'auto',
Expand All @@ -31,19 +31,19 @@ def __init__(

super().__init__(
model_name=model_name,
apikey=apikey,
api_key=api_key,
temperature=temperature,
output_max_length=output_max_length,
device=device,
)

self.tokenizer = AutoTokenizer.from_pretrained(
model_name, token=apikey
model_name, token=api_key
)

self.model = AutoModelForCausalLM.from_pretrained(
model_name,
token=apikey,
token=api_key,
torch_dtype=torch.float16
if self.device_name == 'cuda'
else torch.float32,
Expand Down
52 changes: 52 additions & 0 deletions src/rago/generation/openai_gpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""OpenAI Generation Model class for flexible GPT-based text generation."""

from __future__ import annotations

from typing import cast

import openai

from typeguard import typechecked

from rago.generation.base import GenerationBase


@typechecked
class OpenAIGPTGen(GenerationBase):
"""OpenAI generation model for text generation."""

def __init__(
self,
model_name: str = 'gpt-4',
output_max_tokens: int = 500,
api_key: str = '',
) -> None:
"""Initialize OpenAIGenerationModel with OpenAI's model."""
super().__init__(
model_name=model_name, output_max_length=output_max_tokens
)
openai.api_key = api_key

def generate(
self,
query: str,
context: list[str],
language: str = 'en',
) -> str:
"""Generate text using OpenAI's API with dynamic model support."""
input_text = (
f"Question: {query}\nContext: {' '.join(context)}\n"
f"Answer in {language}:"
)

response = openai.Completion.create( # type: ignore[attr-defined]
model=self.model_name,
messages=[{'role': 'user', 'content': input_text}],
max_tokens=self.output_max_length,
temperature=0.7,
top_p=0.9,
frequency_penalty=0.5,
presence_penalty=0.3,
)

return cast(str, response['choices'][0]['message']['content'].strip())
6 changes: 4 additions & 2 deletions tests/test_llama.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for rago package."""

import os

from pathlib import Path

import pytest
Expand All @@ -22,12 +24,12 @@ def animals_data() -> list[str]:
@pytest.mark.skip_on_ci
def test_llama(env, animals_data: list[str], device: str = 'auto') -> None:
"""Test RAG with hugging face."""
HF_TOKEN = env.get('HF_TOKEN', '')
HF_TOKEN = os.getenv('HF_TOKEN', '')

rag = Rago(
retrieval=StringRet(animals_data),
augmented=HuggingFaceAug(k=3),
generation=LlamaGen(apikey=HF_TOKEN, device=device),
generation=LlamaGen(api_key=HF_TOKEN, device=device),
)

query = 'Is there any animals larger than a dinosaur?'
Expand Down
51 changes: 51 additions & 0 deletions tests/test_openai_gpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Tests for Rago package using OpenAI GPT-4."""

import os

from pathlib import Path

import pytest

from rago import Rago
from rago.augmented import OpenAIAug
from rago.generation.openai_gpt import (
OpenAIGPTGen,
)
from rago.retrieval import StringRet


@pytest.fixture
def animals_data() -> list[str]:
"""Fixture for loading the animals dataset."""
data_path = Path(__file__).parent / 'data' / 'animals.txt'
with open(data_path) as f:
data = [line.strip() for line in f.readlines() if line.strip()]
return data


@pytest.fixture
def openai_api_key() -> str:
"""Fixture for OpenAI API key from environment."""
api_key = os.getenv('OPENAI_API_KEY')
if not api_key:
raise EnvironmentError(
'Please set the OPENAI_API_KEY environment variable.'
)
return api_key


@pytest.mark.skip_on_ci
def test_openai_gpt4(animals_data: list[str], openai_api_key: str) -> None:
"""Test RAG pipeline with OpenAI's GPT-4."""
rag = Rago(
retrieval=StringRet(animals_data),
augmented=OpenAIAug(k=3),
generation=OpenAIGPTGen(api_key=openai_api_key, model_name='gpt-4'),
)

query = 'Is there any animal larger than a dinosaur?'
result = rag.prompt(query)

assert (
'Blue Whale' in result
), 'Expected response to mention Blue Whale as a larger animal.'

0 comments on commit 42078a3

Please sign in to comment.