Skip to content

Commit

Permalink
fix model_name issue
Browse files Browse the repository at this point in the history
  • Loading branch information
voorhs committed Nov 9, 2024
1 parent c349f18 commit b26a878
Show file tree
Hide file tree
Showing 24 changed files with 163 additions and 80 deletions.
2 changes: 0 additions & 2 deletions autointent/context/vector_index_client/vector_index_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@


class VectorIndexClient:
model_name: str

def __init__(
self,
device: str,
Expand Down
2 changes: 1 addition & 1 deletion autointent/datafiles/default-multiclass-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ nodes:
search_space:
- module_type: vector_db
k: [10]
model_name:
embedder_name:
- avsolatorio/GIST-small-Embedding-v0
- infgrad/stella-base-en-v2
- node_type: scoring
Expand Down
3 changes: 3 additions & 0 deletions autointent/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,6 @@ def predict(self, *args: list[str] | npt.NDArray[Any], **kwargs: dict[str, Any])
@abstractmethod
def from_context(cls, context: Context, **kwargs: dict[str, Any]) -> Self:
pass

def get_embedder_name(self) -> str | None:
return None
14 changes: 7 additions & 7 deletions autointent/modules/retrieval/vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ class VectorDBModule(RetrievalModule):
def __init__(
self,
k: int,
model_name: str,
embedder_name: str,
db_dir: str | None = None,
device: str = "cpu",
batch_size: int = 32,
max_length: int | None = None,
) -> None:
if db_dir is None:
db_dir = str(get_db_dir())
self.model_name = model_name
self.embedder_name = embedder_name
self.device = device
self.db_dir = db_dir
self.batch_size = batch_size
Expand All @@ -47,11 +47,11 @@ def from_context(
cls,
context: Context,
k: int,
model_name: str,
embedder_name: str,
) -> Self:
return cls(
k=k,
model_name=model_name,
embedder_name=embedder_name,
db_dir=str(context.get_db_dir()),
device=context.get_device(),
batch_size=context.get_batch_size(),
Expand All @@ -63,7 +63,7 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
self.device, self.db_dir, embedder_batch_size=self.batch_size, embedder_max_length=self.max_length
)

self.vector_index = vector_index_client.create_index(self.model_name, utterances, labels)
self.vector_index = vector_index_client.create_index(self.embedder_name, utterances, labels)

def score(self, context: Context, metric_fn: RetrievalMetricFn) -> float:
labels_pred, _, _ = self.vector_index.query(
Expand All @@ -73,7 +73,7 @@ def score(self, context: Context, metric_fn: RetrievalMetricFn) -> float:
return metric_fn(context.data_handler.labels_test, labels_pred)

def get_assets(self) -> RetrieverArtifact:
return RetrieverArtifact(embedder_name=self.model_name)
return RetrieverArtifact(embedder_name=self.embedder_name)

def clear_cache(self) -> None:
self.vector_index.delete()
Expand Down Expand Up @@ -101,7 +101,7 @@ def load(self, path: str) -> None:
embedder_batch_size=self.metadata["batch_size"],
embedder_max_length=self.metadata["max_length"],
)
self.vector_index = vector_index_client.get_index(self.model_name)
self.vector_index = vector_index_client.get_index(self.embedder_name)

def predict(self, utterances: list[str]) -> tuple[list[list[int | list[int]]], list[list[float]], list[list[str]]]:
"""
Expand Down
24 changes: 15 additions & 9 deletions autointent/modules/scoring/description/description.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DescriptionScorer(ScoringModule):

def __init__(
self,
model_name: str,
embedder_name: str,
db_dir: Path | None = None,
temperature: float = 1.0,
device: str = "cpu",
Expand All @@ -46,7 +46,7 @@ def __init__(
self.temperature = temperature
self.device = device
self.db_dir = db_dir
self.model_name = model_name
self.embedder_name = embedder_name
self.batch_size = batch_size
self.max_length = max_length

Expand All @@ -55,23 +55,26 @@ def from_context(
cls,
context: Context,
temperature: float,
model_name: str | None = None,
embedder_name: str | None = None,
) -> Self:
if model_name is None:
model_name = context.optimization_info.get_best_embedder()
if embedder_name is None:
embedder_name = context.optimization_info.get_best_embedder()
precomputed_embeddings = True
else:
precomputed_embeddings = context.vector_index_client.exists(model_name)
precomputed_embeddings = context.vector_index_client.exists(embedder_name)

instance = cls(
temperature=temperature,
device=context.get_device(),
db_dir=context.get_db_dir(),
model_name=model_name,
embedder_name=embedder_name,
)
instance.precomputed_embeddings = precomputed_embeddings
return instance

def get_embedder_name(self) -> str:
return self.embedder_name

def fit(
self,
utterances: list[str],
Expand All @@ -88,15 +91,18 @@ def fit(
if self.precomputed_embeddings:
# this happens only when LinearScorer is within Pipeline opimization after RetrievalNode optimization
vector_index_client = VectorIndexClient(self.device, self.db_dir, self.batch_size, self.max_length)
vector_index = vector_index_client.get_index(self.model_name)
vector_index = vector_index_client.get_index(self.embedder_name)
features = vector_index.get_all_embeddings()
if len(features) != len(utterances):
msg = "Vector index mismatches provided utterances"
raise ValueError(msg)
embedder = vector_index.embedder
else:
embedder = Embedder(
device=self.device, model_name=self.model_name, batch_size=self.batch_size, max_length=self.max_length
device=self.device,
model_name=self.embedder_name,
batch_size=self.batch_size,
max_length=self.max_length,
)
features = embedder.embed(utterances)

Expand Down
20 changes: 10 additions & 10 deletions autointent/modules/scoring/dnnc/dnnc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DNNCScorer(ScoringModule):
def __init__(
self,
cross_encoder_name: str,
search_model_name: str,
embedder_name: str,
k: int,
db_dir: str | None = None,
device: str = "cpu",
Expand All @@ -56,7 +56,7 @@ def __init__(
db_dir = str(get_db_dir())

self.cross_encoder_name = cross_encoder_name
self.search_model_name = search_model_name
self.embedder_name = embedder_name
self.k = k
self.train_head = train_head
self.device = device
Expand All @@ -70,18 +70,18 @@ def from_context(
context: Context,
cross_encoder_name: str,
k: int,
search_model_name: str | None = None,
embedder_name: str | None = None,
train_head: bool = False,
) -> Self:
if search_model_name is None:
search_model_name = context.optimization_info.get_best_embedder()
if embedder_name is None:
embedder_name = context.optimization_info.get_best_embedder()
prebuilt_index = True
else:
prebuilt_index = context.vector_index_client.exists(search_model_name)
prebuilt_index = context.vector_index_client.exists(embedder_name)

instance = cls(
cross_encoder_name=cross_encoder_name,
search_model_name=search_model_name,
embedder_name=embedder_name,
k=k,
train_head=train_head,
device=context.get_device(),
Expand All @@ -101,12 +101,12 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:

if self.prebuilt_index:
# this happens only when LinearScorer is within Pipeline opimization after RetrievalNode optimization
self.vector_index = vector_index_client.get_index(self.search_model_name)
self.vector_index = vector_index_client.get_index(self.embedder_name)
if len(utterances) != len(self.vector_index.texts):
msg = "Vector index mismatches provided utterances"
raise ValueError(msg)
else:
self.vector_index = vector_index_client.create_index(self.search_model_name, utterances, labels)
self.vector_index = vector_index_client.create_index(self.embedder_name, utterances, labels)

if self.train_head:
model = CrossEncoderWithLogreg(self.model)
Expand Down Expand Up @@ -207,7 +207,7 @@ def load(self, path: str) -> None:
embedder_batch_size=self.metadata["batch_size"],
embedder_max_length=self.metadata["max_length"],
)
self.vector_index = vector_index_client.get_index(self.search_model_name)
self.vector_index = vector_index_client.get_index(self.embedder_name)

crossencoder_dir = str(dump_dir / self.crossencoder_subdir)
if self.train_head:
Expand Down
23 changes: 13 additions & 10 deletions autointent/modules/scoring/knn/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class KNNScorer(ScoringModule):

def __init__(
self,
model_name: str,
embedder_name: str,
k: int,
weights: WEIGHT_TYPES,
db_dir: str | None = None,
Expand All @@ -51,7 +51,7 @@ def __init__(
"""
if db_dir is None:
db_dir = str(get_db_dir())
self.model_name = model_name
self.embedder_name = embedder_name
self.k = k
self.weights = weights
self.db_dir = db_dir
Expand All @@ -65,16 +65,16 @@ def from_context(
context: Context,
k: int,
weights: WEIGHT_TYPES,
model_name: str | None = None,
embedder_name: str | None = None,
) -> Self:
if model_name is None:
model_name = context.optimization_info.get_best_embedder()
if embedder_name is None:
embedder_name = context.optimization_info.get_best_embedder()
prebuilt_index = True
else:
prebuilt_index = context.vector_index_client.exists(model_name)
prebuilt_index = context.vector_index_client.exists(embedder_name)

instance = cls(
model_name=model_name,
embedder_name=embedder_name,
k=k,
weights=weights,
db_dir=str(context.get_db_dir()),
Expand All @@ -85,6 +85,9 @@ def from_context(
instance.prebuilt_index = prebuilt_index
return instance

def get_embedder_name(self) -> str:
return self.embedder_name

def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
if isinstance(labels[0], list):
self.n_classes = len(labels[0])
Expand All @@ -96,12 +99,12 @@ def fit(self, utterances: list[str], labels: list[LabelType]) -> None:

if self.prebuilt_index:
# this happens only after RetrievalNode optimization
self._vector_index = vector_index_client.get_index(self.model_name)
self._vector_index = vector_index_client.get_index(self.embedder_name)
if len(utterances) != len(self._vector_index.texts):
msg = "Vector index mismatches provided utterances"
raise ValueError(msg)
else:
self._vector_index = vector_index_client.create_index(self.model_name, utterances, labels)
self._vector_index = vector_index_client.create_index(self.embedder_name, utterances, labels)

def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
labels, distances, _ = self._vector_index.query(utterances, self.k)
Expand Down Expand Up @@ -141,4 +144,4 @@ def load(self, path: str) -> None:
embedder_batch_size=self.metadata["batch_size"],
embedder_max_length=self.metadata["max_length"],
)
self._vector_index = vector_index_client.get_index(self.model_name)
self._vector_index = vector_index_client.get_index(self.embedder_name)
24 changes: 15 additions & 9 deletions autointent/modules/scoring/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class LinearScorer(ScoringModule):

def __init__(
self,
model_name: str,
embedder_name: str,
cv: int = 3,
n_jobs: int = -1,
device: str = "cpu",
Expand All @@ -60,24 +60,24 @@ def __init__(
self.n_jobs = n_jobs
self.device = device
self.seed = seed
self.model_name = model_name
self.embedder_name = embedder_name
self.batch_size = batch_size
self.max_length = max_length

@classmethod
def from_context(
cls,
context: Context,
model_name: str | None = None,
embedder_name: str | None = None,
) -> Self:
if model_name is None:
model_name = context.optimization_info.get_best_embedder()
if embedder_name is None:
embedder_name = context.optimization_info.get_best_embedder()
precomputed_embeddings = True
else:
precomputed_embeddings = context.vector_index_client.exists(model_name)
precomputed_embeddings = context.vector_index_client.exists(embedder_name)

instance = cls(
model_name=model_name,
embedder_name=embedder_name,
device=context.get_device(),
seed=context.seed,
batch_size=context.get_batch_size(),
Expand All @@ -87,6 +87,9 @@ def from_context(
instance.db_dir = str(context.get_db_dir())
return instance

def get_embedder_name(self) -> str:
return self.embedder_name

def fit(
self,
utterances: list[str],
Expand All @@ -97,15 +100,18 @@ def fit(
if self.precomputed_embeddings:
# this happens only when LinearScorer is within Pipeline opimization after RetrievalNode optimization
vector_index_client = VectorIndexClient(self.device, self.db_dir, self.batch_size, self.max_length)
vector_index = vector_index_client.get_index(self.model_name)
vector_index = vector_index_client.get_index(self.embedder_name)
features = vector_index.get_all_embeddings()
if len(features) != len(utterances):
msg = "Vector index mismatches provided utterances"
raise ValueError(msg)
embedder = vector_index.embedder
else:
embedder = Embedder(
device=self.device, model_name=self.model_name, batch_size=self.batch_size, max_length=self.max_length
device=self.device,
model_name=self.embedder_name,
batch_size=self.batch_size,
max_length=self.max_length,
)
features = embedder.embed(utterances)

Expand Down
Loading

0 comments on commit b26a878

Please sign in to comment.