Skip to content

Commit

Permalink
feat: Introduce InstructionFinetuningDataRepository (#1033)
Browse files Browse the repository at this point in the history
* feat: Support for finetuning dataset creation

WIP: implement initial interface

WIP: minimal working implementation

WIP: store multiple samples for postgres repo

WIP: poetry lock, linting

WIP: actually running poetry lock

WIP: seperate functions for single and batch storing

WIP: test sample validations

WIP: `InstructionFinetuningDataHandler`

WIP: Support filtering

WIP: linting

feat: `FileInstructionFinetuningDataRepository`

WIP: user-facing functions

poetry install

* `instruction_finetuning_handler_builder` for easier handler construction

temp commit

bugfix in samples_with_filter

poetry update

* feat: use session pooling & pagination in `PostgresInstructionFinetuningDataRepository`

poetry lock

* docs: Add docstrings to added classes

poeytry lock

* fix: Pagination for `PostgresInstructionFinetuningDataRepository`

poetry lock

fix pagination

fix test
  • Loading branch information
NickyHavoc authored Oct 29, 2024
1 parent 89031b4 commit 097a152
Show file tree
Hide file tree
Showing 19 changed files with 3,987 additions and 1,444 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/sdk-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ jobs:
ARGILLA_API_KEY: "argilla.apikey"
CLIENT_URL: ${{ secrets.CLIENT_URL }}
STUDIO_URL: "http://localhost:8000/"
POSTGRES_HOST: "localhost:5433"
POSTGRES_DB: "il_sdk"
POSTGRES_USER: "il_sdk"
POSTGRES_PASSWORD: "test"
run: |
./scripts/test.sh
run-notebooks:
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ repos:
"-L",
"newyorker,te,responde,ist,als,oder,technik,sie,rouge,unter,juli,fiel,couldn,mke, vor,fille,ans",
]
exclude: '^(poetry\.lock|tests/connectors/retrievers/test_document_index_retriever\.py|src/intelligence_layer/examples/qa/multiple_chunk_qa.py|src/intelligence_layer/examples/summarize/.*|tests/connectors/retrievers/test_document_index_retriever\.py|src/intelligence_layer/examples/classify/keyword_extract.py|tests/examples/summarize/test_single_chunk_few_shot_summarize.py|tests/examples/summarize/very_long_text.txt)$'
exclude: '^(poetry\.lock|tests/connectors/retrievers/test_document_index_retriever\.py|src/intelligence_layer/examples/qa/multiple_chunk_qa.py|src/intelligence_layer/examples/summarize/.*|tests/connectors/retrievers/test_document_index_retriever\.py|src/intelligence_layer/examples/classify/keyword_extract.py|tests/examples/summarize/test_single_chunk_few_shot_summarize.py|tests/examples/summarize/very_long_text.txt|src/intelligence_layer/learning/enrich.py)$'
- repo: https://github.com/akaihola/darglint2
rev: v1.8.2
hooks:
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,17 @@
- Add support for Llama3InstructModel in PromptBasedClassify
- Add TextControl to 'to_instruct_prompt' for instruct models
- Add 'attention_manipulation_with_text_controls.ipynb' to tutorial notebooks
- Introduced `InstructionFinetuningDataHandler` to provide methods for storing, retrieving and updating finetuning data samples given an `InstructionFinetuningDataRepository`. Also has methods for filtered sample retrieval and for dataset formatting.
- Introduced `InstructionFinetuningDataRepository` for storing and retrieving finetuning samples. Comes in two implementations:
- `PostgresInstructionFinetuningDataRepository` to work with data stored in a Postgres database.
- `FileInstructionFinetuningDataRepository` to work with data stored in the local file-system.


### Fixes
...
### Deprecations
...

### Breaking Changes
...

Expand Down
3,336 changes: 1,900 additions & 1,436 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ rouge-score = "^0.1.2"
sacrebleu = "^2.4.3"
lingua-language-detector = "^2.0.2"
argilla = "^2.3.0"
dict-hash = "^1.3.5"
dict-hash = "^1.3.4"
sqlalchemy = "^2.0.35"
psycopg2-binary = "^2.9.9"

