Skip to content

Commit

Permalink
feat!: Support custom embedding models in the DocumentIndexClient (#1151
Browse files Browse the repository at this point in the history
)

* feat!: Support custom embedding models in the DocumentIndexClient

- Allow custom embedding models when creating new indexes
- Support new InstructableEmbed embedding strategy
- Update test fixture to cover new configuration
- Update CHANGELOG.md with user-facing changes

* test: Explicitly test new index configuration structure

The changes introduced in the previous commit were covered by a test
that randomly selects either semantic or instructable. Running the test
suite multiple times in a row is the only way to confirm with high
probability that the new structure (which now includes a variant) is
correct.

This commit splits this test out into two sub-tests which check the
semantic and instructable variants individually, with the usual
field-level randomisation.
  • Loading branch information
Michael-JB authored Nov 26, 2024
1 parent 2a8d0cb commit 9bcc58b
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 21 deletions.
8 changes: 6 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
## Unreleased

### Features
...
- You can now customise the embedding model when creating an index using the `DocumentIndexClient`.
- You can now use the `InstructableEmbed` embedding strategy when creating an index using the `DocumentIndexClient`. See the `document_index.ipynb` notebook for more information and an example.

### Fixes
...
Expand All @@ -11,7 +12,10 @@
...

### Breaking Changes
...
- The way you configure indexes in the `DocumentIndexClient` has changed. See the `document_index.ipynb` notebook for more information.
- The `EmbeddingType` alias has been renamed to `Representation` to better align with the underlying API.
- The `embedding_type` field has been removed from the `IndexConfiguration` class. You now configure embedding-related parameters via the `embedding` field.
- You now always need to specify an embedding model when creating an index. Previously, this was always `luminous-base`.

## 7.3.1
### Features
Expand Down
51 changes: 49 additions & 2 deletions src/documentation/document_index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
" DocumentPath,\n",
" IndexConfiguration,\n",
" IndexPath,\n",
" InstructableEmbed,\n",
" LimitedConcurrencyClient,\n",
" SemanticEmbed,\n",
")\n",
"from intelligence_layer.core import InMemoryTracer\n",
"from intelligence_layer.examples import MultipleChunkRetrieverQa, RetrieverBasedQaInput\n",
Expand Down Expand Up @@ -133,7 +135,9 @@
"\n",
"# customise the parameters of the index here\n",
"index_configuration = IndexConfiguration(\n",
" chunk_size=64, chunk_overlap=0, embedding_type=\"asymmetric\"\n",
" chunk_size=64,\n",
" chunk_overlap=0,\n",
" embedding=SemanticEmbed(model_name=\"luminous-base\", representation=\"asymmetric\"),\n",
")\n",
"\n",
"# create the namespace-wide index resource\n",
Expand Down Expand Up @@ -314,7 +318,10 @@
"\n",
"# customise the parameters of the index here\n",
"index_configuration = IndexConfiguration(\n",
" chunk_size=64, chunk_overlap=0, embedding_type=\"asymmetric\", hybrid_index=\"bm25\"\n",
" chunk_size=64,\n",
" chunk_overlap=0,\n",
" hybrid_index=\"bm25\",\n",
" embedding=SemanticEmbed(model_name=\"luminous-base\", representation=\"asymmetric\"),\n",
")\n",
"\n",
"# create the namespace-wide index resource\n",
Expand Down Expand Up @@ -349,6 +356,46 @@
"document_index_retriever.get_relevant_documents_with_scores(query=\"25 April\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Instructable Embeddings\n",
"\n",
"As well as supporting custom embedding models, the Document Index also supports instructable embeddings. This lets you prompt embedding models like `pharia-1-embedding-4608-control` with custom instructions for queries and documents. Steering the model like this can help the model understand nuances of your specific data and ultimately lead to embeddings that are more useful for your use-case. To use default instructions, leave the instruction fields unspecified.\n",
"\n",
"To use an instructable embedding model, create an index as follows."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# change this value if you want to use an index of a different name\n",
"INSTRUCTABLE_EMBEDDING_INDEX = \"intelligence-layer-sdk-demo-instructable-embedding\"\n",
"\n",
"index_path = IndexPath(namespace=NAMESPACE, index=INSTRUCTABLE_EMBEDDING_INDEX)\n",
"\n",
"# customise the parameters of the index here\n",
"index_configuration = IndexConfiguration(\n",
" chunk_size=64,\n",
" chunk_overlap=0,\n",
" embedding=InstructableEmbed(\n",
" model_name=\"pharia-1-embedding-4608-control\",\n",
" query_instruction=\"Represent the user's question about rivers to find a relevant wikipedia paragraph\",\n",
" document_instruction=\"Represent the document so that it can be matched to a user's question about rivers\",\n",
" ),\n",
")\n",
"\n",
"# create the namespace-wide index resource\n",
"document_index.create_index(index_path, index_configuration)\n",
"\n",
"# assign the index to the collection\n",
"document_index.assign_index_to_collection(collection_path, INSTRUCTABLE_EMBEDDING_INDEX)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
2 changes: 2 additions & 0 deletions src/intelligence_layer/connectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,12 @@
from .document_index.document_index import Filters as Filters
from .document_index.document_index import IndexConfiguration as IndexConfiguration
from .document_index.document_index import IndexPath as IndexPath
from .document_index.document_index import InstructableEmbed as InstructableEmbed
from .document_index.document_index import InternalError as InternalError
from .document_index.document_index import InvalidInput as InvalidInput
from .document_index.document_index import ResourceNotFound as ResourceNotFound
from .document_index.document_index import SearchQuery as SearchQuery
from .document_index.document_index import SemanticEmbed as SemanticEmbed
from .limited_concurrency_client import (
AlephAlphaClientProtocol as AlephAlphaClientProtocol,
)
Expand Down
52 changes: 46 additions & 6 deletions src/intelligence_layer/connectors/document_index/document_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
from urllib.parse import quote, urljoin

import requests
from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic.types import StringConstraints
from requests import HTTPError
from typing_extensions import Self

from intelligence_layer.connectors.base.json_serializable import JsonSerializable

EmbeddingType: TypeAlias = Literal["symmetric", "asymmetric"]
Representation: TypeAlias = Literal["symmetric", "asymmetric"]
HybridIndex: TypeAlias = Literal["bm25"] | None
EmbeddingConfig: TypeAlias = Union["SemanticEmbed", "InstructableEmbed"]


class IndexPath(BaseModel, frozen=True):
Expand All @@ -31,21 +32,60 @@ class IndexPath(BaseModel, frozen=True):
index: str


class SemanticEmbed(BaseModel):
"""Semantic embedding configuration.
Args:
model_name: Name of the model to use.
representation: The embedding representation to use: "symmetric" or "asymmetric".
Use "symmetric" when the queries and documents are the same, e.g., for classification tasks.
Use "asymmetric" when the queries and documents are different, e.g., for search tasks.
"""

# `model_name` conflicts with the default protected `model_*` namespace
model_config = ConfigDict(protected_namespaces=())

strategy: Literal["semantic_embed"] = "semantic_embed"
model_name: str
representation: Representation


class InstructableEmbed(BaseModel):
"""Instructable embedding configuration.
Args:
model_name: Name of the model to use.
query_instruction: Instruction to apply when embedding queries.
document_instruction: Instruction to apply when embedding documents.
"""

# `model_name` conflicts with the default protected `model_*` namespace
model_config = ConfigDict(protected_namespaces=())

strategy: Literal["instructable_embed"] = "instructable_embed"
model_name: str
query_instruction: str = ""
document_instruction: str = ""


class IndexConfiguration(BaseModel):
"""Configuration of an index.
Args:
embedding_type: "symmetric" or "asymmetric" embedding type.
chunk_overlap: The maximum number of tokens of overlap between consecutive chunks. Must be
less than `chunk_size`.
chunk_size: The maximum size of the chunks in tokens to be used for the index.
hybrid_index: If set to "bm25", combine vector search and keyword search (bm25) results.
embedding: Configuration for the embedding of chunks.
"""

embedding_type: EmbeddingType
# `model_name` in `embedding` conflicts with the default protected `model_*` namespace
model_config = ConfigDict(protected_namespaces=())

chunk_overlap: int = Field(default=0, ge=0)
chunk_size: int = Field(..., gt=0, le=2046)
hybrid_index: HybridIndex = None
embedding: EmbeddingConfig