[tool.poetry.group.dev.dependencies]
# lint & format
Expand Down
1 change: 1 addition & 0 deletions src/intelligence_layer/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .model import ControlModel as ControlModel
from .model import ExplainInput as ExplainInput
from .model import ExplainOutput as ExplainOutput
from .model import FinetuningMessage as FinetuningMessage
from .model import LanguageModel as LanguageModel
from .model import Llama2InstructModel as Llama2InstructModel
from .model import Llama3ChatModel as Llama3ChatModel
Expand Down
93 changes: 88 additions & 5 deletions src/intelligence_layer/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,25 @@ def echo(
...


class Message(BaseModel):
class Message(BaseModel, frozen=True):
role: Literal["system", "user", "assistant"]
content: str


class FinetuningMessage(BaseModel, frozen=True):
"""Represent a prompt message in a finetuning sample as required to finetune an llm using [scaling](https://github.com/Aleph-Alpha/scaling).
Args:
has_loss: Flag indicated whether loss should be applied to the message during training.
content: The text in the message
type: Should always be "text"
"""

has_loss: bool
content: str
type: str = "text"


class ChatModel(LanguageModel):
"""Abstract base class to implement any model that supports chat."""

Expand Down Expand Up @@ -611,24 +625,39 @@ class AlephAlphaChatModel(ChatModel, ControlModel):

CHAT_PROMPT_TEMPLATE: PromptTemplate

@abstractmethod
def to_finetuning_sample(
self, messages: Sequence[Message]
) -> Sequence[FinetuningMessage]:
"""Abstract function allowing a user to define what the model's finetuning samples should look like.
Args:
messages: The messages making up the finetuning sample
Returns:
A finetuning sample containing the input messages
"""
...

def to_chat_prompt(
self,
messages: list[Message],
response_prefix: str | None = None,
self, messages: Sequence[Message], response_prefix: str | None = None
) -> RichPrompt:
"""Method to create a chat-`RichPrompt` object to use with any `AlephAlphaModel`.
Args:
messages: A number of messages to use as prompt for the model
response_prefix: Append the given string to the beginning of the final agent message to
steer the generation. Defaults to None.
Returns:
A RichPrompt object to be consumed by the Aleph Alpha client
"""
return self.CHAT_PROMPT_TEMPLATE.to_rich_prompt(
messages=[m.model_dump() for m in messages], response_prefix=response_prefix
)

def generate_chat(
self, messages: list[Message], response_prefix: str | None, tracer: Tracer
self, messages: Sequence[Message], response_prefix: str | None, tracer: Tracer
) -> str:
"""Generate a raw completion to messages for any `AlephAlphaChatModel`.
Expand All @@ -637,6 +666,9 @@ def generate_chat(
response_prefix: Optional argument to append a string to the beginning of the
final agent message to steer the generation
tracer: Valid instance of a tracer
Returns:
An LLM completion
"""
prompt = self.to_chat_prompt(messages, response_prefix)
prompt_item = prompt.items[0]
Expand Down Expand Up @@ -700,6 +732,47 @@ def to_instruct_prompt(
)


def to_llama_3_finetuning_sample(
messages: Sequence[Message], eot_token: str
) -> Sequence[FinetuningMessage]:
"""Turn a sequence of messages into a finetuning training sample using the llama-3 format.
Args:
messages: The messages making up the finetuning sample
eot_token: The end-of-turn token used to separate the messages
Returns:
A sequence of formatted message for finetuning
"""

def get_content(
message: Message, is_first_message: bool, is_preceding_assistant_message: bool
) -> str:
prompt = "<|begin_of_text|>" if is_first_message else ""
prompt += (
f"<|begin_of_text|><|start_header_id|>{message.role}<|end_header_id|>\n\n{message.content}{eot_token}"
if message.role != "assistant"
else f"{message.content}{eot_token}"
)
if is_preceding_assistant_message:
prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n"
return prompt

return [
FinetuningMessage(
has_loss=message.role == "assistant",
content=get_content(
message,
index == 0,
messages[index + 1].role == "assistant"
if index + 1 < len(messages)
else False,
),
)
for index, message in enumerate(messages)
]


class Pharia1ChatModel(AlephAlphaChatModel):
"""Chat model to be used for any `"pharia-1-llm-*` model.
Expand Down Expand Up @@ -739,6 +812,11 @@ def complete(self, input: CompleteInput, tracer: Tracer) -> CompleteOutput:
def eot_token(self) -> str:
return "<|endoftext|>"

def to_finetuning_sample(
self, messages: Sequence[Message]
) -> Sequence[FinetuningMessage]:
return to_llama_3_finetuning_sample(messages, self.eot_token)


class Llama3ChatModel(AlephAlphaChatModel):
"""Chat model to be used for `llama-3-*` and `llama-3.1-*` models.
Expand Down Expand Up @@ -769,3 +847,8 @@ def __init__(
@property
def eot_token(self) -> str:
return "<|eot_id|>"

def to_finetuning_sample(
self, messages: Sequence[Message]
) -> Sequence[FinetuningMessage]:
return to_llama_3_finetuning_sample(messages, self.eot_token)
30 changes: 30 additions & 0 deletions src/intelligence_layer/learning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from .enrich import EnrichDomain as EnrichDomain
from .enrich import EnrichQuality as EnrichQuality
from .file_instruction_finetuning_data_repository import (
FileInstructionFinetuningDataRepository as FileInstructionFinetuningDataRepository,
)
from .instruction_finetuning_data_handler import EnrichAction as EnrichAction
from .instruction_finetuning_data_handler import (
InstructionFinetuningDataHandler as InstructionFinetuningDataHandler,
)
from .instruction_finetuning_data_handler import (
instruction_finetuning_handler_builder as instruction_finetuning_handler_builder,
)
from .instruction_finetuning_data_repository import (
InstructionFinetuningDataRepository as InstructionFinetuningDataRepository,
)
from .models import InstructionFinetuningSample as InstructionFinetuningSample
from .models import (
InstructionFinetuningSample_ as InstructionFinetuningSample_,
)
from .models import (
InstructionFinetuningSampleAttributes as InstructionFinetuningSampleAttributes,
)
from .models import InvalidSampleError as InvalidSampleError
from .models import RawInstructionFinetuningSample as RawInstructionFinetuningSample
from .models import TripletTransformation as TripletTransformation
from .postgres_instruction_finetuning_data_repository import (
PostgresInstructionFinetuningDataRepository as PostgresInstructionFinetuningDataRepository,
)

__all__ = [symbol for symbol in dir()]
Loading

0 comments on commit 097a152

Please sign in to comment.