@model_validator(mode="after")
def validate_chunk_overlap(self) -> Self:
Expand Down Expand Up @@ -535,8 +575,8 @@ def create_index(
data = {
"chunk_size": index_configuration.chunk_size,
"chunk_overlap": index_configuration.chunk_overlap,
"embedding_type": index_configuration.embedding_type,
"hybrid_index": index_configuration.hybrid_index,
"embedding": index_configuration.embedding.model_dump(),
}
response = requests.put(url, data=dumps(data), headers=self.headers)
self._raise_for_status(response)
Expand Down Expand Up @@ -596,10 +636,10 @@ def index_configuration(self, index_path: IndexPath) -> IndexConfiguration:
self._raise_for_status(response)
response_json: Mapping[str, Any] = response.json()
return IndexConfiguration(
embedding_type=response_json["embedding_type"],
chunk_overlap=response_json["chunk_overlap"],
chunk_size=response_json["chunk_size"],
hybrid_index=response_json.get("hybrid_index"),
embedding=response_json["embedding"],
)

def assign_index_to_collection(
Expand Down
84 changes: 73 additions & 11 deletions tests/connectors/document_index/test_document_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
import string
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from functools import wraps
from http import HTTPStatus
Expand All @@ -19,16 +20,19 @@
DocumentFilterQueryParams,
DocumentIndexClient,
DocumentPath,
EmbeddingType,
EmbeddingConfig,
FilterField,
FilterOps,
Filters,
HybridIndex,
IndexConfiguration,
IndexPath,
InstructableEmbed,
InvalidInput,
Representation,
ResourceNotFound,
SearchQuery,
SemanticEmbed,
)

P = ParamSpec("P")
Expand Down Expand Up @@ -72,8 +76,12 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
return decorator(func)


def random_alphanumeric_string(length: int = 20) -> str:
return "".join(random.choices(string.ascii_letters + string.digits, k=length))


def random_identifier() -> str:
name = "".join(random.choices(string.ascii_letters + string.digits, k=20))
name = random_alphanumeric_string(20)
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S")
return f"ci-il-{name}-{timestamp}"

Expand All @@ -92,6 +100,25 @@ def is_outdated_identifier(identifier: str, timestamp_threshold: datetime) -> bo
return not timestamp > timestamp_threshold


def random_semantic_embed() -> EmbeddingConfig:
return SemanticEmbed(
representation=random.choice(get_args(Representation)),
model_name="luminous-base",
)


def random_instructable_embed() -> EmbeddingConfig:
return InstructableEmbed(
model_name="pharia-1-embedding-4608-control",
query_instruction=random_alphanumeric_string(),
document_instruction=random_alphanumeric_string(),
)


def random_embedding_config() -> EmbeddingConfig:
return random.choice([random_semantic_embed(), random_instructable_embed()])


@fixture(scope="session")
def document_index_namespace() -> str:
return "team-document-index"
Expand Down Expand Up @@ -203,24 +230,27 @@ def read_only_collection_path(
document_index.delete_collection(collection_path)


@fixture
def random_index(
document_index: DocumentIndexClient, document_index_namespace: str
@contextmanager
def random_index_with_embedding_config(
document_index: DocumentIndexClient,
document_index_namespace: str,
embedding_config: EmbeddingConfig,
) -> Iterator[tuple[IndexPath, IndexConfiguration]]:
name = random_identifier()

chunk_size, chunk_overlap = sorted(
random.sample([0, 32, 64, 128, 256, 512, 1024], 2), reverse=True
)
embedding_type = random.choice(get_args(EmbeddingType))

hybrid_index_choices: list[HybridIndex] = ["bm25", None]
hybrid_index = random.choice(hybrid_index_choices)

index = IndexPath(namespace=document_index_namespace, index=name)
index_configuration = IndexConfiguration(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
embedding_type=embedding_type,
hybrid_index=hybrid_index,
embedding=embedding_config,
)
try:
document_index.create_index(index, index_configuration)
Expand All @@ -229,6 +259,26 @@ def random_index(
document_index.delete_index(index)


@fixture
def random_instructable_index(
document_index: DocumentIndexClient, document_index_namespace: str
) -> Iterator[tuple[IndexPath, IndexConfiguration]]:
with random_index_with_embedding_config(
document_index, document_index_namespace, random_instructable_embed()
) as index:
yield index


@fixture
def random_semantic_index(
document_index: DocumentIndexClient, document_index_namespace: str
) -> Iterator[tuple[IndexPath, IndexConfiguration]]:
with random_index_with_embedding_config(
document_index, document_index_namespace, random_semantic_embed()
) as index:
yield index


@fixture
def document_contents() -> DocumentContents:
text = """John Stith Pemberton, the inventor of the world-renowned beverage Coca-Cola, was a figure whose life was marked by creativity, entrepreneurial spirit, and the turbulent backdrop of 19th-century America. Born on January 8, 1831, in Knoxville, Georgia, Pemberton grew up in an era of profound transformation and change.
Expand Down Expand Up @@ -547,19 +597,31 @@ def test_document_path_is_immutable() -> None:
def test_index_configuration_rejects_invalid_chunk_overlap() -> None:
try:
IndexConfiguration(
chunk_size=128, chunk_overlap=128, embedding_type="asymmetric"
chunk_size=128,
chunk_overlap=128,
embedding=random_embedding_config(),
)
except ValidationError as e:
assert "chunk_overlap must be less than chunk_size" in str(e)
else:
raise AssertionError("ValidationError was not raised")


def test_indexes_in_namespace_are_returned(
def test_semantic_indexes_in_namespace_are_returned(
document_index: DocumentIndexClient,
random_semantic_index: tuple[IndexPath, IndexConfiguration],
) -> None:
index_path, index_configuration = random_semantic_index
retrieved_index_configuration = document_index.index_configuration(index_path)

assert retrieved_index_configuration == index_configuration


def test_instructable_indexes_in_namespace_are_returned(
document_index: DocumentIndexClient,
random_index: tuple[IndexPath, IndexConfiguration],
random_instructable_index: tuple[IndexPath, IndexConfiguration],
) -> None:
index_path, index_configuration = random_index
index_path, index_configuration = random_instructable_index
retrieved_index_configuration = document_index.index_configuration(index_path)

assert retrieved_index_configuration == index_configuration
Expand Down

0 comments on commit 9bcc58b

Please sign in to comment.